forked from mindspore-Ecosystem/mindspore
clould
This commit is contained in:
parent
0b7f3ebd89
commit
e3f07763c8
|
@ -0,0 +1,106 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: Ascend
|
||||
enable_profiling: False
|
||||
|
||||
# ==============================================================================
|
||||
#"""data config"""
|
||||
data_vocab_size: 184965
|
||||
train_num_of_parts: 21
|
||||
test_num_of_parts: 3
|
||||
batch_size: 16000
|
||||
data_field_size: 39
|
||||
data_format: 1
|
||||
|
||||
#"""model config"""
|
||||
data_emb_dim: 80
|
||||
deep_layer_args: [[1024, 512, 256, 128], "relu"]
|
||||
init_args: [-0.01, 0.01]
|
||||
weight_bias_init: ['normal', 'normal']
|
||||
keep_prob: 0.9
|
||||
convert_dtype: True
|
||||
|
||||
# """train config"""
|
||||
l2_coef: 0.00008 # 8e-5
|
||||
learning_rate: 0.0005 # 5e-4
|
||||
epsilon: 0.00000005 # 5e-8
|
||||
loss_scale: 1024.0
|
||||
train_epochs: 5
|
||||
save_checkpoint: True
|
||||
ckpt_file_name_prefix: "deepfm"
|
||||
save_checkpoint_steps: 1
|
||||
keep_checkpoint_max: 50
|
||||
eval_callback: True
|
||||
loss_callback: True
|
||||
|
||||
# train.py 'CTR Prediction'
|
||||
dataset_path: "/cache/data"
|
||||
ckpt_path: "/cache/train"
|
||||
eval_file_name: "./auc.log"
|
||||
loss_file_name: "./loss.log"
|
||||
do_eval: 'True'
|
||||
|
||||
# eval.py 'CTR Prediction'
|
||||
checkpoint_path: "/cache/train/deepfm-5_2582.ckpt"
|
||||
|
||||
# export.py "deepfm export"
|
||||
device_id: 0
|
||||
ckpt_file: "/cache/train/deepfm-5_2582.ckpt"
|
||||
file_name: "deepfm"
|
||||
file_format: "AIR"
|
||||
|
||||
# 'preprocess.'
|
||||
result_path: './preprocess_Result'
|
||||
|
||||
# 'postprocess'
|
||||
# result_path: "./result_Files"
|
||||
label_path: ''
|
||||
|
||||
# src/preprocess_data.py "Recommendation dataset"
|
||||
# data_path: "./recommendation_dataset/"
|
||||
dense_dim: 13
|
||||
slot_dim: 26
|
||||
threshold: 100
|
||||
train_line_count: 45840617
|
||||
skip_id_convert: 0
|
||||
|
||||
---
|
||||
# Config description for each option
|
||||
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||
data_url: 'Dataset url for obs'
|
||||
train_url: 'Training output url for obs'
|
||||
data_path: 'Dataset path for local'
|
||||
output_path: 'Training output path for local'
|
||||
|
||||
device_target: "device target, support Ascend, GPU and CPU."
|
||||
dataset_path: 'Dataset path'
|
||||
batch_size: "batch size"
|
||||
ckpt_path: 'Checkpoint path'
|
||||
eval_file_name: 'Auc log file path. Default: "./auc.log"'
|
||||
loss_file_name: 'Loss log file path. Default: "./loss.log"'
|
||||
do_eval: 'Do evaluation or not, only support "True" or "False". Default: "True"'
|
||||
checkpoint_path: 'Checkpoint file path'
|
||||
device_id: "Device id"
|
||||
ckpt_file: "Checkpoint file path."
|
||||
file_name: "output file name."
|
||||
file_format: "file format"
|
||||
result_path: 'Result path'
|
||||
# result_path: "./result_Files" # 'result path'
|
||||
label_path: 'label path'
|
||||
|
||||
dense_dim: 'The number of your continues fields'
|
||||
slot_dim: 'The number of your sparse fields, it can also be called catelogy features.'
|
||||
threshold: 'Word frequency below this will be regarded as OOV. It aims to reduce the vocab size'
|
||||
train_line_count: 'The number of examples in your dataset'
|
||||
skip_id_convert: 'Skip the id convert, regarding the original id as the final id.'
|
||||
---
|
||||
device_target: ['Ascend', 'GPU', 'CPU']
|
||||
file_format: ["AIR", "ONNX", "MINDIR"]
|
||||
freeze_layer: ["", "none", "backbone"]
|
||||
skip_id_convert: [0, 1]
|
|
@ -16,49 +16,45 @@
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.deepfm import ModelBuilder, AUCMetric
|
||||
from src.config import DataConfig, ModelConfig, TrainConfig
|
||||
from src.dataset import create_dataset, DataType
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
parser = argparse.ArgumentParser(description='CTR Prediction')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"),
|
||||
help="device target, support Ascend, GPU and CPU.")
|
||||
args_opt, _ = parser.parse_known_args()
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
device_id = get_device_id() # int(os.getenv('DEVICE_ID', '0'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=device_id)
|
||||
|
||||
def add_write(file_path, print_str):
|
||||
with open(file_path, 'a+', encoding='utf-8') as file_out:
|
||||
file_out.write(print_str + '\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_config = DataConfig()
|
||||
model_config = ModelConfig()
|
||||
train_config = TrainConfig()
|
||||
def modelarts_process():
|
||||
pass
|
||||
|
||||
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
|
||||
epochs=1, batch_size=train_config.batch_size,
|
||||
data_type=DataType(data_config.data_format))
|
||||
if model_config.convert_dtype:
|
||||
model_config.convert_dtype = args_opt.device_target != "CPU"
|
||||
model_builder = ModelBuilder(model_config, train_config)
|
||||
@moxing_wrapper(pre_process=modelarts_process)
|
||||
def eval_deepfm():
|
||||
""" eval_deepfm """
|
||||
ds_eval = create_dataset(config.dataset_path, train_mode=False,
|
||||
epochs=1, batch_size=config.batch_size,
|
||||
data_type=DataType(config.data_format))
|
||||
if config.convert_dtype:
|
||||
config.convert_dtype = config.device_target != "CPU"
|
||||
model_builder = ModelBuilder(config, config)
|
||||
train_net, eval_net = model_builder.get_train_eval_net()
|
||||
train_net.set_train()
|
||||
eval_net.set_train(False)
|
||||
auc_metric = AUCMetric()
|
||||
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
param_dict = load_checkpoint(config.checkpoint_path)
|
||||
load_param_into_net(eval_net, param_dict)
|
||||
|
||||
start = time.time()
|
||||
|
@ -68,3 +64,6 @@ if __name__ == '__main__':
|
|||
out_str = f'{time_str} AUC: {list(res.values())[0]}, eval time: {eval_time}s.'
|
||||
print(out_str)
|
||||
add_write('./auc.log', str(out_str))
|
||||
|
||||
if __name__ == '__main__':
|
||||
eval_deepfm()
|
||||
|
|
|
@ -13,41 +13,32 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export ckpt to model"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import export, load_checkpoint
|
||||
|
||||
from src.deepfm import ModelBuilder
|
||||
from src.config import DataConfig, ModelConfig, TrainConfig
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
|
||||
parser = argparse.ArgumentParser(description="deepfm export")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=16000, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="deepfm", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=get_device_id())
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_config = DataConfig()
|
||||
|
||||
model_builder = ModelBuilder(ModelConfig, TrainConfig)
|
||||
model_builder = ModelBuilder(config, config)
|
||||
_, network = model_builder.get_train_eval_net()
|
||||
network.set_train(False)
|
||||
|
||||
load_checkpoint(args.ckpt_file, net=network)
|
||||
load_checkpoint(config.ckpt_file, net=network)
|
||||
|
||||
batch_ids = Tensor(np.zeros([data_config.batch_size, data_config.data_field_size]).astype(np.int32))
|
||||
batch_wts = Tensor(np.zeros([data_config.batch_size, data_config.data_field_size]).astype(np.float32))
|
||||
labels = Tensor(np.zeros([data_config.batch_size, 1]).astype(np.float32))
|
||||
batch_ids = Tensor(np.zeros([config.batch_size, config.data_field_size]).astype(np.int32))
|
||||
batch_wts = Tensor(np.zeros([config.batch_size, config.data_field_size]).astype(np.float32))
|
||||
labels = Tensor(np.zeros([config.batch_size, 1]).astype(np.float32))
|
||||
|
||||
input_data = [batch_ids, batch_wts, labels]
|
||||
export(network, *input_data, file_name=args.file_name, file_format=args.file_format)
|
||||
export(network, *input_data, file_name=config.file_name, file_format=config.file_format)
|
||||
|
|
|
@ -14,13 +14,11 @@
|
|||
# ============================================================================
|
||||
"""hub config."""
|
||||
from src.deepfm import ModelBuilder
|
||||
from src.config import ModelConfig, TrainConfig
|
||||
from src.model_utils.config import config
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
if name == 'deepfm':
|
||||
model_config = ModelConfig()
|
||||
train_config = TrainConfig()
|
||||
model_builder = ModelBuilder(model_config, train_config)
|
||||
model_builder = ModelBuilder(config, config)
|
||||
_, deepfm_eval_net = model_builder.get_train_eval_net()
|
||||
return deepfm_eval_net
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
||||
|
|
|
@ -14,27 +14,21 @@
|
|||
# ============================================================================
|
||||
"""postprocess."""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from src.deepfm import AUCMetric
|
||||
from src.config import TrainConfig
|
||||
from src.model_utils.config import config
|
||||
|
||||
parser = argparse.ArgumentParser(description='postprocess')
|
||||
parser.add_argument('--result_path', type=str, default="./result_Files", help='result path')
|
||||
parser.add_argument('--label_path', type=str, default=None, help='label path')
|
||||
args_opt, _ = parser.parse_known_args()
|
||||
|
||||
def get_acc():
|
||||
''' get accuracy '''
|
||||
auc_metric = AUCMetric()
|
||||
train_config = TrainConfig()
|
||||
files = os.listdir(args_opt.label_path)
|
||||
batch_size = train_config.batch_size
|
||||
files = os.listdir(config.label_path)
|
||||
batch_size = config.batch_size
|
||||
|
||||
for f in files:
|
||||
rst_file = os.path.join(args_opt.result_path, f.split('.')[0] + '_0.bin')
|
||||
label_file = os.path.join(args_opt.label_path, f)
|
||||
rst_file = os.path.join(config.result_path, f.split('.')[0] + '_0.bin')
|
||||
label_file = os.path.join(config.label_path, f)
|
||||
|
||||
logit = Tensor(np.fromfile(rst_file, np.float32).reshape(batch_size, 1))
|
||||
label = Tensor(np.fromfile(label_file, np.float32).reshape(batch_size, 1))
|
||||
|
|
|
@ -14,34 +14,27 @@
|
|||
# ============================================================================
|
||||
"""preprocess."""
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from src.config import DataConfig, TrainConfig
|
||||
from src.dataset import create_dataset, DataType
|
||||
from src.model_utils.config import config
|
||||
|
||||
parser = argparse.ArgumentParser(description='preprocess.')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--result_path', type=str, default='./preprocess_Result', help='Result path')
|
||||
args_opt, _ = parser.parse_known_args()
|
||||
|
||||
def generate_bin():
|
||||
'''generate bin files'''
|
||||
data_config = DataConfig()
|
||||
train_config = TrainConfig()
|
||||
|
||||
ds = create_dataset(args_opt.dataset_path, train_mode=False,
|
||||
epochs=1, batch_size=train_config.batch_size,
|
||||
data_type=DataType(data_config.data_format))
|
||||
batch_ids_path = os.path.join(args_opt.result_path, "00_batch_ids")
|
||||
batch_wts_path = os.path.join(args_opt.result_path, "01_batch_wts")
|
||||
labels_path = os.path.join(args_opt.result_path, "02_labels")
|
||||
ds = create_dataset(config.dataset_path, train_mode=False,
|
||||
epochs=1, batch_size=config.batch_size,
|
||||
data_type=DataType(config.data_format))
|
||||
batch_ids_path = os.path.join(config.result_path, "00_batch_ids")
|
||||
batch_wts_path = os.path.join(config.result_path, "01_batch_wts")
|
||||
labels_path = os.path.join(config.result_path, "02_labels")
|
||||
|
||||
os.makedirs(batch_ids_path)
|
||||
os.makedirs(batch_wts_path)
|
||||
os.makedirs(labels_path)
|
||||
|
||||
for i, data in enumerate(ds.create_dict_iterator(output_numpy=True)):
|
||||
file_name = "criteo_bs" + str(train_config.batch_size) + "_" + str(i) + ".bin"
|
||||
file_name = "criteo_bs" + str(config.batch_size) + "_" + str(i) + ".bin"
|
||||
batch_ids = data['feat_ids']
|
||||
batch_ids.tofile(os.path.join(batch_ids_path, file_name))
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ do
|
|||
rm -rf log$i
|
||||
mkdir ./log$i
|
||||
cp *.py ./log$i
|
||||
cp *.yaml ./log$i
|
||||
cp -r src ./log$i
|
||||
cd ./log$i || exit
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
|
|
|
@ -25,6 +25,7 @@ DATA_URL=$2
|
|||
rm -rf log
|
||||
mkdir ./log
|
||||
cp *.py ./log
|
||||
cp *.yaml ./log
|
||||
cp -r src ./log
|
||||
cd ./log || exit
|
||||
env > env.log
|
||||
|
|
|
@ -1,56 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
|
||||
class DataConfig:
|
||||
"""data config"""
|
||||
data_vocab_size = 184965
|
||||
train_num_of_parts = 21
|
||||
test_num_of_parts = 3
|
||||
batch_size = 16000
|
||||
data_field_size = 39
|
||||
data_format = 1
|
||||
|
||||
class ModelConfig:
|
||||
"""model config"""
|
||||
batch_size = DataConfig.batch_size
|
||||
data_field_size = DataConfig.data_field_size
|
||||
data_vocab_size = DataConfig.data_vocab_size
|
||||
data_emb_dim = 80
|
||||
deep_layer_args = [[1024, 512, 256, 128], "relu"]
|
||||
init_args = [-0.01, 0.01]
|
||||
weight_bias_init = ['normal', 'normal']
|
||||
keep_prob = 0.9
|
||||
convert_dtype = True
|
||||
|
||||
class TrainConfig:
|
||||
"""train config"""
|
||||
batch_size = DataConfig.batch_size
|
||||
l2_coef = 8e-5
|
||||
learning_rate = 5e-4
|
||||
epsilon = 5e-8
|
||||
loss_scale = 1024.0
|
||||
|
||||
train_epochs = 5
|
||||
|
||||
save_checkpoint = True
|
||||
ckpt_file_name_prefix = "deepfm"
|
||||
save_checkpoint_steps = 1
|
||||
keep_checkpoint_max = 50
|
||||
|
||||
eval_callback = True
|
||||
loss_callback = True
|
|
@ -23,8 +23,7 @@ import numpy as np
|
|||
import pandas as pd
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from .config import DataConfig
|
||||
from .model_utils.config import config
|
||||
|
||||
|
||||
class DataType(Enum):
|
||||
|
@ -49,8 +48,8 @@ class H5Dataset():
|
|||
max_length = 39
|
||||
|
||||
def __init__(self, data_path, train_mode=True,
|
||||
train_num_of_parts=DataConfig.train_num_of_parts,
|
||||
test_num_of_parts=DataConfig.test_num_of_parts):
|
||||
train_num_of_parts=config.train_num_of_parts,
|
||||
test_num_of_parts=config.test_num_of_parts):
|
||||
self._hdf_data_dir = data_path
|
||||
self._is_training = train_mode
|
||||
if self._is_training:
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from mindspore.profiler import Profiler
|
||||
from .config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler = Profiler()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler.analyse()
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -16,10 +16,9 @@
|
|||
import os
|
||||
import pickle
|
||||
import collections
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
from .model_utils.config import config
|
||||
|
||||
class StatsDict():
|
||||
"""preprocessed data"""
|
||||
|
@ -259,35 +258,22 @@ def random_split_trans2mindrecord(input_file_path, output_file_path, recommendat
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Recommendation dataset")
|
||||
parser.add_argument("--data_path", type=str, default="./recommendation_dataset/", help='The path of the data file')
|
||||
parser.add_argument("--dense_dim", type=int, default=13, help='The number of your continues fields')
|
||||
parser.add_argument("--slot_dim", type=int, default=26,
|
||||
help='The number of your sparse fields, it can also be called catelogy features.')
|
||||
parser.add_argument("--threshold", type=int, default=100,
|
||||
help='Word frequency below this will be regarded as OOV. It aims to reduce the vocab size')
|
||||
parser.add_argument("--train_line_count", type=int, default=45840617,
|
||||
help='The number of examples in your dataset')
|
||||
parser.add_argument("--skip_id_convert", type=int, default=0, choices=[0, 1],
|
||||
help='Skip the id convert, regarding the original id as the final id.')
|
||||
data_path = config.data_path
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
data_path = args.data_path
|
||||
|
||||
target_field_size = args.dense_dim + args.slot_dim
|
||||
stats = StatsDict(field_size=target_field_size, dense_dim=args.dense_dim, slot_dim=args.slot_dim,
|
||||
skip_id_convert=args.skip_id_convert)
|
||||
target_field_size = config.dense_dim + config.slot_dim
|
||||
stats = StatsDict(field_size=target_field_size, dense_dim=config.dense_dim, slot_dim=config.slot_dim,
|
||||
skip_id_convert=config.skip_id_convert)
|
||||
data_file_path = data_path + "origin_data/train.txt"
|
||||
stats_output_path = data_path + "stats_dict/"
|
||||
mkdir_path(stats_output_path)
|
||||
statsdata(data_file_path, stats_output_path, stats, dense_dim=args.dense_dim, slot_dim=args.slot_dim)
|
||||
statsdata(data_file_path, stats_output_path, stats, dense_dim=config.dense_dim, slot_dim=config.slot_dim)
|
||||
|
||||
stats.load_dict(dict_path=stats_output_path, prefix="")
|
||||
stats.get_cat2id(threshold=args.threshold)
|
||||
stats.get_cat2id(threshold=config.threshold)
|
||||
|
||||
in_file_path = data_path + "origin_data/train.txt"
|
||||
output_path = data_path + "mindrecord/"
|
||||
mkdir_path(output_path)
|
||||
random_split_trans2mindrecord(in_file_path, output_path, stats, part_rows=2000000,
|
||||
train_line_count=args.train_line_count, line_per_sample=1000,
|
||||
test_size=0.1, seed=2020, dense_dim=args.dense_dim, slot_dim=args.slot_dim)
|
||||
train_line_count=config.train_line_count, line_per_sample=1000,
|
||||
test_size=0.1, seed=2020, dense_dim=config.dense_dim, slot_dim=config.slot_dim)
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
"""train_criteo."""
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.context import ParallelMode
|
||||
|
@ -25,46 +24,37 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMoni
|
|||
from mindspore.common import set_seed
|
||||
|
||||
from src.deepfm import ModelBuilder, AUCMetric
|
||||
from src.config import DataConfig, ModelConfig, TrainConfig
|
||||
from src.dataset import create_dataset, DataType
|
||||
from src.callback import EvalCallBack, LossCallBack
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_num
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
parser = argparse.ArgumentParser(description='CTR Prediction')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--ckpt_path', type=str, default=None, help='Checkpoint path')
|
||||
parser.add_argument('--eval_file_name', type=str, default="./auc.log",
|
||||
help='Auc log file path. Default: "./auc.log"')
|
||||
parser.add_argument('--loss_file_name', type=str, default="./loss.log",
|
||||
help='Loss log file path. Default: "./loss.log"')
|
||||
parser.add_argument('--do_eval', type=str, default='True',
|
||||
help='Do evaluation or not, only support "True" or "False". Default: "True"')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"),
|
||||
help="device target, support Ascend, GPU and CPU.")
|
||||
args_opt, _ = parser.parse_known_args()
|
||||
args_opt.do_eval = args_opt.do_eval == 'True'
|
||||
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||
config.do_eval = config.do_eval == 'True'
|
||||
config.rank_size = get_device_num() # int(os.environ.get("RANK_SIZE", 1))
|
||||
|
||||
set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_config = DataConfig()
|
||||
model_config = ModelConfig()
|
||||
train_config = TrainConfig()
|
||||
def modelarts_pre_process():
|
||||
pass
|
||||
|
||||
if rank_size > 1:
|
||||
if args_opt.device_target == "Ascend":
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def train_deepfm():
|
||||
""" train_deepfm """
|
||||
if config.rank_size > 1:
|
||||
if config.device_target == "Ascend":
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=device_id)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True,
|
||||
all_reduce_fusion_config=[9, 11])
|
||||
init()
|
||||
rank_id = int(os.environ.get('RANK_ID'))
|
||||
elif args_opt.device_target == "GPU":
|
||||
elif config.device_target == "GPU":
|
||||
init()
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target=args_opt.device_target)
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target=config.device_target)
|
||||
context.set_context(graph_kernel_flags="--enable_cluster_ops=MatMul")
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=get_group_size(),
|
||||
|
@ -72,62 +62,65 @@ if __name__ == '__main__':
|
|||
gradients_mean=True)
|
||||
rank_id = get_rank()
|
||||
else:
|
||||
print("Unsupported device_target ", args_opt.device_target)
|
||||
print("Unsupported device_target ", config.device_target)
|
||||
exit()
|
||||
else:
|
||||
if args_opt.device_target == "Ascend":
|
||||
if config.device_target == "Ascend":
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
|
||||
elif args_opt.device_target == "GPU":
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target=args_opt.device_target)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=device_id)
|
||||
elif config.device_target == "GPU":
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target=config.device_target)
|
||||
context.set_context(graph_kernel_flags="--enable_cluster_ops=MatMul")
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
rank_size = None
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
config.rank_size = None
|
||||
rank_id = None
|
||||
|
||||
ds_train = create_dataset(args_opt.dataset_path,
|
||||
ds_train = create_dataset(config.dataset_path,
|
||||
train_mode=True,
|
||||
epochs=1,
|
||||
batch_size=train_config.batch_size,
|
||||
data_type=DataType(data_config.data_format),
|
||||
rank_size=rank_size,
|
||||
batch_size=config.batch_size,
|
||||
data_type=DataType(config.data_format),
|
||||
rank_size=config.rank_size,
|
||||
rank_id=rank_id)
|
||||
|
||||
steps_size = ds_train.get_dataset_size()
|
||||
|
||||
if model_config.convert_dtype:
|
||||
model_config.convert_dtype = args_opt.device_target != "CPU"
|
||||
model_builder = ModelBuilder(model_config, train_config)
|
||||
if config.convert_dtype:
|
||||
config.convert_dtype = config.device_target != "CPU"
|
||||
model_builder = ModelBuilder(config, config)
|
||||
train_net, eval_net = model_builder.get_train_eval_net()
|
||||
auc_metric = AUCMetric()
|
||||
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
|
||||
|
||||
time_callback = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
loss_callback = LossCallBack(loss_file_path=args_opt.loss_file_name)
|
||||
loss_callback = LossCallBack(loss_file_path=config.loss_file_name)
|
||||
callback_list = [time_callback, loss_callback]
|
||||
|
||||
if train_config.save_checkpoint:
|
||||
if rank_size:
|
||||
train_config.ckpt_file_name_prefix = train_config.ckpt_file_name_prefix + str(get_rank())
|
||||
args_opt.ckpt_path = os.path.join(args_opt.ckpt_path, 'ckpt_' + str(get_rank()) + '/')
|
||||
if args_opt.device_target != "Ascend":
|
||||
if config.save_checkpoint:
|
||||
if config.rank_size:
|
||||
config.ckpt_file_name_prefix = config.ckpt_file_name_prefix + str(get_rank())
|
||||
config.ckpt_path = os.path.join(config.ckpt_path, 'ckpt_' + str(get_rank()) + '/')
|
||||
if config.device_target != "Ascend":
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=steps_size,
|
||||
keep_checkpoint_max=train_config.keep_checkpoint_max)
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
else:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps,
|
||||
keep_checkpoint_max=train_config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix,
|
||||
directory=args_opt.ckpt_path,
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix=config.ckpt_file_name_prefix,
|
||||
directory=config.ckpt_path,
|
||||
config=config_ck)
|
||||
callback_list.append(ckpt_cb)
|
||||
|
||||
if args_opt.do_eval:
|
||||
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
|
||||
if config.do_eval:
|
||||
ds_eval = create_dataset(config.dataset_path, train_mode=False,
|
||||
epochs=1,
|
||||
batch_size=train_config.batch_size,
|
||||
data_type=DataType(data_config.data_format))
|
||||
batch_size=config.batch_size,
|
||||
data_type=DataType(config.data_format))
|
||||
eval_callback = EvalCallBack(model, ds_eval, auc_metric,
|
||||
eval_file_path=args_opt.eval_file_name)
|
||||
eval_file_path=config.eval_file_name)
|
||||
callback_list.append(eval_callback)
|
||||
model.train(train_config.train_epochs, ds_train, callbacks=callback_list)
|
||||
model.train(config.train_epochs, ds_train, callbacks=callback_list)
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_deepfm()
|
||||
|
|
Loading…
Reference in New Issue