Add TinyDarkNet CPU scripts and cifar-10 Dataset support
This commit is contained in:
parent
9085be08b9
commit
dd2ddc70b3
|
@ -60,8 +60,8 @@ Dataset used can refer to [paper](<https://ieeexplore.ieee.org/abstract/document
|
|||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend or GPU processor.
|
||||
- Hardware(Ascend/CPU)
|
||||
- Prepare hardware environment with Ascend/CPU processor.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information,please check the resources below:
|
||||
|
@ -158,7 +158,9 @@ For more details, please refer the specify script.
|
|||
├── scripts
|
||||
├── run_standalone_train.sh // shell script for single on Ascend
|
||||
├── run_distribute_train.sh // shell script for distributed on Ascend
|
||||
├── run_train_cpu.sh // shell script for distributed on CPU
|
||||
├── run_eval.sh // shell script for evaluation on Ascend
|
||||
├── run_eval_cpu.sh // shell script for evaluation on CPU
|
||||
├── run_infer_310.sh // shell script for inference on Ascend310
|
||||
├── src
|
||||
├── lr_scheduler //learning rate scheduler
|
||||
|
@ -177,7 +179,8 @@ For more details, please refer the specify script.
|
|||
├── train.py // training script
|
||||
├── eval.py // evaluation script
|
||||
├── export.py // export checkpoint file into air/onnx
|
||||
├── imagenet_config.yaml // parameter configuration
|
||||
├── imagenet_config.yaml // imagenet parameter configuration
|
||||
├── cifar10_config.yaml // cifar10 parameter configuration
|
||||
├── mindspore_hub_conf.py // hub config
|
||||
├── postprocess.py // postprocess script
|
||||
|
||||
|
@ -227,7 +230,7 @@ For more configuration details, please refer the script `imagenet_config.yaml`.
|
|||
- running on Ascend:
|
||||
|
||||
```python
|
||||
bash scripts/run_standalone_train.sh [DEVICE_ID]
|
||||
bash ./scripts/run_standalone_train.sh [DEVICE_ID]
|
||||
```
|
||||
|
||||
The command above will run in the background, you can view the results through the file train.log.
|
||||
|
@ -249,6 +252,12 @@ For more configuration details, please refer the script `imagenet_config.yaml`.
|
|||
The model checkpoint file will be saved in the current folder.
|
||||
<!-- The model checkpoint will be saved in the current directory. -->
|
||||
|
||||
- running on CPU
|
||||
|
||||
```python
|
||||
bash scripts/run_train_cpu.sh [TRAIN_DATA_DIR] [cifar10|imagenet]
|
||||
```
|
||||
|
||||
### [Distributed Training](#contents)
|
||||
|
||||
- running on Ascend:
|
||||
|
@ -298,6 +307,21 @@ For more configuration details, please refer the script `imagenet_config.yaml`.
|
|||
accuracy: {'top_1_accuracy': 0.5871979166666667, 'top_5_accuracy': 0.8175280448717949}
|
||||
```
|
||||
|
||||
- evaluation on cifar-10 dataset when running on CPU:
|
||||
|
||||
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "/username/tinydaeknet/train_tinydarknet.ckpt".
|
||||
|
||||
```python
|
||||
bash scripts/run_eval.sh [VAL_DATA_DIR] [imagenet|cifar10] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
|
||||
|
||||
```python
|
||||
# grep "accuracy: " eval.log
|
||||
accuracy: {'top_5_accuracy': 1.0, 'top_1_accuracy': 0.9829727564102564}
|
||||
```
|
||||
|
||||
## [Inference process](#contents)
|
||||
|
||||
### Export MindIR
|
||||
|
|
|
@ -68,8 +68,8 @@ Tiny-DarkNet是Joseph Chet Redmon等人提出的一个16层的针对于经典的
|
|||
|
||||
# [环境要求](#目录)
|
||||
|
||||
- 硬件(Ascend)
|
||||
- 请准备具有Ascend处理器的硬件环境.
|
||||
- 硬件(Ascend/CPU)
|
||||
- 请准备具有Ascend/CPU处理器的硬件环境.
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- 更多的信息请访问以下链接:
|
||||
|
@ -165,7 +165,9 @@ Tiny-DarkNet是Joseph Chet Redmon等人提出的一个16层的针对于经典的
|
|||
├── scripts
|
||||
├── run_standalone_train.sh // Ascend单卡训练shell脚本
|
||||
├── run_distribute_train.sh // Ascend分布式训练shell脚本
|
||||
├── run_train_cpu.sh // CPU训练shell脚本
|
||||
├── run_eval.sh // Ascend评估shell脚本
|
||||
├── run_eval_cpu.sh // CPU评估shell脚本
|
||||
└── run_infer_310.sh // Ascend310推理shell脚本
|
||||
├── src
|
||||
├── lr_scheduler // 学习率策略
|
||||
|
@ -184,7 +186,8 @@ Tiny-DarkNet是Joseph Chet Redmon等人提出的一个16层的针对于经典的
|
|||
├── train.py // 训练脚本
|
||||
├── eval.py // 评估脚本
|
||||
├── export.py // 导出checkpoint文件
|
||||
├── imagenet_config.yaml // 参数配置
|
||||
├── imagenet_config.yaml // imagenet参数配置
|
||||
├── cifar10_config.yaml // cifar10参数配置
|
||||
├── mindspore_hub_conf.py // hub配置文件
|
||||
└── postprocess.py // 310推理后处理脚本
|
||||
|
||||
|
@ -256,6 +259,12 @@ Tiny-DarkNet是Joseph Chet Redmon等人提出的一个16层的针对于经典的
|
|||
模型checkpoint文件将会保存在当前文件夹下.
|
||||
<!-- The model checkpoint will be saved in the current directory. -->
|
||||
|
||||
- 在CPU资源上运行:
|
||||
|
||||
```python
|
||||
bash scripts/run_train_cpu.sh [TRAIN_DATA_DIR] [cifar10|imagenet]
|
||||
```
|
||||
|
||||
### [分布式训练](#目录)
|
||||
|
||||
- 在Ascend资源上运行:
|
||||
|
@ -305,6 +314,21 @@ Tiny-DarkNet是Joseph Chet Redmon等人提出的一个16层的针对于经典的
|
|||
accuracy: {'top_1_accuracy': 0.5871979166666667, 'top_5_accuracy': 0.8175280448717949}
|
||||
```
|
||||
|
||||
- 在CPU资源上进行评估
|
||||
|
||||
在运行如下命令前,请确认用于评估的checkpoint文件的路径.checkpoint文件须包含在tinydarknet文件夹内.请将checkpoint路径设置为相对于 eval.py文件 的路径,例如:"./ckpts/train_tinydarknet.ckpt"(ckpts 与 eval.py 同级).
|
||||
|
||||
```python
|
||||
bash scripts/run_eval.sh [VAL_DATA_DIR] [imagenet|cifar10] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
可以通过"eval.log"文件查看结果. 测试数据集的准确率将如下面所列:
|
||||
|
||||
```python
|
||||
# grep "accuracy: " eval.log
|
||||
accuracy: {'top_5_accuracy': 1.0, 'top_1_accuracy': 0.9829727564102564}
|
||||
```
|
||||
|
||||
## 推理过程
|
||||
|
||||
### 导出MindIR
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
|
||||
modelarts_dataset_unzip_name: ''
|
||||
# ==============================================================================
|
||||
#train-eval-export related
|
||||
dataset_name : cifar10
|
||||
ckpt_save_dir: checkpoints
|
||||
pre_trained: False
|
||||
device_id: 0
|
||||
num_classes: 10
|
||||
lr_init: 0.1
|
||||
batch_size: 32
|
||||
epoch_size: 120
|
||||
momentum: 0.9
|
||||
weight_decay: 0.0001
|
||||
image_height: 227
|
||||
image_width: 227
|
||||
train_data_dir: './dataset/imagenet_original/train/'
|
||||
val_data_dir: './dataset/imagenet_original/val/'
|
||||
keep_checkpoint_max: 1
|
||||
checkpoint_path: './scripts/train_parallel4/ckpt_4/train_tinydarknet_imagenet-300_1251.ckpt'
|
||||
onnx_filename: 'tinydarknet.onnx'
|
||||
air_filename: 'tinydarknet.air'
|
||||
# optimizer and lr related
|
||||
lr_scheduler: 'exponential'
|
||||
lr_epochs: [70, 140, 210, 280]
|
||||
lr_gamma: 0.1
|
||||
eta_min: 0.0
|
||||
T_max: 150
|
||||
warmup_epochs: 0
|
||||
# loss related
|
||||
is_dynamic_loss_scale: False
|
||||
loss_scale: 1024
|
||||
label_smooth_factor: 0.1
|
||||
use_label_smooth: True
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of the input data."
|
||||
output_path: "The location of the output file."
|
||||
device_target: "Running platform, choose from Ascend, GPU or CPU, and default is Ascend."
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -23,13 +23,15 @@ from mindspore import context
|
|||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
|
||||
from src.dataset import create_dataset_imagenet
|
||||
from src.dataset import create_dataset_imagenet, create_dataset_cifar
|
||||
from src.tinydarknet import TinyDarkNet
|
||||
from src.CrossEntropySmooth import CrossEntropySmooth
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_num
|
||||
|
||||
set_seed(1)
|
||||
|
||||
def modelarts_pre_process():
|
||||
|
@ -89,27 +91,32 @@ def modelarts_pre_process():
|
|||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_eval():
|
||||
cfg = config
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=cfg.device_target,
|
||||
device_id=cfg.device_id)
|
||||
if config.dataset_name == "imagenet":
|
||||
cfg = config
|
||||
dataset = create_dataset_imagenet(cfg.val_data_dir, 1, False)
|
||||
if not cfg.use_label_smooth:
|
||||
cfg.label_smooth_factor = 0.0
|
||||
dataset = create_dataset_imagenet(cfg.data_path, 1, False)
|
||||
loss = CrossEntropySmooth(sparse=True, reduction="mean",
|
||||
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
|
||||
net = TinyDarkNet(num_classes=cfg.num_classes)
|
||||
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
|
||||
|
||||
elif config.dataset_name == "cifar10":
|
||||
dataset = create_dataset_cifar(dataset_path=config.data_path,
|
||||
do_train=True,
|
||||
repeat_num=1,
|
||||
batch_size=config.batch_size,
|
||||
target=cfg.device_target)
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
else:
|
||||
raise ValueError("Dataset is not support.")
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=config.device_id)
|
||||
net = TinyDarkNet(num_classes=cfg.num_classes)
|
||||
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||
print("Load checkpoint from [{}].".format(cfg.checkpoint_path))
|
||||
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
|
||||
|
||||
acc = model.eval(dataset)
|
||||
print("accuracy: ", acc)
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
#!/usr/bin/env bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ $# != 1 ] && [ $# != 2 ] && [ $# != 3 ]
|
||||
then
|
||||
echo "Usage bash scripts/run_train_cpu.sh [VAL_DATA_DIR] [cifar10|imagenet] [checkpoint_path]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: VAL_DATA_DIR=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PATH2=$(get_real_path $3)
|
||||
if [ ! -f $PATH2 ]
|
||||
then
|
||||
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")")
|
||||
if [ $2 == 'imagenet' ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/imagenet_config.yaml"
|
||||
elif [ $2 == 'cifar10' ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/cifar10_config.yaml"
|
||||
else
|
||||
echo "error: the selected dataset is neither cifar10 nor imagenet"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
rm -rf ./eval
|
||||
mkdir ./eval
|
||||
cp -r ./src ./eval
|
||||
cp ./eval.py ./eval
|
||||
cp ./*.yaml ./eval
|
||||
env >env.log
|
||||
echo "start evaluation for device CPU"
|
||||
cd ./eval || exit
|
||||
python ./eval.py --device_target=CPU --data_path=$PATH1 --dataset_name=$2 --config_path=$CONFIG_FILE \
|
||||
--checkpoint_path=$PATH2 > ./eval.log 2>&1 &
|
||||
cd ..
|
|
@ -0,0 +1,57 @@
|
|||
#!/usr/bin/env bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ] && [ $# != 2 ]
|
||||
then
|
||||
echo "Usage bash scripts/run_train_cpu.sh [TRAIN_DATA_DIR] [cifar10|imagenet]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: TRAIN_DATA_DIR=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")")
|
||||
if [ $2 == 'imagenet' ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/imagenet_config.yaml"
|
||||
elif [ $2 == 'cifar10' ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/cifar10_config.yaml"
|
||||
else
|
||||
echo "error: the selected dataset is neither cifar10 nor imagenet"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
rm -rf ./train_cpu
|
||||
mkdir ./train_cpu
|
||||
cp ./train.py ./train_cpu
|
||||
cp -r ./src ./train_cpu
|
||||
cp ./*.yaml ./train_cpu
|
||||
echo "start training for device CPU"
|
||||
cd ./train_cpu || exit
|
||||
env > env.log
|
||||
python train.py --device_target=CPU --data_path=$PATH1 --dataset_name=$2 --config_path=$CONFIG_FILE --lr_init=0.01> ./train.log 2>&1 &
|
||||
cd ..
|
|
@ -23,6 +23,80 @@ import mindspore.dataset.transforms.c_transforms as C
|
|||
import mindspore.dataset.vision.c_transforms as vision
|
||||
from src.model_utils.config import config as imagenet_cfg
|
||||
|
||||
def create_dataset_cifar(dataset_path,
|
||||
do_train,
|
||||
repeat_num=1,
|
||||
batch_size=32,
|
||||
target="Ascend"):
|
||||
"""
|
||||
create a train or evaluate cifar10 dataset
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
target(str): the device target. Default: Ascend
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
if target == "Ascend":
|
||||
device_num, rank_id = _get_rank_info()
|
||||
elif target == "CPU":
|
||||
device_num = 1
|
||||
else:
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
|
||||
if device_num == 1:
|
||||
data_set = ds.Cifar10Dataset(dataset_path,
|
||||
num_parallel_workers=8,
|
||||
shuffle=True)
|
||||
else:
|
||||
data_set = ds.Cifar10Dataset(dataset_path,
|
||||
num_parallel_workers=8,
|
||||
shuffle=True,
|
||||
num_shards=device_num,
|
||||
shard_id=rank_id)
|
||||
|
||||
# define map operations
|
||||
if do_train:
|
||||
trans = [
|
||||
vision.RandomCrop((32, 32), (4, 4, 4, 4)),
|
||||
vision.RandomHorizontalFlip(prob=0.5),
|
||||
vision.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4),
|
||||
vision.Resize((227, 227)),
|
||||
vision.Rescale(1.0 / 255.0, 0.0),
|
||||
vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
|
||||
vision.CutOut(112),
|
||||
vision.HWC2CHW()
|
||||
]
|
||||
else:
|
||||
trans = [
|
||||
vision.Resize((227, 227)),
|
||||
vision.Rescale(1.0 / 255.0, 0.0),
|
||||
vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
|
||||
vision.HWC2CHW()
|
||||
]
|
||||
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
|
||||
data_set = data_set.map(operations=type_cast_op,
|
||||
input_columns="label",
|
||||
num_parallel_workers=8)
|
||||
data_set = data_set.map(operations=trans,
|
||||
input_columns="image",
|
||||
num_parallel_workers=8)
|
||||
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||
|
||||
# apply dataset repeat operation
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
|
||||
return data_set
|
||||
|
||||
def create_dataset_imagenet(dataset_path, repeat_num=1, training=True,
|
||||
num_parallel_workers=None, shuffle=None):
|
||||
"""
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -25,12 +25,13 @@ from mindspore.communication.management import init
|
|||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.dataset import create_dataset_imagenet
|
||||
from src.dataset import create_dataset_imagenet, create_dataset_cifar
|
||||
from src.tinydarknet import TinyDarkNet
|
||||
from src.CrossEntropySmooth import CrossEntropySmooth
|
||||
from src.model_utils.config import config
|
||||
|
@ -120,9 +121,13 @@ def modelarts_pre_process():
|
|||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_train():
|
||||
if config.dataset_name == "imagenet":
|
||||
pass
|
||||
dataset = create_dataset_imagenet(config.data_path, 1)
|
||||
elif config.dataset_name == "cifar10":
|
||||
raise ValueError("Unsupported dataset: 'cifar10'.")
|
||||
dataset = create_dataset_cifar(dataset_path=config.data_path,
|
||||
do_train=True,
|
||||
repeat_num=1,
|
||||
batch_size=config.batch_size,
|
||||
target=config.device_target)
|
||||
else:
|
||||
raise ValueError("Unsupported dataset.")
|
||||
|
||||
|
@ -133,22 +138,16 @@ def run_train():
|
|||
device_num = get_device_num()
|
||||
|
||||
rank = 0
|
||||
if device_target == "Ascend":
|
||||
if device_target == "CPU":
|
||||
pass
|
||||
else:
|
||||
context.set_context(device_id=get_device_id())
|
||||
|
||||
if device_num > 1:
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
rank = get_rank_id()
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
if config.dataset_name == "imagenet":
|
||||
dataset = create_dataset_imagenet(config.train_data_dir, 1)
|
||||
else:
|
||||
raise ValueError("Unsupported dataset.")
|
||||
|
||||
batch_num = dataset.get_dataset_size()
|
||||
|
||||
|
@ -159,50 +158,56 @@ def run_train():
|
|||
load_param_into_net(net, param_dict)
|
||||
|
||||
loss_scale_manager = None
|
||||
lr = lr_steps_imagenet(config, batch_num)
|
||||
|
||||
def get_param_groups(network):
|
||||
""" get param groups """
|
||||
decay_params = []
|
||||
no_decay_params = []
|
||||
for x in network.trainable_params():
|
||||
parameter_name = x.name
|
||||
if parameter_name.endswith('.bias'):
|
||||
# all bias not using weight decay
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.gamma'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.beta'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
no_decay_params.append(x)
|
||||
else:
|
||||
decay_params.append(x)
|
||||
|
||||
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
|
||||
|
||||
|
||||
if config.is_dynamic_loss_scale:
|
||||
config.loss_scale = 1
|
||||
|
||||
opt = Momentum(params=get_param_groups(net),
|
||||
learning_rate=Tensor(lr),
|
||||
momentum=config.momentum,
|
||||
weight_decay=config.weight_decay,
|
||||
loss_scale=config.loss_scale)
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
if config.dataset_name == 'imagenet':
|
||||
lr = lr_steps_imagenet(config, batch_num)
|
||||
|
||||
def get_param_groups(network):
|
||||
""" get param groups """
|
||||
decay_params = []
|
||||
no_decay_params = []
|
||||
for x in network.trainable_params():
|
||||
parameter_name = x.name
|
||||
if parameter_name.endswith('.bias'):
|
||||
# all bias not using weight decay
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.gamma'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.beta'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
no_decay_params.append(x)
|
||||
else:
|
||||
decay_params.append(x)
|
||||
|
||||
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
|
||||
|
||||
|
||||
if config.is_dynamic_loss_scale:
|
||||
config.loss_scale = 1
|
||||
|
||||
opt = Momentum(params=get_param_groups(net),
|
||||
learning_rate=Tensor(lr),
|
||||
momentum=config.momentum,
|
||||
weight_decay=config.weight_decay,
|
||||
loss_scale=config.loss_scale)
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
loss = CrossEntropySmooth(sparse=True, reduction="mean",
|
||||
smooth_factor=config.label_smooth_factor, num_classes=config.num_classes)
|
||||
else:
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
|
||||
if config.is_dynamic_loss_scale:
|
||||
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
|
||||
else:
|
||||
loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
|
||||
amp_level="O3", loss_scale_manager=loss_scale_manager)
|
||||
if config.is_dynamic_loss_scale:
|
||||
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
|
||||
else:
|
||||
loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
|
||||
if device_target == "CPU":
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, loss_scale_manager=loss_scale_manager)
|
||||
else:
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
|
||||
amp_level="O3", loss_scale_manager=loss_scale_manager)
|
||||
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 50, keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
time_cb = TimeMonitor(data_size=batch_num)
|
||||
|
|
Loading…
Reference in New Issue