Fix issue: #I3ZBQD. Port vgg19/resnetv2 to MA.

This commit is contained in:
sora-city 2021-08-15 10:28:34 +08:00
parent 962c3f4ae3
commit de43ed5420
14 changed files with 953 additions and 0 deletions

View File

@ -0,0 +1,136 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../cifar10_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
# --------------------------------------------------------------------------------------------------------------------
def get_config_static(config_path="../cifar10_config.yaml"):
"""
Get Config according to the yaml file and cli arguments.
"""
if not config_path.startswith("/"):
current_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(current_dir, config_path)
final_config, _, _ = parse_yaml(config_path)
return Config(final_config)

View File

@ -0,0 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from .moxing_adapter import config
if config.enable_modelarts:
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -0,0 +1,118 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from .config import get_config
config = get_config()
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
# Run the main function
run_func(*args, **kwargs)
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -0,0 +1,105 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: "Ascend"
need_modelarts_dataset_unzip: True
modelarts_dataset_unzip_name: "cifar10"
# ==============================================================================
# options
num_classes: 10
lr: 0.01
lr_init: 0.01
lr_max: 0.1
lr_epochs: '30,60,90,120'
lr_scheduler: "step"
warmup_epochs: 5
batch_size: 64
max_epoch: 70
momentum: 0.9
weight_decay: 0.0005 # 5e-4
loss_scale: 1.0
label_smooth: 0
label_smooth_factor: 0
buffer_size: 10
image_size: '224,224'
pad_mode: 'same'
padding: 0
has_bias: False
batch_norm: True
keep_checkpoint_max: 10
initialize_mode: "XavierUniform"
has_dropout: False
# train options
dataset: "cifar10"
data_dir: ""
pre_trained: ""
lr_gamma: 0.1
eta_min: 0.0
T_max: 90
log_interval: 100
ckpt_path: "outputs/"
ckpt_interval: 5
is_save_on_master: 1
is_distributed: 0
# eval options
per_batch_size: 32
graph_ckpt: 1
log_path: "outputs/"
# postprocess options
result_dir: ""
label_dir: ""
dataset_name: "cifar10"
# preprocess options
result_path: "./preprocess_Result/"
# export options
ckpt_file: ""
file_name: "vgg19"
file_format: "AIR"
---
# Help description for each configuration
device_target: "device where the code will be implemented."
dataset: "choices in ['cifar10', 'imagenet2012']"
data_dir: "data dir"
pre_trained: "model_path, local pretrained model to load"
lr_gamma: "decrease lr by a factor of exponential lr_scheduler"
eta_min: "eta_min in cosine_annealing scheduler"
T_max: "T-max in cosine_annealing scheduler"
log_interval: "logging interval"
ckpt_path: "checkpoint save location"
ckpt_interval: "ckpt_interval"
is_save_on_master: "save ckpt on master or all rank"
is_distributed: "if multi device"
# eval options
per_batch_size: "batch size for per npu"
graph_ckpt: "graph ckpt or feed ckpt"
log_path: "path to save log"
# postprocess options
result_dir: "result files path."
label_dir: "image file path."
dataset_name: "choices in ['cifar10', 'imagenet2012']"
# preprocess options
result_path: "result path"
# export options
ckpt_file: "vgg19 ckpt file."
file_name: "vgg19 output file name."
file_format: "file format, choices in ['AIR', 'ONNX', 'MINDIR']"

View File

@ -0,0 +1,104 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: "Ascend"
need_modelarts_dataset_unzip: True
modelarts_dataset_unzip_name: "ImageNet"
# ==============================================================================
# options
num_classes: 1000
lr: 0.04
lr_init: 0.01
lr_max: 0.1
lr_epochs: '30,60,90,120'
lr_scheduler: 'cosine_annealing'
warmup_epochs: 0
batch_size: 64
max_epoch: 90
momentum: 0.9
weight_decay: 0.0001 # 1e-4
loss_scale: 1024
label_smooth: 1
label_smooth_factor: 0.1
buffer_size: 10
image_size: '224,224'
pad_mode: 'pad'
padding: 1
has_bias: False
batch_norm: False
keep_checkpoint_max: 10
initialize_mode: "KaimingNormal"
has_dropout: True
# train option
dataset: "imagenet2012"
data_dir: ""
pre_trained: ""
lr_gamma: 0.1
eta_min: 0.0
T_max: 90
log_interval: 100
ckpt_path: "outputs/"
ckpt_interval: 5
is_save_on_master: 1
is_distributed: 0
# eval options
per_batch_size: 32
graph_ckpt: 1
log_path: "outputs/"
# postprocess options
result_dir: ""
label_dir: ""
dataset_name: "imagenet2012"
# preprocess options
result_path: "./preprocess_Result/"
# export options
ckpt_file: ""
file_name: "vgg19"
file_format: "AIR"
---
# Help description for each configuration
device_target: "device where the code will be implemented."
dataset: "choices in ['cifar10', 'imagenet2012']"
data_dir: "data dir"
pre_trained: "model_path, local pretrained model to load"
lr_gamma: "decrease lr by a factor of exponential lr_scheduler"
eta_min: "eta_min in cosine_annealing scheduler"
T_max: "T-max in cosine_annealing scheduler"
log_interval: "logging interval"
ckpt_path: "checkpoint save location"
ckpt_interval: "ckpt_interval"
is_save_on_master: "save ckpt on master or all rank"
is_distributed: "if multi device"
# eval options
per_batch_size: "batch size for per npu"
graph_ckpt: "graph ckpt or feed ckpt"
log_path: "path to save log"
# postprocess options
result_dir: "result files path."
label_dir: "image file path."
dataset_name: "choices in ['cifar10', 'imagenet2012']"
# preprocess options
result_path: "result path"
# export options
ckpt_file: "vgg19 ckpt file."
file_name: "vgg19 output file name."
file_format: "file format, choices in ['AIR', 'ONNX', 'MINDIR']"

