modify model_zoo resnext50 network for clould
This commit is contained in:
parent
0459e53040
commit
cd105f7ed8
|
@ -61,6 +61,35 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
|
|||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training and evaluation as follows:
|
||||
|
||||
```python
|
||||
# run distributed training on modelarts example
|
||||
# (1) First, Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on yaml file.
|
||||
# Set other parameters on yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (2) Set the code directory to "/path/resnext50" on the website UI interface.
|
||||
# (3) Set the startup file to "train.py" on the website UI interface.
|
||||
# (4) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (5) Create your job.
|
||||
|
||||
# run evaluation on modelarts example
|
||||
# (1) Copy or upload your trained model to S3 bucket.
|
||||
# (2) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on yaml file.
|
||||
# Set "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on yaml file.
|
||||
# Set "checkpoint_url=/The path of checkpoint in S3/" on yaml file.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on the website UI interface.
|
||||
# Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface.
|
||||
# (3) Set the code directory to "/path/resnext50" on the website UI interface.
|
||||
# (4) Set the startup file to "eval.py" on the website UI interface.
|
||||
# (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (6) Create your job.
|
||||
```
|
||||
|
||||
# [Script description](#contents)
|
||||
|
||||
## [Script and sample code](#contents)
|
||||
|
@ -95,10 +124,16 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
|
|||
├─linear_warmup.py # linear warmup learning rate
|
||||
├─warmup_cosine_annealing.py # learning rate each step
|
||||
├─warmup_step_lr.py # warmup step learning rate
|
||||
├─eval.py # eval net
|
||||
├── model_utils
|
||||
├──config.py # parameter configuration
|
||||
├──device_adapter.py # device adapter
|
||||
├──local_adapter.py # local adapter
|
||||
├──moxing_adapter.py # moxing adapter
|
||||
├── default_config.yaml # parameter configuration
|
||||
├──eval.py # eval net
|
||||
├──train.py # train net
|
||||
├──export.py # export mindir script
|
||||
├──mindspore_hub_conf.py # mindspore hub interface
|
||||
├──mindspore_hub_conf.py # mindspore hub interface
|
||||
|
||||
```
|
||||
|
||||
|
@ -138,7 +173,7 @@ Parameters for both training and evaluating can be set in config.py.
|
|||
You can start training by python script:
|
||||
|
||||
```script
|
||||
python train.py --data_dir ~/imagenet/train/ --platform Ascend --is_distributed 0
|
||||
python train.py --data_path ~/imagenet/train/ --device_target Ascend --run_distribute 0
|
||||
```
|
||||
|
||||
or shell script:
|
||||
|
@ -179,14 +214,14 @@ You can find checkpoint file together with result in log.
|
|||
You can start training by python script:
|
||||
|
||||
```script
|
||||
python eval.py --data_dir ~/imagenet/val/ --platform Ascend --pretrained resnext.ckpt
|
||||
python eval.py --data_path ~/imagenet/val/ --device_target Ascend --checkpoint_file_path resnext.ckpt
|
||||
```
|
||||
|
||||
or shell script:
|
||||
|
||||
```script
|
||||
# Evaluation
|
||||
sh run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH PLATFORM
|
||||
sh run_eval.sh DEVICE_ID DATA_PATH CHECKPOINT_FILE_PATH DEVICE_TARGET
|
||||
```
|
||||
|
||||
PLATFORM is Ascend or GPU, default is Ascend.
|
||||
|
@ -210,7 +245,7 @@ acc=93.88%(TOP5)
|
|||
## [Model Export](#contents)
|
||||
|
||||
```shell
|
||||
python export.py --device_target [PLATFORM] --ckpt_file [CKPT_PATH] --file_format [EXPORT_FORMAT]
|
||||
python export.py --device_target [PLATFORM] --checkpoint_file_path [CKPT_PATH] --file_format [EXPORT_FORMAT]
|
||||
```
|
||||
|
||||
The `ckpt_file` parameter is required.
|
||||
|
|
|
@ -68,6 +68,38 @@ ResNeXt整体网络架构如下:
|
|||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||
|
||||
如果要在modelarts上进行模型的训练,可以参考modelarts的官方指导文档(https://support.huaweicloud.com/modelarts/)
|
||||
开始进行模型的训练和推理,具体操作如下:
|
||||
|
||||
```python
|
||||
# 在modelarts上使用分布式训练的示例:
|
||||
# (1) 选址a或者b其中一种方式。
|
||||
# a. 设置 "enable_modelarts=True" 。
|
||||
# 在yaml文件上设置网络所需的参数。
|
||||
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。
|
||||
# 在modelarts的界面上设置网络所需的参数。
|
||||
# (2) 在modelarts的界面上设置代码的路径 "/path/resnext50"。
|
||||
# (3) 在modelarts的界面上设置模型的启动文件 "train.py" 。
|
||||
# (4) 在modelarts的界面上设置模型的数据路径 "Dataset path" ,
|
||||
# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
|
||||
# (5) 开始模型的训练。
|
||||
|
||||
# 在modelarts上使用模型推理的示例
|
||||
# (1) 把训练好的模型地方到桶的对应位置。
|
||||
# (2) 选址a或者b其中一种方式。
|
||||
# a. 设置 "enable_modelarts=True"
|
||||
# 设置 "checkpoint_file_path='/cache/checkpoint_path/model.ckpt" 在 yaml 文件.
|
||||
# 设置 "checkpoint_url=/The path of checkpoint in S3/" 在 yaml 文件.
|
||||
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。
|
||||
# 增加 "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" 参数在modearts的界面上。
|
||||
# 增加 "checkpoint_url=/The path of checkpoint in S3/" 参数在modearts的界面上。
|
||||
# (3) 在modelarts的界面上设置代码的路径 "/path/resnext50"。
|
||||
# (4) 在modelarts的界面上设置模型的启动文件 "eval.py" 。
|
||||
# (5) 在modelarts的界面上设置模型的数据路径 "Dataset path" ,
|
||||
# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
|
||||
# (6) 开始模型的推理。
|
||||
```
|
||||
|
||||
# 脚本说明
|
||||
|
||||
## 脚本及样例代码
|
||||
|
@ -102,9 +134,14 @@ ResNeXt整体网络架构如下:
|
|||
├─linear_warmup.py # 线性热身学习率
|
||||
├─warmup_cosine_annealing.py # 每次迭代的学习率
|
||||
├─warmup_step_lr.py # 热身迭代学习率
|
||||
├─eval.py # 评估网络
|
||||
├─model_utils
|
||||
├──config.py # 参数配置
|
||||
├──device_adapter.py # 设备配置
|
||||
├──local_adapter.py # 本地设备配置
|
||||
├──moxing_adapter.py # modelarts设备配置
|
||||
├──eval.py # 评估网络
|
||||
├──train.py # 训练网络
|
||||
├──mindspore_hub_conf.py # MindSpore Hub接口
|
||||
├──mindspore_hub_conf.py # MindSpore Hub接口
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
|
@ -144,7 +181,7 @@ ResNeXt整体网络架构如下:
|
|||
您可以通过python脚本开始训练:
|
||||
|
||||
```shell
|
||||
python train.py --data_dir ~/imagenet/train/ --platform Ascend --is_distributed 0
|
||||
python train.py --data_path ~/imagenet/train/ --device_target Ascend --run_distribute 0
|
||||
```
|
||||
|
||||
或通过shell脚本开始训练:
|
||||
|
@ -185,17 +222,17 @@ sh scripts/run_standalone_train_for_gpu.sh 0 /dataset/train
|
|||
您可以通过python脚本开始训练:
|
||||
|
||||
```shell
|
||||
python eval.py --data_dir ~/imagenet/val/ --platform Ascend --pretrained resnext.ckpt
|
||||
python eval.py --data_path ~/imagenet/val/ --device_target Ascend --checkpoint_file_path resnext.ckpt
|
||||
```
|
||||
|
||||
或通过shell脚本开始训练:
|
||||
|
||||
```shell
|
||||
# 评估
|
||||
sh run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH PLATFORM
|
||||
sh run_eval.sh DEVICE_ID DATA_PATH CHECKPOINT_FILE_PATH PLATFORM
|
||||
```
|
||||
|
||||
PLATFORM is Ascend or GPU, default is Ascend.
|
||||
DEVICE_TARGET is Ascend or GPU, default is Ascend.
|
||||
|
||||
#### 样例
|
||||
|
||||
|
@ -216,7 +253,7 @@ acc=93.88%(TOP5)
|
|||
## 模型导出
|
||||
|
||||
```shell
|
||||
python export.py --device_target [PLATFORM] --ckpt_file [CKPT_PATH] --file_format [EXPORT_FORMAT]
|
||||
python export.py --device_target [PLATFORM] --checkpoint_file_path [CKPT_PATH] --file_format [EXPORT_FORMAT]
|
||||
```
|
||||
|
||||
`ckpt_file` 参数为必填项。
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
run_distribute: False
|
||||
enable_profiling: False
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path/"
|
||||
device_target: 'Ascend'
|
||||
checkpoint_path: './checkpoint/'
|
||||
checkpoint_file_path: ''
|
||||
|
||||
# ==============================================================================
|
||||
# Training options
|
||||
image_size: [224,224]
|
||||
num_classes: 1000
|
||||
batch_size: 1
|
||||
|
||||
lr: 0.4
|
||||
lr_scheduler: 'cosine_annealing'
|
||||
lr_epochs: [30,60,90,120]
|
||||
lr_gamma: 0.1
|
||||
eta_min: 0
|
||||
T_max: 150
|
||||
max_epoch: 150
|
||||
warmup_epochs: 1
|
||||
|
||||
weight_decay: 0.0001
|
||||
momentum: 0.9
|
||||
is_dynamic_loss_scale: 0
|
||||
loss_scale: 1024
|
||||
label_smooth: 1
|
||||
label_smooth_factor: 0.1
|
||||
per_batch_size: 128
|
||||
|
||||
ckpt_interval: 5
|
||||
ckpt_save_max: 5
|
||||
is_save_on_master: 1
|
||||
rank_save_ckpt_flag: 0
|
||||
outputs_dir: ""
|
||||
log_path: './output_log'
|
||||
|
||||
# Export options
|
||||
device_id: 0
|
||||
width: 224
|
||||
height: 224
|
||||
file_name: "resnext101"
|
||||
file_format: "AIR"
|
||||
result_path: ""
|
||||
label_path: ""
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||
data_url: 'Dataset url for obs'
|
||||
train_url: 'Training output url for obs'
|
||||
checkpoint_url: 'The location of checkpoint for obs'
|
||||
data_path: 'Dataset path for local'
|
||||
output_path: 'Training output path for local'
|
||||
load_path: 'The location of checkpoint for obs'
|
||||
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
num_classes: 'Class for dataset'
|
||||
batch_size: "Batch size for training and evaluation"
|
||||
epoch_size: "Total training epochs."
|
||||
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
|
||||
checkpoint_path: "The location of the checkpoint file."
|
||||
checkpoint_file_path: "The location of the checkpoint file."
|
|
@ -15,7 +15,6 @@
|
|||
"""Eval"""
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import datetime
|
||||
import glob
|
||||
import numpy as np
|
||||
|
@ -33,7 +32,8 @@ from src.utils.auto_mixed_precision import auto_mixed_precision
|
|||
from src.utils.var_init import load_pretrain_model
|
||||
from src.image_classification import get_network
|
||||
from src.dataset import classification_dataset
|
||||
from src.config import config
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
|
||||
class ParameterReduce(nn.Cell):
|
||||
|
@ -50,52 +50,24 @@ class ParameterReduce(nn.Cell):
|
|||
return ret
|
||||
|
||||
|
||||
def parse_args(cloud_args=None):
|
||||
"""parse_args"""
|
||||
parser = argparse.ArgumentParser('mindspore classification test')
|
||||
parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform')
|
||||
|
||||
# dataset related
|
||||
parser.add_argument('--data_dir', type=str, default='/opt/npu/datasets/classification/val', help='eval data dir')
|
||||
parser.add_argument('--per_batch_size', default=32, type=int, help='batch size for per npu')
|
||||
# network related
|
||||
parser.add_argument('--graph_ckpt', type=int, default=1, help='graph ckpt or feed ckpt')
|
||||
parser.add_argument('--pretrained', default='', type=str, help='fully path of pretrained model to load. '
|
||||
'If it is a direction, it will test all ckpt')
|
||||
|
||||
# logging related
|
||||
parser.add_argument('--log_path', type=str, default='outputs/', help='path to save log')
|
||||
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
|
||||
|
||||
# roma obs
|
||||
parser.add_argument('--train_url', type=str, default="", help='train url')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
args = merge_args(args, cloud_args)
|
||||
args.image_size = config.image_size
|
||||
args.num_classes = config.num_classes
|
||||
args.rank = config.rank
|
||||
args.group_size = config.group_size
|
||||
|
||||
args.image_size = list(map(int, args.image_size.split(',')))
|
||||
|
||||
# init distributed
|
||||
if args.is_distributed:
|
||||
if args.platform == "Ascend":
|
||||
def set_parameters():
|
||||
"""set_parameters"""
|
||||
if config.run_distribute:
|
||||
if config.device_target == "Ascend":
|
||||
init()
|
||||
elif args.platform == "GPU":
|
||||
elif config.device_target == "GPU":
|
||||
init("nccl")
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
config.rank = get_rank()
|
||||
config.group_size = get_group_size()
|
||||
else:
|
||||
args.rank = 0
|
||||
args.group_size = 1
|
||||
config.rank = 0
|
||||
config.group_size = 1
|
||||
|
||||
args.outputs_dir = os.path.join(args.log_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
config.outputs_dir = os.path.join(config.log_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
|
||||
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||
return args
|
||||
config.logger = get_logger(config.outputs_dir, config.rank)
|
||||
return config
|
||||
|
||||
|
||||
def get_top5_acc(top5_arg, gt_class):
|
||||
|
@ -105,38 +77,25 @@ def get_top5_acc(top5_arg, gt_class):
|
|||
sub_count += 1
|
||||
return sub_count
|
||||
|
||||
def merge_args(args, cloud_args):
|
||||
"""merge_args"""
|
||||
args_dict = vars(args)
|
||||
if isinstance(cloud_args, dict):
|
||||
for key in cloud_args.keys():
|
||||
val = cloud_args[key]
|
||||
if key in args_dict and val:
|
||||
arg_type = type(args_dict[key])
|
||||
if arg_type is not type(None):
|
||||
val = arg_type(val)
|
||||
args_dict[key] = val
|
||||
return args
|
||||
|
||||
|
||||
def get_result(args, model, top1_correct, top5_correct, img_tot):
|
||||
def get_result(model, top1_correct, top5_correct, img_tot):
|
||||
"""calculate top1 and top5 value."""
|
||||
results = [[top1_correct], [top5_correct], [img_tot]]
|
||||
args.logger.info('before results={}'.format(results))
|
||||
if args.is_distributed:
|
||||
config.logger.info('before results=%s', results)
|
||||
if config.run_distribute:
|
||||
model_md5 = model.replace('/', '')
|
||||
tmp_dir = '/cache'
|
||||
if not os.path.exists(tmp_dir):
|
||||
os.mkdir(tmp_dir)
|
||||
top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(args.rank, model_md5)
|
||||
top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(args.rank, model_md5)
|
||||
img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(args.rank, model_md5)
|
||||
top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(config.rank, model_md5)
|
||||
top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(config.rank, model_md5)
|
||||
img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(config.rank, model_md5)
|
||||
np.save(top1_correct_npy, top1_correct)
|
||||
np.save(top5_correct_npy, top5_correct)
|
||||
np.save(img_tot_npy, img_tot)
|
||||
while True:
|
||||
rank_ok = True
|
||||
for other_rank in range(args.group_size):
|
||||
for other_rank in range(config.group_size):
|
||||
top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(other_rank, model_md5)
|
||||
top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(other_rank, model_md5)
|
||||
img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(other_rank, model_md5)
|
||||
|
@ -149,7 +108,7 @@ def get_result(args, model, top1_correct, top5_correct, img_tot):
|
|||
top1_correct_all = 0
|
||||
top5_correct_all = 0
|
||||
img_tot_all = 0
|
||||
for other_rank in range(args.group_size):
|
||||
for other_rank in range(config.group_size):
|
||||
top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(other_rank, model_md5)
|
||||
top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(other_rank, model_md5)
|
||||
img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(other_rank, model_md5)
|
||||
|
@ -161,53 +120,53 @@ def get_result(args, model, top1_correct, top5_correct, img_tot):
|
|||
else:
|
||||
results = np.array(results)
|
||||
|
||||
args.logger.info('after results={}'.format(results))
|
||||
config.logger.info('after results=%s', results)
|
||||
return results
|
||||
|
||||
|
||||
@moxing_wrapper()
|
||||
def test(cloud_args=None):
|
||||
"""test"""
|
||||
args = parse_args(cloud_args)
|
||||
set_parameters()
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
device_target=args.platform, save_graphs=False)
|
||||
device_target=config.device_target, save_graphs=False)
|
||||
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||
|
||||
# init distributed
|
||||
if args.is_distributed:
|
||||
if config.run_distribute:
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=config.group_size,
|
||||
gradients_mean=True)
|
||||
|
||||
args.logger.save_args(args)
|
||||
config.logger.save_args(config)
|
||||
|
||||
# network
|
||||
args.logger.important_info('start create network')
|
||||
if os.path.isdir(args.pretrained):
|
||||
models = list(glob.glob(os.path.join(args.pretrained, '*.ckpt')))
|
||||
config.logger.important_info('start create network')
|
||||
if os.path.isdir(config.pretrained):
|
||||
models = list(glob.glob(os.path.join(config.pretrained, '*.ckpt')))
|
||||
print(models)
|
||||
if args.graph_ckpt:
|
||||
if config.graph_ckpt:
|
||||
f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('-')[-1].split('_')[0])
|
||||
else:
|
||||
f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('_')[-1])
|
||||
args.models = sorted(models, key=f)
|
||||
config.models = sorted(models, key=f)
|
||||
else:
|
||||
args.models = [args.pretrained,]
|
||||
config.models = [config.checkpoint_file_path,]
|
||||
|
||||
for model in args.models:
|
||||
de_dataset = classification_dataset(args.data_dir, image_size=args.image_size,
|
||||
per_batch_size=args.per_batch_size,
|
||||
max_epoch=1, rank=args.rank, group_size=args.group_size,
|
||||
for model in config.models:
|
||||
de_dataset = classification_dataset(config.data_path, image_size=config.image_size,
|
||||
per_batch_size=config.per_batch_size,
|
||||
max_epoch=1, rank=config.rank, group_size=config.group_size,
|
||||
mode='eval')
|
||||
eval_dataloader = de_dataset.create_tuple_iterator(output_numpy=True, num_epochs=1)
|
||||
network = get_network(num_classes=args.num_classes, platform=args.platform)
|
||||
network = get_network(num_classes=config.num_classes, platform=config.device_target)
|
||||
|
||||
load_pretrain_model(model, network, args)
|
||||
load_pretrain_model(model, network, config)
|
||||
|
||||
img_tot = 0
|
||||
top1_correct = 0
|
||||
top5_correct = 0
|
||||
if args.platform == "Ascend":
|
||||
if config.device_target == "Ascend":
|
||||
network.to_float(mstype.float16)
|
||||
else:
|
||||
auto_mixed_precision(network)
|
||||
|
@ -224,26 +183,26 @@ def test(cloud_args=None):
|
|||
t1_correct = np.equal(top1_output, gt_classes).sum()
|
||||
top1_correct += t1_correct
|
||||
top5_correct += get_top5_acc(top5_output, gt_classes)
|
||||
img_tot += args.per_batch_size
|
||||
img_tot += config.per_batch_size
|
||||
|
||||
if args.rank == 0 and it == 0:
|
||||
if config.rank == 0 and it == 0:
|
||||
t_end = time.time()
|
||||
it = 1
|
||||
if args.rank == 0:
|
||||
if config.rank == 0:
|
||||
time_used = time.time() - t_end
|
||||
fps = (img_tot - args.per_batch_size) * args.group_size / time_used
|
||||
args.logger.info('Inference Performance: {:.2f} img/sec'.format(fps))
|
||||
results = get_result(args, model, top1_correct, top5_correct, img_tot)
|
||||
fps = (img_tot - config.per_batch_size) * config.group_size / time_used
|
||||
config.logger.info('Inference Performance: {:.2f} img/sec'.format(fps))
|
||||
results = get_result(model, top1_correct, top5_correct, img_tot)
|
||||
top1_correct = results[0, 0]
|
||||
top5_correct = results[1, 0]
|
||||
img_tot = results[2, 0]
|
||||
acc1 = 100.0 * top1_correct / img_tot
|
||||
acc5 = 100.0 * top5_correct / img_tot
|
||||
args.logger.info('after allreduce eval: top1_correct={}, tot={},'
|
||||
'acc={:.2f}%(TOP1)'.format(top1_correct, img_tot, acc1))
|
||||
args.logger.info('after allreduce eval: top5_correct={}, tot={},'
|
||||
'acc={:.2f}%(TOP5)'.format(top5_correct, img_tot, acc5))
|
||||
if args.is_distributed:
|
||||
config.logger.info('after allreduce eval: top1_correct={}, tot={},'
|
||||
'acc={:.2f}%(TOP1)'.format(top1_correct, img_tot, acc1))
|
||||
config.logger.info('after allreduce eval: top5_correct={}, tot={},'
|
||||
'acc={:.2f}%(TOP5)'.format(top5_correct, img_tot, acc5))
|
||||
if config.run_distribute:
|
||||
release()
|
||||
|
||||
|
||||
|
|
|
@ -15,40 +15,28 @@
|
|||
"""
|
||||
resnext export mindir.
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
|
||||
from src.config import config
|
||||
from src.model_utils.config import config
|
||||
from src.image_classification import get_network
|
||||
from src.utils.auto_mixed_precision import auto_mixed_precision
|
||||
|
||||
parser = argparse.ArgumentParser(description='checkpoint export')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument('--width', type=int, default=224, help='input width')
|
||||
parser.add_argument('--height', type=int, default=224, help='input height')
|
||||
parser.add_argument("--file_name", type=str, default="resnext50", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=config.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
network = get_network(num_classes=config.num_classes, platform=args.device_target)
|
||||
network = get_network(num_classes=config.num_classes, platform=config.device_target)
|
||||
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
param_dict = load_checkpoint(config.checkpoint_file_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
if args.device_target == "Ascend":
|
||||
if config.device_target == "Ascend":
|
||||
network.to_float(mstype.float16)
|
||||
else:
|
||||
auto_mixed_precision(network)
|
||||
network.set_train(False)
|
||||
input_shp = [args.batch_size, 3, args.height, args.width]
|
||||
input_shp = [config.batch_size, 3, config.height, config.width]
|
||||
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
|
||||
export(network, input_array, file_name=args.file_name, file_format=args.file_format)
|
||||
export(network, input_array, file_name=config.file_name, file_format=config.file_format)
|
||||
|
|
|
@ -15,16 +15,10 @@
|
|||
"""post process for 310 inference"""
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import numpy as np
|
||||
from src.config import config
|
||||
from src.model_utils.config import config
|
||||
|
||||
batch_size = 1
|
||||
parser = argparse.ArgumentParser(description="resnet inference")
|
||||
parser.add_argument("--result_path", type=str, required=True, help="result files path.")
|
||||
parser.add_argument("--label_path", type=str, required=True, help="image file path.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def get_result(result_path, label_path):
|
||||
files = os.listdir(result_path)
|
||||
|
@ -48,4 +42,4 @@ def get_result(result_path, label_path):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
get_result(args.result_path, args.label_path)
|
||||
get_result(config.result_path, config.label_path)
|
||||
|
|
|
@ -44,14 +44,17 @@ do
|
|||
rm -rf LOG$i
|
||||
mkdir ./LOG$i
|
||||
cp *.py ./LOG$i
|
||||
cp *.yaml ./LOG$i
|
||||
cp ./src ./LOG$i
|
||||
cd ./LOG$i || exit
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
|
||||
env > env.log
|
||||
taskset -c $cmdopt python ../train.py \
|
||||
--is_distribute=1 \
|
||||
--run_distribute=1 \
|
||||
--device_id=$DEVICE_ID \
|
||||
--pretrained=$PATH_CHECKPOINT \
|
||||
--data_dir=$DATA_DIR > log.txt 2>&1 &
|
||||
--checkpoint_file_path=$PATH_CHECKPOINT \
|
||||
--data_path=$DATA_DIR \
|
||||
--output_path './output' > log.txt 2>&1 &
|
||||
cd ../
|
||||
done
|
||||
|
|
|
@ -24,7 +24,8 @@ fi
|
|||
|
||||
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
|
||||
python train.py \
|
||||
--is_distribute=1 \
|
||||
--platform="GPU" \
|
||||
--pretrained=$PATH_CHECKPOINT \
|
||||
--data_dir=$DATA_DIR > log.txt 2>&1 &
|
||||
--run_distribute=1 \
|
||||
--device_target="GPU" \
|
||||
--checkpoint_file_path=$PATH_CHECKPOINT \
|
||||
--data_path=$DATA_DIR \
|
||||
--output_path './output' > log.txt 2>&1 &
|
||||
|
|
|
@ -24,6 +24,6 @@ then
|
|||
fi
|
||||
|
||||
python eval.py \
|
||||
--pretrained=$PATH_CHECKPOINT \
|
||||
--platform=$PLATFORM \
|
||||
--data_dir=$DATA_DIR > log.txt 2>&1 &
|
||||
--checkpoint_file_path=$PATH_CHECKPOINT \
|
||||
--device_target=$PLATFORM \
|
||||
--data_path=$DATA_DIR > log.txt 2>&1 &
|
||||
|
|
|
@ -23,8 +23,9 @@ then
|
|||
fi
|
||||
|
||||
python train.py \
|
||||
--is_distribute=0 \
|
||||
--run_distribute=0 \
|
||||
--device_id=$DEVICE_ID \
|
||||
--pretrained=$PATH_CHECKPOINT \
|
||||
--data_dir=$DATA_DIR > log.txt 2>&1 &
|
||||
--checkpoint_file_path=$PATH_CHECKPOINT \
|
||||
--data_path=$DATA_DIR \
|
||||
--output_path './output' > log.txt 2>&1 &
|
||||
|
||||
|
|
|
@ -23,8 +23,9 @@ then
|
|||
fi
|
||||
|
||||
python train.py \
|
||||
--is_distribute=0 \
|
||||
--pretrained=$PATH_CHECKPOINT \
|
||||
--platform="GPU" \
|
||||
--data_dir=$DATA_DIR > log.txt 2>&1 &
|
||||
--run_distribute=0 \
|
||||
--checkpoint_file_path=$PATH_CHECKPOINT \
|
||||
--device_target="GPU" \
|
||||
--data_path=$DATA_DIR \
|
||||
--output_path './output' > log.txt 2>&1 &
|
||||
|
||||
|
|
|
@ -1,45 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""config"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
"image_size": '224,224',
|
||||
"num_classes": 1000,
|
||||
|
||||
"lr": 0.4,
|
||||
"lr_scheduler": 'cosine_annealing',
|
||||
"lr_epochs": '30,60,90,120',
|
||||
"lr_gamma": 0.1,
|
||||
"eta_min": 0,
|
||||
"T_max": 150,
|
||||
"max_epoch": 150,
|
||||
"warmup_epochs": 1,
|
||||
|
||||
"weight_decay": 0.0001,
|
||||
"momentum": 0.9,
|
||||
"is_dynamic_loss_scale": 0,
|
||||
"loss_scale": 1024,
|
||||
"label_smooth": 1,
|
||||
"label_smooth_factor": 0.1,
|
||||
|
||||
"ckpt_interval": 5,
|
||||
"ckpt_save_max": 5,
|
||||
"ckpt_path": 'outputs/',
|
||||
"is_save_on_master": 1,
|
||||
|
||||
"rank": 0,
|
||||
"group_size": 1
|
||||
})
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
_config_path = "./default_config.yaml"
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
else:
|
||||
raise ValueError("At most 2 docs (config and help description for help) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser, default, helper, path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from src.model_utils.config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from src.model_utils.config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -15,7 +15,6 @@
|
|||
"""train ImageNet."""
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import datetime
|
||||
|
||||
import mindspore.nn as nn
|
||||
|
@ -36,7 +35,8 @@ from src.utils.logging import get_logger
|
|||
from src.utils.optimizers__init__ import get_param_groups
|
||||
from src.utils.var_init import load_pretrain_model
|
||||
from src.image_classification import get_network
|
||||
from src.config import config
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
@ -104,155 +104,101 @@ class ProgressMonitor(Callback):
|
|||
self.args.logger.info('end network train...')
|
||||
|
||||
|
||||
def parse_args(cloud_args=None):
|
||||
def set_parameters():
|
||||
"""parameters"""
|
||||
parser = argparse.ArgumentParser('mindspore classification training')
|
||||
parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform')
|
||||
|
||||
# dataset related
|
||||
parser.add_argument('--data_dir', type=str, default='', help='train data dir')
|
||||
parser.add_argument('--per_batch_size', default=128, type=int, help='batch size for per gpu')
|
||||
# network related
|
||||
parser.add_argument('--pretrained', default='', type=str, help='model_path, local pretrained model to load')
|
||||
|
||||
# distributed related
|
||||
parser.add_argument('--is_distributed', type=int, default=1, help='if multi device')
|
||||
# roma obs
|
||||
parser.add_argument('--train_url', type=str, default="", help='train url')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
args = merge_args(args, cloud_args)
|
||||
args.image_size = config.image_size
|
||||
args.num_classes = config.num_classes
|
||||
args.lr = config.lr
|
||||
args.lr_scheduler = config.lr_scheduler
|
||||
args.lr_epochs = config.lr_epochs
|
||||
args.lr_gamma = config.lr_gamma
|
||||
args.eta_min = config.eta_min
|
||||
args.T_max = config.T_max
|
||||
args.max_epoch = config.max_epoch
|
||||
args.warmup_epochs = config.warmup_epochs
|
||||
args.weight_decay = config.weight_decay
|
||||
args.momentum = config.momentum
|
||||
args.is_dynamic_loss_scale = config.is_dynamic_loss_scale
|
||||
args.loss_scale = config.loss_scale
|
||||
args.label_smooth = config.label_smooth
|
||||
args.label_smooth_factor = config.label_smooth_factor
|
||||
args.ckpt_interval = config.ckpt_interval
|
||||
args.ckpt_save_max = config.ckpt_save_max
|
||||
args.ckpt_path = config.ckpt_path
|
||||
args.is_save_on_master = config.is_save_on_master
|
||||
args.rank = config.rank
|
||||
args.group_size = config.group_size
|
||||
args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
|
||||
args.image_size = list(map(int, args.image_size.split(',')))
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
device_target=args.platform, save_graphs=False)
|
||||
device_target=config.device_target, save_graphs=False)
|
||||
# init distributed
|
||||
if args.is_distributed:
|
||||
if config.run_distribute:
|
||||
init()
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
config.rank = get_rank()
|
||||
config.group_size = get_group_size()
|
||||
else:
|
||||
args.rank = 0
|
||||
args.group_size = 1
|
||||
config.rank = 0
|
||||
config.group_size = 1
|
||||
|
||||
if args.is_dynamic_loss_scale == 1:
|
||||
args.loss_scale = 1 # for dynamic loss scale can not set loss scale in momentum opt
|
||||
if config.is_dynamic_loss_scale == 1:
|
||||
config.loss_scale = 1 # for dynamic loss scale can not set loss scale in momentum opt
|
||||
|
||||
# select for master rank save ckpt or all rank save, compatible for model parallel
|
||||
args.rank_save_ckpt_flag = 0
|
||||
if args.is_save_on_master:
|
||||
if args.rank == 0:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
config.rank_save_ckpt_flag = 0
|
||||
if config.is_save_on_master:
|
||||
if config.rank == 0:
|
||||
config.rank_save_ckpt_flag = 1
|
||||
else:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
config.rank_save_ckpt_flag = 1
|
||||
|
||||
# logger
|
||||
args.outputs_dir = os.path.join(args.ckpt_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||
return args
|
||||
config.outputs_dir = os.path.join(config.output_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
config.logger = get_logger(config.outputs_dir, config.rank)
|
||||
return config
|
||||
|
||||
def merge_args(args, cloud_args):
|
||||
"""dictionary"""
|
||||
args_dict = vars(args)
|
||||
if isinstance(cloud_args, dict):
|
||||
for key in cloud_args.keys():
|
||||
val = cloud_args[key]
|
||||
if key in args_dict and val:
|
||||
arg_type = type(args_dict[key])
|
||||
if arg_type is not type(None):
|
||||
val = arg_type(val)
|
||||
args_dict[key] = val
|
||||
return args
|
||||
|
||||
|
||||
def train(cloud_args=None):
|
||||
@moxing_wrapper()
|
||||
def train():
|
||||
"""training process"""
|
||||
args = parse_args(cloud_args)
|
||||
set_parameters()
|
||||
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||
|
||||
# init distributed
|
||||
if args.is_distributed:
|
||||
if config.run_distribute:
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=config.group_size,
|
||||
gradients_mean=True)
|
||||
# dataloader
|
||||
de_dataset = classification_dataset(args.data_dir, args.image_size,
|
||||
args.per_batch_size, 1,
|
||||
args.rank, args.group_size, num_parallel_workers=8)
|
||||
de_dataset = classification_dataset(config.data_path, config.image_size,
|
||||
config.per_batch_size, 1,
|
||||
config.rank, config.group_size, num_parallel_workers=8)
|
||||
de_dataset.map_model = 4 # !!!important
|
||||
args.steps_per_epoch = de_dataset.get_dataset_size()
|
||||
config.steps_per_epoch = de_dataset.get_dataset_size()
|
||||
|
||||
args.logger.save_args(args)
|
||||
config.logger.save_args(config)
|
||||
|
||||
# network
|
||||
args.logger.important_info('start create network')
|
||||
config.logger.important_info('start create network')
|
||||
# get network and init
|
||||
network = get_network(num_classes=args.num_classes, platform=args.platform)
|
||||
network = get_network(num_classes=config.num_classes, platform=config.device_target)
|
||||
|
||||
load_pretrain_model(args.pretrained, network, args)
|
||||
load_pretrain_model(config.checkpoint_file_path, network, config)
|
||||
|
||||
# lr scheduler
|
||||
lr = get_lr(args)
|
||||
lr = get_lr(config)
|
||||
|
||||
# optimizer
|
||||
opt = Momentum(params=get_param_groups(network),
|
||||
learning_rate=Tensor(lr),
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
loss_scale=args.loss_scale)
|
||||
momentum=config.momentum,
|
||||
weight_decay=config.weight_decay,
|
||||
loss_scale=config.loss_scale)
|
||||
|
||||
|
||||
# loss
|
||||
if not args.label_smooth:
|
||||
args.label_smooth_factor = 0.0
|
||||
loss = CrossEntropy(smooth_factor=args.label_smooth_factor, num_classes=args.num_classes)
|
||||
if not config.label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.num_classes)
|
||||
|
||||
if args.is_dynamic_loss_scale == 1:
|
||||
if config.is_dynamic_loss_scale == 1:
|
||||
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
|
||||
else:
|
||||
loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
|
||||
loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
|
||||
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
|
||||
metrics={'acc'}, amp_level="O3")
|
||||
|
||||
# checkpoint save
|
||||
progress_cb = ProgressMonitor(args)
|
||||
progress_cb = ProgressMonitor(config)
|
||||
callbacks = [progress_cb,]
|
||||
if args.rank_save_ckpt_flag:
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch,
|
||||
keep_checkpoint_max=args.ckpt_save_max)
|
||||
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
|
||||
if config.rank_save_ckpt_flag:
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.ckpt_interval * config.steps_per_epoch,
|
||||
keep_checkpoint_max=config.ckpt_save_max)
|
||||
save_ckpt_path = os.path.join(config.outputs_dir, 'ckpt_' + str(config.rank) + '/')
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix='{}'.format(args.rank))
|
||||
prefix='{}'.format(config.rank))
|
||||
callbacks.append(ckpt_cb)
|
||||
|
||||
model.train(args.max_epoch, de_dataset, callbacks=callbacks, dataset_sink_mode=True)
|
||||
model.train(config.max_epoch, de_dataset, callbacks=callbacks, dataset_sink_mode=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue