commit
1c44e367e0
|
@ -0,0 +1,19 @@
|
|||
# Model Scaffolding
|
||||
|
||||
## Introduction
|
||||
|
||||
This is a scaffolding framework for model development.
|
||||
|
||||
### Framework
|
||||
|
||||
TBD
|
||||
|
||||
### Desserts
|
||||
|
||||
TBD
|
||||
|
||||
## Usage
|
||||
|
||||
TBD
|
||||
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: Ascend
|
||||
enable_profiling: False
|
||||
|
||||
# ==============================================================================
|
||||
# Training options
|
||||
epoch_size: 30
|
||||
save_checkpoint: True
|
||||
keep_checkpoint_max: 10
|
||||
save_checkpoint_epochs: 5
|
||||
save_checkpoint_steps: -1
|
||||
ckpt_path: ""
|
||||
lr: 0.1
|
||||
momentum: 0.9
|
||||
batch_size: 32
|
||||
buffer_size: 1000
|
||||
dataset_name: imagenet
|
||||
|
||||
# Model Description
|
||||
model_name: lenet_running_8p
|
||||
image_height: 32
|
||||
image_width: 32
|
||||
num_classes: 10
|
||||
|
||||
|
||||
---
|
||||
# Config description for each option
|
||||
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||
data_url: 'Dataset url for obs'
|
||||
train_url: 'Training output url for obs'
|
||||
data_path: 'Dataset path for local'
|
||||
output_path: 'Training output path for local'
|
||||
|
||||
device_target: 'Target device type'
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
|
||||
|
||||
---
|
||||
device_target: ['Ascend', 'GPU', 'CPU']
|
|
@ -0,0 +1,48 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: Ascend
|
||||
enable_profiling: False
|
||||
|
||||
# ==============================================================================
|
||||
# Training options
|
||||
epoch_size: 10
|
||||
save_checkpoint: True
|
||||
keep_checkpoint_max: 10
|
||||
save_checkpoint_epochs: 2
|
||||
save_checkpoint_steps: 1875
|
||||
ckpt_path: ""
|
||||
lr: 0.01
|
||||
momentum: 0.9
|
||||
batch_size: 32
|
||||
buffer_size: 1000
|
||||
dataset_name: mnist
|
||||
|
||||
# Model Description
|
||||
model_name: lenet
|
||||
image_height: 32
|
||||
image_width: 32
|
||||
num_classes: 10
|
||||
|
||||
|
||||
---
|
||||
# Config description for each option
|
||||
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||
data_url: 'Dataset url for obs'
|
||||
train_url: 'Training output url for obs'
|
||||
data_path: 'Dataset path for local'
|
||||
output_path: 'Training output path for local'
|
||||
|
||||
device_target: 'Target device type'
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
|
||||
|
||||
---
|
||||
device_target: ['Ascend', 'GPU', 'CPU']
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Evaluation process"""
|
||||
|
||||
import os
|
||||
|
||||
from mindspore import nn
|
||||
from mindspore import context
|
||||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
|
||||
from src.moxing_adapter import moxing_wrapper
|
||||
from src.config import config
|
||||
from src.dataset import create_lenet_dataset
|
||||
from src.foo import LeNet5
|
||||
|
||||
|
||||
@moxing_wrapper()
|
||||
def eval_lenet5():
|
||||
"""Evaluation of lenet5"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
|
||||
network = LeNet5(config.num_classes)
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
net_opt = nn.Momentum(network.trainable_params(), config.lr, config.momentum)
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
print("============== Starting Testing ==============")
|
||||
load_checkpoint(config.ckpt_path, network)
|
||||
ds_eval = create_lenet_dataset(os.path.join(config.data_path, "test"), config.batch_size, 1)
|
||||
if ds_eval.get_dataset_size() == 0:
|
||||
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
|
||||
|
||||
acc = model.eval(ds_eval)
|
||||
print("============== {} ==============".format(acc))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
eval_lenet5()
|
|
@ -0,0 +1,188 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""distribute running script"""
|
||||
import os
|
||||
import json
|
||||
import multiprocessing
|
||||
from argparse import ArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
parse args .
|
||||
|
||||
Args:
|
||||
|
||||
Returns:
|
||||
args.
|
||||
|
||||
Examples:
|
||||
>>> parse_args()
|
||||
"""
|
||||
parser = ArgumentParser(description="Distributed training scripts generator for MindSpore")
|
||||
|
||||
parser.add_argument("--run_script_path", type=str, default="",
|
||||
help="Run script path, it is better to use absolute path")
|
||||
parser.add_argument("--args", type=str, default="",
|
||||
help="Other arguments which will be passed to main program directly")
|
||||
parser.add_argument("--hccl_config_dir", type=str, default="",
|
||||
help="Hccl config path, it is better to use absolute path")
|
||||
parser.add_argument("--cmd_file", type=str, default="distributed_cmd.sh",
|
||||
help="Path of the generated cmd file.")
|
||||
parser.add_argument("--hccl_time_out", type=int, default=120,
|
||||
help="Seconds to determine the hccl time out,"
|
||||
"default: 120, which is the same as hccl default config")
|
||||
parser.add_argument("--cpu_bind", action="store_true", default=False,
|
||||
help="Bind cpu cores or not")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def append_cmd(cmd, s):
|
||||
cmd += s
|
||||
cmd += "\n"
|
||||
return cmd
|
||||
|
||||
|
||||
def append_cmd_env(cmd, key, value):
|
||||
return append_cmd(cmd, "export " + str(key) + "=" + str(value))
|
||||
|
||||
|
||||
def set_envs(cmd, logic_id, rank_id):
|
||||
"""
|
||||
Set environment variables.
|
||||
"""
|
||||
cmd = append_cmd_env(cmd, "DEVICE_ID", str(logic_id))
|
||||
cmd = append_cmd_env(cmd, "RANK_ID", str(rank_id))
|
||||
return cmd
|
||||
|
||||
|
||||
def make_dirs(cmd, logic_id):
|
||||
"""
|
||||
Make directories and change path.
|
||||
"""
|
||||
cmd = append_cmd(cmd, "rm -rf LOG" + str(logic_id))
|
||||
cmd = append_cmd(cmd, "mkdir ./LOG" + str(logic_id))
|
||||
cmd = append_cmd(cmd, "mkdir -p ./LOG" + str(logic_id) + "/ms_log")
|
||||
cmd = append_cmd(cmd, "env > ./LOG" + str(logic_id) + "/env.log")
|
||||
cur_dir = os.getcwd()
|
||||
cmd = append_cmd_env(cmd, "GLOG_log_dir", cur_dir + "/LOG" + str(logic_id) + "/ms_log")
|
||||
cmd = append_cmd_env(cmd, "GLOG_logtostderr", "0")
|
||||
cmd = append_cmd(cmd, "cd " + cur_dir + "/LOG" + str(logic_id))
|
||||
return cmd
|
||||
|
||||
|
||||
def print_info(rank_id, device_id, logic_id, cmdopt, cur_dir):
|
||||
"""
|
||||
Print some information about scripts.
|
||||
"""
|
||||
print("\nstart training for rank " + str(rank_id) + ", device " + str(device_id) + ":")
|
||||
print("rank_id:", rank_id)
|
||||
print("device_id:", device_id)
|
||||
print("logic_id", logic_id)
|
||||
print("core_nums:", cmdopt)
|
||||
print("log_file_dir: " + cur_dir + "/LOG" + str(logic_id) + "/pretraining_log.txt")
|
||||
|
||||
def distribute_run():
|
||||
"""
|
||||
distribute pretrain scripts. The number of Ascend accelerators can be automatically allocated
|
||||
based on the device_num set in hccl config file, You don not need to specify that.
|
||||
"""
|
||||
cmd = ""
|
||||
print("start", __file__)
|
||||
args = parse_args()
|
||||
|
||||
run_script = args.run_script_path
|
||||
|
||||
print("hccl_config_dir:", args.hccl_config_dir)
|
||||
print("hccl_time_out:", args.hccl_time_out)
|
||||
cmd = append_cmd_env(cmd, 'HCCL_CONNECT_TIMEOUT', args.hccl_time_out)
|
||||
cmd = append_cmd_env(cmd, 'RANK_TABLE_FILE', args.hccl_config_dir)
|
||||
|
||||
cores = multiprocessing.cpu_count()
|
||||
print("the number of logical core:", cores)
|
||||
|
||||
# get device_ips
|
||||
device_ips = {}
|
||||
physic_logic_ids = {}
|
||||
with open('/etc/hccn.conf', 'r') as fin:
|
||||
for hccn_item in fin.readlines():
|
||||
if hccn_item.strip().startswith('address_'):
|
||||
device_id, device_ip = hccn_item.split('=')
|
||||
device_id = device_id.split('_')[1]
|
||||
device_ips[device_id] = device_ip.strip()
|
||||
|
||||
if not device_ips:
|
||||
raise ValueError("There is no address in /etc/hccn.conf")
|
||||
|
||||
for logic_id, device_id in enumerate(sorted(device_ips.keys())):
|
||||
physic_logic_ids[device_id] = logic_id
|
||||
|
||||
with open(args.hccl_config_dir, "r", encoding="utf-8") as fin:
|
||||
hccl_config = json.loads(fin.read())
|
||||
rank_size = 0
|
||||
for server in hccl_config["server_list"]:
|
||||
rank_size += len(server["device"])
|
||||
if server["device"][0]["device_ip"] in device_ips.values():
|
||||
this_server = server
|
||||
|
||||
cmd = append_cmd_env(cmd, "RANK_SIZE", str(rank_size))
|
||||
print("total rank size:", rank_size)
|
||||
print("this server rank size:", len(this_server["device"]))
|
||||
avg_core_per_rank = int(int(cores) / len(this_server["device"]))
|
||||
core_gap = avg_core_per_rank - 1
|
||||
print("avg_core_per_rank:", avg_core_per_rank)
|
||||
|
||||
count = 0
|
||||
for instance in this_server["device"]:
|
||||
# device_id is the physical id, we use logic id to specific the selected device.
|
||||
# While running on a server with 8 pcs, the logic ids are equal to the device ids.
|
||||
device_id = instance["device_id"]
|
||||
rank_id = instance["rank_id"]
|
||||
logic_id = physic_logic_ids[device_id]
|
||||
start = count * int(avg_core_per_rank)
|
||||
count += 1
|
||||
end = start + core_gap
|
||||
cmdopt = str(start) + "-" + str(end)
|
||||
cur_dir = os.getcwd()
|
||||
|
||||
cmd = set_envs(cmd, logic_id, rank_id)
|
||||
cmd = make_dirs(cmd, logic_id)
|
||||
|
||||
print_info(rank_id=rank_id, device_id=device_id, logic_id=logic_id, cmdopt=cmdopt, cur_dir=cur_dir)
|
||||
|
||||
if args.cpu_bind:
|
||||
run_cmd = 'taskset -c ' + cmdopt + ' '
|
||||
else:
|
||||
run_cmd = ""
|
||||
run_cmd += 'nohup python ' + run_script + " "
|
||||
|
||||
run_cmd += " " + ' '.join([str(x) for x in args.args.split(' ')[1:]])
|
||||
run_cmd += ' >./log.txt 2>&1 &'
|
||||
|
||||
cmd = append_cmd(cmd, run_cmd)
|
||||
cmd = append_cmd(cmd, "cd -")
|
||||
cmd = append_cmd(cmd, "echo \"run with" +
|
||||
" rank_id=" + str(rank_id) +
|
||||
" device_id=" + str(device_id) +
|
||||
" logic_id=" + str(logic_id) + "\"")
|
||||
cmd += "\n"
|
||||
|
||||
with open(args.cmd_file, "w") as f:
|
||||
f.write(cmd)
|
||||
|
||||
if __name__ == "__main__":
|
||||
distribute_run()
|
|
@ -0,0 +1,35 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
rm -rf EVAL
|
||||
mkdir EVAL
|
||||
cd EVAL
|
||||
|
||||
if [ $# -ge 1 ]; then
|
||||
if [ $1 == 'imagenet' ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/../config_imagenet.yaml"
|
||||
elif [ $1 == 'default' ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
|
||||
else
|
||||
echo "Unrecognized parameter"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
|
||||
fi
|
||||
|
||||
python ${BASE_PATH}/../eval.py --config_path=$CONFIG_FILE
|
|
@ -0,0 +1,39 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -lt 1 ]; then
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_local_train.sh RANK_TABLE_FILE [OTHER_ARGS]"
|
||||
echo "OTHER_ARGS will be passed to the training scripts directly,"
|
||||
echo "for example: bash run_local_train.sh /path/hccl.json --data_dir=/path/data_dir --epochs=40"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
exit 1
|
||||
fi
|
||||
|
||||
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
OTHER_ARGS=$*
|
||||
echo ${OTHER_ARGS[0]}
|
||||
|
||||
python3 ${BASE_PATH}/ascend_distributed_launcher/get_distribute_pretrain_cmd.py \
|
||||
--run_script_path=${BASE_PATH}/../train.py \
|
||||
--hccl_config_dir=$1 \
|
||||
--hccl_time_out=600 \
|
||||
--args="$*" \
|
||||
--cmd_file=distributed_cmd.sh
|
||||
|
||||
bash distributed_cmd.sh
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Dataset"""
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as CV
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
from mindspore.dataset.vision import Inter
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
def create_lenet_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1):
|
||||
"""
|
||||
create dataset for train or test
|
||||
"""
|
||||
# define dataset
|
||||
mnist_ds = ds.MnistDataset(data_path, num_parallel_workers=num_parallel_workers)
|
||||
|
||||
resize_height, resize_width = 32, 32
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
rescale_nml = 1 / 0.3081
|
||||
shift_nml = -1 * 0.1307 / 0.3081
|
||||
|
||||
# define map operations
|
||||
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
|
||||
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
|
||||
rescale_op = CV.Rescale(rescale, shift)
|
||||
hwc2chw_op = CV.HWC2CHW()
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
|
||||
# apply map operations on images
|
||||
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
|
||||
|
||||
# apply DatasetOps
|
||||
buffer_size = 10000
|
||||
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
|
||||
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True, num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.repeat(repeat_size)
|
||||
|
||||
return mnist_ds
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Model Structure"""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.initializer import Normal
|
||||
|
||||
|
||||
class LeNet5(nn.Cell):
|
||||
"""
|
||||
Lenet network
|
||||
|
||||
Args:
|
||||
num_class (int): Number of classes. Default: 10.
|
||||
num_channel (int): Number of channels. Default: 1.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor
|
||||
Examples:
|
||||
>>> LeNet(num_class=10)
|
||||
|
||||
"""
|
||||
def __init__(self, num_class=10, num_channel=1, include_top=True):
|
||||
super(LeNet5, self).__init__()
|
||||
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
|
||||
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
|
||||
self.relu = nn.ReLU()
|
||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.include_top = include_top
|
||||
if self.include_top:
|
||||
self.flatten = nn.Flatten()
|
||||
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
|
||||
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
|
||||
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
|
||||
|
||||
|
||||
def construct(self, x):
|
||||
"""construct lenet5"""
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.conv2(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
if not self.include_top:
|
||||
return x
|
||||
x = self.flatten(x)
|
||||
x = self.relu(self.fc1(x))
|
||||
x = self.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from mindspore.profiler import Profiler
|
||||
from .config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler = Profiler()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler.analyse()
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Training process"""
|
||||
|
||||
import os
|
||||
|
||||
from mindspore import nn
|
||||
from mindspore import context
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
|
||||
from src.moxing_adapter import moxing_wrapper
|
||||
from src.config import config
|
||||
from src.dataset import create_lenet_dataset
|
||||
from src.foo import LeNet5
|
||||
|
||||
|
||||
@moxing_wrapper()
|
||||
def train_lenet5():
|
||||
"""
|
||||
Train lenet5
|
||||
"""
|
||||
config.ckpt_path = config.output_path
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
ds_train = create_lenet_dataset(os.path.join(config.data_path, "train"), config.batch_size, num_parallel_workers=1)
|
||||
if ds_train.get_dataset_size() == 0:
|
||||
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
|
||||
|
||||
network = LeNet5(config.num_classes)
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
net_opt = nn.Momentum(network.trainable_params(), config.lr, config.momentum)
|
||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
|
||||
directory=None if config.ckpt_path == "" else config.ckpt_path,
|
||||
config=config_ck)
|
||||
|
||||
if config.device_target != "Ascend":
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
else:
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O2")
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
model.train(config.epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_lenet5()
|
Loading…
Reference in New Issue