This commit is contained in:
huchunmei 2021-06-18 15:38:26 +08:00
parent 0b7f3ebd89
commit e3f07763c8
17 changed files with 529 additions and 212 deletions

View File

@ -0,0 +1,106 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
data_url: ""
train_url: ""
checkpoint_url: ""
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: Ascend
enable_profiling: False
# ==============================================================================
#"""data config"""
data_vocab_size: 184965
train_num_of_parts: 21
test_num_of_parts: 3
batch_size: 16000
data_field_size: 39
data_format: 1
#"""model config"""
data_emb_dim: 80
deep_layer_args: [[1024, 512, 256, 128], "relu"]
init_args: [-0.01, 0.01]
weight_bias_init: ['normal', 'normal']
keep_prob: 0.9
convert_dtype: True
# """train config"""
l2_coef: 0.00008 # 8e-5
learning_rate: 0.0005 # 5e-4
epsilon: 0.00000005 # 5e-8
loss_scale: 1024.0
train_epochs: 5
save_checkpoint: True
ckpt_file_name_prefix: "deepfm"
save_checkpoint_steps: 1
keep_checkpoint_max: 50
eval_callback: True
loss_callback: True
# train.py 'CTR Prediction'
dataset_path: "/cache/data"
ckpt_path: "/cache/train"
eval_file_name: "./auc.log"
loss_file_name: "./loss.log"
do_eval: 'True'
# eval.py 'CTR Prediction'
checkpoint_path: "/cache/train/deepfm-5_2582.ckpt"
# export.py "deepfm export"
device_id: 0
ckpt_file: "/cache/train/deepfm-5_2582.ckpt"
file_name: "deepfm"
file_format: "AIR"
# 'preprocess.'
result_path: './preprocess_Result'
# 'postprocess'
# result_path: "./result_Files"
label_path: ''
# src/preprocess_data.py "Recommendation dataset"
# data_path: "./recommendation_dataset/"
dense_dim: 13
slot_dim: 26
threshold: 100
train_line_count: 45840617
skip_id_convert: 0
---
# Config description for each option
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
device_target: "device target, support Ascend, GPU and CPU."
dataset_path: 'Dataset path'
batch_size: "batch size"
ckpt_path: 'Checkpoint path'
eval_file_name: 'Auc log file path. Default: "./auc.log"'
loss_file_name: 'Loss log file path. Default: "./loss.log"'
do_eval: 'Do evaluation or not, only support "True" or "False". Default: "True"'
checkpoint_path: 'Checkpoint file path'
device_id: "Device id"
ckpt_file: "Checkpoint file path."
file_name: "output file name."
file_format: "file format"
result_path: 'Result path'
# result_path: "./result_Files" # 'result path'
label_path: 'label path'
dense_dim: 'The number of your continues fields'
slot_dim: 'The number of your sparse fields, it can also be called catelogy features.'
threshold: 'Word frequency below this will be regarded as OOV. It aims to reduce the vocab size'
train_line_count: 'The number of examples in your dataset'
skip_id_convert: 'Skip the id convert, regarding the original id as the final id.'
---
device_target: ['Ascend', 'GPU', 'CPU']
file_format: ["AIR", "ONNX", "MINDIR"]
freeze_layer: ["", "none", "backbone"]
skip_id_convert: [0, 1]

View File

