!3931 Generating the synthetic data for the Wide&Deep model

Merge pull request !3931 from huangxinjing/synthetic_data
This commit is contained in:
mindspore-ci-bot 2020-08-07 14:14:10 +08:00 committed by Gitee
commit 5cd4c3eb09
3 changed files with 199 additions and 45 deletions

View File

@ -10,14 +10,47 @@ WideDeep model jointly trained wide linear models and deep neural network, which
- Download the dataset and convert the dataset to mindrecord, command as follows:
```
python src/preprocess_data.py
python src/preprocess_data.py --dense_dim=13 --slot_dim=26 --threshold=100 --train_line_count=45840617 --skip_id_convert=0
```
Arguments:
* `--data_path`: Dataset storage path (Default: ./criteo_data/).
* `--data_type` {criteo,synthetic}: Currently we support criteo dataset and synthetic dataset.(Default: ./criteo_data/).
* `--data_path` : The path of the data file.
* `--dense_dim` : The number of your continues fields.
* `--slot_dim` : The number of your sparse fields, it can also be called category features.
* `--threshold` : Word frequency below this value will be regarded as OOV. It aims to reduce the vocab size.
* `--train_line_count`: The number of examples in your dataset.
* `--skip_id_convert`: 0 or 1. If set 1, the code will skip the id convert, regarding the original id as the final id.
## Dataset
The common used benchmark datasets are used for model training and evaluation.
### Generate the synthetic Data
The following command will generate 40 million lines of click data, in the format of "label\tdense_feature[0]\tdense_feature[1]...\tsparse_feature[0]\tsparse_feature[1]...".
```
mkdir -p syn_data/origin_data
python src/generate_synthetic_data.py --output_file=syn_data/origin_data/train.txt --number_examples=40000000 --dense_dim=13 --slot_dim=51 --vocabulary_size=2000000000 --random_slot_values=0
```
Arguments:
* `--output_file`: The output path of the generated file
* `--label_dim` : The label category
* `--number_examples`: The row numbers of the generated file
* `--dense_dim` : The number of the continue feature.
* `--slot_dim`: The number of the category features
* `--vocabulary_size`: The vocabulary size of the total dataset
* `--random_slot_values`: 0 or 1. If 1, the id is generated by the random. If 0, the id is set by the row_index mod part_size, where
part_size is the vocab size for each slot
Preprocess the generated data
```
python src/preprocess_data.py --data_path=./syn_data/ --data_type=synthetic --dense_dim=13 --slot_dim=51 --threshold=0 --train_line_count=40000000 --skip_id_convert=1
```
## Running Code
### Code Structure
@ -37,6 +70,7 @@ The entire code structure is as following:
preprocess_data.py "Pre_process dataset"
wide_and_deep.py "Model structure"
callbacks.py "Callback class for training and evaluation"
generate_synthetic_data.py "Generate the synthetic data for benchmark"
metrics.py "Metric class"
|--- script/ "Run shell dir"
run_multinpu_train.sh "Run data parallel"
@ -44,6 +78,7 @@ The entire code structure is as following:
run_parameter_server_train.sh "Run parameter server"
```
### Train and evaluate model
To train and evaluate the model, command as follows:
```

View File

