!15367 Add model scaffolding.

From: @c_34
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-04-23 16:13:05 +08:00 committed by Gitee
commit 1c44e367e0
16 changed files with 925 additions and 0 deletions

View File

@ -0,0 +1,19 @@
# Model Scaffolding
## Introduction
This is a scaffolding framework for model development.
### Framework
TBD
### Desserts
TBD
## Usage
TBD

View File

@ -0,0 +1,48 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: Ascend
enable_profiling: False
# ==============================================================================
# Training options
epoch_size: 30
save_checkpoint: True
keep_checkpoint_max: 10
save_checkpoint_epochs: 5
save_checkpoint_steps: -1
ckpt_path: ""
lr: 0.1
momentum: 0.9
batch_size: 32
buffer_size: 1000
dataset_name: imagenet
# Model Description
model_name: lenet_running_8p
image_height: 32
image_width: 32
num_classes: 10
---
# Config description for each option
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
device_target: 'Target device type'
enable_profiling: 'Whether enable profiling while training, default: False'
---
device_target: ['Ascend', 'GPU', 'CPU']

View File

@ -0,0 +1,48 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: Ascend
enable_profiling: False
# ==============================================================================
# Training options
epoch_size: 10
save_checkpoint: True
keep_checkpoint_max: 10
save_checkpoint_epochs: 2
save_checkpoint_steps: 1875
ckpt_path: ""
lr: 0.01
momentum: 0.9
batch_size: 32
buffer_size: 1000
dataset_name: mnist
# Model Description
model_name: lenet
image_height: 32
image_width: 32
num_classes: 10
---
# Config description for each option
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
device_target: 'Target device type'
enable_profiling: 'Whether enable profiling while training, default: False'
---
device_target: ['Ascend', 'GPU', 'CPU']

View File

