forked from mindspore-Ecosystem/mindspore
update modelzoo ncf network.
This commit is contained in:
parent
2d6f12d485
commit
0b6f035bba
|
@ -100,6 +100,33 @@ sh scripts/run_train.sh rank_table.json
|
||||||
sh run_eval.sh
|
sh run_eval.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training and evaluation as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# run distributed training on modelarts example
|
||||||
|
# (1) First, Perform a or b.
|
||||||
|
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||||
|
# Set other parameters on default_config.yaml file you need.
|
||||||
|
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||||
|
# Add other parameters on the website UI interface.
|
||||||
|
# (2) Set the code directory to "/path/ncf" on the website UI interface.
|
||||||
|
# (3) Set the startup file to "train.py" on the website UI interface.
|
||||||
|
# (4) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||||
|
# (5) Create your job.
|
||||||
|
|
||||||
|
# run evaluation on modelarts example
|
||||||
|
# (1) Copy or upload your trained model to S3 bucket.
|
||||||
|
# (2) Perform a or b.
|
||||||
|
# a. Set "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on default_config.yaml file.
|
||||||
|
# Set "checkpoint_url=/The path of checkpoint in S3/" on default_config.yaml file.
|
||||||
|
# b. Add "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on the website UI interface.
|
||||||
|
# Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface.
|
||||||
|
# (3) Set the code directory to "/path/ncf" on the website UI interface.
|
||||||
|
# (4) Set the startup file to "eval.py" on the website UI interface.
|
||||||
|
# (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||||
|
# (6) Create your job.
|
||||||
|
```
|
||||||
|
|
||||||
# [Script Description](#contents)
|
# [Script Description](#contents)
|
||||||
|
|
||||||
## [Script and Sample Code](#contents)
|
## [Script and Sample Code](#contents)
|
||||||
|
@ -108,6 +135,9 @@ sh run_eval.sh
|
||||||
├── ModelZoo_NCF_ME
|
├── ModelZoo_NCF_ME
|
||||||
├── README.md // descriptions about NCF
|
├── README.md // descriptions about NCF
|
||||||
├── scripts
|
├── scripts
|
||||||
|
│ ├──ascend_distributed_launcher
|
||||||
|
│ ├──__init__.py // init file
|
||||||
|
│ ├──get_distribute_pretrain_cmd.py // create distribute shell script
|
||||||
│ ├──run_train.sh // shell script for train
|
│ ├──run_train.sh // shell script for train
|
||||||
│ ├──run_distribute_train.sh // shell script for distribute train
|
│ ├──run_distribute_train.sh // shell script for distribute train
|
||||||
│ ├──run_eval.sh // shell script for evaluation
|
│ ├──run_eval.sh // shell script for evaluation
|
||||||
|
@ -116,15 +146,19 @@ sh run_eval.sh
|
||||||
├── src
|
├── src
|
||||||
│ ├──dataset.py // creating dataset
|
│ ├──dataset.py // creating dataset
|
||||||
│ ├──ncf.py // ncf architecture
|
│ ├──ncf.py // ncf architecture
|
||||||
│ ├──config.py // parameter configuration
|
│ ├──config.py // parameter analysis
|
||||||
|
│ ├──device_adapter.py // device adapter
|
||||||
|
│ ├──local_adapter.py // local adapter
|
||||||
|
│ ├──moxing_adapter.py // moxing adapter
|
||||||
│ ├──movielens.py // data download file
|
│ ├──movielens.py // data download file
|
||||||
│ ├──callbacks.py // model loss and eval callback file
|
│ ├──callbacks.py // model loss and eval callback file
|
||||||
│ ├──constants.py // the constants of model
|
│ ├──constants.py // the constants of model
|
||||||
│ ├──export.py // export checkpoint files into geir/onnx
|
│ ├──export.py // export checkpoint files into geir/onnx
|
||||||
│ ├──metrics.py // the file for auc compute
|
│ ├──metrics.py // the file for auc compute
|
||||||
│ ├──stat_utils.py // the file for data process functions
|
│ ├──stat_utils.py // the file for data process functions
|
||||||
|
├── default_config.yaml // parameter configuration
|
||||||
├── train.py // training script
|
├── train.py // training script
|
||||||
├── eval.py // evaluation script
|
├── eval.py // evaluation script
|
||||||
```
|
```
|
||||||
|
|
||||||
## [Script Parameters](#contents)
|
## [Script Parameters](#contents)
|
||||||
|
@ -144,7 +178,6 @@ Parameters for both training and evaluation can be set in config.py.
|
||||||
* `--num_factors`:The Embedding size of MF model.
|
* `--num_factors`:The Embedding size of MF model.
|
||||||
* `--output_path`:The location of the output file.
|
* `--output_path`:The location of the output file.
|
||||||
* `--eval_file_name` : Eval output file.
|
* `--eval_file_name` : Eval output file.
|
||||||
* `--loss_file_name` : Loss output file.
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## [Training Process](#contents)
|
## [Training Process](#contents)
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||||
|
enable_modelarts: False
|
||||||
|
# Url for modelarts
|
||||||
|
data_url: ""
|
||||||
|
train_url: ""
|
||||||
|
checkpoint_url: ""
|
||||||
|
# Path for local
|
||||||
|
data_path: "/cache/data"
|
||||||
|
output_path: "/cache/train"
|
||||||
|
load_path: "/cache/checkpoint_path"
|
||||||
|
device_target: "Ascend"
|
||||||
|
enable_profiling: False
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Training options
|
||||||
|
dataset: "ml-1m"
|
||||||
|
train_epochs: 14
|
||||||
|
batch_size: 256
|
||||||
|
eval_batch_size: 160000
|
||||||
|
num_neg: 4
|
||||||
|
layers: [64, 32, 16]
|
||||||
|
num_factors: 16
|
||||||
|
checkpoint_path: "./checkpoint/"
|
||||||
|
|
||||||
|
# Eval options
|
||||||
|
eval_file_name: "eval.log"
|
||||||
|
checkpoint_file_path: "./checkpoint/NCF-14_19418.ckpt"
|
||||||
|
|
||||||
|
# Export options
|
||||||
|
device_id: 0
|
||||||
|
ckpt_file: ""
|
||||||
|
file_name: ""
|
||||||
|
file_format: ""
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# Help description for each configuration
|
||||||
|
enable_modelarts: "Whether training on modelarts, default: False"
|
||||||
|
data_url: "Url for modelarts"
|
||||||
|
train_url: "Url for modelarts"
|
||||||
|
data_path: "The location of the input data."
|
||||||
|
output_path: "The location of the output file."
|
||||||
|
device_target: 'Target device type'
|
||||||
|
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||||
|
dataset: "Dataset to be trained and evaluated, choice: ['ml-1m', 'ml-20m']"
|
||||||
|
train_epochs: "The number of epochs used to train."
|
||||||
|
batch_size: "Batch size for training and evaluation"
|
||||||
|
eval_batch_size: "The batch size used for evaluation."
|
||||||
|
num_neg: "The Number of negative instances to pair with a positive instance."
|
||||||
|
layers: "The sizes of hidden layers for MLP"
|
||||||
|
num_factors: "The Embedding size of MF model."
|
||||||
|
checkpoint_path: "The location of the checkpoint file."
|
||||||
|
eval_file_name: "Eval output file."
|
||||||
|
checkpoint_file_path: "The location of the checkpoint file."
|
|
@ -15,7 +15,6 @@
|
||||||
"""Using for eval the model checkpoint"""
|
"""Using for eval the model checkpoint"""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import argparse
|
|
||||||
from absl import logging
|
from absl import logging
|
||||||
|
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
@ -26,31 +25,31 @@ from src.dataset import create_dataset
|
||||||
from src.metrics import NCFMetric
|
from src.metrics import NCFMetric
|
||||||
from src.ncf import NCFModel, NetWithLossClass, TrainStepWrap, PredictWithSigmoid
|
from src.ncf import NCFModel, NetWithLossClass, TrainStepWrap, PredictWithSigmoid
|
||||||
|
|
||||||
from src.config import cfg
|
from utils.config import config
|
||||||
|
from utils.moxing_adapter import moxing_wrapper
|
||||||
|
from utils.device_adapter import get_device_id
|
||||||
|
|
||||||
logging.set_verbosity(logging.INFO)
|
logging.set_verbosity(logging.INFO)
|
||||||
|
|
||||||
|
@moxing_wrapper()
|
||||||
parser = argparse.ArgumentParser(description='NCF')
|
def run_eval():
|
||||||
parser.add_argument("--data_path", type=str, default="./dataset/") # The location of the input data.
|
|
||||||
parser.add_argument("--dataset", type=str, default="ml-1m", choices=["ml-1m", "ml-20m"]) # Dataset to be trained and evaluated. ["ml-1m", "ml-20m"]
|
|
||||||
parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file.
|
|
||||||
parser.add_argument("--eval_file_name", type=str, default="eval.log") # Eval output file.
|
|
||||||
parser.add_argument("--checkpoint_file_path", type=str, default="./checkpoint/NCF-14_19418.ckpt") # The location of the checkpoint file.
|
|
||||||
args, _ = parser.parse_known_args()
|
|
||||||
|
|
||||||
def test_eval():
|
|
||||||
"""eval method"""
|
"""eval method"""
|
||||||
if not os.path.exists(args.output_path):
|
if not os.path.exists(config.output_path):
|
||||||
os.makedirs(args.output_path)
|
os.makedirs(config.output_path)
|
||||||
|
|
||||||
layers = cfg.layers
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
num_factors = cfg.num_factors
|
device_target="Davinci",
|
||||||
|
save_graphs=False,
|
||||||
|
device_id=get_device_id())
|
||||||
|
|
||||||
|
layers = config.layers
|
||||||
|
num_factors = config.num_factors
|
||||||
topk = rconst.TOP_K
|
topk = rconst.TOP_K
|
||||||
num_eval_neg = rconst.NUM_EVAL_NEGATIVES
|
num_eval_neg = rconst.NUM_EVAL_NEGATIVES
|
||||||
|
|
||||||
ds_eval, num_eval_users, num_eval_items = create_dataset(test_train=False, data_dir=args.data_path,
|
ds_eval, num_eval_users, num_eval_items = create_dataset(test_train=False, data_dir=config.data_path,
|
||||||
dataset=args.dataset, train_epochs=0,
|
dataset=config.dataset, train_epochs=0,
|
||||||
eval_batch_size=cfg.eval_batch_size)
|
eval_batch_size=config.eval_batch_size)
|
||||||
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
|
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
|
||||||
|
|
||||||
ncf_net = NCFModel(num_users=num_eval_users,
|
ncf_net = NCFModel(num_users=num_eval_users,
|
||||||
|
@ -60,7 +59,7 @@ def test_eval():
|
||||||
mf_regularization=0,
|
mf_regularization=0,
|
||||||
mlp_reg_layers=[0.0, 0.0, 0.0, 0.0],
|
mlp_reg_layers=[0.0, 0.0, 0.0, 0.0],
|
||||||
mf_dim=16)
|
mf_dim=16)
|
||||||
param_dict = load_checkpoint(args.checkpoint_file_path)
|
param_dict = load_checkpoint(config.checkpoint_file_path)
|
||||||
load_param_into_net(ncf_net, param_dict)
|
load_param_into_net(ncf_net, param_dict)
|
||||||
|
|
||||||
loss_net = NetWithLossClass(ncf_net)
|
loss_net = NetWithLossClass(ncf_net)
|
||||||
|
@ -73,18 +72,12 @@ def test_eval():
|
||||||
ncf_metric.clear()
|
ncf_metric.clear()
|
||||||
out = model.eval(ds_eval)
|
out = model.eval(ds_eval)
|
||||||
|
|
||||||
eval_file_path = os.path.join(args.output_path, args.eval_file_name)
|
eval_file_path = os.path.join(config.output_path, config.eval_file_name)
|
||||||
eval_file = open(eval_file_path, "a+")
|
eval_file = open(eval_file_path, "a+")
|
||||||
eval_file.write("EvalCallBack: HR = {}, NDCG = {}\n".format(out['ncf'][0], out['ncf'][1]))
|
eval_file.write("EvalCallBack: HR = {}, NDCG = {}\n".format(out['ncf'][0], out['ncf'][1]))
|
||||||
eval_file.close()
|
eval_file.close()
|
||||||
print("EvalCallBack: HR = {}, NDCG = {}".format(out['ncf'][0], out['ncf'][1]))
|
print("EvalCallBack: HR = {}, NDCG = {}".format(out['ncf'][0], out['ncf'][1]))
|
||||||
|
print("=" * 100 + "Eval Finish!" + "=" * 100)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
devid = int(os.getenv('DEVICE_ID'))
|
run_eval()
|
||||||
context.set_context(mode=context.GRAPH_MODE,
|
|
||||||
device_target="Davinci",
|
|
||||||
save_graphs=True,
|
|
||||||
device_id=devid)
|
|
||||||
|
|
||||||
test_eval()
|
|
||||||
|
|
|
@ -13,37 +13,26 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""ncf export file"""
|
"""ncf export file"""
|
||||||
import argparse
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||||
|
|
||||||
import src.constants as rconst
|
import src.constants as rconst
|
||||||
from src.config import cfg
|
from utils.config import config
|
||||||
from ncf import NCFModel, PredictWithSigmoid
|
from ncf import NCFModel, PredictWithSigmoid
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='ncf export')
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
if config.device_target == "Ascend":
|
||||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
context.set_context(device_id=config.device_id)
|
||||||
parser.add_argument("--dataset", type=str, default="ml-1m", choices=["ml-1m", "ml-20m"], help="Dataset.")
|
|
||||||
parser.add_argument("--file_name", type=str, default="ncf", help="output file name.")
|
|
||||||
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
|
|
||||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
|
||||||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
|
||||||
if args.device_target == "Ascend":
|
|
||||||
context.set_context(device_id=args.device_id)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
topk = rconst.TOP_K
|
topk = rconst.TOP_K
|
||||||
num_eval_neg = rconst.NUM_EVAL_NEGATIVES
|
num_eval_neg = rconst.NUM_EVAL_NEGATIVES
|
||||||
|
|
||||||
if args.dataset == "ml-1m":
|
if config.dataset == "ml-1m":
|
||||||
num_eval_users = 6040
|
num_eval_users = 6040
|
||||||
num_eval_items = 3706
|
num_eval_items = 3706
|
||||||
elif args.dataset == "ml-20m":
|
elif config.dataset == "ml-20m":
|
||||||
num_eval_users = 138493
|
num_eval_users = 138493
|
||||||
num_eval_items = 26744
|
num_eval_items = 26744
|
||||||
else:
|
else:
|
||||||
|
@ -51,20 +40,20 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
ncf_net = NCFModel(num_users=num_eval_users,
|
ncf_net = NCFModel(num_users=num_eval_users,
|
||||||
num_items=num_eval_items,
|
num_items=num_eval_items,
|
||||||
num_factors=cfg.num_factors,
|
num_factors=config.num_factors,
|
||||||
model_layers=cfg.layers,
|
model_layers=config.layers,
|
||||||
mf_regularization=0,
|
mf_regularization=0,
|
||||||
mlp_reg_layers=[0.0, 0.0, 0.0, 0.0],
|
mlp_reg_layers=[0.0, 0.0, 0.0, 0.0],
|
||||||
mf_dim=16)
|
mf_dim=16)
|
||||||
|
|
||||||
param_dict = load_checkpoint(args.ckpt_file)
|
param_dict = load_checkpoint(config.ckpt_file)
|
||||||
load_param_into_net(ncf_net, param_dict)
|
load_param_into_net(ncf_net, param_dict)
|
||||||
|
|
||||||
network = PredictWithSigmoid(ncf_net, topk, num_eval_neg)
|
network = PredictWithSigmoid(ncf_net, topk, num_eval_neg)
|
||||||
|
|
||||||
users = Tensor(np.zeros([cfg.eval_batch_size, 1]).astype(np.int32))
|
users = Tensor(np.zeros([config.eval_batch_size, 1]).astype(np.int32))
|
||||||
items = Tensor(np.zeros([cfg.eval_batch_size, 1]).astype(np.int32))
|
items = Tensor(np.zeros([config.eval_batch_size, 1]).astype(np.int32))
|
||||||
masks = Tensor(np.zeros([cfg.eval_batch_size, 1]).astype(np.float32))
|
masks = Tensor(np.zeros([config.eval_batch_size, 1]).astype(np.float32))
|
||||||
|
|
||||||
input_data = [users, items, masks]
|
input_data = [users, items, masks]
|
||||||
export(network, *input_data, file_name=args.file_name, file_format=args.file_format)
|
export(network, *input_data, file_name=config.file_name, file_format=config.file_format)
|
||||||
|
|
|
@ -0,0 +1,188 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""distribute running script"""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import multiprocessing
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
"""
|
||||||
|
parse args .
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
args.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> parse_args()
|
||||||
|
"""
|
||||||
|
parser = ArgumentParser(description="Distributed training scripts generator for MindSpore")
|
||||||
|
|
||||||
|
parser.add_argument("--run_script_path", type=str, default="",
|
||||||
|
help="Run script path, it is better to use absolute path")
|
||||||
|
parser.add_argument("--args", type=str, default="",
|
||||||
|
help="Other arguments which will be passed to main program directly")
|
||||||
|
parser.add_argument("--hccl_config_dir", type=str, default="",
|
||||||
|
help="Hccl config path, it is better to use absolute path")
|
||||||
|
parser.add_argument("--cmd_file", type=str, default="distributed_cmd.sh",
|
||||||
|
help="Path of the generated cmd file.")
|
||||||
|
parser.add_argument("--hccl_time_out", type=int, default=120,
|
||||||
|
help="Seconds to determine the hccl time out,"
|
||||||
|
"default: 120, which is the same as hccl default config")
|
||||||
|
parser.add_argument("--cpu_bind", action="store_true", default=False,
|
||||||
|
help="Bind cpu cores or not")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def append_cmd(cmd, s):
|
||||||
|
cmd += s
|
||||||
|
cmd += "\n"
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
|
def append_cmd_env(cmd, key, value):
|
||||||
|
return append_cmd(cmd, "export " + str(key) + "=" + str(value))
|
||||||
|
|
||||||
|
|
||||||
|
def set_envs(cmd, logic_id, rank_id):
|
||||||
|
"""
|
||||||
|
Set environment variables.
|
||||||
|
"""
|
||||||
|
cmd = append_cmd_env(cmd, "DEVICE_ID", str(logic_id))
|
||||||
|
cmd = append_cmd_env(cmd, "RANK_ID", str(rank_id))
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
|
def make_dirs(cmd, logic_id):
|
||||||
|
"""
|
||||||
|
Make directories and change path.
|
||||||
|
"""
|
||||||
|
cmd = append_cmd(cmd, "rm -rf LOG" + str(logic_id))
|
||||||
|
cmd = append_cmd(cmd, "mkdir ./LOG" + str(logic_id))
|
||||||
|
cmd = append_cmd(cmd, "mkdir -p ./LOG" + str(logic_id) + "/ms_log")
|
||||||
|
cmd = append_cmd(cmd, "env > ./LOG" + str(logic_id) + "/env.log")
|
||||||
|
cur_dir = os.getcwd()
|
||||||
|
cmd = append_cmd_env(cmd, "GLOG_log_dir", cur_dir + "/LOG" + str(logic_id) + "/ms_log")
|
||||||
|
cmd = append_cmd_env(cmd, "GLOG_logtostderr", "0")
|
||||||
|
cmd = append_cmd(cmd, "cd " + cur_dir + "/LOG" + str(logic_id))
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
|
def print_info(rank_id, device_id, logic_id, cmdopt, cur_dir):
|
||||||
|
"""
|
||||||
|
Print some information about scripts.
|
||||||
|
"""
|
||||||
|
print("\nstart training for rank " + str(rank_id) + ", device " + str(device_id) + ":")
|
||||||
|
print("rank_id:", rank_id)
|
||||||
|
print("device_id:", device_id)
|
||||||
|
print("logic_id", logic_id)
|
||||||
|
print("core_nums:", cmdopt)
|
||||||
|
print("log_file_dir: " + cur_dir + "/LOG" + str(logic_id) + "/pretraining_log.txt")
|
||||||
|
|
||||||
|
def distribute_run():
|
||||||
|
"""
|
||||||
|
distribute pretrain scripts. The number of Ascend accelerators can be automatically allocated
|
||||||
|
based on the device_num set in hccl config file, You don not need to specify that.
|
||||||
|
"""
|
||||||
|
cmd = ""
|
||||||
|
print("start", __file__)
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
run_script = args.run_script_path
|
||||||
|
|
||||||
|
print("hccl_config_dir:", args.hccl_config_dir)
|
||||||
|
print("hccl_time_out:", args.hccl_time_out)
|
||||||
|
cmd = append_cmd_env(cmd, 'HCCL_CONNECT_TIMEOUT', args.hccl_time_out)
|
||||||
|
cmd = append_cmd_env(cmd, 'RANK_TABLE_FILE', args.hccl_config_dir)
|
||||||
|
|
||||||
|
cores = multiprocessing.cpu_count()
|
||||||
|
print("the number of logical core:", cores)
|
||||||
|
|
||||||
|
# get device_ips
|
||||||
|
device_ips = {}
|
||||||
|
physic_logic_ids = {}
|
||||||
|
with open('/etc/hccn.conf', 'r') as fin:
|
||||||
|
for hccn_item in fin.readlines():
|
||||||
|
if hccn_item.strip().startswith('address_'):
|
||||||
|
device_id, device_ip = hccn_item.split('=')
|
||||||
|
device_id = device_id.split('_')[1]
|
||||||
|
device_ips[device_id] = device_ip.strip()
|
||||||
|
|
||||||
|
if not device_ips:
|
||||||
|
raise ValueError("There is no address in /etc/hccn.conf")
|
||||||
|
|
||||||
|
for logic_id, device_id in enumerate(sorted(device_ips.keys())):
|
||||||
|
physic_logic_ids[device_id] = logic_id
|
||||||
|
|
||||||
|
with open(args.hccl_config_dir, "r", encoding="utf-8") as fin:
|
||||||
|
hccl_config = json.loads(fin.read())
|
||||||
|
rank_size = 0
|
||||||
|
for server in hccl_config["server_list"]:
|
||||||
|
rank_size += len(server["device"])
|
||||||
|
if server["device"][0]["device_ip"] in device_ips.values():
|
||||||
|
this_server = server
|
||||||
|
|
||||||
|
cmd = append_cmd_env(cmd, "RANK_SIZE", str(rank_size))
|
||||||
|
print("total rank size:", rank_size)
|
||||||
|
print("this server rank size:", len(this_server["device"]))
|
||||||
|
avg_core_per_rank = int(int(cores) / len(this_server["device"]))
|
||||||
|
core_gap = avg_core_per_rank - 1
|
||||||
|
print("avg_core_per_rank:", avg_core_per_rank)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for instance in this_server["device"]:
|
||||||
|
# device_id is the physical id, we use logic id to specific the selected device.
|
||||||
|
# While running on a server with 8 pcs, the logic ids are equal to the device ids.
|
||||||
|
device_id = instance["device_id"]
|
||||||
|
rank_id = instance["rank_id"]
|
||||||
|
logic_id = physic_logic_ids[device_id]
|
||||||
|
start = count * int(avg_core_per_rank)
|
||||||
|
count += 1
|
||||||
|
end = start + core_gap
|
||||||
|
cmdopt = str(start) + "-" + str(end)
|
||||||
|
cur_dir = os.getcwd()
|
||||||
|
|
||||||
|
cmd = set_envs(cmd, logic_id, rank_id)
|
||||||
|
cmd = make_dirs(cmd, logic_id)
|
||||||
|
|
||||||
|
print_info(rank_id=rank_id, device_id=device_id, logic_id=logic_id, cmdopt=cmdopt, cur_dir=cur_dir)
|
||||||
|
|
||||||
|
if args.cpu_bind:
|
||||||
|
run_cmd = 'taskset -c ' + cmdopt + ' '
|
||||||
|
else:
|
||||||
|
run_cmd = ""
|
||||||
|
run_cmd += 'nohup python ' + run_script + " "
|
||||||
|
|
||||||
|
run_cmd += " " + ' '.join([str(x) for x in args.args.split(' ')[1:]])
|
||||||
|
run_cmd += ' >./log.txt 2>&1 &'
|
||||||
|
|
||||||
|
cmd = append_cmd(cmd, run_cmd)
|
||||||
|
cmd = append_cmd(cmd, "cd -")
|
||||||
|
cmd = append_cmd(cmd, "echo \"run with" +
|
||||||
|
" rank_id=" + str(rank_id) +
|
||||||
|
" device_id=" + str(device_id) +
|
||||||
|
" logic_id=" + str(logic_id) + "\"")
|
||||||
|
cmd += "\n"
|
||||||
|
|
||||||
|
with open(args.cmd_file, "w") as f:
|
||||||
|
f.write(cmd)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
distribute_run()
|
|
@ -13,35 +13,31 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
echo "Please run the script as: "
|
|
||||||
echo "sh scripts/run_distribute_train.sh DEVICE_NUM DATASET_PATH RANK_TABLE_FILE"
|
|
||||||
echo "for example: sh scripts/run_distribute_train.sh 8 /dataset_path /rank_table_8p.json"
|
|
||||||
|
|
||||||
current_exec_path=$(pwd)
|
if [ $# -lt 1 ]; then
|
||||||
echo ${current_exec_path}
|
echo "=============================================================================================================="
|
||||||
|
echo "Please run the script as: "
|
||||||
|
echo "bash run_local_train.sh RANK_TABLE_FILE [OTHER_ARGS]"
|
||||||
|
echo "OTHER_ARGS will be passed to the training scripts directly,"
|
||||||
|
echo "for example: bash run_local_train.sh /path/hccl.json /dataset_path"
|
||||||
|
echo "It is better to use absolute path."
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
export RANK_SIZE=$1
|
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||||
data_path=$2
|
|
||||||
export RANK_TABLE_FILE=$3
|
|
||||||
|
|
||||||
for((i=0;i<=RANK_SIZE;i++));
|
python3 ${BASE_PATH}/ascend_distributed_launcher/get_distribute_pretrain_cmd.py \
|
||||||
do
|
--run_script_path=${BASE_PATH}/../train.py \
|
||||||
rm ${current_exec_path}/device_$i/ -rf
|
--hccl_config_dir=$1 \
|
||||||
mkdir ${current_exec_path}/device_$i
|
--hccl_time_out=600 \
|
||||||
cd ${current_exec_path}/device_$i || exit
|
--args=" --data_path=$2 \
|
||||||
export RANK_ID=$i
|
--dataset='ml-1m' \
|
||||||
export DEVICE_ID=$i
|
--train_epochs=50 \
|
||||||
python -u ${current_exec_path}/train.py \
|
--output_path='./output/' \
|
||||||
--data_path $data_path \
|
--eval_file_name='eval.log' \
|
||||||
--dataset 'ml-1m' \
|
--checkpoint_path='./checkpoint/' \
|
||||||
--train_epochs 50 \
|
--device_target='Ascend'" \
|
||||||
--output_path './output/' \
|
--cmd_file=distributed_cmd.sh
|
||||||
--eval_file_name 'eval.log' \
|
|
||||||
--loss_file_name 'loss.log' \
|
|
||||||
--checkpoint_path './checkpoint/' \
|
|
||||||
--device_target="Ascend" \
|
|
||||||
--device_id=$i \
|
|
||||||
--is_distributed=1 \
|
|
||||||
>log_$i.log 2>&1 &
|
|
||||||
done
|
|
||||||
|
|
||||||
|
bash distributed_cmd.sh
|
||||||
|
|
|
@ -19,4 +19,4 @@ echo "for example: sh scripts/run_train.sh /dataset_path /ncf.ckpt"
|
||||||
|
|
||||||
data_path=$1
|
data_path=$1
|
||||||
ckpt_file=$2
|
ckpt_file=$2
|
||||||
python ./train.py --data_path $data_path --dataset 'ml-1m' --train_epochs 20 --batch_size 256 --output_path './output/' --loss_file_name 'loss.log' --checkpoint_path $ckpt_file
|
python ./train.py --data_path $data_path --dataset 'ml-1m' --train_epochs 20 --batch_size 256 --output_path './output/' --checkpoint_path $ckpt_file
|
||||||
|
|
|
@ -1,38 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""
|
|
||||||
network config setting, will be used in main.py
|
|
||||||
"""
|
|
||||||
from easydict import EasyDict as edict
|
|
||||||
|
|
||||||
|
|
||||||
cfg = edict({
|
|
||||||
'dataset': 'ml-1m', # Dataset to be trained and evaluated, choice: ["ml-1m", "ml-20m"]
|
|
||||||
|
|
||||||
'data_dir': '../dataset', # The location of the input data.
|
|
||||||
|
|
||||||
'train_epochs': 14, # The number of epochs used to train.
|
|
||||||
|
|
||||||
'batch_size': 256, # Batch size for training and evaluation
|
|
||||||
|
|
||||||
'eval_batch_size': 160000, # The batch size used for evaluation.
|
|
||||||
|
|
||||||
'num_neg': 4, # The Number of negative instances to pair with a positive instance.
|
|
||||||
|
|
||||||
'layers': [64, 32, 16], # The sizes of hidden layers for MLP
|
|
||||||
|
|
||||||
'num_factors': 16 # The Embedding size of MF model.
|
|
||||||
|
|
||||||
})
|
|
|
@ -14,71 +14,61 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Training entry file"""
|
"""Training entry file"""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import argparse
|
|
||||||
from absl import logging
|
from absl import logging
|
||||||
|
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||||
from mindspore import context, Model
|
from mindspore import context, Model
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.communication.management import get_rank, get_group_size, init
|
from mindspore.communication.management import init
|
||||||
from mindspore.common import set_seed
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.dataset import create_dataset
|
from src.dataset import create_dataset
|
||||||
from src.ncf import NCFModel, NetWithLossClass, TrainStepWrap
|
from src.ncf import NCFModel, NetWithLossClass, TrainStepWrap
|
||||||
|
|
||||||
from config import cfg
|
from utils.moxing_adapter import moxing_wrapper
|
||||||
|
from utils.config import config
|
||||||
|
from utils.device_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||||
|
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
|
|
||||||
logging.set_verbosity(logging.INFO)
|
logging.set_verbosity(logging.INFO)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='NCF')
|
def modelarts_pre_process():
|
||||||
parser.add_argument("--data_path", type=str, default="./dataset/") # The location of the input data.
|
config.checkpoint_path = os.path.join(config.output_path, str(get_rank_id()), config.checkpoint_path)
|
||||||
parser.add_argument("--dataset", type=str, default="ml-1m", choices=["ml-1m", "ml-20m"]) # Dataset to be trained and evaluated. ["ml-1m", "ml-20m"]
|
|
||||||
parser.add_argument("--train_epochs", type=int, default=14) # The number of epochs used to train.
|
|
||||||
parser.add_argument("--batch_size", type=int, default=256) # Batch size for training and evaluation
|
|
||||||
parser.add_argument("--num_neg", type=int, default=4) # The Number of negative instances to pair with a positive instance.
|
|
||||||
parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file.
|
|
||||||
parser.add_argument("--loss_file_name", type=str, default="loss.log") # Loss output file.
|
|
||||||
parser.add_argument("--checkpoint_path", type=str, default="./checkpoint/") # The location of the checkpoint file.
|
|
||||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
|
||||||
help='device where the code will be implemented. (Default: Ascend)')
|
|
||||||
parser.add_argument('--device_id', type=int, default=1, help='device id of GPU or Ascend. (Default: None)')
|
|
||||||
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
|
|
||||||
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
|
|
||||||
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
def test_train():
|
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||||
|
def run_train():
|
||||||
"""train entry method"""
|
"""train entry method"""
|
||||||
if args.is_distributed:
|
print(config)
|
||||||
if args.device_target == "Ascend":
|
print("device id: ", get_device_id())
|
||||||
init()
|
print("device num: ", get_device_num())
|
||||||
context.set_context(device_id=args.device_id)
|
print("rank id: ", get_rank_id())
|
||||||
elif args.device_target == "GPU":
|
print("job id: ", get_job_id())
|
||||||
init()
|
|
||||||
|
|
||||||
args.rank = get_rank()
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||||
args.group_size = get_group_size()
|
|
||||||
device_num = args.group_size
|
config.is_distributed = bool(get_device_num() > 1)
|
||||||
|
if config.is_distributed:
|
||||||
|
config.group_size = get_device_num()
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
context.set_auto_parallel_context(device_num=config.group_size, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
parameter_broadcast=True, gradients_mean=True)
|
parameter_broadcast=True, gradients_mean=True)
|
||||||
|
|
||||||
|
if config.device_target == "Ascend":
|
||||||
|
context.set_context(device_id=get_device_id())
|
||||||
|
init()
|
||||||
|
elif config.device_target == "GPU":
|
||||||
|
init()
|
||||||
else:
|
else:
|
||||||
context.set_context(device_id=args.device_id)
|
context.set_context(device_id=get_device_id())
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
|
||||||
|
|
||||||
if not os.path.exists(args.output_path):
|
layers = config.layers
|
||||||
os.makedirs(args.output_path)
|
num_factors = config.num_factors
|
||||||
|
epochs = config.train_epochs
|
||||||
|
|
||||||
layers = cfg.layers
|
ds_train, num_train_users, num_train_items = create_dataset(test_train=True, data_dir=config.data_path,
|
||||||
num_factors = cfg.num_factors
|
dataset=config.dataset, train_epochs=1,
|
||||||
epochs = args.train_epochs
|
batch_size=config.batch_size, num_neg=config.num_neg)
|
||||||
|
|
||||||
ds_train, num_train_users, num_train_items = create_dataset(test_train=True, data_dir=args.data_path,
|
|
||||||
dataset=args.dataset, train_epochs=1,
|
|
||||||
batch_size=args.batch_size, num_neg=args.num_neg)
|
|
||||||
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
||||||
|
|
||||||
ncf_net = NCFModel(num_users=num_train_users,
|
ncf_net = NCFModel(num_users=num_train_users,
|
||||||
|
@ -95,14 +85,14 @@ def test_train():
|
||||||
|
|
||||||
model = Model(train_net)
|
model = Model(train_net)
|
||||||
callback = LossMonitor(per_print_times=ds_train.get_dataset_size())
|
callback = LossMonitor(per_print_times=ds_train.get_dataset_size())
|
||||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=(4970845+args.batch_size-1)//(args.batch_size),
|
ckpt_config = CheckpointConfig(save_checkpoint_steps=(4970845+config.batch_size-1)//(config.batch_size),
|
||||||
keep_checkpoint_max=100)
|
keep_checkpoint_max=100)
|
||||||
ckpoint_cb = ModelCheckpoint(prefix='NCF', directory=args.checkpoint_path, config=ckpt_config)
|
ckpoint_cb = ModelCheckpoint(prefix='NCF', directory=config.checkpoint_path, config=ckpt_config)
|
||||||
model.train(epochs,
|
model.train(epochs,
|
||||||
ds_train,
|
ds_train,
|
||||||
callbacks=[TimeMonitor(ds_train.get_dataset_size()), callback, ckpoint_cb],
|
callbacks=[TimeMonitor(ds_train.get_dataset_size()), callback, ckpoint_cb],
|
||||||
dataset_sink_mode=True)
|
dataset_sink_mode=True)
|
||||||
|
print("="*100 + "Training Finish!" + "="*100)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_train()
|
run_train()
|
||||||
|
|
|
@ -0,0 +1,127 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Parse arguments"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import ast
|
||||||
|
import argparse
|
||||||
|
from pprint import pprint, pformat
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""
|
||||||
|
Configuration namespace. Convert dictionary to members.
|
||||||
|
"""
|
||||||
|
def __init__(self, cfg_dict):
|
||||||
|
for k, v in cfg_dict.items():
|
||||||
|
if isinstance(v, (list, tuple)):
|
||||||
|
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||||
|
else:
|
||||||
|
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return pformat(self.__dict__)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__str__()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||||
|
"""
|
||||||
|
Parse command line arguments to the configuration according to the default yaml.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parser: Parent parser.
|
||||||
|
cfg: Base configuration.
|
||||||
|
helper: Helper description.
|
||||||
|
cfg_path: Path to the default yaml config.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||||
|
parents=[parser])
|
||||||
|
helper = {} if helper is None else helper
|
||||||
|
choices = {} if choices is None else choices
|
||||||
|
for item in cfg:
|
||||||
|
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||||
|
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||||
|
choice = choices[item] if item in choices else None
|
||||||
|
if isinstance(cfg[item], bool):
|
||||||
|
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||||
|
help=help_description)
|
||||||
|
else:
|
||||||
|
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||||
|
help=help_description)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def parse_yaml(yaml_path):
|
||||||
|
"""
|
||||||
|
Parse the yaml config file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
yaml_path: Path to the yaml config.
|
||||||
|
"""
|
||||||
|
with open(yaml_path, 'r') as fin:
|
||||||
|
try:
|
||||||
|
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||||
|
cfgs = [x for x in cfgs]
|
||||||
|
if len(cfgs) == 1:
|
||||||
|
cfg_helper = {}
|
||||||
|
cfg = cfgs[0]
|
||||||
|
cfg_choices = {}
|
||||||
|
elif len(cfgs) == 2:
|
||||||
|
cfg, cfg_helper = cfgs
|
||||||
|
cfg_choices = {}
|
||||||
|
elif len(cfgs) == 3:
|
||||||
|
cfg, cfg_helper, cfg_choices = cfgs
|
||||||
|
else:
|
||||||
|
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||||
|
print(cfg_helper)
|
||||||
|
except:
|
||||||
|
raise ValueError("Failed to parse yaml")
|
||||||
|
return cfg, cfg_helper, cfg_choices
|
||||||
|
|
||||||
|
|
||||||
|
def merge(args, cfg):
|
||||||
|
"""
|
||||||
|
Merge the base config from yaml file and command line arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Command line arguments.
|
||||||
|
cfg: Base configuration.
|
||||||
|
"""
|
||||||
|
args_var = vars(args)
|
||||||
|
for item in args_var:
|
||||||
|
cfg[item] = args_var[item]
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def get_config():
|
||||||
|
"""
|
||||||
|
Get Config according to the yaml file and cli arguments.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
|
||||||
|
help="Config file path")
|
||||||
|
path_args, _ = parser.parse_known_args()
|
||||||
|
default, helper, choices = parse_yaml(path_args.config_path)
|
||||||
|
pprint(default)
|
||||||
|
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||||
|
final_config = merge(args, default)
|
||||||
|
return Config(final_config)
|
||||||
|
|
||||||
|
config = get_config()
|
|
@ -0,0 +1,27 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Device adapter for ModelArts"""
|
||||||
|
|
||||||
|
from utils.config import config
|
||||||
|
|
||||||
|
if config.enable_modelarts:
|
||||||
|
from utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||||
|
else:
|
||||||
|
from utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||||
|
]
|
|
@ -0,0 +1,36 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Local adapter"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
def get_device_id():
|
||||||
|
device_id = os.getenv('DEVICE_ID', '0')
|
||||||
|
return int(device_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_num():
|
||||||
|
device_num = os.getenv('RANK_SIZE', '1')
|
||||||
|
return int(device_num)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank_id():
|
||||||
|
global_rank_id = os.getenv('RANK_ID', '0')
|
||||||
|
return int(global_rank_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_job_id():
|
||||||
|
return "Local Job"
|
|
@ -0,0 +1,122 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Moxing adapter for ModelArts"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import functools
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.profiler import Profiler
|
||||||
|
from utils.config import config
|
||||||
|
|
||||||
|
_global_sync_count = 0
|
||||||
|
|
||||||
|
def get_device_id():
|
||||||
|
device_id = os.getenv('DEVICE_ID', '0')
|
||||||
|
return int(device_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_num():
|
||||||
|
device_num = os.getenv('RANK_SIZE', '1')
|
||||||
|
return int(device_num)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank_id():
|
||||||
|
global_rank_id = os.getenv('RANK_ID', '0')
|
||||||
|
return int(global_rank_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_job_id():
|
||||||
|
job_id = os.getenv('JOB_ID')
|
||||||
|
job_id = job_id if job_id != "" else "default"
|
||||||
|
return job_id
|
||||||
|
|
||||||
|
def sync_data(from_path, to_path):
|
||||||
|
"""
|
||||||
|
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||||
|
Upload data from local directory to remote obs in contrast.
|
||||||
|
"""
|
||||||
|
import moxing as mox
|
||||||
|
import time
|
||||||
|
global _global_sync_count
|
||||||
|
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||||
|
_global_sync_count += 1
|
||||||
|
|
||||||
|
# Each server contains 8 devices as most.
|
||||||
|
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||||
|
print("from path: ", from_path)
|
||||||
|
print("to path: ", to_path)
|
||||||
|
mox.file.copy_parallel(from_path, to_path)
|
||||||
|
print("===finish data synchronization===")
|
||||||
|
try:
|
||||||
|
os.mknod(sync_lock)
|
||||||
|
except IOError:
|
||||||
|
pass
|
||||||
|
print("===save flag===")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if os.path.exists(sync_lock):
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||||
|
|
||||||
|
|
||||||
|
def moxing_wrapper(pre_process=None, post_process=None):
|
||||||
|
"""
|
||||||
|
Moxing wrapper to download dataset and upload outputs.
|
||||||
|
"""
|
||||||
|
def wrapper(run_func):
|
||||||
|
@functools.wraps(run_func)
|
||||||
|
def wrapped_func(*args, **kwargs):
|
||||||
|
# Download data from data_url
|
||||||
|
if config.enable_modelarts:
|
||||||
|
if config.data_url:
|
||||||
|
sync_data(config.data_url, config.data_path)
|
||||||
|
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||||
|
if config.checkpoint_url:
|
||||||
|
sync_data(config.checkpoint_url, config.load_path)
|
||||||
|
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||||
|
if config.train_url:
|
||||||
|
sync_data(config.train_url, config.output_path)
|
||||||
|
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||||
|
|
||||||
|
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||||
|
config.device_num = get_device_num()
|
||||||
|
config.device_id = get_device_id()
|
||||||
|
if not os.path.exists(config.output_path):
|
||||||
|
os.makedirs(config.output_path)
|
||||||
|
|
||||||
|
if pre_process:
|
||||||
|
pre_process()
|
||||||
|
|
||||||
|
if config.enable_profiling:
|
||||||
|
profiler = Profiler()
|
||||||
|
|
||||||
|
run_func(*args, **kwargs)
|
||||||
|
|
||||||
|
if config.enable_profiling:
|
||||||
|
profiler.analyse()
|
||||||
|
|
||||||
|
# Upload data to train_url
|
||||||
|
if config.enable_modelarts:
|
||||||
|
if post_process:
|
||||||
|
post_process()
|
||||||
|
|
||||||
|
if config.train_url:
|
||||||
|
print("Start to copy output directory")
|
||||||
|
sync_data(config.output_path, config.train_url)
|
||||||
|
return wrapped_func
|
||||||
|
return wrapper
|
Loading…
Reference in New Issue