@ -16,49 +16,45 @@
import os
import sys
import time
import argparse
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.deepfm import ModelBuilder, AUCMetric
from src.config import DataConfig, ModelConfig, TrainConfig
from src.dataset import create_dataset, DataType
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
parser = argparse.ArgumentParser(description='CTR Prediction')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"),
help="device target, support Ascend, GPU and CPU.")
args_opt, _ = parser.parse_known_args()
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
device_id = get_device_id() # int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=device_id)
def add_write(file_path, print_str):
with open(file_path, 'a+', encoding='utf-8') as file_out:
file_out.write(print_str + '\n')
if __name__ == '__main__':
data_config = DataConfig()
model_config = ModelConfig()
train_config = TrainConfig()
def modelarts_process():
pass
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
epochs=1, batch_size=train_config.batch_size,
data_type=DataType(data_config.data_format))
if model_config.convert_dtype:
model_config.convert_dtype = args_opt.device_target != "CPU"
model_builder = ModelBuilder(model_config, train_config)
@moxing_wrapper(pre_process=modelarts_process)
def eval_deepfm():
""" eval_deepfm """
ds_eval = create_dataset(config.dataset_path, train_mode=False,
epochs=1, batch_size=config.batch_size,
data_type=DataType(config.data_format))
if config.convert_dtype:
config.convert_dtype = config.device_target != "CPU"
model_builder = ModelBuilder(config, config)
train_net, eval_net = model_builder.get_train_eval_net()
train_net.set_train()
eval_net.set_train(False)
auc_metric = AUCMetric()
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
param_dict = load_checkpoint(args_opt.checkpoint_path)
param_dict = load_checkpoint(config.checkpoint_path)
load_param_into_net(eval_net, param_dict)
start = time.time()
@ -68,3 +64,6 @@ if __name__ == '__main__':
out_str = f'{time_str} AUC: {list(res.values())[0]}, eval time: {eval_time}s.'
print(out_str)
add_write('./auc.log', str(out_str))
if __name__ == '__main__':
eval_deepfm()

View File

@ -13,41 +13,32 @@
# limitations under the License.
# ============================================================================
"""export ckpt to model"""
import argparse
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import export, load_checkpoint
from src.deepfm import ModelBuilder
from src.config import DataConfig, ModelConfig, TrainConfig
from src.model_utils.config import config
from src.model_utils.device_adapter import get_device_id
parser = argparse.ArgumentParser(description="deepfm export")
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=16000, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="deepfm", help="output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
help="device target")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if config.device_target == "Ascend":
context.set_context(device_id=get_device_id())
if __name__ == "__main__":
data_config = DataConfig()
model_builder = ModelBuilder(ModelConfig, TrainConfig)
model_builder = ModelBuilder(config, config)
_, network = model_builder.get_train_eval_net()
network.set_train(False)
load_checkpoint(args.ckpt_file, net=network)
load_checkpoint(config.ckpt_file, net=network)
batch_ids = Tensor(np.zeros([data_config.batch_size, data_config.data_field_size]).astype(np.int32))
batch_wts = Tensor(np.zeros([data_config.batch_size, data_config.data_field_size]).astype(np.float32))
labels = Tensor(np.zeros([data_config.batch_size, 1]).astype(np.float32))
batch_ids = Tensor(np.zeros([config.batch_size, config.data_field_size]).astype(np.int32))
batch_wts = Tensor(np.zeros([config.batch_size, config.data_field_size]).astype(np.float32))
labels = Tensor(np.zeros([config.batch_size, 1]).astype(np.float32))
input_data = [batch_ids, batch_wts, labels]
export(network, *input_data, file_name=args.file_name, file_format=args.file_format)
export(network, *input_data, file_name=config.file_name, file_format=config.file_format)

View File

@ -14,13 +14,11 @@
# ============================================================================
"""hub config."""
from src.deepfm import ModelBuilder
from src.config import ModelConfig, TrainConfig
from src.model_utils.config import config
def create_network(name, *args, **kwargs):
if name == 'deepfm':
model_config = ModelConfig()
train_config = TrainConfig()
model_builder = ModelBuilder(model_config, train_config)
model_builder = ModelBuilder(config, config)
_, deepfm_eval_net = model_builder.get_train_eval_net()
return deepfm_eval_net
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -14,27 +14,21 @@
# ============================================================================
"""postprocess."""
import os
import argparse
import numpy as np
from mindspore import Tensor
from src.deepfm import AUCMetric
from src.config import TrainConfig
from src.model_utils.config import config
parser = argparse.ArgumentParser(description='postprocess')
parser.add_argument('--result_path', type=str, default="./result_Files", help='result path')
parser.add_argument('--label_path', type=str, default=None, help='label path')
args_opt, _ = parser.parse_known_args()
def get_acc():
''' get accuracy '''
auc_metric = AUCMetric()
train_config = TrainConfig()
files = os.listdir(args_opt.label_path)
batch_size = train_config.batch_size
files = os.listdir(config.label_path)
batch_size = config.batch_size
for f in files:
rst_file = os.path.join(args_opt.result_path, f.split('.')[0] + '_0.bin')
label_file = os.path.join(args_opt.label_path, f)
rst_file = os.path.join(config.result_path, f.split('.')[0] + '_0.bin')
label_file = os.path.join(config.label_path, f)
logit = Tensor(np.fromfile(rst_file, np.float32).reshape(batch_size, 1))
label = Tensor(np.fromfile(label_file, np.float32).reshape(batch_size, 1))

