deepfmnet test case

chenged dataset(dataset in H5 format)
add tables labrary
This commit is contained in:
wsq3 2020-09-28 16:59:14 +08:00
parent d113e7a694
commit 9ebf8e2362
9 changed files with 1154 additions and 0 deletions

View File

@ -17,3 +17,4 @@ bs4
astunparse
packaging >= 20.0
pycocotools >= 2.0.0 # for st test
tables >= 3.6.1 # for st test

View File

@ -0,0 +1,233 @@
# coding:utf-8
import os
import pickle
import collections
import argparse
import numpy as np
import pandas as pd
TRAIN_LINE_COUNT = 45840617
TEST_LINE_COUNT = 6042135
class DataStatsDict():
def __init__(self):
self.field_size = 39 # value_1-13; cat_1-26;
self.val_cols = ["val_{}".format(i + 1) for i in range(13)]
self.cat_cols = ["cat_{}".format(i + 1) for i in range(26)]
#
self.val_min_dict = {col: 0 for col in self.val_cols}
self.val_max_dict = {col: 0 for col in self.val_cols}
self.cat_count_dict = {col: collections.defaultdict(int) for col in self.cat_cols}
#
self.oov_prefix = "OOV_"
self.cat2id_dict = {}
self.cat2id_dict.update({col: i for i, col in enumerate(self.val_cols)})
self.cat2id_dict.update({self.oov_prefix + col: i + len(self.val_cols) for i, col in enumerate(self.cat_cols)})
# { "val_1": , ..., "val_13": , "OOV_cat_1": , ..., "OOV_cat_26": }
def stats_vals(self, val_list):
assert len(val_list) == len(self.val_cols)
def map_max_min(i, val):
key = self.val_cols[i]
if val != "":
if float(val) > self.val_max_dict[key]:
self.val_max_dict[key] = float(val)
if float(val) < self.val_min_dict[key]:
self.val_min_dict[key] = float(val)
for i, val in enumerate(val_list):
map_max_min(i, val)
def stats_cats(self, cat_list):
assert len(cat_list) == len(self.cat_cols)
def map_cat_count(i, cat):
key = self.cat_cols[i]
self.cat_count_dict[key][cat] += 1
for i, cat in enumerate(cat_list):
map_cat_count(i, cat)
#
def save_dict(self, output_path, prefix=""):
with open(os.path.join(output_path, "{}val_max_dict.pkl".format(prefix)), "wb") as file_wrt:
pickle.dump(self.val_max_dict, file_wrt)
with open(os.path.join(output_path, "{}val_min_dict.pkl".format(prefix)), "wb") as file_wrt:
pickle.dump(self.val_min_dict, file_wrt)
with open(os.path.join(output_path, "{}cat_count_dict.pkl".format(prefix)), "wb") as file_wrt:
pickle.dump(self.cat_count_dict, file_wrt)
def load_dict(self, dict_path, prefix=""):
with open(os.path.join(dict_path, "{}val_max_dict.pkl".format(prefix)), "rb") as file_wrt:
self.val_max_dict = pickle.load(file_wrt)
with open(os.path.join(dict_path, "{}val_min_dict.pkl".format(prefix)), "rb") as file_wrt:
self.val_min_dict = pickle.load(file_wrt)
with open(os.path.join(dict_path, "{}cat_count_dict.pkl".format(prefix)), "rb") as file_wrt:
self.cat_count_dict = pickle.load(file_wrt)
print("val_max_dict.items()[:50]: {}".format(list(self.val_max_dict.items())))
print("val_min_dict.items()[:50]: {}".format(list(self.val_min_dict.items())))
def get_cat2id(self, threshold=100):
for key, cat_count_d in self.cat_count_dict.items():
new_cat_count_d = dict(filter(lambda x: x[1] > threshold, cat_count_d.items()))
for cat_str, _ in new_cat_count_d.items():
self.cat2id_dict[key + "_" + cat_str] = len(self.cat2id_dict)
# print("before_all_count: {}".format( before_all_count )) # before_all_count: 33762577
# print("after_all_count: {}".format( after_all_count )) # after_all_count: 184926
print("cat2id_dict.size: {}".format(len(self.cat2id_dict)))
print("cat2id_dict.items()[:50]: {}".format(list(self.cat2id_dict.items())[:50]))
def map_cat2id(self, values, cats):
def minmax_scale_value(i, val):
# min_v = float(self.val_min_dict[ "val_{}".format(i+1) ])
max_v = float(self.val_max_dict["val_{}".format(i + 1)])
# return ( float(val) - min_v ) * 1.0 / (max_v - min_v)
return float(val) * 1.0 / max_v
id_list = []
weight_list = []
for i, val in enumerate(values):
if val == "":
id_list.append(i)
weight_list.append(0)
else:
key = "val_{}".format(i + 1)
id_list.append(self.cat2id_dict[key])
weight_list.append(minmax_scale_value(i, float(val)))
for i, cat_str in enumerate(cats):
key = "cat_{}".format(i + 1) + "_" + cat_str
if key in self.cat2id_dict:
id_list.append(self.cat2id_dict[key])
else:
id_list.append(self.cat2id_dict[self.oov_prefix + "cat_{}".format(i + 1)])
weight_list.append(1.0)
return id_list, weight_list
def mkdir_path(file_path):
if not os.path.exists(file_path):
os.makedirs(file_path)
def statsdata(data_source_path, output_path, data_stats1):
with open(data_source_path, encoding="utf-8") as file_in:
errorline_list = []
count = 0
for line in file_in:
count += 1
line = line.strip("\n")
items = line.split("\t")
if len(items) != 40:
errorline_list.append(count)
print("line: {}".format(line))
continue
if count % 1000000 == 0:
print("Have handle {}w lines.".format(count // 10000))
values = items[1:14]
cats = items[14:]
assert len(values) == 13, "values.size {}".format(len(values))
assert len(cats) == 26, "cats.size {}".format(len(cats))
data_stats1.stats_vals(values)
data_stats1.stats_cats(cats)
data_stats1.save_dict(output_path)
def add_write(file_path, wrt_str):
with open(file_path, 'a', encoding="utf-8") as file_out:
file_out.write(wrt_str + "\n")
def random_split_trans2h5(input_file_path, output_path, data_stats2, part_rows=2000000, test_size=0.1, seed=2020):
test_size = int(TRAIN_LINE_COUNT * test_size)
all_indices = [i for i in range(TRAIN_LINE_COUNT)]
np.random.seed(seed)
np.random.shuffle(all_indices)
print("all_indices.size: {}".format(len(all_indices)))
test_indices_set = set(all_indices[: test_size])
print("test_indices_set.size: {}".format(len(test_indices_set)))
print("----------" * 10 + "\n" * 2)
train_feature_file_name = os.path.join(output_path, "train_input_part_{}.h5")
train_label_file_name = os.path.join(output_path, "train_output_part_{}.h5")
test_feature_file_name = os.path.join(output_path, "test_input_part_{}.h5")
test_label_file_name = os.path.join(output_path, "test_output_part_{}.h5")
train_feature_list = []
train_label_list = []
test_feature_list = []
test_label_list = []
with open(input_file_path, encoding="utf-8") as file_in:
count = 0
train_part_number = 0
test_part_number = 0
for i, line in enumerate(file_in):
count += 1
if count % 1000000 == 0:
print("Have handle {}w lines.".format(count // 10000))
line = line.strip("\n")
items = line.split("\t")
if len(items) != 40:
continue
label = float(items[0])
values = items[1:14]
cats = items[14:]
assert len(values) == 13, "values.size {}".format(len(values))
assert len(cats) == 26, "cats.size {}".format(len(cats))
ids, wts = data_stats2.map_cat2id(values, cats)
if i not in test_indices_set:
train_feature_list.append(ids + wts)
train_label_list.append(label)
else:
test_feature_list.append(ids + wts)
test_label_list.append(label)
if train_label_list and (len(train_label_list) % part_rows == 0):
pd.DataFrame(np.asarray(train_feature_list)).to_hdf(train_feature_file_name.format(train_part_number),
key="fixed")
pd.DataFrame(np.asarray(train_label_list)).to_hdf(train_label_file_name.format(train_part_number),
key="fixed")
train_feature_list = []
train_label_list = []
train_part_number += 1
if test_label_list and (len(test_label_list) % part_rows == 0):
pd.DataFrame(np.asarray(test_feature_list)).to_hdf(test_feature_file_name.format(test_part_number),
key="fixed")
pd.DataFrame(np.asarray(test_label_list)).to_hdf(test_label_file_name.format(test_part_number),
key="fixed")
test_feature_list = []
test_label_list = []
test_part_number += 1
if train_label_list:
pd.DataFrame(np.asarray(train_feature_list)).to_hdf(train_feature_file_name.format(train_part_number),
key="fixed")
pd.DataFrame(np.asarray(train_label_list)).to_hdf(train_label_file_name.format(train_part_number),
key="fixed")
if test_label_list:
pd.DataFrame(np.asarray(test_feature_list)).to_hdf(test_feature_file_name.format(test_part_number),
key="fixed")
pd.DataFrame(np.asarray(test_label_list)).to_hdf(test_label_file_name.format(test_part_number), key="fixed")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Get and Process datasets')
parser.add_argument('--base_path', default="/home/wushuquan/tmp/", help='The path to save dataset')
parser.add_argument('--output_path', default="/home/wushuquan/tmp/h5dataset/",
help='The path to save h5 dataset')
args, _ = parser.parse_known_args()
base_path = args.base_path
data_path = base_path + ""
# mkdir_path(data_path)
# if not os.path.exists(base_path + "dac.tar.gz"):
# os.system(
# "wget -P {} -c https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz --no-check-certificate".format(
# base_path))
os.system("tar -zxvf {}dac.tar.gz".format(data_path))
print("********tar end***********")
data_stats = DataStatsDict()
# step 1, stats the vocab and normalize value
data_file_path = "./train.txt"
stats_output_path = base_path + "stats_dict/"
mkdir_path(stats_output_path)
statsdata(data_file_path, stats_output_path, data_stats)
print("----------" * 10)
data_stats.load_dict(dict_path=stats_output_path, prefix="")
data_stats.get_cat2id(threshold=100)
# step 2, transform data trans2h5; version 2: np.random.shuffle
in_file_path = "./train.txt"
mkdir_path(args.output_path)
random_split_trans2h5(in_file_path, args.output_path, data_stats, part_rows=2000000, test_size=0.1, seed=2020)

View File

@ -0,0 +1,110 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Defined callback for DeepFM.
"""
import time
from mindspore.train.callback import Callback
def add_write(file_path, out_str):
with open(file_path, 'a+', encoding='utf-8') as file_out:
file_out.write(out_str + '\n')
class EvalCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note
If per_print_times is 0 do not print loss.
"""
def __init__(self, model, eval_dataset, auc_metric, eval_file_path):
super(EvalCallBack, self).__init__()
self.model = model
self.eval_dataset = eval_dataset
self.aucMetric = auc_metric
self.aucMetric.clear()
self.eval_file_path = eval_file_path
def epoch_end(self, run_context):
start_time = time.time()
out = self.model.eval(self.eval_dataset)
eval_time = int(time.time() - start_time)
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
out_str = "{} EvalCallBack metric{}; eval_time{}s".format(
time_str, out.values(), eval_time)
print(out_str)
add_write(self.eval_file_path, out_str)
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note
If per_print_times is 0 do not print loss.
Args
loss_file_path (str) The file absolute path, to save as loss_file;
per_print_times (int) Print loss every times. Default 1.
"""
def __init__(self, loss_file_path, per_print_times=1):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self.loss_file_path = loss_file_path
self._per_print_times = per_print_times
self.loss = 0
def step_end(self, run_context):
cb_params = run_context.original_args()
loss = cb_params.net_outputs.asnumpy()
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
cur_num = cb_params.cur_step_num
if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
with open(self.loss_file_path, "a+") as loss_file:
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
loss_file.write("{} epoch: {} step: {}, loss is {}\n".format(
time_str, cb_params.cur_epoch_num, cur_step_in_epoch, loss))
print("epoch: {} step: {}, loss is {}\n".format(
cb_params.cur_epoch_num, cur_step_in_epoch, loss))
self.loss = loss
class TimeMonitor(Callback):
"""
Time monitor for calculating cost of each epoch.
Args
data_size (int) step size of an epoch.
"""
def __init__(self, data_size):
super(TimeMonitor, self).__init__()
self.data_size = data_size
self.per_step_time = 0
def epoch_begin(self, run_context):
self.epoch_time = time.time()
def epoch_end(self, run_context):
epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / self.data_size
print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
self.per_step_time = per_step_mseconds
def step_begin(self, run_context):
self.step_time = time.time()
def step_end(self, run_context):
step_mseconds = (time.time() - self.step_time) * 1000
print(f"step time {step_mseconds}", flush=True)

View File

@ -0,0 +1,62 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
class DataConfig:
"""
Define parameters of dataset.
"""
data_vocab_size = 184965
train_num_of_parts = 21
test_num_of_parts = 3
batch_size = 1000
data_field_size = 39
# dataset format, 1: mindrecord, 2: tfrecord, 3: h5
data_format = 3
class ModelConfig:
"""
Define parameters of model.
"""
batch_size = DataConfig.batch_size
data_field_size = DataConfig.data_field_size
data_vocab_size = DataConfig.data_vocab_size
data_emb_dim = 80
deep_layer_args = [[400, 400, 512], "relu"]
init_args = [-0.01, 0.01]
weight_bias_init = ['normal', 'normal']
keep_prob = 0.9
class TrainConfig:
"""
Define parameters of training.
"""
batch_size = DataConfig.batch_size
l2_coef = 1e-6
learning_rate = 1e-5
epsilon = 1e-8
loss_scale = 1024.0
train_epochs = 3
save_checkpoint = True
ckpt_file_name_prefix = "deepfm"
save_checkpoint_steps = 1
keep_checkpoint_max = 15
eval_callback = True
loss_callback = True

View File

@ -0,0 +1,298 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Create train or eval dataset.
"""
import os
import math
from enum import Enum
import pandas as pd
import numpy as np
import mindspore.dataset.engine as de
import mindspore.common.dtype as mstype
from .config import DataConfig
class DataType(Enum):
"""
Enumerate supported dataset format.
"""
MINDRECORD = 1
TFRECORD = 2
H5 = 3
class H5Dataset():
"""
Create dataset with H5 format.
Args:
data_path (str): Dataset directory.
train_mode (bool): Whether dataset is used for train or eval (default=True).
train_num_of_parts (int): The number of train data file (default=21).
test_num_of_parts (int): The number of test data file (default=3).
"""
max_length = 39
def __init__(self, data_path, train_mode=True,
train_num_of_parts=DataConfig.train_num_of_parts,
test_num_of_parts=DataConfig.test_num_of_parts):
self._hdf_data_dir = data_path
self._is_training = train_mode
if self._is_training:
self._file_prefix = 'train'
self._num_of_parts = train_num_of_parts
else:
self._file_prefix = 'test'
self._num_of_parts = test_num_of_parts
self.data_size = self._bin_count(self._hdf_data_dir, self._file_prefix, self._num_of_parts)
print("data_size: {}".format(self.data_size))
def _bin_count(self, hdf_data_dir, file_prefix, num_of_parts):
size = 0
for part in range(num_of_parts):
_y = pd.read_hdf(os.path.join(hdf_data_dir, f'{file_prefix}_output_part_{str(part)}.h5'))
size += _y.shape[0]
return size
def _iterate_hdf_files_(self, num_of_parts=None,
shuffle_block=False):
"""
iterate among hdf files(blocks). when the whole data set is finished, the iterator restarts
from the beginning, thus the data stream will never stop
:param train_mode: True or false,false is eval_mode,
this file iterator will go through the train set
:param num_of_parts: number of files
:param shuffle_block: shuffle block files at every round
:return: input_hdf_file_name, output_hdf_file_name, finish_flag
"""
parts = np.arange(num_of_parts)
while True:
if shuffle_block:
for _ in range(int(shuffle_block)):
np.random.shuffle(parts)
for i, p in enumerate(parts):
yield os.path.join(self._hdf_data_dir, f'{self._file_prefix}_input_part_{str(p)}.h5'), \
os.path.join(self._hdf_data_dir, f'{self._file_prefix}_output_part_{str(p)}.h5'), \
i + 1 == len(parts)
def _generator(self, X, y, batch_size, shuffle=True):
"""
should be accessed only in private
:param X:
:param y:
:param batch_size:
:param shuffle:
:return:
"""
number_of_batches = np.ceil(1. * X.shape[0] / batch_size)
counter = 0
finished = False
sample_index = np.arange(X.shape[0])
if shuffle:
for _ in range(int(shuffle)):
np.random.shuffle(sample_index)
assert X.shape[0] > 0
while True:
batch_index = sample_index[batch_size * counter: batch_size * (counter + 1)]
X_batch = X[batch_index]
y_batch = y[batch_index]
counter += 1
yield X_batch, y_batch, finished
if counter == number_of_batches:
counter = 0
finished = True
def batch_generator(self, batch_size=1000,
random_sample=False, shuffle_block=False):
"""
:param train_mode: True or false,false is eval_mode,
:param batch_size
:param num_of_parts: number of files
:param random_sample: if True, will shuffle
:param shuffle_block: shuffle file blocks at every round
:return:
"""
for hdf_in, hdf_out, _ in self._iterate_hdf_files_(self._num_of_parts,
shuffle_block):
start = stop = None
X_all = pd.read_hdf(hdf_in, start=start, stop=stop).values
y_all = pd.read_hdf(hdf_out, start=start, stop=stop).values
data_gen = self._generator(X_all, y_all, batch_size,
shuffle=random_sample)
finished = False
while not finished:
X, y, finished = data_gen.__next__()
X_id = X[:, 0:self.max_length]
X_va = X[:, self.max_length:]
yield np.array(X_id.astype(dtype=np.int32)), \
np.array(X_va.astype(dtype=np.float32)), \
np.array(y.astype(dtype=np.float32))
def _get_h5_dataset(directory, train_mode=True, epochs=1, batch_size=1000):
"""
Get dataset with h5 format.
Args:
directory (str): Dataset directory.
train_mode (bool): Whether dataset is use for train or eval (default=True).
epochs (int): Dataset epoch size (default=1).
batch_size (int): Dataset batch size (default=1000)
Returns:
Dataset.
"""
data_para = {'batch_size': batch_size}
if train_mode:
data_para['random_sample'] = True
data_para['shuffle_block'] = True
h5_dataset = H5Dataset(data_path=directory, train_mode=train_mode)
numbers_of_batch = math.ceil(h5_dataset.data_size / batch_size)
def _iter_h5_data():
train_eval_gen = h5_dataset.batch_generator(**data_para)
for _ in range(0, numbers_of_batch, 1):
yield train_eval_gen.__next__()
ds = de.GeneratorDataset(_iter_h5_data, ["ids", "weights", "labels"], num_samples=3000)
ds = ds.repeat(epochs)
return ds
def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=1000,
line_per_sample=1000, rank_size=None, rank_id=None):
"""
Get dataset with mindrecord format.
Args:
directory (str): Dataset directory.
train_mode (bool): Whether dataset is use for train or eval (default=True).
epochs (int): Dataset epoch size (default=1).
batch_size (int): Dataset batch size (default=1000).
line_per_sample (int): The number of sample per line (default=1000).
rank_size (int): The number of device, not necessary for single device (default=None).
rank_id (int): Id of device, not necessary for single device (default=None).
Returns:
Dataset.
"""
file_prefix_name = 'train_input_part.mindrecord' if train_mode else 'test_input_part.mindrecord'
file_suffix_name = '00' if train_mode else '0'
shuffle = train_mode
if rank_size is not None and rank_id is not None:
ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
columns_list=['feat_ids', 'feat_vals', 'label'],
num_shards=rank_size, shard_id=rank_id, shuffle=shuffle,
num_parallel_workers=8)
else:
ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
columns_list=['feat_ids', 'feat_vals', 'label'],
shuffle=shuffle, num_parallel_workers=8)
ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True)
ds = ds.map(operations=(lambda x, y, z: (np.array(x).flatten().reshape(batch_size, 39),
np.array(y).flatten().reshape(batch_size, 39),
np.array(z).flatten().reshape(batch_size, 1))),
input_columns=['feat_ids', 'feat_vals', 'label'],
columns_order=['feat_ids', 'feat_vals', 'label'],
num_parallel_workers=8)
ds = ds.repeat(epochs)
return ds
def _get_tf_dataset(directory, train_mode=True, epochs=1, batch_size=1000,
line_per_sample=1000, rank_size=None, rank_id=None):
"""
Get dataset with tfrecord format.
Args:
directory (str): Dataset directory.
train_mode (bool): Whether dataset is use for train or eval (default=True).
epochs (int): Dataset epoch size (default=1).
batch_size (int): Dataset batch size (default=1000).
line_per_sample (int): The number of sample per line (default=1000).
rank_size (int): The number of device, not necessary for single device (default=None).
rank_id (int): Id of device, not necessary for single device (default=None).
Returns:
Dataset.
"""
dataset_files = []
file_prefixt_name = 'train' if train_mode else 'test'
shuffle = train_mode
for (dir_path, _, filenames) in os.walk(directory):
for filename in filenames:
if file_prefixt_name in filename and 'tfrecord' in filename:
dataset_files.append(os.path.join(dir_path, filename))
schema = de.Schema()
schema.add_column('feat_ids', de_type=mstype.int32)
schema.add_column('feat_vals', de_type=mstype.float32)
schema.add_column('label', de_type=mstype.float32)
if rank_size is not None and rank_id is not None:
ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle,
schema=schema, num_parallel_workers=8,
num_shards=rank_size, shard_id=rank_id,
shard_equal_rows=True, num_samples=3000)
else:
ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle,
schema=schema, num_parallel_workers=8, num_samples=3000)
ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True)
ds = ds.map(operations=(lambda x, y, z: (
np.array(x).flatten().reshape(batch_size, 39),
np.array(y).flatten().reshape(batch_size, 39),
np.array(z).flatten().reshape(batch_size, 1))),
input_columns=['feat_ids', 'feat_vals', 'label'],
column_order=['feat_ids', 'feat_vals', 'label'],
num_parallel_workers=8)
ds = ds.repeat(epochs)
return ds
def create_dataset(directory, train_mode=True, epochs=1, batch_size=1000,
data_type=DataType.TFRECORD, line_per_sample=1000,
rank_size=None, rank_id=None):
"""
Get dataset.
Args:
directory (str): Dataset directory.
train_mode (bool): Whether dataset is use for train or eval (default=True).
epochs (int): Dataset epoch size (default=1).
batch_size (int): Dataset batch size (default=1000).
data_type (DataType): The type of dataset which is one of H5, TFRECORE, MINDRECORD (default=TFRECORD).
line_per_sample (int): The number of sample per line (default=1000).
rank_size (int): The number of device, not necessary for single device (default=None).
rank_id (int): Id of device, not necessary for single device (default=None).
Returns:
Dataset.
"""
if data_type == DataType.MINDRECORD:
return _get_mindrecord_dataset(directory, train_mode, epochs,
batch_size, line_per_sample,
rank_size, rank_id)
if data_type == DataType.TFRECORD:
return _get_tf_dataset(directory, train_mode, epochs, batch_size,
line_per_sample, rank_size=rank_size, rank_id=rank_id)
if rank_size is not None and rank_size > 1:
raise ValueError('Please use mindrecord dataset.')
return _get_h5_dataset(directory, train_mode, epochs, batch_size)

