From 8cae48091b63b0c3b3c0414d912b8921ec171c57 Mon Sep 17 00:00:00 2001 From: wukesong Date: Fri, 29 May 2020 15:58:18 +0800 Subject: [PATCH] pre_process --- .../wide_and_deep/src/preprocess_data.py | 278 ++++++++++++++++++ 1 file changed, 278 insertions(+) create mode 100644 model_zoo/wide_and_deep/src/preprocess_data.py diff --git a/model_zoo/wide_and_deep/src/preprocess_data.py b/model_zoo/wide_and_deep/src/preprocess_data.py new file mode 100644 index 00000000000..35d13b841da --- /dev/null +++ b/model_zoo/wide_and_deep/src/preprocess_data.py @@ -0,0 +1,278 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Download raw data and preprocessed data.""" +import os +import pickle +import collections +import argparse +import numpy as np +from mindspore.mindrecord import FileWriter + +TRAIN_LINE_COUNT = 45840617 +TEST_LINE_COUNT = 6042135 + + +class CriteoStatsDict(): + """preprocessed data""" + + def __init__(self): + self.field_size = 39 + self.val_cols = ["val_{}".format(i + 1) for i in range(13)] + self.cat_cols = ["cat_{}".format(i + 1) for i in range(26)] + + self.val_min_dict = {col: 0 for col in self.val_cols} + self.val_max_dict = {col: 0 for col in self.val_cols} + + self.cat_count_dict = {col: collections.defaultdict(int) for col in self.cat_cols} + + self.oov_prefix = "OOV" + + self.cat2id_dict = {} + self.cat2id_dict.update({col: i for i, col in enumerate(self.val_cols)}) + self.cat2id_dict.update( + {self.oov_prefix + col: i + len(self.val_cols) for i, col in enumerate(self.cat_cols)}) + + def stats_vals(self, val_list): + """Handling weights column""" + assert len(val_list) == len(self.val_cols) + + def map_max_min(i, val): + key = self.val_cols[i] + if val != "": + if float(val) > self.val_max_dict[key]: + self.val_max_dict[key] = float(val) + if float(val) < self.val_min_dict[key]: + self.val_min_dict[key] = float(val) + + for i, val in enumerate(val_list): + map_max_min(i, val) + + def stats_cats(self, cat_list): + """Handling cats column""" + + assert len(cat_list) == len(self.cat_cols) + + def map_cat_count(i, cat): + key = self.cat_cols[i] + self.cat_count_dict[key][cat] += 1 + + for i, cat in enumerate(cat_list): + map_cat_count(i, cat) + + def save_dict(self, dict_path, prefix=""): + with open(os.path.join(dict_path, "{}val_max_dict.pkl".format(prefix)), "wb") as file_wrt: + pickle.dump(self.val_max_dict, file_wrt) + with open(os.path.join(dict_path, "{}val_min_dict.pkl".format(prefix)), "wb") as file_wrt: + pickle.dump(self.val_min_dict, file_wrt) + with open(os.path.join(dict_path, "{}cat_count_dict.pkl".format(prefix)), "wb") as file_wrt: + pickle.dump(self.cat_count_dict, file_wrt) + + def load_dict(self, dict_path, prefix=""): + with open(os.path.join(dict_path, "{}val_max_dict.pkl".format(prefix)), "rb") as file_wrt: + self.val_max_dict = pickle.load(file_wrt) + with open(os.path.join(dict_path, "{}val_min_dict.pkl".format(prefix)), "rb") as file_wrt: + self.val_min_dict = pickle.load(file_wrt) + with open(os.path.join(dict_path, "{}cat_count_dict.pkl".format(prefix)), "rb") as file_wrt: + self.cat_count_dict = pickle.load(file_wrt) + print("val_max_dict.items()[:50]:{}".format(list(self.val_max_dict.items()))) + print("val_min_dict.items()[:50]:{}".format(list(self.val_min_dict.items()))) + + def get_cat2id(self, threshold=100): + for key, cat_count_d in self.cat_count_dict.items(): + new_cat_count_d = dict(filter(lambda x: x[1] > threshold, cat_count_d.items())) + for cat_str, _ in new_cat_count_d.items(): + self.cat2id_dict[key + "_" + cat_str] = len(self.cat2id_dict) + print("cat2id_dict.size:{}".format(len(self.cat2id_dict))) + print("cat2id.dict.items()[:50]:{}".format(list(self.cat2id_dict.items())[:50])) + + def map_cat2id(self, values, cats): + """Cat to id""" + + def minmax_scale_value(i, val): + max_v = float(self.val_max_dict["val_{}".format(i + 1)]) + return float(val) * 1.0 / max_v + + id_list = [] + weight_list = [] + for i, val in enumerate(values): + if val == "": + id_list.append(i) + weight_list.append(0) + else: + key = "val_{}".format(i + 1) + id_list.append(self.cat2id_dict[key]) + weight_list.append(minmax_scale_value(i, float(val))) + + for i, cat_str in enumerate(cats): + key = "cat_{}".format(i + 1) + "_" + cat_str + if key in self.cat2id_dict: + id_list.append(self.cat2id_dict[key]) + else: + id_list.append(self.cat2id_dict[self.oov_prefix + "cat_{}".format(i + 1)]) + weight_list.append(1.0) + return id_list, weight_list + + +def mkdir_path(file_path): + if not os.path.exists(file_path): + os.makedirs(file_path) + + +def statsdata(file_path, dict_output_path, criteo_stats_dict): + """Preprocess data and save data""" + with open(file_path, encoding="utf-8") as file_in: + errorline_list = [] + count = 0 + for line in file_in: + count += 1 + line = line.strip("\n") + items = line.split("\t") + if len(items) != 40: + errorline_list.append(count) + print("line: {}".format(line)) + continue + if count % 1000000 == 0: + print("Have handled {}w lines.".format(count // 10000)) + values = items[1:14] + cats = items[14:] + + assert len(values) == 13, "values.size: {}".format(len(values)) + assert len(cats) == 26, "cats.size: {}".format(len(cats)) + criteo_stats_dict.stats_vals(values) + criteo_stats_dict.stats_cats(cats) + criteo_stats_dict.save_dict(dict_output_path) + + +def random_split_trans2mindrecord(input_file_path, output_file_path, criteo_stats_dict, part_rows=2000000, + line_per_sample=1000, + test_size=0.1, seed=2020): + """Random split data and save mindrecord""" + test_size = int(TRAIN_LINE_COUNT * test_size) + all_indices = [i for i in range(TRAIN_LINE_COUNT)] + np.random.seed(seed) + np.random.shuffle(all_indices) + print("all_indices.size:{}".format(len(all_indices))) + test_indices_set = set(all_indices[:test_size]) + print("test_indices_set.size:{}".format(len(test_indices_set))) + print("-----------------------" * 10 + "\n" * 2) + + train_data_list = [] + test_data_list = [] + ids_list = [] + wts_list = [] + label_list = [] + + writer_train = FileWriter(os.path.join(output_file_path, "train_input_part.mindrecord"), 21) + writer_test = FileWriter(os.path.join(output_file_path, "test_input_part.mindrecord"), 3) + + schema = {"label": {"type": "float32", "shape": [-1]}, "feat_vals": {"type": "float32", "shape": [-1]}, + "feat_ids": {"type": "int32", "shape": [-1]}} + writer_train.add_schema(schema, "CRITEO_TRAIN") + writer_test.add_schema(schema, "CRITEO_TEST") + + with open(input_file_path, encoding="utf-8") as file_in: + items_error_size_lineCount = [] + count = 0 + train_part_number = 0 + test_part_number = 0 + for i, line in enumerate(file_in): + count += 1 + if count % 1000000 == 0: + print("Have handle {}w lines.".format(count // 10000)) + line = line.strip("\n") + items = line.split("\t") + if len(items) != 40: + items_error_size_lineCount.append(i) + continue + label = float(items[0]) + values = items[1:14] + cats = items[14:] + + assert len(values) == 13, "values.size: {}".format(len(values)) + assert len(cats) == 26, "cats.size: {}".format(len(cats)) + + ids, wts = criteo_stats_dict.map_cat2id(values, cats) + + ids_list.extend(ids) + wts_list.extend(wts) + label_list.append(label) + + if count % line_per_sample == 0: + if i not in test_indices_set: + train_data_list.append({"feat_ids": np.array(ids_list, dtype=np.int32), + "feat_vals": np.array(wts_list, dtype=np.float32), + "label": np.array(label_list, dtype=np.float32) + }) + else: + test_data_list.append({"feat_ids": np.array(ids_list, dtype=np.int32), + "feat_vals": np.array(wts_list, dtype=np.float32), + "label": np.array(label_list, dtype=np.float32) + }) + if train_data_list and len(train_data_list) % part_rows == 0: + writer_train.write_raw_data(train_data_list) + train_data_list.clear() + train_part_number += 1 + + if test_data_list and len(test_data_list) % part_rows == 0: + writer_test.write_raw_data(test_data_list) + test_data_list.clear() + test_part_number += 1 + + ids_list.clear() + wts_list.clear() + label_list.clear() + + if train_data_list: + writer_train.write_raw_data(train_data_list) + if test_data_list: + writer_test.write_raw_data(test_data_list) + writer_train.commit() + writer_test.commit() + + print("-------------" * 10) + print("items_error_size_lineCount.size(): {}.".format(len(items_error_size_lineCount))) + print("-------------" * 10) + np.save("items_error_size_lineCount.npy", items_error_size_lineCount) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="criteo data") + parser.add_argument("--data_path", type=str, default="./criteo_data/") + + args, _ = parser.parse_known_args() + data_path = args.data_path + + download_data_path = data_path + "origin_data/" + mkdir_path(download_data_path) + + os.system( + "wget -P {} -c https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz --no-check-certificate".format( + download_data_path)) + os.system("tar -zxvf {}dac.tar.gz".format(download_data_path)) + + criteo_stats = CriteoStatsDict() + data_file_path = data_path + "origin_data/train.txt" + stats_output_path = data_path + "stats_dict/" + mkdir_path(stats_output_path) + statsdata(data_file_path, stats_output_path, criteo_stats) + + criteo_stats.load_dict(dict_path=stats_output_path, prefix="") + criteo_stats.get_cat2id(threshold=100) + + in_file_path = data_path + "origin_data/train.txt" + output_path = data_path + "mindrecord/" + mkdir_path(output_path) + random_split_trans2mindrecord(in_file_path, output_path, criteo_stats, part_rows=2000000, line_per_sample=1000, + test_size=0.1, seed=2020)