@ -0,0 +1,52 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation process"""
import os
from mindspore import nn
from mindspore import context
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.train.serialization import load_checkpoint
from src.moxing_adapter import moxing_wrapper
from src.config import config
from src.dataset import create_lenet_dataset
from src.foo import LeNet5
@moxing_wrapper()
def eval_lenet5():
"""Evaluation of lenet5"""
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
network = LeNet5(config.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), config.lr, config.momentum)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Testing ==============")
load_checkpoint(config.ckpt_path, network)
ds_eval = create_lenet_dataset(os.path.join(config.data_path, "test"), config.batch_size, 1)
if ds_eval.get_dataset_size() == 0:
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
acc = model.eval(ds_eval)
print("============== {} ==============".format(acc))
if __name__ == '__main__':
eval_lenet5()

View File

@ -0,0 +1,188 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""distribute running script"""
import os
import json
import multiprocessing
from argparse import ArgumentParser
def parse_args():
"""
parse args .
Args:
Returns:
args.
Examples:
>>> parse_args()
"""
parser = ArgumentParser(description="Distributed training scripts generator for MindSpore")
parser.add_argument("--run_script_path", type=str, default="",
help="Run script path, it is better to use absolute path")
parser.add_argument("--args", type=str, default="",
help="Other arguments which will be passed to main program directly")
parser.add_argument("--hccl_config_dir", type=str, default="",
help="Hccl config path, it is better to use absolute path")
parser.add_argument("--cmd_file", type=str, default="distributed_cmd.sh",
help="Path of the generated cmd file.")
parser.add_argument("--hccl_time_out", type=int, default=120,
help="Seconds to determine the hccl time out,"
"default: 120, which is the same as hccl default config")
parser.add_argument("--cpu_bind", action="store_true", default=False,
help="Bind cpu cores or not")
args = parser.parse_args()
return args
def append_cmd(cmd, s):
cmd += s
cmd += "\n"
return cmd
def append_cmd_env(cmd, key, value):
return append_cmd(cmd, "export " + str(key) + "=" + str(value))
def set_envs(cmd, logic_id, rank_id):
"""
Set environment variables.
"""
cmd = append_cmd_env(cmd, "DEVICE_ID", str(logic_id))
cmd = append_cmd_env(cmd, "RANK_ID", str(rank_id))
return cmd
def make_dirs(cmd, logic_id):
"""
Make directories and change path.
"""
cmd = append_cmd(cmd, "rm -rf LOG" + str(logic_id))
cmd = append_cmd(cmd, "mkdir ./LOG" + str(logic_id))
cmd = append_cmd(cmd, "mkdir -p ./LOG" + str(logic_id) + "/ms_log")
cmd = append_cmd(cmd, "env > ./LOG" + str(logic_id) + "/env.log")
cur_dir = os.getcwd()
cmd = append_cmd_env(cmd, "GLOG_log_dir", cur_dir + "/LOG" + str(logic_id) + "/ms_log")
cmd = append_cmd_env(cmd, "GLOG_logtostderr", "0")
cmd = append_cmd(cmd, "cd " + cur_dir + "/LOG" + str(logic_id))
return cmd
def print_info(rank_id, device_id, logic_id, cmdopt, cur_dir):
"""
Print some information about scripts.
"""
print("\nstart training for rank " + str(rank_id) + ", device " + str(device_id) + ":")
print("rank_id:", rank_id)
print("device_id:", device_id)
print("logic_id", logic_id)
print("core_nums:", cmdopt)
print("log_file_dir: " + cur_dir + "/LOG" + str(logic_id) + "/pretraining_log.txt")
def distribute_run():
"""
distribute pretrain scripts. The number of Ascend accelerators can be automatically allocated
based on the device_num set in hccl config file, You don not need to specify that.
"""
cmd = ""
print("start", __file__)
args = parse_args()
run_script = args.run_script_path
print("hccl_config_dir:", args.hccl_config_dir)
print("hccl_time_out:", args.hccl_time_out)
cmd = append_cmd_env(cmd, 'HCCL_CONNECT_TIMEOUT', args.hccl_time_out)
cmd = append_cmd_env(cmd, 'RANK_TABLE_FILE', args.hccl_config_dir)
cores = multiprocessing.cpu_count()
print("the number of logical core:", cores)
# get device_ips
device_ips = {}
physic_logic_ids = {}
with open('/etc/hccn.conf', 'r') as fin:
for hccn_item in fin.readlines():
if hccn_item.strip().startswith('address_'):
device_id, device_ip = hccn_item.split('=')
device_id = device_id.split('_')[1]
device_ips[device_id] = device_ip.strip()
if not device_ips:
raise ValueError("There is no address in /etc/hccn.conf")
for logic_id, device_id in enumerate(sorted(device_ips.keys())):
physic_logic_ids[device_id] = logic_id
with open(args.hccl_config_dir, "r", encoding="utf-8") as fin:
hccl_config = json.loads(fin.read())
rank_size = 0
for server in hccl_config["server_list"]:
rank_size += len(server["device"])
if server["device"][0]["device_ip"] in device_ips.values():
this_server = server
cmd = append_cmd_env(cmd, "RANK_SIZE", str(rank_size))
print("total rank size:", rank_size)
print("this server rank size:", len(this_server["device"]))
avg_core_per_rank = int(int(cores) / len(this_server["device"]))
core_gap = avg_core_per_rank - 1
print("avg_core_per_rank:", avg_core_per_rank)
count = 0
for instance in this_server["device"]:
# device_id is the physical id, we use logic id to specific the selected device.
# While running on a server with 8 pcs, the logic ids are equal to the device ids.
device_id = instance["device_id"]
rank_id = instance["rank_id"]
logic_id = physic_logic_ids[device_id]
start = count * int(avg_core_per_rank)
count += 1
end = start + core_gap
cmdopt = str(start) + "-" + str(end)
cur_dir = os.getcwd()
cmd = set_envs(cmd, logic_id, rank_id)
cmd = make_dirs(cmd, logic_id)
print_info(rank_id=rank_id, device_id=device_id, logic_id=logic_id, cmdopt=cmdopt, cur_dir=cur_dir)
if args.cpu_bind:
run_cmd = 'taskset -c ' + cmdopt + ' '
else:
run_cmd = ""
run_cmd += 'nohup python ' + run_script + " "
run_cmd += " " + ' '.join([str(x) for x in args.args.split(' ')[1:]])
run_cmd += ' >./log.txt 2>&1 &'
cmd = append_cmd(cmd, run_cmd)
cmd = append_cmd(cmd, "cd -")
cmd = append_cmd(cmd, "echo \"run with" +
" rank_id=" + str(rank_id) +
" device_id=" + str(device_id) +
" logic_id=" + str(logic_id) + "\"")
cmd += "\n"
with open(args.cmd_file, "w") as f:
f.write(cmd)
if __name__ == "__main__":
distribute_run()

