forked from mindspore-Ecosystem/mindspore
!3931 Generating the synthetic data for the Wide&Deep model
Merge pull request !3931 from huangxinjing/synthetic_data
This commit is contained in:
commit
5cd4c3eb09
|
@ -10,14 +10,47 @@ WideDeep model jointly trained wide linear models and deep neural network, which
|
|||
|
||||
- Download the dataset and convert the dataset to mindrecord, command as follows:
|
||||
```
|
||||
python src/preprocess_data.py
|
||||
python src/preprocess_data.py --dense_dim=13 --slot_dim=26 --threshold=100 --train_line_count=45840617 --skip_id_convert=0
|
||||
|
||||
```
|
||||
Arguments:
|
||||
* `--data_path`: Dataset storage path (Default: ./criteo_data/).
|
||||
* `--data_type` {criteo,synthetic}: Currently we support criteo dataset and synthetic dataset.(Default: ./criteo_data/).
|
||||
* `--data_path` : The path of the data file.
|
||||
* `--dense_dim` : The number of your continues fields.
|
||||
* `--slot_dim` : The number of your sparse fields, it can also be called category features.
|
||||
* `--threshold` : Word frequency below this value will be regarded as OOV. It aims to reduce the vocab size.
|
||||
* `--train_line_count`: The number of examples in your dataset.
|
||||
* `--skip_id_convert`: 0 or 1. If set 1, the code will skip the id convert, regarding the original id as the final id.
|
||||
|
||||
|
||||
## Dataset
|
||||
The common used benchmark datasets are used for model training and evaluation.
|
||||
|
||||
|
||||
### Generate the synthetic Data
|
||||
|
||||
The following command will generate 40 million lines of click data, in the format of "label\tdense_feature[0]\tdense_feature[1]...\tsparse_feature[0]\tsparse_feature[1]...".
|
||||
```
|
||||
mkdir -p syn_data/origin_data
|
||||
python src/generate_synthetic_data.py --output_file=syn_data/origin_data/train.txt --number_examples=40000000 --dense_dim=13 --slot_dim=51 --vocabulary_size=2000000000 --random_slot_values=0
|
||||
```
|
||||
Arguments:
|
||||
* `--output_file`: The output path of the generated file
|
||||
* `--label_dim` : The label category
|
||||
* `--number_examples`: The row numbers of the generated file
|
||||
* `--dense_dim` : The number of the continue feature.
|
||||
* `--slot_dim`: The number of the category features
|
||||
* `--vocabulary_size`: The vocabulary size of the total dataset
|
||||
* `--random_slot_values`: 0 or 1. If 1, the id is generated by the random. If 0, the id is set by the row_index mod part_size, where
|
||||
part_size is the vocab size for each slot
|
||||
|
||||
Preprocess the generated data
|
||||
```
|
||||
python src/preprocess_data.py --data_path=./syn_data/ --data_type=synthetic --dense_dim=13 --slot_dim=51 --threshold=0 --train_line_count=40000000 --skip_id_convert=1
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Running Code
|
||||
|
||||
### Code Structure
|
||||
|
@ -37,6 +70,7 @@ The entire code structure is as following:
|
|||
preprocess_data.py "Pre_process dataset"
|
||||
wide_and_deep.py "Model structure"
|
||||
callbacks.py "Callback class for training and evaluation"
|
||||
generate_synthetic_data.py "Generate the synthetic data for benchmark"
|
||||
metrics.py "Metric class"
|
||||
|--- script/ "Run shell dir"
|
||||
run_multinpu_train.sh "Run data parallel"
|
||||
|
@ -44,6 +78,7 @@ The entire code structure is as following:
|
|||
run_parameter_server_train.sh "Run parameter server"
|
||||
```
|
||||
|
||||
|
||||
### Train and evaluate model
|
||||
To train and evaluate the model, command as follows:
|
||||
```
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Generate the synthetic data for wide&deep model training"""
|
||||
import time
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
|
||||
def generate_data(output_path, label_dim, number_examples, dense_dim, slot_dim, vocabulary_size, random_slot_values):
|
||||
"""
|
||||
This function generates the synthetic data of the web clicking data. Each row in the output file is as follows
|
||||
'label\tdense_feature[0] dense_feature[1] ... sparse_feature[0]...sparse_feature[1]...'
|
||||
Each value is dilimited by '\t'.
|
||||
Args:
|
||||
output_path: string. The output file path of the synthetic data
|
||||
label_dim: int. The category of the label. For 0-1 clicking problem, it's value is 2
|
||||
number_examples: int. The row numbers of the synthetic dataset
|
||||
dense_dim: int. The number of continue features.
|
||||
slot_dim: int. The number of the category features
|
||||
vocabulary_size: int. The value of vocabulary size
|
||||
random_slot_values: bool. If true, the id is geneted by the random. If false, the id is set by the row_index
|
||||
mod part_size, where part_size the the vocab size for each slot
|
||||
"""
|
||||
|
||||
part_size = (vocabulary_size - dense_dim) // slot_dim
|
||||
|
||||
if random_slot_values is True:
|
||||
print('Each field size is supposed to be {}, so number of examples should be no less than this value'.format(
|
||||
part_size))
|
||||
|
||||
start = time.time()
|
||||
|
||||
buffer_data = []
|
||||
|
||||
with open(output_path, 'w') as fp:
|
||||
for i in range(number_examples):
|
||||
example = []
|
||||
label = i % label_dim
|
||||
example.append(label)
|
||||
|
||||
dense_feature = ["{:.3f}".format(j + 0.01 * i % 10) for j in range(dense_dim)]
|
||||
example.extend(dense_feature)
|
||||
|
||||
if random_slot_values is True:
|
||||
for j in range(slot_dim):
|
||||
example.append(dense_dim + np.random.randint(j * part_size, min((j + 1) * part_size,
|
||||
vocabulary_size - dense_dim - 1)))
|
||||
else:
|
||||
sp = i % part_size
|
||||
example.extend(
|
||||
[dense_dim + min(sp + j * part_size, vocabulary_size - dense_dim - 1) for j in range(slot_dim)])
|
||||
|
||||
buffer_data.append("\t".join([str(item) for item in example]))
|
||||
|
||||
if (i + 1) % 10000 == 0:
|
||||
end = time.time()
|
||||
speed = 10000 / (end - start)
|
||||
start = time.time()
|
||||
print("Processed {} examples with speed {:.2f} examples/s".format(i + 1, speed))
|
||||
fp.write('\n'.join(buffer_data) + '\n')
|
||||
buffer_data = []
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Generate Synthetic Data')
|
||||
|
||||
parser.add_argument("--output_file", type=str, default="./train.txt", help='The output path of the generated file')
|
||||
parser.add_argument("--label_dim", type=int, default=2, help='The label category')
|
||||
parser.add_argument("--number_examples", type=int, default=4000000, help='The row numbers of the generated file')
|
||||
parser.add_argument("--dense_dim", type=int, default=13, help='The number of the continue feature.')
|
||||
parser.add_argument("--slot_dim", type=int, default=26, help="The number of the category features")
|
||||
parser.add_argument("--vocabulary_size", type=int, default=400000000,
|
||||
help="The vocabulary size of the total dataset")
|
||||
parser.add_argument("--random_slot_values", type=int, default=0,
|
||||
help="If 1, the id is geneted by the random. If false, the id is set by "
|
||||
"the row_index mod part_size, where part_size the the vocab size for each slot")
|
||||
args = parser.parse_args()
|
||||
args.random_slot_values = bool(args.random_slot_values)
|
||||
|
||||
generate_data(output_path=args.output_file, label_dim=args.label_dim, number_examples=args.number_examples,
|
||||
dense_dim=args.dense_dim, slot_dim=args.slot_dim, vocabulary_size=args.vocabulary_size,
|
||||
random_slot_values=args.random_slot_values)
|
|
@ -22,17 +22,18 @@ import tarfile
|
|||
import numpy as np
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
TRAIN_LINE_COUNT = 45840617
|
||||
TEST_LINE_COUNT = 6042135
|
||||
|
||||
|
||||
class CriteoStatsDict():
|
||||
class StatsDict():
|
||||
"""preprocessed data"""
|
||||
|
||||
def __init__(self):
|
||||
self.field_size = 39
|
||||
self.val_cols = ["val_{}".format(i + 1) for i in range(13)]
|
||||
self.cat_cols = ["cat_{}".format(i + 1) for i in range(26)]
|
||||
def __init__(self, field_size, dense_dim, slot_dim, skip_id_convert):
|
||||
self.field_size = field_size
|
||||
self.dense_dim = dense_dim
|
||||
self.slot_dim = slot_dim
|
||||
self.skip_id_convert = bool(skip_id_convert)
|
||||
|
||||
self.val_cols = ["val_{}".format(i + 1) for i in range(self.dense_dim)]
|
||||
self.cat_cols = ["cat_{}".format(i + 1) for i in range(self.slot_dim)]
|
||||
|
||||
self.val_min_dict = {col: 0 for col in self.val_cols}
|
||||
self.val_max_dict = {col: 0 for col in self.val_cols}
|
||||
|
@ -120,7 +121,14 @@ class CriteoStatsDict():
|
|||
for i, cat_str in enumerate(cats):
|
||||
key = "cat_{}".format(i + 1) + "_" + cat_str
|
||||
if key in self.cat2id_dict:
|
||||
id_list.append(self.cat2id_dict[key])
|
||||
if self.skip_id_convert is True:
|
||||
# For the synthetic data, if the generated id is between [0, max_vcoab], but the num examples is l
|
||||
# ess than vocab_size/ slot_nums the id will still be converted to [0, real_vocab], where real_vocab
|
||||
# the actually the vocab size, rather than the max_vocab. So a simple way to alleviate this
|
||||
# problem is skip the id convert, regarding the synthetic data id as the final id.
|
||||
id_list.append(cat_str)
|
||||
else:
|
||||
id_list.append(self.cat2id_dict[key])
|
||||
else:
|
||||
id_list.append(self.cat2id_dict[self.oov_prefix + "cat_{}".format(i + 1)])
|
||||
weight_list.append(1.0)
|
||||
|
@ -132,7 +140,7 @@ def mkdir_path(file_path):
|
|||
os.makedirs(file_path)
|
||||
|
||||
|
||||
def statsdata(file_path, dict_output_path, criteo_stats_dict):
|
||||
def statsdata(file_path, dict_output_path, criteo_stats_dict, dense_dim=13, slot_dim=26):
|
||||
"""Preprocess data and save data"""
|
||||
with open(file_path, encoding="utf-8") as file_in:
|
||||
errorline_list = []
|
||||
|
@ -141,28 +149,31 @@ def statsdata(file_path, dict_output_path, criteo_stats_dict):
|
|||
count += 1
|
||||
line = line.strip("\n")
|
||||
items = line.split("\t")
|
||||
if len(items) != 40:
|
||||
if len(items) != (dense_dim + slot_dim + 1):
|
||||
errorline_list.append(count)
|
||||
print("line: {}".format(line))
|
||||
print("Found line length: {}, suppose to be {}, the line is {}".format(len(items),
|
||||
dense_dim + slot_dim + 1, line))
|
||||
continue
|
||||
if count % 1000000 == 0:
|
||||
print("Have handled {}w lines.".format(count // 10000))
|
||||
values = items[1:14]
|
||||
cats = items[14:]
|
||||
values = items[1: dense_dim + 1]
|
||||
cats = items[dense_dim + 1:]
|
||||
|
||||
assert len(values) == 13, "values.size: {}".format(len(values))
|
||||
assert len(cats) == 26, "cats.size: {}".format(len(cats))
|
||||
assert len(values) == dense_dim, "values.size: {}".format(len(values))
|
||||
assert len(cats) == slot_dim, "cats.size: {}".format(len(cats))
|
||||
criteo_stats_dict.stats_vals(values)
|
||||
criteo_stats_dict.stats_cats(cats)
|
||||
criteo_stats_dict.save_dict(dict_output_path)
|
||||
|
||||
|
||||
def random_split_trans2mindrecord(input_file_path, output_file_path, criteo_stats_dict, part_rows=2000000,
|
||||
line_per_sample=1000,
|
||||
test_size=0.1, seed=2020):
|
||||
line_per_sample=1000, train_line_count=None,
|
||||
test_size=0.1, seed=2020, dense_dim=13, slot_dim=26):
|
||||
"""Random split data and save mindrecord"""
|
||||
test_size = int(TRAIN_LINE_COUNT * test_size)
|
||||
all_indices = [i for i in range(TRAIN_LINE_COUNT)]
|
||||
if train_line_count is None:
|
||||
raise ValueError("Please provide training file line count")
|
||||
test_size = int(train_line_count * test_size)
|
||||
all_indices = [i for i in range(train_line_count)]
|
||||
np.random.seed(seed)
|
||||
np.random.shuffle(all_indices)
|
||||
print("all_indices.size:{}".format(len(all_indices)))
|
||||
|
@ -195,15 +206,15 @@ def random_split_trans2mindrecord(input_file_path, output_file_path, criteo_stat
|
|||
print("Have handle {}w lines.".format(count // 10000))
|
||||
line = line.strip("\n")
|
||||
items = line.split("\t")
|
||||
if len(items) != 40:
|
||||
if len(items) != (1 + dense_dim + slot_dim):
|
||||
items_error_size_lineCount.append(i)
|
||||
continue
|
||||
label = float(items[0])
|
||||
values = items[1:14]
|
||||
cats = items[14:]
|
||||
values = items[1:1 + dense_dim]
|
||||
cats = items[1 + dense_dim:]
|
||||
|
||||
assert len(values) == 13, "values.size: {}".format(len(values))
|
||||
assert len(cats) == 26, "cats.size: {}".format(len(cats))
|
||||
assert len(values) == dense_dim, "values.size: {}".format(len(values))
|
||||
assert len(cats) == slot_dim, "cats.size: {}".format(len(cats))
|
||||
|
||||
ids, wts = criteo_stats_dict.map_cat2id(values, cats)
|
||||
|
||||
|
@ -251,35 +262,48 @@ def random_split_trans2mindrecord(input_file_path, output_file_path, criteo_stat
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="criteo data")
|
||||
parser.add_argument("--data_path", type=str, default="./criteo_data/")
|
||||
parser.add_argument("--data_type", type=str, default='criteo', choices=['criteo', 'synthetic'],
|
||||
help='Currently we support criteo dataset and synthetic dataset')
|
||||
parser.add_argument("--data_path", type=str, default="./criteo_data/", help='The path of the data file')
|
||||
parser.add_argument("--dense_dim", type=int, default=13, help='The number of your continues fields')
|
||||
parser.add_argument("--slot_dim", type=int, default=26,
|
||||
help='The number of your sparse fields, it can also be called catelogy features.')
|
||||
parser.add_argument("--threshold", type=int, default=100,
|
||||
help='Word frequency below this will be regarded as OOV. It aims to reduce the vocab size')
|
||||
parser.add_argument("--train_line_count", type=int, help='The number of examples in your dataset')
|
||||
parser.add_argument("--skip_id_convert", type=int, default=0, choices=[0, 1],
|
||||
help='Skip the id convert, regarding the original id as the final id.')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
data_path = args.data_path
|
||||
|
||||
download_data_path = data_path + "origin_data/"
|
||||
mkdir_path(download_data_path)
|
||||
if args.data_type == 'criteo':
|
||||
download_data_path = data_path + "origin_data/"
|
||||
mkdir_path(download_data_path)
|
||||
|
||||
url = "https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz"
|
||||
file_name = download_data_path + '/' + url.split('/')[-1]
|
||||
urllib.request.urlretrieve(url, filename=file_name)
|
||||
url = "https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz"
|
||||
file_name = download_data_path + '/' + url.split('/')[-1]
|
||||
urllib.request.urlretrieve(url, filename=file_name)
|
||||
|
||||
tar = tarfile.open(file_name)
|
||||
names = tar.getnames()
|
||||
for name in names:
|
||||
tar.extract(name, path=download_data_path)
|
||||
tar.close()
|
||||
|
||||
criteo_stats = CriteoStatsDict()
|
||||
tar = tarfile.open(file_name)
|
||||
names = tar.getnames()
|
||||
for name in names:
|
||||
tar.extract(name, path=download_data_path)
|
||||
tar.close()
|
||||
target_field_size = args.dense_dim + args.slot_dim
|
||||
stats = StatsDict(field_size=target_field_size, dense_dim=args.dense_dim, slot_dim=args.slot_dim,
|
||||
skip_id_convert=args.skip_id_convert)
|
||||
data_file_path = data_path + "origin_data/train.txt"
|
||||
stats_output_path = data_path + "stats_dict/"
|
||||
mkdir_path(stats_output_path)
|
||||
statsdata(data_file_path, stats_output_path, criteo_stats)
|
||||
statsdata(data_file_path, stats_output_path, stats, dense_dim=args.dense_dim, slot_dim=args.slot_dim)
|
||||
|
||||
criteo_stats.load_dict(dict_path=stats_output_path, prefix="")
|
||||
criteo_stats.get_cat2id(threshold=100)
|
||||
stats.load_dict(dict_path=stats_output_path, prefix="")
|
||||
stats.get_cat2id(threshold=args.threshold)
|
||||
|
||||
in_file_path = data_path + "origin_data/train.txt"
|
||||
output_path = data_path + "mindrecord/"
|
||||
mkdir_path(output_path)
|
||||
random_split_trans2mindrecord(in_file_path, output_path, criteo_stats, part_rows=2000000, line_per_sample=1000,
|
||||
test_size=0.1, seed=2020)
|
||||
random_split_trans2mindrecord(in_file_path, output_path, stats, part_rows=2000000,
|
||||
train_line_count=args.train_line_count, line_per_sample=1000,
|
||||
test_size=0.1, seed=2020, dense_dim=args.dense_dim, slot_dim=args.slot_dim)
|
||||
|
|
Loading…
Reference in New Issue