View File

@ -14,34 +14,27 @@
# ============================================================================
"""preprocess."""
import os
import argparse
from src.config import DataConfig, TrainConfig
from src.dataset import create_dataset, DataType
from src.model_utils.config import config
parser = argparse.ArgumentParser(description='preprocess.')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--result_path', type=str, default='./preprocess_Result', help='Result path')
args_opt, _ = parser.parse_known_args()
def generate_bin():
'''generate bin files'''
data_config = DataConfig()
train_config = TrainConfig()
ds = create_dataset(args_opt.dataset_path, train_mode=False,
epochs=1, batch_size=train_config.batch_size,
data_type=DataType(data_config.data_format))
batch_ids_path = os.path.join(args_opt.result_path, "00_batch_ids")
batch_wts_path = os.path.join(args_opt.result_path, "01_batch_wts")
labels_path = os.path.join(args_opt.result_path, "02_labels")
ds = create_dataset(config.dataset_path, train_mode=False,
epochs=1, batch_size=config.batch_size,
data_type=DataType(config.data_format))
batch_ids_path = os.path.join(config.result_path, "00_batch_ids")
batch_wts_path = os.path.join(config.result_path, "01_batch_wts")
labels_path = os.path.join(config.result_path, "02_labels")
os.makedirs(batch_ids_path)
os.makedirs(batch_wts_path)
os.makedirs(labels_path)
for i, data in enumerate(ds.create_dict_iterator(output_numpy=True)):
file_name = "criteo_bs" + str(train_config.batch_size) + "_" + str(i) + ".bin"
file_name = "criteo_bs" + str(config.batch_size) + "_" + str(i) + ".bin"
batch_ids = data['feat_ids']
batch_ids.tofile(os.path.join(batch_ids_path, file_name))

View File

@ -30,6 +30,7 @@ do
rm -rf log$i
mkdir ./log$i
cp *.py ./log$i
cp *.yaml ./log$i
cp -r src ./log$i
cd ./log$i || exit
echo "start training for rank $i, device $DEVICE_ID"

View File

@ -25,6 +25,7 @@ DATA_URL=$2
rm -rf log
mkdir ./log
cp *.py ./log
cp *.yaml ./log
cp -r src ./log
cd ./log || exit
env > env.log

View File

@ -1,56 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
class DataConfig:
"""data config"""
data_vocab_size = 184965
train_num_of_parts = 21
test_num_of_parts = 3
batch_size = 16000
data_field_size = 39
data_format = 1
class ModelConfig:
"""model config"""
batch_size = DataConfig.batch_size
data_field_size = DataConfig.data_field_size
data_vocab_size = DataConfig.data_vocab_size
data_emb_dim = 80
deep_layer_args = [[1024, 512, 256, 128], "relu"]
init_args = [-0.01, 0.01]
weight_bias_init = ['normal', 'normal']
keep_prob = 0.9
convert_dtype = True
class TrainConfig:
"""train config"""
batch_size = DataConfig.batch_size
l2_coef = 8e-5
learning_rate = 5e-4
epsilon = 5e-8
loss_scale = 1024.0
train_epochs = 5
save_checkpoint = True
ckpt_file_name_prefix = "deepfm"
save_checkpoint_steps = 1
keep_checkpoint_max = 50
eval_callback = True
loss_callback = True

View File

@ -23,8 +23,7 @@ import numpy as np
import pandas as pd
import mindspore.dataset as ds
import mindspore.common.dtype as mstype
from .config import DataConfig
from .model_utils.config import config
class DataType(Enum):
@ -49,8 +48,8 @@ class H5Dataset():
max_length = 39
def __init__(self, data_path, train_mode=True,
train_num_of_parts=DataConfig.train_num_of_parts,
test_num_of_parts=DataConfig.test_num_of_parts):
train_num_of_parts=config.train_num_of_parts,
test_num_of_parts=config.test_num_of_parts):
self._hdf_data_dir = data_path
self._is_training = train_mode
if self._is_training:

View File

@ -0,0 +1,127 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
pprint(default)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()

View File