View File

@ -0,0 +1,370 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_training """
import os
import numpy as np
from sklearn.metrics import roc_auc_score
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.nn import Dropout
from mindspore.nn.optim import Adam
from mindspore.nn.metrics import Metric
from mindspore import nn, ParameterTuple, Parameter
from mindspore.common.initializer import Uniform, initializer, Normal
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from .callback import EvalCallBack, LossCallBack
np_type = np.float32
ms_type = mstype.float32
class AUCMetric(Metric):
"""AUC metric for DeepFM model."""
def __init__(self):
super(AUCMetric, self).__init__()
self.pred_probs = []
self.true_labels = []
def clear(self):
"""Clear the internal evaluation result."""
self.pred_probs = []
self.true_labels = []
def update(self, *inputs):
batch_predict = inputs[1].asnumpy()
batch_label = inputs[2].asnumpy()
self.pred_probs.extend(batch_predict.flatten().tolist())
self.true_labels.extend(batch_label.flatten().tolist())
def eval(self):
if len(self.true_labels) != len(self.pred_probs):
raise RuntimeError('true_labels.size() is not equal to pred_probs.size()')
auc = roc_auc_score(self.true_labels, self.pred_probs)
return auc
def init_method(method, shape, name, max_val=0.01):
"""
The method of init parameters.
Args:
method (str): The method uses to initialize parameter.
shape (list): The shape of parameter.
name (str): The name of parameter.
max_val (float): Max value in parameter when uses 'random' or 'uniform' to initialize parameter.
Returns:
Parameter.
"""
if method in ['random', 'uniform']:
params = Parameter(initializer(Uniform(max_val), shape, ms_type), name=name)
elif method == "one":
params = Parameter(initializer("ones", shape, ms_type), name=name)
elif method == 'zero':
params = Parameter(initializer("zeros", shape, ms_type), name=name)
elif method == "normal":
params = Parameter(initializer(Normal(max_val), shape, ms_type), name=name)
return params
def init_var_dict(init_args, values):
"""
Init parameter.
Args:
init_args (list): Define max and min value of parameters.
values (list): Define name, shape and init method of parameters.
Returns:
dict, a dict ot Parameter.
"""
var_map = {}
_, _max_val = init_args
for key, shape, init_flag in values:
if key not in var_map.keys():
if init_flag in ['random', 'uniform']:
var_map[key] = Parameter(initializer(Uniform(_max_val), shape, ms_type), name=key)
elif init_flag == "one":
var_map[key] = Parameter(initializer("ones", shape, ms_type), name=key)
elif init_flag == "zero":
var_map[key] = Parameter(initializer("zeros", shape, ms_type), name=key)
elif init_flag == 'normal':
var_map[key] = Parameter(initializer(Normal(_max_val), shape, ms_type), name=key)
return var_map
class DenseLayer(nn.Cell):
"""
Dense Layer for Deep Layer of DeepFM Model;
Containing: activation, matmul, bias_add;
Args:
input_dim (int): the shape of weight at 0-aixs;
output_dim (int): the shape of weight at 1-aixs, and shape of bias
weight_bias_init (list): weight and bias init method, "random", "uniform", "one", "zero", "normal";
act_str (str): activation function method, "relu", "sigmoid", "tanh";
keep_prob (float): Dropout Layer keep_prob_rate;
scale_coef (float): input scale coefficient;
"""
def __init__(self, input_dim, output_dim, weight_bias_init, act_str, keep_prob=0.9, scale_coef=1.0):
super(DenseLayer, self).__init__()
weight_init, bias_init = weight_bias_init
self.weight = init_method(weight_init, [input_dim, output_dim], name="weight")
self.bias = init_method(bias_init, [output_dim], name="bias")
self.act_func = self._init_activation(act_str)
self.matmul = P.MatMul(transpose_b=False)
self.bias_add = P.BiasAdd()
self.cast = P.Cast()
self.dropout = Dropout(keep_prob=keep_prob)
self.mul = P.Mul()
self.realDiv = P.RealDiv()
self.scale_coef = scale_coef
def _init_activation(self, act_str):
act_str = act_str.lower()
if act_str == "relu":
act_func = P.ReLU()
elif act_str == "sigmoid":
act_func = P.Sigmoid()
elif act_str == "tanh":
act_func = P.Tanh()
return act_func
def construct(self, x):
x = self.act_func(x)
if self.training:
x = self.dropout(x)
x = self.mul(x, self.scale_coef)
x = self.cast(x, mstype.float16)
weight = self.cast(self.weight, mstype.float16)
wx = self.matmul(x, weight)
wx = self.cast(wx, mstype.float32)
wx = self.realDiv(wx, self.scale_coef)
output = self.bias_add(wx, self.bias)
return output
class DeepFMModel(nn.Cell):
"""
From paper: "DeepFM: A Factorization-Machine based Neural Network for CTR Prediction"
Args:
batch_size (int): smaple_number of per step in training; (int, batch_size=128)
filed_size (int): input filed number, or called id_feature number; (int, filed_size=39)
vocab_size (int): id_feature vocab size, id dict size; (int, vocab_size=200000)
emb_dim (int): id embedding vector dim, id mapped to embedding vector; (int, emb_dim=100)
deep_layer_args (list): Deep Layer args, layer_dim_list, layer_activator;
(int, deep_layer_args=[[100, 100, 100], "relu"])
init_args (list): init args for Parameter init; (list, init_args=[min, max, seeds])
weight_bias_init (list): weight, bias init method for deep layers;
(list[str], weight_bias_init=['random', 'zero'])
keep_prob (float): if dropout_flag is True, keep_prob rate to keep connect; (float, keep_prob=0.8)
"""
def __init__(self, config):
super(DeepFMModel, self).__init__()
self.batch_size = config.batch_size
self.field_size = config.data_field_size
self.vocab_size = config.data_vocab_size
self.emb_dim = config.data_emb_dim
self.deep_layer_dims_list, self.deep_layer_act = config.deep_layer_args
self.init_args = config.init_args
self.weight_bias_init = config.weight_bias_init
self.keep_prob = config.keep_prob
init_acts = [('W_l2', [self.vocab_size, 1], 'normal'),
('V_l2', [self.vocab_size, self.emb_dim], 'normal'),
('b', [1], 'normal')]
var_map = init_var_dict(self.init_args, init_acts)
self.fm_w = var_map["W_l2"]
self.fm_b = var_map["b"]
self.embedding_table = var_map["V_l2"]
# Deep Layers
self.deep_input_dims = self.field_size * self.emb_dim + 1
self.all_dim_list = [self.deep_input_dims] + self.deep_layer_dims_list + [1]
self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1],
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2],
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3],
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4],
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
# FM, linear Layers
self.Gatherv2 = P.GatherV2()
self.Mul = P.Mul()
self.ReduceSum = P.ReduceSum(keep_dims=False)
self.Reshape = P.Reshape()
self.Square = P.Square()
self.Shape = P.Shape()
self.Tile = P.Tile()
self.Concat = P.Concat(axis=1)
self.Cast = P.Cast()
def construct(self, id_hldr, wt_hldr):
"""
Args:
id_hldr: batch ids; [bs, field_size]
wt_hldr: batch weights; [bs, field_size]
"""
mask = self.Reshape(wt_hldr, (self.batch_size, self.field_size, 1))
# Linear layer
fm_id_weight = self.Gatherv2(self.fm_w, id_hldr, 0)
wx = self.Mul(fm_id_weight, mask)
linear_out = self.ReduceSum(wx, 1)
# FM layer
fm_id_embs = self.Gatherv2(self.embedding_table, id_hldr, 0)
vx = self.Mul(fm_id_embs, mask)
v1 = self.ReduceSum(vx, 1)
v1 = self.Square(v1)
v2 = self.Square(vx)
v2 = self.ReduceSum(v2, 1)
fm_out = 0.5 * self.ReduceSum(v1 - v2, 1)
fm_out = self.Reshape(fm_out, (-1, 1))
# Deep layer
b = self.Reshape(self.fm_b, (1, 1))
b = self.Tile(b, (self.batch_size, 1))
deep_in = self.Reshape(vx, (-1, self.field_size * self.emb_dim))
deep_in = self.Concat((deep_in, b))
deep_in = self.dense_layer_1(deep_in)
deep_in = self.dense_layer_2(deep_in)
deep_in = self.dense_layer_3(deep_in)
deep_out = self.dense_layer_4(deep_in)
out = linear_out + fm_out + deep_out
return out, fm_id_weight, fm_id_embs
class NetWithLossClass(nn.Cell):
"""
NetWithLossClass definition.
"""
def __init__(self, network, l2_coef=1e-6):
super(NetWithLossClass, self).__init__(auto_prefix=False)
self.loss = P.SigmoidCrossEntropyWithLogits()
self.network = network
self.l2_coef = l2_coef
self.Square = P.Square()
self.ReduceMean_false = P.ReduceMean(keep_dims=False)
self.ReduceSum_false = P.ReduceSum(keep_dims=False)
def construct(self, batch_ids, batch_wts, label):
predict, fm_id_weight, fm_id_embs = self.network(batch_ids, batch_wts)
log_loss = self.loss(predict, label)
mean_log_loss = self.ReduceMean_false(log_loss)
l2_loss_w = self.ReduceSum_false(self.Square(fm_id_weight))
l2_loss_v = self.ReduceSum_false(self.Square(fm_id_embs))
l2_loss_all = self.l2_coef * (l2_loss_v + l2_loss_w) * 0.5
loss = mean_log_loss + l2_loss_all
return loss
class TrainStepWrap(nn.Cell):
"""
TrainStepWrap definition
"""
def __init__(self, network, lr=5e-8, eps=1e-8, loss_scale=1000.0):
super(TrainStepWrap, self).__init__(auto_prefix=False)
self.network = network
self.network.set_train()
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = Adam(self.weights, learning_rate=lr, eps=eps, loss_scale=loss_scale)
self.hyper_map = C.HyperMap()
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = loss_scale
def construct(self, batch_ids, batch_wts, label):
weights = self.weights
loss = self.network(batch_ids, batch_wts, label)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) #
grads = self.grad(self.network, weights)(batch_ids, batch_wts, label, sens)
return F.depend(loss, self.optimizer(grads))
class PredictWithSigmoid(nn.Cell):
"""
Eval model with sigmoid.
"""
def __init__(self, network):
super(PredictWithSigmoid, self).__init__(auto_prefix=False)
self.network = network
self.sigmoid = P.Sigmoid()
def construct(self, batch_ids, batch_wts, labels):
logits, _, _, = self.network(batch_ids, batch_wts)
pred_probs = self.sigmoid(logits)
return logits, pred_probs, labels
class ModelBuilder:
"""
Model builder for DeepFM.
Args:
model_config (ModelConfig): Model configuration.
train_config (TrainConfig): Train configuration.
"""
def __init__(self, model_config, train_config):
self.model_config = model_config
self.train_config = train_config
def get_callback_list(self, model=None, eval_dataset=None):
"""
Get callbacks which contains checkpoint callback, eval callback and loss callback.
Args:
model (Cell): The network is added callback (default=None).
eval_dataset (Dataset): Dataset for eval (default=None).
"""
callback_list = []
if self.train_config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=self.train_config.save_checkpoint_steps,
keep_checkpoint_max=self.train_config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix=self.train_config.ckpt_file_name_prefix,
directory=self.train_config.output_path,
config=config_ck)
callback_list.append(ckpt_cb)
if self.train_config.eval_callback:
if model is None:
raise RuntimeError("train_config.eval_callback is {}; get_callback_list() args model is {}".format(
self.train_config.eval_callback, model))
if eval_dataset is None:
raise RuntimeError("train_config.eval_callback is {}; get_callback_list() "
"args eval_dataset is {}".format(self.train_config.eval_callback, eval_dataset))
auc_metric = AUCMetric()
eval_callback = EvalCallBack(model, eval_dataset, auc_metric,
eval_file_path=os.path.join(self.train_config.output_path,
self.train_config.eval_file_name))
callback_list.append(eval_callback)
if self.train_config.loss_callback:
loss_callback = LossCallBack(loss_file_path=os.path.join(self.train_config.output_path,
self.train_config.loss_file_name))
callback_list.append(loss_callback)
if callback_list:
return callback_list
return None
def get_train_eval_net(self):
deepfm_net = DeepFMModel(self.model_config)
loss_net = NetWithLossClass(deepfm_net, l2_coef=self.train_config.l2_coef)
train_net = TrainStepWrap(loss_net, lr=self.train_config.learning_rate,
eps=self.train_config.epsilon,
loss_scale=self.train_config.loss_scale)
eval_net = PredictWithSigmoid(deepfm_net)
return train_net, eval_net

View File

@ -0,0 +1,80 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_criteo."""
import os
import pytest
from mindspore import context
from mindspore.train.model import Model
from mindspore.common import set_seed
from src.deepfm import ModelBuilder, AUCMetric
from src.config import DataConfig, ModelConfig, TrainConfig
from src.dataset import create_dataset, DataType
from src.callback import EvalCallBack, LossCallBack, TimeMonitor
set_seed(1)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_deepfm():
data_config = DataConfig()
train_config = TrainConfig()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
rank_size = None
rank_id = None
dataset_path = "/home/workspace/mindspore_dataset/criteo_data/criteo_h5/"
print("dataset_path:", dataset_path)
ds_train = create_dataset(dataset_path,
train_mode=True,
epochs=1,
batch_size=train_config.batch_size,
data_type=DataType(data_config.data_format),
rank_size=rank_size,
rank_id=rank_id)
model_builder = ModelBuilder(ModelConfig, TrainConfig)
train_net, eval_net = model_builder.get_train_eval_net()
auc_metric = AUCMetric()
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
loss_file_name = './loss.log'
time_callback = TimeMonitor(data_size=ds_train.get_dataset_size())
loss_callback = LossCallBack(loss_file_path=loss_file_name)
callback_list = [time_callback, loss_callback]
eval_file_name = './auc.log'
ds_eval = create_dataset(dataset_path, train_mode=False,
epochs=1,
batch_size=train_config.batch_size,
data_type=DataType(data_config.data_format))
eval_callback = EvalCallBack(model, ds_eval, auc_metric,
eval_file_path=eval_file_name)
callback_list.append(eval_callback)
print("train_config.train_epochs:", train_config.train_epochs)
model.train(train_config.train_epochs, ds_train, callbacks=callback_list)
export_loss_value = 0.51
print("loss_callback.loss:", loss_callback.loss)
assert loss_callback.loss < export_loss_value
export_per_step_time = 10.4
print("time_callback:", time_callback.per_step_time)
assert time_callback.per_step_time < export_per_step_time
print("*******test case pass!********")