View File

@ -0,0 +1,35 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
rm -rf EVAL
mkdir EVAL
cd EVAL
if [ $# -ge 1 ]; then
if [ $1 == 'imagenet' ]; then
CONFIG_FILE="${BASE_PATH}/../config_imagenet.yaml"
elif [ $1 == 'default' ]; then
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
else
echo "Unrecognized parameter"
exit 1
fi
else
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
fi
python ${BASE_PATH}/../eval.py --config_path=$CONFIG_FILE

View File

@ -0,0 +1,39 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 1 ]; then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_local_train.sh RANK_TABLE_FILE [OTHER_ARGS]"
echo "OTHER_ARGS will be passed to the training scripts directly,"
echo "for example: bash run_local_train.sh /path/hccl.json --data_dir=/path/data_dir --epochs=40"
echo "It is better to use absolute path."
echo "=============================================================================================================="
exit 1
fi
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
OTHER_ARGS=$*
echo ${OTHER_ARGS[0]}
python3 ${BASE_PATH}/ascend_distributed_launcher/get_distribute_pretrain_cmd.py \
--run_script_path=${BASE_PATH}/../train.py \
--hccl_config_dir=$1 \
--hccl_time_out=600 \
--args="$*" \
--cmd_file=distributed_cmd.sh
bash distributed_cmd.sh

View File

@ -0,0 +1,127 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
pprint(default)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()

View File

@ -0,0 +1,57 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset"""
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore.common import dtype as mstype
def create_lenet_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1):
"""
create dataset for train or test
"""
# define dataset
mnist_ds = ds.MnistDataset(data_path, num_parallel_workers=num_parallel_workers)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
# define map operations
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# apply map operations on images
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
# apply DatasetOps
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds

View File

@ -0,0 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from .config import config
if config.enable_modelarts:
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -0,0 +1,64 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Model Structure"""
import mindspore.nn as nn
from mindspore.common.initializer import Normal
class LeNet5(nn.Cell):
"""
Lenet network
Args:
num_class (int): Number of classes. Default: 10.
num_channel (int): Number of channels. Default: 1.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
"""
def __init__(self, num_class=10, num_channel=1, include_top=True):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.include_top = include_top
if self.include_top:
self.flatten = nn.Flatten()
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
def construct(self, x):
"""construct lenet5"""
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
if not self.include_top:
return x
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -0,0 +1,122 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from mindspore.profiler import Profiler
from .config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
if config.enable_profiling:
profiler = Profiler()
run_func(*args, **kwargs)
if config.enable_profiling:
profiler.analyse()
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -0,0 +1,63 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Training process"""
import os
from mindspore import nn
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from src.moxing_adapter import moxing_wrapper
from src.config import config
from src.dataset import create_lenet_dataset
from src.foo import LeNet5
@moxing_wrapper()
def train_lenet5():
"""
Train lenet5
"""
config.ckpt_path = config.output_path
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
ds_train = create_lenet_dataset(os.path.join(config.data_path, "train"), config.batch_size, num_parallel_workers=1)
if ds_train.get_dataset_size() == 0:
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
network = LeNet5(config.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), config.lr, config.momentum)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
directory=None if config.ckpt_path == "" else config.ckpt_path,
config=config_ck)
if config.device_target != "Ascend":
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
else:
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O2")
print("============== Starting Training ==============")
model.train(config.epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
if __name__ == '__main__':
train_lenet5()