@ -0,0 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from .config import config
if config.enable_modelarts:
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -0,0 +1,122 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from mindspore.profiler import Profiler
from .config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
if config.enable_profiling:
profiler = Profiler()
run_func(*args, **kwargs)
if config.enable_profiling:
profiler.analyse()
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -16,10 +16,9 @@
import os
import pickle
import collections
import argparse
import numpy as np
from mindspore.mindrecord import FileWriter
from .model_utils.config import config
class StatsDict():
"""preprocessed data"""
@ -259,35 +258,22 @@ def random_split_trans2mindrecord(input_file_path, output_file_path, recommendat
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Recommendation dataset")
parser.add_argument("--data_path", type=str, default="./recommendation_dataset/", help='The path of the data file')
parser.add_argument("--dense_dim", type=int, default=13, help='The number of your continues fields')
parser.add_argument("--slot_dim", type=int, default=26,
help='The number of your sparse fields, it can also be called catelogy features.')
parser.add_argument("--threshold", type=int, default=100,
help='Word frequency below this will be regarded as OOV. It aims to reduce the vocab size')
parser.add_argument("--train_line_count", type=int, default=45840617,
help='The number of examples in your dataset')
parser.add_argument("--skip_id_convert", type=int, default=0, choices=[0, 1],
help='Skip the id convert, regarding the original id as the final id.')
data_path = config.data_path
args, _ = parser.parse_known_args()
data_path = args.data_path
target_field_size = args.dense_dim + args.slot_dim
stats = StatsDict(field_size=target_field_size, dense_dim=args.dense_dim, slot_dim=args.slot_dim,
skip_id_convert=args.skip_id_convert)
target_field_size = config.dense_dim + config.slot_dim
stats = StatsDict(field_size=target_field_size, dense_dim=config.dense_dim, slot_dim=config.slot_dim,
skip_id_convert=config.skip_id_convert)
data_file_path = data_path + "origin_data/train.txt"
stats_output_path = data_path + "stats_dict/"
mkdir_path(stats_output_path)
statsdata(data_file_path, stats_output_path, stats, dense_dim=args.dense_dim, slot_dim=args.slot_dim)
statsdata(data_file_path, stats_output_path, stats, dense_dim=config.dense_dim, slot_dim=config.slot_dim)
stats.load_dict(dict_path=stats_output_path, prefix="")
stats.get_cat2id(threshold=args.threshold)
stats.get_cat2id(threshold=config.threshold)
in_file_path = data_path + "origin_data/train.txt"
output_path = data_path + "mindrecord/"
mkdir_path(output_path)
random_split_trans2mindrecord(in_file_path, output_path, stats, part_rows=2000000,
train_line_count=args.train_line_count, line_per_sample=1000,
test_size=0.1, seed=2020, dense_dim=args.dense_dim, slot_dim=args.slot_dim)
train_line_count=config.train_line_count, line_per_sample=1000,
test_size=0.1, seed=2020, dense_dim=config.dense_dim, slot_dim=config.slot_dim)

View File