@ -0,0 +1,95 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate the synthetic data for wide&deep model training"""
import time
import argparse
import numpy as np
def generate_data(output_path, label_dim, number_examples, dense_dim, slot_dim, vocabulary_size, random_slot_values):
"""
This function generates the synthetic data of the web clicking data. Each row in the output file is as follows
'label\tdense_feature[0] dense_feature[1] ... sparse_feature[0]...sparse_feature[1]...'
Each value is dilimited by '\t'.
Args:
output_path: string. The output file path of the synthetic data
label_dim: int. The category of the label. For 0-1 clicking problem, it's value is 2
number_examples: int. The row numbers of the synthetic dataset
dense_dim: int. The number of continue features.
slot_dim: int. The number of the category features
vocabulary_size: int. The value of vocabulary size
random_slot_values: bool. If true, the id is geneted by the random. If false, the id is set by the row_index
mod part_size, where part_size the the vocab size for each slot
"""
part_size = (vocabulary_size - dense_dim) // slot_dim
if random_slot_values is True:
print('Each field size is supposed to be {}, so number of examples should be no less than this value'.format(
part_size))
start = time.time()
buffer_data = []
with open(output_path, 'w') as fp:
for i in range(number_examples):
example = []
label = i % label_dim
example.append(label)
dense_feature = ["{:.3f}".format(j + 0.01 * i % 10) for j in range(dense_dim)]
example.extend(dense_feature)
if random_slot_values is True:
for j in range(slot_dim):
example.append(dense_dim + np.random.randint(j * part_size, min((j + 1) * part_size,
vocabulary_size - dense_dim - 1)))
else:
sp = i % part_size
example.extend(
[dense_dim + min(sp + j * part_size, vocabulary_size - dense_dim - 1) for j in range(slot_dim)])
buffer_data.append("\t".join([str(item) for item in example]))
if (i + 1) % 10000 == 0:
end = time.time()
speed = 10000 / (end - start)
start = time.time()
print("Processed {} examples with speed {:.2f} examples/s".format(i + 1, speed))
fp.write('\n'.join(buffer_data) + '\n')
buffer_data = []
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Generate Synthetic Data')
parser.add_argument("--output_file", type=str, default="./train.txt", help='The output path of the generated file')
parser.add_argument("--label_dim", type=int, default=2, help='The label category')
parser.add_argument("--number_examples", type=int, default=4000000, help='The row numbers of the generated file')
parser.add_argument("--dense_dim", type=int, default=13, help='The number of the continue feature.')
parser.add_argument("--slot_dim", type=int, default=26, help="The number of the category features")
parser.add_argument("--vocabulary_size", type=int, default=400000000,
help="The vocabulary size of the total dataset")
parser.add_argument("--random_slot_values", type=int, default=0,
help="If 1, the id is geneted by the random. If false, the id is set by "
"the row_index mod part_size, where part_size the the vocab size for each slot")
args = parser.parse_args()
args.random_slot_values = bool(args.random_slot_values)
generate_data(output_path=args.output_file, label_dim=args.label_dim, number_examples=args.number_examples,
dense_dim=args.dense_dim, slot_dim=args.slot_dim, vocabulary_size=args.vocabulary_size,
random_slot_values=args.random_slot_values)

View File

@ -22,17 +22,18 @@ import tarfile
import numpy as np
from mindspore.mindrecord import FileWriter
TRAIN_LINE_COUNT = 45840617
TEST_LINE_COUNT = 6042135
class CriteoStatsDict():
class StatsDict():
"""preprocessed data"""
def __init__(self):
self.field_size = 39
self.val_cols = ["val_{}".format(i + 1) for i in range(13)]
self.cat_cols = ["cat_{}".format(i + 1) for i in range(26)]
def __init__(self, field_size, dense_dim, slot_dim, skip_id_convert):
self.field_size = field_size
self.dense_dim = dense_dim
self.slot_dim = slot_dim
self.skip_id_convert = bool(skip_id_convert)
self.val_cols = ["val_{}".format(i + 1) for i in range(self.dense_dim)]
self.cat_cols = ["cat_{}".format(i + 1) for i in range(self.slot_dim)]
self.val_min_dict = {col: 0 for col in self.val_cols}
self.val_max_dict = {col: 0 for col in self.val_cols}
@ -120,7 +121,14 @@ class CriteoStatsDict():
for i, cat_str in enumerate(cats):
key = "cat_{}".format(i + 1) + "_" + cat_str
if key in self.cat2id_dict:
id_list.append(self.cat2id_dict[key])
if self.skip_id_convert is True:
# For the synthetic data, if the generated id is between [0, max_vcoab], but the num examples is l
# ess than vocab_size/ slot_nums the id will still be converted to [0, real_vocab], where real_vocab
# the actually the vocab size, rather than the max_vocab. So a simple way to alleviate this
# problem is skip the id convert, regarding the synthetic data id as the final id.
id_list.append(cat_str)
else:
id_list.append(self.cat2id_dict[key])
else:
id_list.append(self.cat2id_dict[self.oov_prefix + "cat_{}".format(i + 1)])
weight_list.append(1.0)
@ -132,7 +140,7 @@ def mkdir_path(file_path):
os.makedirs(file_path)
def statsdata(file_path, dict_output_path, criteo_stats_dict):
def statsdata(file_path, dict_output_path, criteo_stats_dict, dense_dim=13, slot_dim=26):
"""Preprocess data and save data"""
with open(file_path, encoding="utf-8") as file_in:
errorline_list = []
@ -141,28 +149,31 @@ def statsdata(file_path, dict_output_path, criteo_stats_dict):
count += 1
line = line.strip("\n")
items = line.split("\t")
if len(items) != 40:
if len(items) != (dense_dim + slot_dim + 1):
errorline_list.append(count)
print("line: {}".format(line))
print("Found line length: {}, suppose to be {}, the line is {}".format(len(items),
dense_dim + slot_dim + 1, line))
continue
if count % 1000000 == 0:
print("Have handled {}w lines.".format(count // 10000))
values = items[1:14]
cats = items[14:]
values = items[1: dense_dim + 1]
cats = items[dense_dim + 1:]
assert len(values) == 13, "values.size: {}".format(len(values))
assert len(cats) == 26, "cats.size: {}".format(len(cats))
assert len(values) == dense_dim, "values.size: {}".format(len(values))
assert len(cats) == slot_dim, "cats.size: {}".format(len(cats))
criteo_stats_dict.stats_vals(values)
criteo_stats_dict.stats_cats(cats)
criteo_stats_dict.save_dict(dict_output_path)
def random_split_trans2mindrecord(input_file_path, output_file_path, criteo_stats_dict, part_rows=2000000,
line_per_sample=1000,
test_size=0.1, seed=2020):
line_per_sample=1000, train_line_count=None,
test_size=0.1, seed=2020, dense_dim=13, slot_dim=26):
"""Random split data and save mindrecord"""
test_size = int(TRAIN_LINE_COUNT * test_size)
all_indices = [i for i in range(TRAIN_LINE_COUNT)]
if train_line_count is None:
raise ValueError("Please provide training file line count")
test_size = int(train_line_count * test_size)
all_indices = [i for i in range(train_line_count)]
np.random.seed(seed)
np.random.shuffle(all_indices)
print("all_indices.size:{}".format(len(all_indices)))
@ -195,15 +206,15 @@ def random_split_trans2mindrecord(input_file_path, output_file_path, criteo_stat
print("Have handle {}w lines.".format(count // 10000))
line = line.strip("\n")
items = line.split("\t")
if len(items) != 40:
if len(items) != (1 + dense_dim + slot_dim):
items_error_size_lineCount.append(i)
continue
label = float(items[0])
values = items[1:14]
cats = items[14:]
values = items[1:1 + dense_dim]
cats = items[1 + dense_dim:]
assert len(values) == 13, "values.size: {}".format(len(values))
assert len(cats) == 26, "cats.size: {}".format(len(cats))
assert len(values) == dense_dim, "values.size: {}".format(len(values))
assert len(cats) == slot_dim, "cats.size: {}".format(len(cats))
ids, wts = criteo_stats_dict.map_cat2id(values, cats)
@ -251,35 +262,48 @@ def random_split_trans2mindrecord(input_file_path, output_file_path, criteo_stat
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="criteo data")
parser.add_argument("--data_path", type=str, default="./criteo_data/")
parser.add_argument("--data_type", type=str, default='criteo', choices=['criteo', 'synthetic'],
help='Currently we support criteo dataset and synthetic dataset')
parser.add_argument("--data_path", type=str, default="./criteo_data/", help='The path of the data file')
parser.add_argument("--dense_dim", type=int, default=13, help='The number of your continues fields')
parser.add_argument("--slot_dim", type=int, default=26,
help='The number of your sparse fields, it can also be called catelogy features.')
parser.add_argument("--threshold", type=int, default=100,
help='Word frequency below this will be regarded as OOV. It aims to reduce the vocab size')
parser.add_argument("--train_line_count", type=int, help='The number of examples in your dataset')
parser.add_argument("--skip_id_convert", type=int, default=0, choices=[0, 1],
help='Skip the id convert, regarding the original id as the final id.')
args, _ = parser.parse_known_args()
data_path = args.data_path
download_data_path = data_path + "origin_data/"
mkdir_path(download_data_path)
if args.data_type == 'criteo':
download_data_path = data_path + "origin_data/"
mkdir_path(download_data_path)
url = "https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz"
file_name = download_data_path + '/' + url.split('/')[-1]
urllib.request.urlretrieve(url, filename=file_name)
url = "https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz"
file_name = download_data_path + '/' + url.split('/')[-1]
urllib.request.urlretrieve(url, filename=file_name)
tar = tarfile.open(file_name)
names = tar.getnames()
for name in names:
tar.extract(name, path=download_data_path)
tar.close()
criteo_stats = CriteoStatsDict()
tar = tarfile.open(file_name)
names = tar.getnames()
for name in names:
tar.extract(name, path=download_data_path)
tar.close()
target_field_size = args.dense_dim + args.slot_dim
stats = StatsDict(field_size=target_field_size, dense_dim=args.dense_dim, slot_dim=args.slot_dim,
skip_id_convert=args.skip_id_convert)
data_file_path = data_path + "origin_data/train.txt"
stats_output_path = data_path + "stats_dict/"
mkdir_path(stats_output_path)
statsdata(data_file_path, stats_output_path, criteo_stats)
statsdata(data_file_path, stats_output_path, stats, dense_dim=args.dense_dim, slot_dim=args.slot_dim)
criteo_stats.load_dict(dict_path=stats_output_path, prefix="")
criteo_stats.get_cat2id(threshold=100)
stats.load_dict(dict_path=stats_output_path, prefix="")
stats.get_cat2id(threshold=args.threshold)
in_file_path = data_path + "origin_data/train.txt"
output_path = data_path + "mindrecord/"
mkdir_path(output_path)
random_split_trans2mindrecord(in_file_path, output_path, criteo_stats, part_rows=2000000, line_per_sample=1000,
test_size=0.1, seed=2020)
random_split_trans2mindrecord(in_file_path, output_path, stats, part_rows=2000000,
train_line_count=args.train_line_count, line_per_sample=1000,
test_size=0.1, seed=2020, dense_dim=args.dense_dim, slot_dim=args.slot_dim)