!4569 remove criteo keyword
Merge pull request !4569 from hanjun996/master
This commit is contained in:
commit
77e83fb24d
|
@ -8,13 +8,14 @@ WideDeep model jointly trained wide linear models and deep neural network, which
|
|||
|
||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||
|
||||
- Download the dataset and convert the dataset to mindrecord, command as follows:
|
||||
- Place the raw dataset under a certain path, such as: ./recommendation_dataset/origin_data, if you use [criteo dataset](https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz), please downlowd the dataset and unzip it to ./recommendation_dataset/origin_data.
|
||||
|
||||
- Convert the dataset to mindrecord, command as follows:
|
||||
```
|
||||
python src/preprocess_data.py --dense_dim=13 --slot_dim=26 --threshold=100 --train_line_count=45840617 --skip_id_convert=0
|
||||
python src/preprocess_data.py --data_path=./recommendation_dataset --dense_dim=13 --slot_dim=26 --threshold=100 --train_line_count=45840617 --skip_id_convert=0
|
||||
|
||||
```
|
||||
Arguments:
|
||||
* `--data_type` {criteo,synthetic}: Currently we support criteo dataset and synthetic dataset.(Default: ./criteo_data/).
|
||||
* `--data_path` : The path of the data file.
|
||||
* `--dense_dim` : The number of your continues fields.
|
||||
* `--slot_dim` : The number of your sparse fields, it can also be called category features.
|
||||
|
|
|
@ -17,8 +17,6 @@ import os
|
|||
import pickle
|
||||
import collections
|
||||
import argparse
|
||||
import urllib.request
|
||||
import tarfile
|
||||
import numpy as np
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
|
@ -140,7 +138,7 @@ def mkdir_path(file_path):
|
|||
os.makedirs(file_path)
|
||||
|
||||
|
||||
def statsdata(file_path, dict_output_path, criteo_stats_dict, dense_dim=13, slot_dim=26):
|
||||
def statsdata(file_path, dict_output_path, recommendation_dataset_stats_dict, dense_dim=13, slot_dim=26):
|
||||
"""Preprocess data and save data"""
|
||||
with open(file_path, encoding="utf-8") as file_in:
|
||||
errorline_list = []
|
||||
|
@ -161,13 +159,13 @@ def statsdata(file_path, dict_output_path, criteo_stats_dict, dense_dim=13, slot
|
|||
|
||||
assert len(values) == dense_dim, "values.size: {}".format(len(values))
|
||||
assert len(cats) == slot_dim, "cats.size: {}".format(len(cats))
|
||||
criteo_stats_dict.stats_vals(values)
|
||||
criteo_stats_dict.stats_cats(cats)
|
||||
criteo_stats_dict.save_dict(dict_output_path)
|
||||
recommendation_dataset_stats_dict.stats_vals(values)
|
||||
recommendation_dataset_stats_dict.stats_cats(cats)
|
||||
recommendation_dataset_stats_dict.save_dict(dict_output_path)
|
||||
|
||||
|
||||
def random_split_trans2mindrecord(input_file_path, output_file_path, criteo_stats_dict, part_rows=2000000,
|
||||
line_per_sample=1000, train_line_count=None,
|
||||
def random_split_trans2mindrecord(input_file_path, output_file_path, recommendation_dataset_stats_dict,
|
||||
part_rows=2000000, line_per_sample=1000, train_line_count=None,
|
||||
test_size=0.1, seed=2020, dense_dim=13, slot_dim=26):
|
||||
"""Random split data and save mindrecord"""
|
||||
if train_line_count is None:
|
||||
|
@ -216,7 +214,7 @@ def random_split_trans2mindrecord(input_file_path, output_file_path, criteo_stat
|
|||
assert len(values) == dense_dim, "values.size: {}".format(len(values))
|
||||
assert len(cats) == slot_dim, "cats.size: {}".format(len(cats))
|
||||
|
||||
ids, wts = criteo_stats_dict.map_cat2id(values, cats)
|
||||
ids, wts = recommendation_dataset_stats_dict.map_cat2id(values, cats)
|
||||
|
||||
ids_list.extend(ids)
|
||||
wts_list.extend(wts)
|
||||
|
@ -261,10 +259,8 @@ def random_split_trans2mindrecord(input_file_path, output_file_path, criteo_stat
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="criteo data")
|
||||
parser.add_argument("--data_type", type=str, default='criteo', choices=['criteo', 'synthetic'],
|
||||
help='Currently we support criteo dataset and synthetic dataset')
|
||||
parser.add_argument("--data_path", type=str, default="./criteo_data/", help='The path of the data file')
|
||||
parser = argparse.ArgumentParser(description="Recommendation dataset")
|
||||
parser.add_argument("--data_path", type=str, default="./recommendation_dataset/", help='The path of the data file')
|
||||
parser.add_argument("--dense_dim", type=int, default=13, help='The number of your continues fields')
|
||||
parser.add_argument("--slot_dim", type=int, default=26,
|
||||
help='The number of your sparse fields, it can also be called catelogy features.')
|
||||
|
@ -277,19 +273,6 @@ if __name__ == '__main__':
|
|||
args, _ = parser.parse_known_args()
|
||||
data_path = args.data_path
|
||||
|
||||
if args.data_type == 'criteo':
|
||||
download_data_path = data_path + "origin_data/"
|
||||
mkdir_path(download_data_path)
|
||||
|
||||
url = "https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz"
|
||||
file_name = download_data_path + '/' + url.split('/')[-1]
|
||||
urllib.request.urlretrieve(url, filename=file_name)
|
||||
|
||||
tar = tarfile.open(file_name)
|
||||
names = tar.getnames()
|
||||
for name in names:
|
||||
tar.extract(name, path=download_data_path)
|
||||
tar.close()
|
||||
target_field_size = args.dense_dim + args.slot_dim
|
||||
stats = StatsDict(field_size=target_field_size, dense_dim=args.dense_dim, slot_dim=args.slot_dim,
|
||||
skip_id_convert=args.skip_id_convert)
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Criteo data process
|
||||
Recommendation dataset process
|
||||
"""
|
||||
|
||||
import os
|
||||
|
@ -27,7 +27,7 @@ import pandas as pd
|
|||
TRAIN_LINE_COUNT = 45840617
|
||||
TEST_LINE_COUNT = 6042135
|
||||
|
||||
class CriteoStatsDict():
|
||||
class RecommendationDatasetStatsDict():
|
||||
"""create data dict"""
|
||||
def __init__(self):
|
||||
self.field_size = 39 # value_1-13; cat_1-26;
|
||||
|
@ -135,7 +135,7 @@ def mkdir_path(file_path):
|
|||
os.makedirs(file_path)
|
||||
#
|
||||
|
||||
def statsdata(data_file_path, output_path, criteo_stats):
|
||||
def statsdata(data_file_path, output_path, recommendation_dataset_stats):
|
||||
"""data status"""
|
||||
with open(data_file_path, encoding="utf-8") as file_in:
|
||||
errorline_list = []
|
||||
|
@ -157,9 +157,9 @@ def statsdata(data_file_path, output_path, criteo_stats):
|
|||
cats = items[14:]
|
||||
assert len(values) == 13, "value.size: {}".format(len(values))
|
||||
assert len(cats) == 26, "cat.size: {}".format(len(cats))
|
||||
criteo_stats.stats_vals(values)
|
||||
criteo_stats.stats_cats(cats)
|
||||
criteo_stats.save_dict(output_path)
|
||||
recommendation_dataset_stats.stats_vals(values)
|
||||
recommendation_dataset_stats.stats_cats(cats)
|
||||
recommendation_dataset_stats.save_dict(output_path)
|
||||
#
|
||||
|
||||
|
||||
|
@ -169,7 +169,8 @@ def add_write(file_path, wr_str):
|
|||
#
|
||||
|
||||
|
||||
def random_split_trans2h5(in_file_path, output_path, criteo_stats, part_rows=2000000, test_size=0.1, seed=2020):
|
||||
def random_split_trans2h5(in_file_path, output_path, recommendation_dataset_stats,
|
||||
part_rows=2000000, test_size=0.1, seed=2020):
|
||||
"""random split trans2h5"""
|
||||
test_size = int(TRAIN_LINE_COUNT * test_size)
|
||||
# train_size = TRAIN_LINE_COUNT - test_size
|
||||
|
@ -207,7 +208,7 @@ def random_split_trans2h5(in_file_path, output_path, criteo_stats, part_rows=200
|
|||
cats = items[14:]
|
||||
assert len(values) == 13, "value.size: {}".format(len(values))
|
||||
assert len(cats) == 26, "cat.size: {}".format(len(cats))
|
||||
ids, wts = criteo_stats.map_cat2id(values, cats)
|
||||
ids, wts = recommendation_dataset_stats.map_cat2id(values, cats)
|
||||
if i not in test_indices_set:
|
||||
train_feature_list.append(ids + wts)
|
||||
train_label_list.append(label)
|
||||
|
@ -253,16 +254,17 @@ if __name__ == "__main__":
|
|||
help="The path to save dataset")
|
||||
args, _ = parser.parse_known_args()
|
||||
base_path = args.raw_data_path
|
||||
criteo_stat = CriteoStatsDict()
|
||||
recommendation_dataset_stat = RecommendationDatasetStatsDict()
|
||||
# step 1, stats the vocab and normalize value
|
||||
datafile_path = base_path + "train_small.txt"
|
||||
stats_out_path = base_path + "stats_dict/"
|
||||
mkdir_path(stats_out_path)
|
||||
statsdata(datafile_path, stats_out_path, criteo_stat)
|
||||
statsdata(datafile_path, stats_out_path, recommendation_dataset_stat)
|
||||
print("------" * 10)
|
||||
criteo_stat.load_dict(dict_path=stats_out_path, prefix="")
|
||||
criteo_stat.get_cat2id(threshold=100)
|
||||
recommendation_dataset_stat.load_dict(dict_path=stats_out_path, prefix="")
|
||||
recommendation_dataset_stat.get_cat2id(threshold=100)
|
||||
# step 2, transform data trans2h5; version 2: np.random.shuffle
|
||||
infile_path = base_path + "train_small.txt"
|
||||
mkdir_path(args.output_path)
|
||||
random_split_trans2h5(infile_path, args.output_path, criteo_stat, part_rows=2000000, test_size=0.1, seed=2020)
|
||||
random_split_trans2h5(infile_path, args.output_path, recommendation_dataset_stat,
|
||||
part_rows=2000000, test_size=0.1, seed=2020)
|
||||
|
|
Loading…
Reference in New Issue