!11949 add inceptionv3 cpu train script
From: @caojian05 Reviewed-by: @wuxuejian,@oacjiewen Signed-off-by: @wuxuejian
This commit is contained in:
commit
c144896aee
|
@ -40,6 +40,14 @@ Dataset used can refer to paper.
|
|||
- Data format: RGB images.
|
||||
- Note: Data will be processed in src/dataset.py
|
||||
|
||||
Dataset used: [CIFAR-10](http://www.cs.toronto.edu/~kriz/cifar.html)
|
||||
|
||||
- Dataset size: 175M, 60,000 32\*32 colorful images in 10 classes
|
||||
- Train: 146M, 50,000 images
|
||||
- Test: 29M, 10,000 images
|
||||
- Data format:binary files
|
||||
- Note:Data will be processed in src/dataset.py
|
||||
|
||||
# [Features](#contents)
|
||||
|
||||
## [Mixed Precision(Ascend)](#contents)
|
||||
|
@ -67,8 +75,13 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
|
|||
└─Inception-v3
|
||||
├─README.md
|
||||
├─scripts
|
||||
├─run_standalone_train_cpu.sh # launch standalone training with cpu platform
|
||||
├─run_standalone_train_gpu.sh # launch standalone training with gpu platform(1p)
|
||||
├─run_distribute_train_gpu.sh # launch distributed training with gpu platform(8p)
|
||||
├─run_standalone_train.sh # launch standalone training with ascend platform(1p)
|
||||
├─run_distribute_train.sh # launch distributed training with ascend platform(8p)
|
||||
├─run_eval_cpu.sh # launch evaluation with cpu platform
|
||||
├─run_eval_gpu.sh # launch evaluation with gpu platform
|
||||
└─run_eval.sh # launch evaluating with ascend platform
|
||||
├─src
|
||||
├─config.py # parameter configuration
|
||||
|
@ -93,6 +106,8 @@ Major parameters in train.py and config.py are:
|
|||
'batch_size' # input batchsize
|
||||
'epoch_size' # total epoch numbers
|
||||
'num_classes' # dataset class numbers
|
||||
'ds_type' # dataset type, such as: imagenet, cifar10
|
||||
'ds_sink_mode' # whether enable dataset sink mode
|
||||
'smooth_factor' # label smoothing factor
|
||||
'aux_factor' # loss factor of aux logit
|
||||
'lr_init' # initiate learning rate
|
||||
|
@ -127,6 +142,13 @@ sh scripts/run_distribute_train.sh RANK_TABLE_FILE DATA_PATH
|
|||
sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH
|
||||
```
|
||||
|
||||
- CPU:
|
||||
|
||||
```shell
|
||||
# standalone training
|
||||
sh scripts/run_standalone_train_cpu.sh DATA_PATH
|
||||
```
|
||||
|
||||
> Notes: RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html), and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV3, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size.
|
||||
>
|
||||
> This is processor cores binding operation regarding the `device_num` and total processor numbers. If you are not expect to do it, remove the operations `taskset` in `scripts/run_distribute_train.sh`
|
||||
|
@ -137,6 +159,7 @@ sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH
|
|||
# training example
|
||||
python:
|
||||
Ascend: python train.py --dataset_path DATA_PATH --platform Ascend
|
||||
CPU: python train.py --dataset_path DATA_PATH --platform CPU
|
||||
|
||||
shell:
|
||||
Ascend:
|
||||
|
@ -144,12 +167,17 @@ sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH
|
|||
sh scripts/run_distribute_train.sh RANK_TABLE_FILE DATA_PATH
|
||||
# standalone training example
|
||||
sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH
|
||||
|
||||
CPU:
|
||||
sh script/run_standalone_train_cpu.sh DATA_PATH
|
||||
```
|
||||
|
||||
### Result
|
||||
|
||||
Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./log.txt` like followings.
|
||||
|
||||
#### Ascend
|
||||
|
||||
```python
|
||||
epoch: 0 step: 1251, loss is 5.7787247
|
||||
epoch time: 360760.985 ms, per step time: 288.378 ms
|
||||
|
@ -157,6 +185,18 @@ epoch: 1 step: 1251, loss is 4.392868
|
|||
epoch time: 160917.911 ms, per step time: 128.631 ms
|
||||
```
|
||||
|
||||
#### CPU
|
||||
|
||||
```bash
|
||||
epoch: 1 step: 390, loss is 2.7072601
|
||||
epoch time: 6334572.124 ms, per step time: 16242.493 ms
|
||||
epoch: 2 step: 390, loss is 2.5908582
|
||||
epoch time: 6217897.644 ms, per step time: 15943.327 ms
|
||||
epoch: 3 step: 390, loss is 2.5612416
|
||||
epoch time: 6358482.104 ms, per step time: 16303.800 ms
|
||||
...
|
||||
```
|
||||
|
||||
## [Eval process](#contents)
|
||||
|
||||
### Usage
|
||||
|
@ -169,15 +209,23 @@ You can start training using python or shell scripts. The usage of shell scripts
|
|||
sh scripts/run_eval.sh DEVICE_ID DATA_PATH PATH_CHECKPOINT
|
||||
```
|
||||
|
||||
- CPU:
|
||||
|
||||
```python
|
||||
sh scripts/run_eval_cpu.sh DATA_PATH PATH_CHECKPOINT
|
||||
```
|
||||
|
||||
### Launch
|
||||
|
||||
```python
|
||||
# eval example
|
||||
python:
|
||||
Ascend: python eval.py --dataset_path DATA_PATH --checkpoint PATH_CHECKPOINT --platform Ascend
|
||||
CPU: python eval.py --dataset_path DATA_PATH --checkpoint PATH_CHECKPOINT --platform CPU
|
||||
|
||||
shell:
|
||||
Ascend: sh scripts/run_eval.sh DEVICE_ID DATA_PATH PATH_CHECKPOINT
|
||||
CPU: sh scripts/run_eval_cpu.sh DATA_PATH PATH_CHECKPOINT
|
||||
```
|
||||
|
||||
> checkpoint can be produced in training process.
|
||||
|
@ -236,4 +284,4 @@ In dataset.py, we set the seed inside “create_dataset" function. We also use r
|
|||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||
|
|
|
@ -51,6 +51,14 @@ InceptionV3的总体网络架构如下:
|
|||
- 数据格式:RGB
|
||||
- 注:数据将在src/dataset.py中处理。
|
||||
|
||||
使用的数据集:[CIFAR-10](<http://www.cs.toronto.edu/~kriz/cifar.html>)
|
||||
|
||||
- 数据集大小:175M,共10个类、6万张32*32彩色图像
|
||||
- 训练集:146M,共5万张图像
|
||||
- 测试集:29M,共1万张图像
|
||||
- 数据格式:二进制文件
|
||||
- 注:数据将在src/dataset.py中处理。
|
||||
|
||||
# 特性
|
||||
|
||||
## 混合精度(Ascend)
|
||||
|
@ -78,9 +86,14 @@ InceptionV3的总体网络架构如下:
|
|||
└─Inception-v3
|
||||
├─README.md
|
||||
├─scripts
|
||||
├─run_standalone_train_cpu.sh # 启动CPU训练
|
||||
├─run_standalone_train_gpu.sh # 启动GPU单机训练(单卡)
|
||||
├─run_distribute_train_gpu.sh # 启动GPU分布式训练(8卡)
|
||||
├─run_standalone_train.sh # 启动Ascend单机训练(单卡)
|
||||
├─run_distribute_train.sh # 启动Ascend分布式训练(8卡)
|
||||
├─run_eval.sh # 启动Ascend评估
|
||||
├─run_eval_cpu.sh # 启动CPU评估
|
||||
├─run_eval_gpu.sh # 启动GPU评估
|
||||
└─run_eval.sh # 启动Ascend评估
|
||||
├─src
|
||||
├─config.py # 参数配置
|
||||
├─dataset.py # 数据预处理
|
||||
|
@ -106,6 +119,8 @@ train.py和config.py中主要参数如下:
|
|||
'batch_size' # 输入张量的批次大小
|
||||
'epoch_size' # 总轮次数
|
||||
'num_classes' # 数据集类数
|
||||
'ds_type' # 数据集类型,如:imagenet, cifar10
|
||||
'ds_sink_mode' # 使能数据下沉
|
||||
'smooth_factor' # 标签平滑因子
|
||||
'aux_factor' # aux logit的损耗因子
|
||||
'lr_init' # 初始学习率
|
||||
|
@ -149,6 +164,7 @@ train.py和config.py中主要参数如下:
|
|||
# 训练示例
|
||||
python:
|
||||
Ascend: python train.py --dataset_path /dataset/train --platform Ascend
|
||||
CPU: python train.py --dataset_path DATA_PATH --platform CPU
|
||||
|
||||
shell:
|
||||
Ascend:
|
||||
|
@ -156,12 +172,17 @@ train.py和config.py中主要参数如下:
|
|||
sh scripts/run_distribute_train.sh RANK_TABLE_FILE DATA_PATH
|
||||
# 单机训练
|
||||
sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH
|
||||
|
||||
CPU:
|
||||
sh script/run_standalone_train_cpu.sh DATA_PATH
|
||||
```
|
||||
|
||||
### 结果
|
||||
|
||||
训练结果保存在示例路径。检查点默认保存在`checkpoint`,训练日志会重定向到`./log.txt`,如下:
|
||||
|
||||
#### Ascend
|
||||
|
||||
```log
|
||||
epoch:0 step:1251, loss is 5.7787247
|
||||
Epoch time:360760.985, per step time:288.378
|
||||
|
@ -169,6 +190,18 @@ epoch:1 step:1251, loss is 4.392868
|
|||
Epoch time:160917.911, per step time:128.631
|
||||
```
|
||||
|
||||
#### CPU
|
||||
|
||||
```bash
|
||||
epoch: 1 step: 390, loss is 2.7072601
|
||||
epoch time: 6334572.124 ms, per step time: 16242.493 ms
|
||||
epoch: 2 step: 390, loss is 2.5908582
|
||||
epoch time: 6217897.644 ms, per step time: 15943.327 ms
|
||||
epoch: 3 step: 390, loss is 2.5612416
|
||||
epoch time: 6358482.104 ms, per step time: 16303.800 ms
|
||||
...
|
||||
```
|
||||
|
||||
## 评估过程
|
||||
|
||||
### 用法
|
||||
|
@ -181,15 +214,23 @@ Epoch time:160917.911, per step time:128.631
|
|||
sh scripts/run_eval.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT
|
||||
```
|
||||
|
||||
- CPU:
|
||||
|
||||
```python
|
||||
sh scripts/run_eval_cpu.sh DATA_PATH PATH_CHECKPOINT
|
||||
```
|
||||
|
||||
### 启动
|
||||
|
||||
``` launch
|
||||
# 评估示例
|
||||
python:
|
||||
Ascend: python eval.py --dataset_path DATA_DIR --checkpoint PATH_CHECKPOINT --platform Ascend
|
||||
CPU: python eval.py --dataset_path DATA_PATH --checkpoint PATH_CHECKPOINT --platform CPU
|
||||
|
||||
shell:
|
||||
Ascend: sh scripts/run_eval.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT
|
||||
CPU: sh scripts/run_eval_cpu.sh DATA_PATH PATH_CHECKPOINT
|
||||
```
|
||||
|
||||
> 训练过程中可以生成检查点。
|
||||
|
|
|
@ -21,33 +21,48 @@ from mindspore import context
|
|||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import config_gpu as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.config import config_gpu, config_ascend, config_cpu
|
||||
from src.dataset import create_dataset_imagenet, create_dataset_cifar10
|
||||
from src.inception_v3 import InceptionV3
|
||||
from src.loss import CrossEntropy_Val
|
||||
|
||||
CFG_DICT = {
|
||||
"Ascend": config_ascend,
|
||||
"GPU": config_gpu,
|
||||
"CPU": config_cpu,
|
||||
}
|
||||
|
||||
DS_DICT = {
|
||||
"imagenet": create_dataset_imagenet,
|
||||
"cifar10": create_dataset_cifar10,
|
||||
}
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='image classification evaluation')
|
||||
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of inception-v3 (Default: None)')
|
||||
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
|
||||
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU', 'CPU'), help='run platform')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if args_opt.platform == 'Ascend':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
cfg = CFG_DICT[args_opt.platform]
|
||||
create_dataset = DS_DICT[cfg.ds_type]
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform)
|
||||
net = InceptionV3(num_classes=cfg.num_classes, is_training=False)
|
||||
ckpt = load_checkpoint(args_opt.checkpoint)
|
||||
load_param_into_net(net, ckpt)
|
||||
net.set_train(False)
|
||||
dataset = create_dataset(args_opt.dataset_path, False, 0, 1)
|
||||
cfg.rank = 0
|
||||
cfg.group_size = 1
|
||||
dataset = create_dataset(args_opt.dataset_path, False, cfg)
|
||||
loss = CrossEntropy_Val(smooth_factor=0.1, num_classes=cfg.num_classes)
|
||||
eval_metrics = {'Loss': nn.Loss(),
|
||||
'Top1-Acc': nn.Top1CategoricalAccuracy(),
|
||||
'Top5-Acc': nn.Top5CategoricalAccuracy()}
|
||||
model = Model(net, loss, optimizer=None, metrics=eval_metrics)
|
||||
metrics = model.eval(dataset)
|
||||
metrics = model.eval(dataset, dataset_sink_mode=cfg.ds_sink_mode)
|
||||
print("metric: ", metrics)
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DATA_DIR=$1
|
||||
PATH_CHECKPOINT=$2
|
||||
python ./eval.py --platform 'CPU' --dataset_path $DATA_DIR --checkpoint $PATH_CHECKPOINT > eval.log 2>&1 &
|
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DATA_DIR=$1
|
||||
python ./train.py --platform 'CPU' --dataset_path $DATA_DIR > train.log 2>&1 &
|
||||
|
|
@ -26,6 +26,8 @@ config_gpu = edict({
|
|||
'batch_size': 128,
|
||||
'epoch_size': 250,
|
||||
'num_classes': 1000,
|
||||
'ds_type': 'imagenet',
|
||||
'ds_sink_mode': True,
|
||||
'smooth_factor': 0.1,
|
||||
'aux_factor': 0.2,
|
||||
'lr_init': 0.00004,
|
||||
|
@ -51,6 +53,8 @@ config_ascend = edict({
|
|||
'batch_size': 128,
|
||||
'epoch_size': 250,
|
||||
'num_classes': 1000,
|
||||
'ds_type': 'imagenet',
|
||||
'ds_sink_mode': True,
|
||||
'smooth_factor': 0.1,
|
||||
'aux_factor': 0.2,
|
||||
'lr_init': 0.00004,
|
||||
|
@ -67,3 +71,30 @@ config_ascend = edict({
|
|||
'has_bias': False,
|
||||
'amp_level': 'O3'
|
||||
})
|
||||
|
||||
config_cpu = edict({
|
||||
'random_seed': 1,
|
||||
'work_nums': 8,
|
||||
'decay_method': 'cosine',
|
||||
"loss_scale": 1024,
|
||||
'batch_size': 128,
|
||||
'epoch_size': 120,
|
||||
'num_classes': 10,
|
||||
'ds_type': 'cifar10',
|
||||
'ds_sink_mode': False,
|
||||
'smooth_factor': 0.1,
|
||||
'aux_factor': 0.2,
|
||||
'lr_init': 0.00004,
|
||||
'lr_max': 0.1,
|
||||
'lr_end': 0.000004,
|
||||
'warmup_epochs': 1,
|
||||
'weight_decay': 0.00004,
|
||||
'momentum': 0.9,
|
||||
'opt_eps': 1.0,
|
||||
'keep_checkpoint_max': 10,
|
||||
'ckpt_path': './',
|
||||
'is_save_on_master': 0,
|
||||
'dropout_keep_prob': 0.8,
|
||||
'has_bias': False,
|
||||
'amp_level': 'O0',
|
||||
})
|
||||
|
|
|
@ -15,32 +15,32 @@
|
|||
"""
|
||||
Data operations, will be used in train.py and eval.py
|
||||
"""
|
||||
import os
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
from src.config import config_gpu as cfg
|
||||
|
||||
|
||||
def create_dataset(dataset_path, do_train, rank, group_size, repeat_num=1):
|
||||
def create_dataset_imagenet(dataset_path, do_train, cfg, repeat_num=1):
|
||||
"""
|
||||
create a train or eval dataset
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
rank (int): The shard ID within num_shards (default=None).
|
||||
group_size (int): Number of shards that the dataset should be divided into (default=None).
|
||||
cfg (dict): the config for creating dataset.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1.
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
if group_size == 1:
|
||||
if cfg.group_size == 1:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True)
|
||||
else:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True,
|
||||
num_shards=group_size, shard_id=rank)
|
||||
num_shards=cfg.group_size, shard_id=cfg.rank)
|
||||
# define map operations
|
||||
if do_train:
|
||||
trans = [
|
||||
|
@ -67,3 +67,44 @@ def create_dataset(dataset_path, do_train, rank, group_size, repeat_num=1):
|
|||
# apply dataset repeat operation
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
return data_set
|
||||
|
||||
|
||||
def create_dataset_cifar10(dataset_path, do_train, cfg, repeat_num=1):
|
||||
"""
|
||||
create a train or eval dataset
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
cfg (dict): the config for creating dataset.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1.
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
dataset_path = os.path.join(dataset_path, "cifar-10-batches-bin" if do_train else "cifar-10-verify-bin")
|
||||
if cfg.group_size == 1:
|
||||
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True)
|
||||
else:
|
||||
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True,
|
||||
num_shards=cfg.group_size, shard_id=cfg.rank)
|
||||
|
||||
# define map operations
|
||||
trans = []
|
||||
if do_train:
|
||||
trans.append(C.RandomCrop((32, 32), (4, 4, 4, 4)))
|
||||
trans.append(C.RandomHorizontalFlip(prob=0.5))
|
||||
|
||||
trans.append(C.Resize((299, 299)))
|
||||
trans.append(C.Rescale(1.0 / 255.0, 0.0))
|
||||
trans.append(C.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]))
|
||||
trans.append(C.HWC2CHW())
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=cfg.work_nums)
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=cfg.work_nums)
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(cfg.batch_size, drop_remainder=do_train)
|
||||
# apply dataset repeat operation
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
return data_set
|
||||
|
|
|
@ -29,14 +29,24 @@ from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
|||
from mindspore.common.initializer import XavierUniform, initializer
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import config_gpu, config_ascend
|
||||
from src.dataset import create_dataset
|
||||
from src.config import config_gpu, config_ascend, config_cpu
|
||||
from src.dataset import create_dataset_imagenet, create_dataset_cifar10
|
||||
from src.inception_v3 import InceptionV3
|
||||
from src.lr_generator import get_lr
|
||||
from src.loss import CrossEntropy
|
||||
|
||||
set_seed(1)
|
||||
|
||||
CFG_DICT = {
|
||||
"Ascend": config_ascend,
|
||||
"GPU": config_gpu,
|
||||
"CPU": config_cpu,
|
||||
}
|
||||
|
||||
DS_DICT = {
|
||||
"imagenet": create_dataset_imagenet,
|
||||
"cifar10": create_dataset_cifar10,
|
||||
}
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='image classification training')
|
||||
|
@ -44,13 +54,16 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
|
||||
parser.add_argument('--is_distributed', action='store_true', default=False,
|
||||
help='distributed training')
|
||||
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
|
||||
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU', 'CPU'), help='run platform')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
cfg = CFG_DICT[args_opt.platform]
|
||||
create_dataset = DS_DICT[cfg.ds_type]
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
||||
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||
cfg = config_ascend if args_opt.platform == 'Ascend' else config_gpu
|
||||
|
||||
# init distributed
|
||||
if args_opt.is_distributed:
|
||||
init()
|
||||
|
@ -64,7 +77,7 @@ if __name__ == '__main__':
|
|||
cfg.group_size = 1
|
||||
|
||||
# dataloader
|
||||
dataset = create_dataset(args_opt.dataset_path, True, cfg.rank, cfg.group_size)
|
||||
dataset = create_dataset(args_opt.dataset_path, True, cfg)
|
||||
batches_per_epoch = dataset.get_dataset_size()
|
||||
|
||||
# network
|
||||
|
@ -120,8 +133,8 @@ if __name__ == '__main__':
|
|||
if args_opt.is_distributed & cfg.is_save_on_master:
|
||||
if cfg.rank == 0:
|
||||
callbacks.append(ckpoint_cb)
|
||||
model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True)
|
||||
model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=cfg.ds_sink_mode)
|
||||
else:
|
||||
callbacks.append(ckpoint_cb)
|
||||
model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True)
|
||||
model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=cfg.ds_sink_mode)
|
||||
print("train success")
|
||||
|
|
Loading…
Reference in New Issue