Fix issue: #I3ZBQD. Port vgg19/resnetv2 to MA.
This commit is contained in:
parent
962c3f4ae3
commit
de43ed5420
|
@ -0,0 +1,136 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../cifar10_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
def get_config_static(config_path="../cifar10_config.yaml"):
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
if not config_path.startswith("/"):
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
config_path = os.path.join(current_dir, config_path)
|
||||
final_config, _, _ = parse_yaml(config_path)
|
||||
return Config(final_config)
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .moxing_adapter import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from .config import get_config
|
||||
|
||||
config = get_config()
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
# Run the main function
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -0,0 +1,105 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
need_modelarts_dataset_unzip: True
|
||||
modelarts_dataset_unzip_name: "cifar10"
|
||||
|
||||
# ==============================================================================
|
||||
# options
|
||||
num_classes: 10
|
||||
lr: 0.01
|
||||
lr_init: 0.01
|
||||
lr_max: 0.1
|
||||
lr_epochs: '30,60,90,120'
|
||||
lr_scheduler: "step"
|
||||
warmup_epochs: 5
|
||||
batch_size: 64
|
||||
max_epoch: 70
|
||||
momentum: 0.9
|
||||
weight_decay: 0.0005 # 5e-4
|
||||
loss_scale: 1.0
|
||||
label_smooth: 0
|
||||
label_smooth_factor: 0
|
||||
buffer_size: 10
|
||||
image_size: '224,224'
|
||||
pad_mode: 'same'
|
||||
padding: 0
|
||||
has_bias: False
|
||||
batch_norm: True
|
||||
keep_checkpoint_max: 10
|
||||
initialize_mode: "XavierUniform"
|
||||
has_dropout: False
|
||||
|
||||
# train options
|
||||
dataset: "cifar10"
|
||||
data_dir: ""
|
||||
pre_trained: ""
|
||||
lr_gamma: 0.1
|
||||
eta_min: 0.0
|
||||
T_max: 90
|
||||
log_interval: 100
|
||||
ckpt_path: "outputs/"
|
||||
ckpt_interval: 5
|
||||
is_save_on_master: 1
|
||||
is_distributed: 0
|
||||
|
||||
# eval options
|
||||
per_batch_size: 32
|
||||
graph_ckpt: 1
|
||||
log_path: "outputs/"
|
||||
|
||||
# postprocess options
|
||||
result_dir: ""
|
||||
label_dir: ""
|
||||
dataset_name: "cifar10"
|
||||
|
||||
# preprocess options
|
||||
result_path: "./preprocess_Result/"
|
||||
|
||||
# export options
|
||||
ckpt_file: ""
|
||||
file_name: "vgg19"
|
||||
file_format: "AIR"
|
||||
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
device_target: "device where the code will be implemented."
|
||||
dataset: "choices in ['cifar10', 'imagenet2012']"
|
||||
data_dir: "data dir"
|
||||
pre_trained: "model_path, local pretrained model to load"
|
||||
lr_gamma: "decrease lr by a factor of exponential lr_scheduler"
|
||||
eta_min: "eta_min in cosine_annealing scheduler"
|
||||
T_max: "T-max in cosine_annealing scheduler"
|
||||
log_interval: "logging interval"
|
||||
ckpt_path: "checkpoint save location"
|
||||
ckpt_interval: "ckpt_interval"
|
||||
is_save_on_master: "save ckpt on master or all rank"
|
||||
is_distributed: "if multi device"
|
||||
|
||||
# eval options
|
||||
per_batch_size: "batch size for per npu"
|
||||
graph_ckpt: "graph ckpt or feed ckpt"
|
||||
log_path: "path to save log"
|
||||
|
||||
# postprocess options
|
||||
result_dir: "result files path."
|
||||
label_dir: "image file path."
|
||||
dataset_name: "choices in ['cifar10', 'imagenet2012']"
|
||||
|
||||
# preprocess options
|
||||
result_path: "result path"
|
||||
|
||||
# export options
|
||||
ckpt_file: "vgg19 ckpt file."
|
||||
file_name: "vgg19 output file name."
|
||||
file_format: "file format, choices in ['AIR', 'ONNX', 'MINDIR']"
|
|
@ -0,0 +1,104 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
need_modelarts_dataset_unzip: True
|
||||
modelarts_dataset_unzip_name: "ImageNet"
|
||||
|
||||
# ==============================================================================
|
||||
# options
|
||||
num_classes: 1000
|
||||
lr: 0.04
|
||||
lr_init: 0.01
|
||||
lr_max: 0.1
|
||||
lr_epochs: '30,60,90,120'
|
||||
lr_scheduler: 'cosine_annealing'
|
||||
warmup_epochs: 0
|
||||
batch_size: 64
|
||||
max_epoch: 90
|
||||
momentum: 0.9
|
||||
weight_decay: 0.0001 # 1e-4
|
||||
loss_scale: 1024
|
||||
label_smooth: 1
|
||||
label_smooth_factor: 0.1
|
||||
buffer_size: 10
|
||||
image_size: '224,224'
|
||||
pad_mode: 'pad'
|
||||
padding: 1
|
||||
has_bias: False
|
||||
batch_norm: False
|
||||
keep_checkpoint_max: 10
|
||||
initialize_mode: "KaimingNormal"
|
||||
has_dropout: True
|
||||
|
||||
# train option
|
||||
dataset: "imagenet2012"
|
||||
data_dir: ""
|
||||
pre_trained: ""
|
||||
lr_gamma: 0.1
|
||||
eta_min: 0.0
|
||||
T_max: 90
|
||||
log_interval: 100
|
||||
ckpt_path: "outputs/"
|
||||
ckpt_interval: 5
|
||||
is_save_on_master: 1
|
||||
is_distributed: 0
|
||||
|
||||
# eval options
|
||||
per_batch_size: 32
|
||||
graph_ckpt: 1
|
||||
log_path: "outputs/"
|
||||
|
||||
# postprocess options
|
||||
result_dir: ""
|
||||
label_dir: ""
|
||||
dataset_name: "imagenet2012"
|
||||
|
||||
# preprocess options
|
||||
result_path: "./preprocess_Result/"
|
||||
|
||||
# export options
|
||||
ckpt_file: ""
|
||||
file_name: "vgg19"
|
||||
file_format: "AIR"
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
device_target: "device where the code will be implemented."
|
||||
dataset: "choices in ['cifar10', 'imagenet2012']"
|
||||
data_dir: "data dir"
|
||||
pre_trained: "model_path, local pretrained model to load"
|
||||
lr_gamma: "decrease lr by a factor of exponential lr_scheduler"
|
||||
eta_min: "eta_min in cosine_annealing scheduler"
|
||||
T_max: "T-max in cosine_annealing scheduler"
|
||||
log_interval: "logging interval"
|
||||
ckpt_path: "checkpoint save location"
|
||||
ckpt_interval: "ckpt_interval"
|
||||
is_save_on_master: "save ckpt on master or all rank"
|
||||
is_distributed: "if multi device"
|
||||
|
||||
# eval options
|
||||
per_batch_size: "batch size for per npu"
|
||||
graph_ckpt: "graph ckpt or feed ckpt"
|
||||
log_path: "path to save log"
|
||||
|
||||
# postprocess options
|
||||
result_dir: "result files path."
|
||||
label_dir: "image file path."
|
||||
dataset_name: "choices in ['cifar10', 'imagenet2012']"
|
||||
|
||||
# preprocess options
|
||||
result_path: "result path"
|
||||
|
||||
# export options
|
||||
ckpt_file: "vgg19 ckpt file."
|
||||
file_name: "vgg19 output file name."
|
||||
file_format: "file format, choices in ['AIR', 'ONNX', 'MINDIR']"
|
|
@ -0,0 +1,136 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../cifar10_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
def get_config_static(config_path="../cifar10_config.yaml"):
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
if not config_path.startswith("/"):
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
config_path = os.path.join(current_dir, config_path)
|
||||
final_config, _, _ = parse_yaml(config_path)
|
||||
return Config(final_config)
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .moxing_adapter import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from .config import get_config
|
||||
|
||||
config = get_config()
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
# Run the main function
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""postprocess for 310 inference"""
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
from mindspore.nn import Top1CategoricalAccuracy, Top5CategoricalAccuracy
|
||||
|
||||
from model_utils.moxing_adapter import config
|
||||
|
||||
if __name__ == '__main__':
|
||||
top1_acc = Top1CategoricalAccuracy()
|
||||
rst_path = config.result_dir
|
||||
if config.dataset_name == "cifar10":
|
||||
labels = np.load(config.label_dir, allow_pickle=True)
|
||||
for idx, label in enumerate(labels):
|
||||
f_name = os.path.join(rst_path, "VGG19_data_bs" + str(config.batch_size) + "_" + str(idx) + "_0.bin")
|
||||
pred = np.fromfile(f_name, np.float32)
|
||||
pred = pred.reshape(config.batch_size, int(pred.shape[0] / config.batch_size))
|
||||
top1_acc.update(pred, labels[idx])
|
||||
print("acc: ", top1_acc.eval())
|
||||
else:
|
||||
top5_acc = Top5CategoricalAccuracy()
|
||||
file_list = os.listdir(rst_path)
|
||||
with open(config.label_dir, "r") as label:
|
||||
labels = json.load(label)
|
||||
for f in file_list:
|
||||
label = f.split("_0.bin")[0] + ".JPEG"
|
||||
pred = np.fromfile(os.path.join(rst_path, f), np.float32)
|
||||
pred = pred.reshape(config.batch_size, int(pred.shape[0] / config.batch_size))
|
||||
top1_acc.update(pred, [labels[label],])
|
||||
top5_acc.update(pred, [labels[label],])
|
||||
print("Top1 acc: ", top1_acc.eval())
|
||||
print("Top5 acc: ", top5_acc.eval())
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""preprocess"""
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
from src.dataset import vgg_create_dataset
|
||||
|
||||
from model_utils.moxing_adapter import config
|
||||
|
||||
|
||||
def create_label(result_path, dir_path):
|
||||
print("[WARNING] Create imagenet label. Currently only use for Imagenet2012!")
|
||||
dirs = os.listdir(dir_path)
|
||||
file_list = []
|
||||
for file in dirs:
|
||||
file_list.append(file)
|
||||
file_list = sorted(file_list)
|
||||
|
||||
total = 0
|
||||
img_label = {}
|
||||
for i, file_dir in enumerate(file_list):
|
||||
files = os.listdir(os.path.join(dir_path, file_dir))
|
||||
for f in files:
|
||||
img_label[f] = i
|
||||
total += len(files)
|
||||
|
||||
json_file = os.path.join(result_path, "imagenet_label.json")
|
||||
with open(json_file, "w+") as label:
|
||||
json.dump(img_label, label)
|
||||
|
||||
print("[INFO] Completed! Total {} data.".format(total))
|
||||
|
||||
config.per_batch_size = config.batch_size
|
||||
config.image_size = list(map(int, config.image_size.split(',')))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if config.dataset == "cifar10":
|
||||
dataset = vgg_create_dataset(config.data_dir, config.image_size, config.per_batch_size, training=False)
|
||||
img_path = os.path.join(config.result_path, "00_data")
|
||||
os.makedirs(img_path)
|
||||
label_list = []
|
||||
for idx, data in enumerate(dataset.create_dict_iterator(output_numpy=True)):
|
||||
file_name = "VGG19_data_bs" + str(config.per_batch_size) + "_" + str(idx) + ".bin"
|
||||
file_path = os.path.join(img_path, file_name)
|
||||
data["image"].tofile(file_path)
|
||||
label_list.append(data["label"])
|
||||
np.save(os.path.join(config.result_path, "cifar10_label_ids.npy"), label_list)
|
||||
print("=" * 20, "export bin files finished", "=" * 20)
|
||||
else:
|
||||
create_label(config.result_path, config.data_dir)
|
Loading…
Reference in New Issue