forked from mindspore-Ecosystem/mindspore
Modify googlenet for clould
This commit is contained in:
parent
a0fe698e61
commit
4de1f37f68
|
@ -120,6 +120,53 @@ After installing MindSpore via the official website, you can start training and
|
|||
|
||||
We use CIFAR-10 dataset by default. Your can also pass `$dataset_type` to the scripts so that select different datasets. For more details, please refer the specify script.
|
||||
|
||||
- ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows)
|
||||
|
||||
- Train imagenet 8p on ModelArts
|
||||
|
||||
```python
|
||||
# (1) Add "config_path='/path_to_code/imagenet_config.yaml'" on the website UI interface.
|
||||
# (2) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on imagenet_config.yaml file.
|
||||
# Set "dataset_name='imagenet'" on imagenet_config.yaml file.
|
||||
# Set "train_data_path='/cache/data/ImageNet/train/'" on imagenet_config.yaml file.
|
||||
# Set other parameters on imagenet_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "dataset_name=imagenet" on the website UI interface.
|
||||
# Add "train_data_path=/cache/data/ImageNet/train/" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||
# (4) Set the code directory to "/path/googlenet" on the website UI interface.
|
||||
# (5) Set the startup file to "train.py" on the website UI interface.
|
||||
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (7) Create your job.
|
||||
```
|
||||
|
||||
- Eval imagenet on ModelArts
|
||||
|
||||
```python
|
||||
# (1) Add "config_path='/path_to_code/imagenet_config.yaml'" on the website UI interface.
|
||||
# (2) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on imagenet_config.yaml file.
|
||||
# Set "dataset_name='imagenet'" on imagenet_config.yaml file.
|
||||
# Set "val_data_path='/cache/data/ImageNet/val/'" on imagenet_config.yaml file.
|
||||
# Set "checkpoint_url='s3://dir_to_trained_ckpt/'" on imagenet_config.yaml file.
|
||||
# Set "checkpoint_path='/cache/checkpoint_path/model.ckpt'" on imagenet_config.yaml file.
|
||||
# Set other parameters on imagenet_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "dataset_name=imagenet" on the website UI interface.
|
||||
# Add "val_data_path=/cache/data/ImageNet/val/" on the website UI interface.
|
||||
# Add "checkpoint_url='s3://dir_to_trained_ckpt/'" on the website UI interface.
|
||||
# Add "checkpoint_path='/cache/checkpoint_path/model.ckpt'" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (3) Upload or copy your pretrained model to S3 bucket.
|
||||
# (4) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||
# (5) Set the code directory to "/path/googlenet" on the website UI interface.
|
||||
# (6) Set the startup file to "eval.py" on the website UI interface.
|
||||
# (7) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (8) Create your job.
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
|
|
@ -127,6 +127,53 @@ GoogleNet由多个inception模块串联起来,可以更加深入。 降维的
|
|||
|
||||
默认使用CIFAR-10数据集。您也可以将`$dataset_type`传入脚本,以便选择其他数据集。如需查看更多详情,请参考指定脚本。
|
||||
|
||||
- 在 ModelArts 进行训练 (如果你想在modelarts上运行,可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/))
|
||||
|
||||
- 在 ModelArts 上使用8卡训练 ImageNet 数据集
|
||||
|
||||
```python
|
||||
# (1) 在网页上设置 "config_path='/path_to_code/imagenet_config.yaml'"
|
||||
# (2) 执行a或者b
|
||||
# a. 在 imagenet_config.yaml 文件中设置 "enable_modelarts=True"
|
||||
# 在 imagenet_config.yaml 文件中设置 "dataset_name='imagenet'"
|
||||
# 在 imagenet_config.yaml 文件中设置 "train_data_path='/cache/data/ImageNet/train/'"
|
||||
# 在 imagenet_config.yaml 文件中设置 其他参数
|
||||
# b. 在网页上设置 "enable_modelarts=True"
|
||||
# 在网页上设置 "dataset_name=imagenet"
|
||||
# 在网页上设置 "train_data_path=/cache/data/ImageNet/train/"
|
||||
# 在网页上设置 其他参数
|
||||
# (3) 上传你的压缩数据集到 S3 桶上 (你也可以上传原始的数据集,但那可能会很慢。)
|
||||
# (4) 在网页上设置你的代码路径为 "/path/googlenet"
|
||||
# (5) 在网页上设置启动文件为 "train.py"
|
||||
# (6) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
|
||||
# (7) 创建训练作业
|
||||
```
|
||||
|
||||
- 在 ModelArts 上使用单卡验证 ImageNet 数据集
|
||||
|
||||
```python
|
||||
# (1) 在网页上设置 "config_path='/path_to_code/imagenet_config.yaml'"
|
||||
# (2) 执行a或者b
|
||||
# a. 在 imagenet_config.yaml 文件中设置 "enable_modelarts=True"
|
||||
# 在 imagenet_config.yaml 文件中设置 "dataset_name='imagenet'"
|
||||
# 在 imagenet_config.yaml 文件中设置 "val_data_path='/cache/data/ImageNet/val/'"
|
||||
# 在 imagenet_config.yaml 文件中设置 "checkpoint_url='s3://dir_to_trained_ckpt/'"
|
||||
# 在 imagenet_config.yaml 文件中设置 "checkpoint_path='/cache/checkpoint_path/model.ckpt'"
|
||||
# 在 imagenet_config.yaml 文件中设置 其他参数
|
||||
# b. 在网页上设置 "enable_modelarts=True"
|
||||
# 在网页上设置 "dataset_name=imagenet"
|
||||
# 在网页上设置 "val_data_path=/cache/data/ImageNet/val/"
|
||||
# 在网页上设置 "checkpoint_url='s3://dir_to_trained_ckpt/'"
|
||||
# 在网页上设置 "checkpoint_path='/cache/checkpoint_path/model.ckpt'"
|
||||
# 在网页上设置 其他参数
|
||||
# (3) 上传你的预训练模型到 S3 桶上
|
||||
# (4) 上传你的压缩数据集到 S3 桶上 (你也可以上传原始的数据集,但那可能会很慢。)
|
||||
# (5) 在网页上设置你的代码路径为 "/path/googlenet"
|
||||
# (6) 在网页上设置启动文件为 "eval.py"
|
||||
# (7) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
|
||||
# (8) 创建训练作业
|
||||
```
|
||||
|
||||
# 脚本说明
|
||||
|
||||
## 脚本及样例代码
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
need_modelarts_dataset_unzip: True
|
||||
modelarts_dataset_unzip_name: "cifar10"
|
||||
|
||||
# ==============================================================================
|
||||
# options
|
||||
dataset_name: "cifar10"
|
||||
name: "cifar10"
|
||||
pre_trained: False
|
||||
num_classes: 10
|
||||
lr_init: 0.1
|
||||
batch_size: 128
|
||||
epoch_size: 125
|
||||
momentum: 0.9
|
||||
weight_decay: 0.0005 #5e-4
|
||||
image_height: 224
|
||||
image_width: 224
|
||||
train_data_path: "/cache/data/cifar10/"
|
||||
val_data_path: "/cache/data/cifar10/"
|
||||
keep_checkpoint_max: 10
|
||||
checkpoint_path: "./train_googlenet_cifar10-125_390.ckpt"
|
||||
onnx_filename: "googlenet"
|
||||
air_filename: "googlenet"
|
||||
ckpt_save_dir: "./ckpt/"
|
||||
|
||||
# export option
|
||||
ckpt_file: ""
|
||||
file_name: "googlenet"
|
||||
file_format: "AIR"
|
||||
#batch_size: 1
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of the input data."
|
||||
output_path: "The location of the output file."
|
||||
device_target: 'Target device type'
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
|
@ -16,7 +16,8 @@
|
|||
##############test googlenet example on cifar10#################
|
||||
python eval.py
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
|
@ -25,58 +26,105 @@ from mindspore.train.model import Model
|
|||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import cifar_cfg, imagenet_cfg
|
||||
from src.dataset import create_dataset_cifar10, create_dataset_imagenet
|
||||
|
||||
from src.googlenet import GoogleNet
|
||||
from src.CrossEntropySmooth import CrossEntropySmooth
|
||||
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description='googlenet')
|
||||
parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'],
|
||||
help='dataset name.')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
args_opt = parser.parse_args()
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
if config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
if args_opt.dataset_name == 'cifar10':
|
||||
cfg = cifar_cfg
|
||||
dataset = create_dataset_cifar10(cfg.data_path, 1, False)
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_eval():
|
||||
if config.dataset_name == 'cifar10':
|
||||
dataset = create_dataset_cifar10(config.val_data_path, 1, False, cifar_cfg=config)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
net = GoogleNet(num_classes=cfg.num_classes)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
|
||||
weight_decay=cfg.weight_decay)
|
||||
net = GoogleNet(num_classes=config.num_classes)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, config.momentum,
|
||||
weight_decay=config.weight_decay)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
|
||||
elif args_opt.dataset_name == "imagenet":
|
||||
cfg = imagenet_cfg
|
||||
dataset = create_dataset_imagenet(cfg.val_data_path, 1, False)
|
||||
if not cfg.use_label_smooth:
|
||||
cfg.label_smooth_factor = 0.0
|
||||
elif config.dataset_name == "imagenet":
|
||||
dataset = create_dataset_imagenet(config.val_data_path, 1, False, imagenet_cfg=config)
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
loss = CrossEntropySmooth(sparse=True, reduction="mean",
|
||||
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
|
||||
net = GoogleNet(num_classes=cfg.num_classes)
|
||||
smooth_factor=config.label_smooth_factor, num_classes=config.num_classes)
|
||||
net = GoogleNet(num_classes=config.num_classes)
|
||||
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
|
||||
|
||||
else:
|
||||
raise ValueError("dataset is not support.")
|
||||
|
||||
device_target = cfg.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
|
||||
device_target = config.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
if device_target == "Ascend":
|
||||
context.set_context(device_id=cfg.device_id)
|
||||
|
||||
if args_opt.checkpoint_path is not None:
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
print("load checkpoint from [{}].".format(args_opt.checkpoint_path))
|
||||
else:
|
||||
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||
print("load checkpoint from [{}].".format(cfg.checkpoint_path))
|
||||
context.set_context(device_id=get_device_id())
|
||||
|
||||
param_dict = load_checkpoint(config.checkpoint_path)
|
||||
print("load checkpoint from [{}].".format(config.checkpoint_path))
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
net.set_train(False)
|
||||
|
||||
acc = model.eval(dataset)
|
||||
print("accuracy: ", acc)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_eval()
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
##############export checkpoint file into air, onnx, mindir models#################
|
||||
python export.py
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
|
@ -25,26 +24,17 @@ from mindspore import Tensor, load_checkpoint, load_param_into_net, export, cont
|
|||
from src.config import cifar_cfg, imagenet_cfg
|
||||
from src.googlenet import GoogleNet
|
||||
|
||||
parser = argparse.ArgumentParser(description='Classification')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="googlenet", help="output file name.")
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'],
|
||||
help='dataset name.')
|
||||
args = parser.parse_args()
|
||||
from model_utils.config import config
|
||||
from model_utils.device_adapter import get_device_id
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=get_device_id())
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args.dataset_name == 'cifar10':
|
||||
if config.dataset_name == 'cifar10':
|
||||
cfg = cifar_cfg
|
||||
elif args.dataset_name == 'imagenet':
|
||||
elif config.dataset_name == 'imagenet':
|
||||
cfg = imagenet_cfg
|
||||
else:
|
||||
raise ValueError("dataset is not support.")
|
||||
|
@ -52,8 +42,8 @@ if __name__ == '__main__':
|
|||
net = GoogleNet(num_classes=cfg.num_classes)
|
||||
|
||||
assert cfg.checkpoint_path is not None, "cfg.checkpoint_path is None."
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
param_dict = load_checkpoint(config.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.ones([args.batch_size, 3, cfg.image_height, cfg.image_width]), ms.float32)
|
||||
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)
|
||||
input_arr = Tensor(np.ones([config.batch_size, 3, cfg.image_height, cfg.image_width]), ms.float32)
|
||||
export(net, input_arr, file_name=config.file_name, file_format=config.file_format)
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
need_modelarts_dataset_unzip: True
|
||||
modelarts_dataset_unzip_name: "ImageNet"
|
||||
|
||||
# ==============================================================================
|
||||
# options
|
||||
dataset_name: "imagenet"
|
||||
name: "imagenet"
|
||||
pre_trained: False
|
||||
num_classes: 1000
|
||||
lr_init: 0.1
|
||||
batch_size: 256
|
||||
epoch_size: 300
|
||||
momentum: 0.9
|
||||
weight_decay: 0.0001 #1e-4
|
||||
image_height: 224
|
||||
image_width: 224
|
||||
train_data_path: "/cache/data/ImageNet/train/"
|
||||
val_data_path: "/cache/data/ImageNet/validation_preprocess/"
|
||||
keep_checkpoint_max: 10
|
||||
checkpoint_path: ""
|
||||
onnx_filename: "googlenet"
|
||||
air_filename: "googlenet"
|
||||
ckpt_save_dir: "./ckpt/"
|
||||
|
||||
# optimizer and lr related
|
||||
lr_scheduler: "exponential"
|
||||
lr_epochs: [70, 140, 210, 280]
|
||||
lr_gamma: 0.3
|
||||
eta_min: 0.0
|
||||
T_max: 150
|
||||
warmup_epochs: 0
|
||||
|
||||
# loss related
|
||||
is_dynamic_loss_scale: 0
|
||||
loss_scale: 1024
|
||||
label_smooth_factor: 0.1
|
||||
use_label_smooth: True
|
||||
|
||||
# export option
|
||||
ckpt_file: ""
|
||||
file_name: "googlenet"
|
||||
file_format: "AIR"
|
||||
#batch_size: 1
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of the input data."
|
||||
output_path: "The location of the output file."
|
||||
device_target: 'Target device type'
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
|
||||
file_format: "choices in ['AIR', 'ONNX', 'MINDIR']"
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from .config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
# Run the main function
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -30,5 +30,7 @@ then
|
|||
fi
|
||||
dataset_type=$1
|
||||
fi
|
||||
config_path="./${dataset_type}_config.yaml"
|
||||
echo "config path is : ${config_path}"
|
||||
|
||||
python ${BASEPATH}/../eval.py --dataset_name=$dataset_type > ./eval.log 2>&1 &
|
||||
python ${BASEPATH}/../eval.py --config_path=$config_path --dataset_name=$dataset_type > ./eval.log 2>&1 &
|
||||
|
|
|
@ -39,6 +39,8 @@ then
|
|||
fi
|
||||
dataset_type=$2
|
||||
fi
|
||||
config_path="./${dataset_type}_config.yaml"
|
||||
echo "config path is : ${config_path}"
|
||||
|
||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||
|
@ -51,4 +53,4 @@ fi
|
|||
mkdir ../eval
|
||||
cd ../eval || exit
|
||||
|
||||
python3 ${BASEPATH}/../eval.py --checkpoint_path=$1 --dataset_name=$dataset_type > ./eval.log 2>&1 &
|
||||
python3 ${BASEPATH}/../eval.py --config_path=$config_path --checkpoint_path=$1 --dataset_name=$dataset_type > ./eval.log 2>&1 &
|
||||
|
|
|
@ -37,14 +37,15 @@ then
|
|||
fi
|
||||
dataset_type=$2
|
||||
fi
|
||||
|
||||
config_path="./${dataset_type}_config.yaml"
|
||||
echo "config path is : ${config_path}"
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
RANK_TABLE_FILE=$(realpath $1)
|
||||
export RANK_TABLE_FILE
|
||||
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
|
||||
PATH1=$(realpath $1)
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
echo "RANK_TABLE_FILE=${PATH1}"
|
||||
|
||||
export SERVER_ID=0
|
||||
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||
|
@ -55,10 +56,12 @@ do
|
|||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp -r ./src ./train_parallel$i
|
||||
cp -r ./model_utils ./train_parallel$i
|
||||
cp -r ./*.yaml ./train_parallel$i
|
||||
cp ./train.py ./train_parallel$i
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type"
|
||||
cd ./train_parallel$i ||exit
|
||||
env > env.log
|
||||
python train.py --device_id=$i --dataset_name=$dataset_type> log 2>&1 &
|
||||
python train.py --config_path=$config_path --dataset_name=$dataset_type> log 2>&1 &
|
||||
cd ..
|
||||
done
|
||||
|
|
|
@ -53,12 +53,14 @@ then
|
|||
fi
|
||||
dataset_type=$3
|
||||
fi
|
||||
config_path="./${dataset_type}_config.yaml"
|
||||
echo "config path is : ${config_path}"
|
||||
|
||||
|
||||
if [ $1 -gt 1 ]
|
||||
then
|
||||
mpirun -n $1 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
|
||||
python3 ${BASEPATH}/../train.py --dataset_name=$dataset_type > train.log 2>&1 &
|
||||
python3 ${BASEPATH}/../train.py --config_path=$config_path --dataset_name=$dataset_type > train.log 2>&1 &
|
||||
else
|
||||
python3 ${BASEPATH}/../train.py --dataset_name=$dataset_type > train.log 2>&1 &
|
||||
python3 ${BASEPATH}/../train.py --config_path=$config_path --dataset_name=$dataset_type > train.log 2>&1 &
|
||||
fi
|
||||
|
|
|
@ -1,73 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in main.py
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
cifar_cfg = edict({
|
||||
'name': 'cifar10',
|
||||
'pre_trained': False,
|
||||
'num_classes': 10,
|
||||
'lr_init': 0.1,
|
||||
'batch_size': 128,
|
||||
'epoch_size': 125,
|
||||
'momentum': 0.9,
|
||||
'weight_decay': 5e-4,
|
||||
'image_height': 224,
|
||||
'image_width': 224,
|
||||
'data_path': './cifar10',
|
||||
'device_target': 'Ascend',
|
||||
'device_id': 0,
|
||||
'keep_checkpoint_max': 10,
|
||||
'checkpoint_path': './train_googlenet_cifar10-125_390.ckpt',
|
||||
'onnx_filename': 'googlenet',
|
||||
'air_filename': 'googlenet'
|
||||
})
|
||||
|
||||
imagenet_cfg = edict({
|
||||
'name': 'imagenet',
|
||||
'pre_trained': False,
|
||||
'num_classes': 1000,
|
||||
'lr_init': 0.1,
|
||||
'batch_size': 256,
|
||||
'epoch_size': 300,
|
||||
'momentum': 0.9,
|
||||
'weight_decay': 1e-4,
|
||||
'image_height': 224,
|
||||
'image_width': 224,
|
||||
'data_path': './ImageNet_Original/train/',
|
||||
'val_data_path': './ImageNet_Original/val/',
|
||||
'device_target': 'Ascend',
|
||||
'device_id': 0,
|
||||
'keep_checkpoint_max': 10,
|
||||
'checkpoint_path': None,
|
||||
'onnx_filename': 'googlenet',
|
||||
'air_filename': 'googlenet',
|
||||
|
||||
# optimizer and lr related
|
||||
'lr_scheduler': 'exponential',
|
||||
'lr_epochs': [70, 140, 210, 280],
|
||||
'lr_gamma': 0.3,
|
||||
'eta_min': 0.0,
|
||||
'T_max': 150,
|
||||
'warmup_epochs': 0,
|
||||
|
||||
# loss related
|
||||
'is_dynamic_loss_scale': 0,
|
||||
'loss_scale': 1024,
|
||||
'label_smooth_factor': 0.1,
|
||||
'use_label_smooth': True,
|
||||
})
|
|
@ -21,20 +21,18 @@ import mindspore.common.dtype as mstype
|
|||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.dataset.vision.c_transforms as vision
|
||||
from src.config import cifar_cfg, imagenet_cfg
|
||||
|
||||
|
||||
def create_dataset_cifar10(data_home, repeat_num=1, training=True):
|
||||
def create_dataset_cifar10(data_home, repeat_num=1, training=True, cifar_cfg=None):
|
||||
"""Data operations."""
|
||||
data_dir = os.path.join(data_home, "cifar-10-batches-bin")
|
||||
if not training:
|
||||
data_dir = os.path.join(data_home, "cifar-10-verify-bin")
|
||||
|
||||
rank_size, rank_id = _get_rank_info()
|
||||
if training:
|
||||
rank_size, rank_id = _get_rank_info()
|
||||
data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=True)
|
||||
else:
|
||||
data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=False)
|
||||
data_set = ds.Cifar10Dataset(data_dir, shuffle=False)
|
||||
|
||||
resize_height = cifar_cfg.image_height
|
||||
resize_width = cifar_cfg.image_width
|
||||
|
@ -67,7 +65,7 @@ def create_dataset_cifar10(data_home, repeat_num=1, training=True):
|
|||
|
||||
|
||||
def create_dataset_imagenet(dataset_path, repeat_num=1, training=True,
|
||||
num_parallel_workers=None, shuffle=None):
|
||||
num_parallel_workers=None, shuffle=None, imagenet_cfg=None):
|
||||
"""
|
||||
create a train or eval imagenet2012 dataset for resnet50
|
||||
|
||||
|
@ -81,14 +79,15 @@ def create_dataset_imagenet(dataset_path, repeat_num=1, training=True,
|
|||
Returns:
|
||||
dataset
|
||||
"""
|
||||
|
||||
device_num, rank_id = _get_rank_info()
|
||||
|
||||
if device_num == 1:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle)
|
||||
if training:
|
||||
device_num, rank_id = _get_rank_info()
|
||||
if device_num == 1:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle)
|
||||
else:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
else:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle)
|
||||
|
||||
assert imagenet_cfg.image_height == imagenet_cfg.image_width, "image_height not equal image_width"
|
||||
image_size = imagenet_cfg.image_height
|
||||
|
|
|
@ -16,15 +16,15 @@
|
|||
#################train googlent example on cifar10########################
|
||||
python train.py
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init, get_rank
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
|
||||
|
@ -33,11 +33,14 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import cifar_cfg, imagenet_cfg
|
||||
from src.dataset import create_dataset_cifar10, create_dataset_imagenet
|
||||
from src.googlenet import GoogleNet
|
||||
from src.CrossEntropySmooth import CrossEntropySmooth
|
||||
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
set_seed(1)
|
||||
|
||||
def lr_steps_cifar10(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None):
|
||||
|
@ -86,53 +89,107 @@ def lr_steps_imagenet(_cfg, steps_per_epoch):
|
|||
return _lr
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Classification')
|
||||
parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'],
|
||||
help='dataset name.')
|
||||
parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if args_opt.dataset_name == "cifar10":
|
||||
cfg = cifar_cfg
|
||||
elif args_opt.dataset_name == "imagenet":
|
||||
cfg = imagenet_cfg
|
||||
else:
|
||||
raise ValueError("Unsupported dataset.")
|
||||
|
||||
# set context
|
||||
device_target = cfg.device_target
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
|
||||
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
||||
|
||||
rank = 0
|
||||
if device_target == "Ascend":
|
||||
if args_opt.device_id is not None:
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
def get_param_groups(network):
|
||||
""" get param groups """
|
||||
decay_params = []
|
||||
no_decay_params = []
|
||||
for x in network.trainable_params():
|
||||
parameter_name = x.name
|
||||
if parameter_name.endswith('.bias'):
|
||||
# all bias not using weight decay
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.gamma'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.beta'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
no_decay_params.append(x)
|
||||
else:
|
||||
context.set_context(device_id=cfg.device_id)
|
||||
decay_params.append(x)
|
||||
|
||||
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
|
||||
config.ckpt_save_dir = os.path.join(config.output_path, config.ckpt_save_dir)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_train():
|
||||
cfg = config
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
|
||||
device_num = get_device_num()
|
||||
|
||||
if cfg.device_target == "Ascend":
|
||||
context.set_context(device_id=get_device_id())
|
||||
if device_num > 1:
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
rank = get_rank()
|
||||
elif device_target == "GPU":
|
||||
elif cfg.device_target == "GPU":
|
||||
if device_num > 1:
|
||||
init()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
rank = get_rank()
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
if args_opt.dataset_name == "cifar10":
|
||||
dataset = create_dataset_cifar10(cfg.data_path, 1)
|
||||
elif args_opt.dataset_name == "imagenet":
|
||||
dataset = create_dataset_imagenet(cfg.data_path, 1)
|
||||
if cfg.dataset_name == "cifar10":
|
||||
dataset = create_dataset_cifar10(cfg.train_data_path, 1, cifar_cfg=cfg)
|
||||
elif cfg.dataset_name == "imagenet":
|
||||
dataset = create_dataset_imagenet(cfg.train_data_path, 1, imagenet_cfg=cfg)
|
||||
else:
|
||||
raise ValueError("Unsupported dataset.")
|
||||
|
||||
|
@ -145,7 +202,7 @@ if __name__ == '__main__':
|
|||
load_param_into_net(net, param_dict)
|
||||
|
||||
loss_scale_manager = None
|
||||
if args_opt.dataset_name == 'cifar10':
|
||||
if cfg.dataset_name == 'cifar10':
|
||||
lr = lr_steps_cifar10(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
|
||||
learning_rate=Tensor(lr),
|
||||
|
@ -153,31 +210,9 @@ if __name__ == '__main__':
|
|||
weight_decay=cfg.weight_decay)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
|
||||
elif args_opt.dataset_name == 'imagenet':
|
||||
elif cfg.dataset_name == 'imagenet':
|
||||
lr = lr_steps_imagenet(cfg, batch_num)
|
||||
|
||||
|
||||
def get_param_groups(network):
|
||||
""" get param groups """
|
||||
decay_params = []
|
||||
no_decay_params = []
|
||||
for x in network.trainable_params():
|
||||
parameter_name = x.name
|
||||
if parameter_name.endswith('.bias'):
|
||||
# all bias not using weight decay
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.gamma'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.beta'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
no_decay_params.append(x)
|
||||
else:
|
||||
decay_params.append(x)
|
||||
|
||||
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
|
||||
|
||||
|
||||
if cfg.is_dynamic_loss_scale:
|
||||
cfg.loss_scale = 1
|
||||
|
||||
|
@ -201,13 +236,17 @@ if __name__ == '__main__':
|
|||
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
time_cb = TimeMonitor(data_size=batch_num)
|
||||
ckpt_save_dir = "./ckpt_" + str(rank) + "/"
|
||||
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_" + args_opt.dataset_name, directory=ckpt_save_dir,
|
||||
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_" + cfg.dataset_name, directory=config.ckpt_save_dir,
|
||||
config=config_ck)
|
||||
loss_cb = LossMonitor()
|
||||
|
||||
cbs = [time_cb, ckpoint_cb, loss_cb]
|
||||
if device_num > 1 and rank != 0:
|
||||
device_id = get_device_id()
|
||||
if device_num > 1 and device_id != 0:
|
||||
cbs = [time_cb, loss_cb]
|
||||
model.train(cfg.epoch_size, dataset, callbacks=cbs)
|
||||
print("train success")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_train()
|
||||
|
|
Loading…
Reference in New Issue