View File

@ -0,0 +1,136 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../cifar10_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
# --------------------------------------------------------------------------------------------------------------------
def get_config_static(config_path="../cifar10_config.yaml"):
"""
Get Config according to the yaml file and cli arguments.
"""
if not config_path.startswith("/"):
current_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(current_dir, config_path)
final_config, _, _ = parse_yaml(config_path)
return Config(final_config)

View File

@ -0,0 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from .moxing_adapter import config
if config.enable_modelarts:
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -0,0 +1,118 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from .config import get_config
config = get_config()
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
# Run the main function
run_func(*args, **kwargs)
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""postprocess for 310 inference"""
import os
import json
import numpy as np
from mindspore.nn import Top1CategoricalAccuracy, Top5CategoricalAccuracy
from model_utils.moxing_adapter import config
if __name__ == '__main__':
top1_acc = Top1CategoricalAccuracy()
rst_path = config.result_dir
if config.dataset_name == "cifar10":
labels = np.load(config.label_dir, allow_pickle=True)
for idx, label in enumerate(labels):
f_name = os.path.join(rst_path, "VGG19_data_bs" + str(config.batch_size) + "_" + str(idx) + "_0.bin")
pred = np.fromfile(f_name, np.float32)
pred = pred.reshape(config.batch_size, int(pred.shape[0] / config.batch_size))
top1_acc.update(pred, labels[idx])
print("acc: ", top1_acc.eval())
else:
top5_acc = Top5CategoricalAccuracy()
file_list = os.listdir(rst_path)
with open(config.label_dir, "r") as label:
labels = json.load(label)
for f in file_list:
label = f.split("_0.bin")[0] + ".JPEG"
pred = np.fromfile(os.path.join(rst_path, f), np.float32)
pred = pred.reshape(config.batch_size, int(pred.shape[0] / config.batch_size))
top1_acc.update(pred, [labels[label],])
top5_acc.update(pred, [labels[label],])
print("Top1 acc: ", top1_acc.eval())
print("Top5 acc: ", top5_acc.eval())

View File

@ -0,0 +1,64 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""preprocess"""
import os
import json
import numpy as np
from src.dataset import vgg_create_dataset
from model_utils.moxing_adapter import config
def create_label(result_path, dir_path):
print("[WARNING] Create imagenet label. Currently only use for Imagenet2012!")
dirs = os.listdir(dir_path)
file_list = []
for file in dirs:
file_list.append(file)
file_list = sorted(file_list)
total = 0
img_label = {}
for i, file_dir in enumerate(file_list):
files = os.listdir(os.path.join(dir_path, file_dir))
for f in files:
img_label[f] = i
total += len(files)
json_file = os.path.join(result_path, "imagenet_label.json")
with open(json_file, "w+") as label:
json.dump(img_label, label)
print("[INFO] Completed! Total {} data.".format(total))
config.per_batch_size = config.batch_size
config.image_size = list(map(int, config.image_size.split(',')))
if __name__ == "__main__":
if config.dataset == "cifar10":
dataset = vgg_create_dataset(config.data_dir, config.image_size, config.per_batch_size, training=False)
img_path = os.path.join(config.result_path, "00_data")
os.makedirs(img_path)
label_list = []
for idx, data in enumerate(dataset.create_dict_iterator(output_numpy=True)):
file_name = "VGG19_data_bs" + str(config.per_batch_size) + "_" + str(idx) + ".bin"
file_path = os.path.join(img_path, file_name)
data["image"].tofile(file_path)
label_list.append(data["label"])
np.save(os.path.join(config.result_path, "cifar10_label_ids.npy"), label_list)
print("=" * 20, "export bin files finished", "=" * 20)
else:
create_label(config.result_path, config.data_dir)