@ -15,7 +15,6 @@
"""train_criteo."""
import os
import sys
import argparse
from mindspore import context
from mindspore.context import ParallelMode
@ -25,46 +24,37 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMoni
from mindspore.common import set_seed
from src.deepfm import ModelBuilder, AUCMetric
from src.config import DataConfig, ModelConfig, TrainConfig
from src.dataset import create_dataset, DataType
from src.callback import EvalCallBack, LossCallBack
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_num
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
parser = argparse.ArgumentParser(description='CTR Prediction')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--ckpt_path', type=str, default=None, help='Checkpoint path')
parser.add_argument('--eval_file_name', type=str, default="./auc.log",
help='Auc log file path. Default: "./auc.log"')
parser.add_argument('--loss_file_name', type=str, default="./loss.log",
help='Loss log file path. Default: "./loss.log"')
parser.add_argument('--do_eval', type=str, default='True',
help='Do evaluation or not, only support "True" or "False". Default: "True"')
parser.add_argument('--device_target', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"),
help="device target, support Ascend, GPU and CPU.")
args_opt, _ = parser.parse_known_args()
args_opt.do_eval = args_opt.do_eval == 'True'
rank_size = int(os.environ.get("RANK_SIZE", 1))
config.do_eval = config.do_eval == 'True'
config.rank_size = get_device_num() # int(os.environ.get("RANK_SIZE", 1))
set_seed(1)
if __name__ == '__main__':
data_config = DataConfig()
model_config = ModelConfig()
train_config = TrainConfig()
def modelarts_pre_process():
pass
if rank_size > 1:
if args_opt.device_target == "Ascend":
@moxing_wrapper(pre_process=modelarts_pre_process)
def train_deepfm():
""" train_deepfm """
if config.rank_size > 1:
if config.device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=device_id)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True,
all_reduce_fusion_config=[9, 11])
init()
rank_id = int(os.environ.get('RANK_ID'))
elif args_opt.device_target == "GPU":
elif config.device_target == "GPU":
init()
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target=args_opt.device_target)
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target=config.device_target)
context.set_context(graph_kernel_flags="--enable_cluster_ops=MatMul")
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=get_group_size(),
@ -72,62 +62,65 @@ if __name__ == '__main__':
gradients_mean=True)
rank_id = get_rank()
else:
print("Unsupported device_target ", args_opt.device_target)
print("Unsupported device_target ", config.device_target)
exit()
else:
if args_opt.device_target == "Ascend":
if config.device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
elif args_opt.device_target == "GPU":
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target=args_opt.device_target)
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=device_id)
elif config.device_target == "GPU":
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target=config.device_target)
context.set_context(graph_kernel_flags="--enable_cluster_ops=MatMul")
else:
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
rank_size = None
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
config.rank_size = None
rank_id = None
ds_train = create_dataset(args_opt.dataset_path,
ds_train = create_dataset(config.dataset_path,
train_mode=True,
epochs=1,
batch_size=train_config.batch_size,
data_type=DataType(data_config.data_format),
rank_size=rank_size,
batch_size=config.batch_size,
data_type=DataType(config.data_format),
rank_size=config.rank_size,
rank_id=rank_id)
steps_size = ds_train.get_dataset_size()
if model_config.convert_dtype:
model_config.convert_dtype = args_opt.device_target != "CPU"
model_builder = ModelBuilder(model_config, train_config)
if config.convert_dtype:
config.convert_dtype = config.device_target != "CPU"
model_builder = ModelBuilder(config, config)
train_net, eval_net = model_builder.get_train_eval_net()
auc_metric = AUCMetric()
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
time_callback = TimeMonitor(data_size=ds_train.get_dataset_size())
loss_callback = LossCallBack(loss_file_path=args_opt.loss_file_name)
loss_callback = LossCallBack(loss_file_path=config.loss_file_name)
callback_list = [time_callback, loss_callback]
if train_config.save_checkpoint:
if rank_size:
train_config.ckpt_file_name_prefix = train_config.ckpt_file_name_prefix + str(get_rank())
args_opt.ckpt_path = os.path.join(args_opt.ckpt_path, 'ckpt_' + str(get_rank()) + '/')
if args_opt.device_target != "Ascend":
if config.save_checkpoint:
if config.rank_size:
config.ckpt_file_name_prefix = config.ckpt_file_name_prefix + str(get_rank())
config.ckpt_path = os.path.join(config.ckpt_path, 'ckpt_' + str(get_rank()) + '/')
if config.device_target != "Ascend":
config_ck = CheckpointConfig(save_checkpoint_steps=steps_size,
keep_checkpoint_max=train_config.keep_checkpoint_max)
keep_checkpoint_max=config.keep_checkpoint_max)
else:
config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps,
keep_checkpoint_max=train_config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix,
directory=args_opt.ckpt_path,
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix=config.ckpt_file_name_prefix,
directory=config.ckpt_path,
config=config_ck)
callback_list.append(ckpt_cb)
if args_opt.do_eval:
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
if config.do_eval:
ds_eval = create_dataset(config.dataset_path, train_mode=False,
epochs=1,
batch_size=train_config.batch_size,
data_type=DataType(data_config.data_format))
batch_size=config.batch_size,
data_type=DataType(config.data_format))
eval_callback = EvalCallBack(model, ds_eval, auc_metric,
eval_file_path=args_opt.eval_file_name)
eval_file_path=config.eval_file_name)
callback_list.append(eval_callback)
model.train(train_config.train_epochs, ds_train, callbacks=callback_list)
model.train(config.train_epochs, ds_train, callbacks=callback_list)
if __name__ == '__main__':
train_deepfm()