diff --git a/example/deepfm_criteo/README.md b/example/deepfm_criteo/README.md new file mode 100644 index 0000000000..47809a54c0 --- /dev/null +++ b/example/deepfm_criteo/README.md @@ -0,0 +1,132 @@ +# DeepFM Description + +This is an example of training DeepFM with Criteo dataset in MindSpore. + +[Paper](https://arxiv.org/pdf/1703.04247.pdf) Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li, Xiuqiang He + + +# Model architecture + +The overall network architecture of DeepFM is show below: + +[Link](https://arxiv.org/pdf/1703.04247.pdf) + + +# Requirements +- Install [MindSpore](https://www.mindspore.cn/install/en). +- Download the criteo dataset for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). Convert the dataset to TFRecord format and move the files to a specified path. +- For more information, please check the resources below: + - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) + - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) + +# Script description + +## Script and sample code + +```python +├── deepfm + ├── README.md + ├── scripts + │ ├──run_train.sh + │ ├──run_eval.sh + ├── src + │ ├──config.py + │ ├──dataset.py + │ ├──callback.py + │ ├──deepfm.py + ├── train.py + ├── eval.py +``` + +## Training process + +### Usage + +- sh run_train.sh [DEVICE_NUM] [DATASET_PATH] [MINDSPORE_HCCL_CONFIG_PAHT] +- python train.py --dataset_path [DATASET_PATH] + +### Launch + +``` +# distribute training example + sh scripts/run_distribute_train.sh 8 /opt/dataset/criteo /opt/mindspore_hccl_file.json +# standalone training example + sh scripts/run_standalone_train.sh 0 /opt/dataset/criteo + or + python train.py --dataset_path /opt/dataset/criteo > output.log 2>&1 & +``` + +### Result + +Training result will be stored in the example path. +Checkpoints will be stored at `./checkpoint` by default, +and training log will be redirected to `./output.log` by default, +and loss log will be redirected to `./loss.log` by default, +and eval log will be redirected to `./auc.log` by default. + + +## Eval process + +### Usage + +- sh run_eval.sh [DEVICE_ID] [DATASET_PATH] [CHECKPOINT_PATH] + +### Launch + +``` +# infer example + sh scripts/run_eval.sh 0 ~/criteo/eval/ ~/train/deepfm-15_41257.ckpt +``` + +> checkpoint can be produced in training process. + +### Result + +Inference result will be stored in the example path, you can find result like the followings in `auc.log`. + +``` +2020-05-27 20:51:35 AUC: 0.80577889065281, eval time: 35.55999s. +``` + +# Model description + +## Performance + +### Training Performance + +| Parameters | DeepFM | +| -------------------------- | ------------------------------------------------------| +| Model Version | | +| Resource | Ascend 910, cpu:2.60GHz 96cores, memory:1.5T | +| uploaded Date | 05/27/2020 | +| MindSpore Version | 0.2.0 | +| Dataset | Criteo | +| Training Parameters | src/config.py | +| Optimizer | Adam | +| Loss Function | SoftmaxCrossEntropyWithLogits | +| outputs | | +| Loss | 0.4234 | +| Accuracy | AUC[0.8055] | +| Total time | 91 min | +| Params (M) | | +| Checkpoint for Fine tuning | | +| Model for inference | | + +#### Inference Performance + +| Parameters | | | +| -------------------------- | ----------------------------- | ------------------------- | +| Model Version | | | +| Resource | Ascend 910 | Ascend 310 | +| uploaded Date | 05/27/2020 | 05/27/2020 | +| MindSpore Version | 0.2.0 | 0.2.0 | +| Dataset | Criteo | | +| batch_size | 1000 | | +| outputs | | | +| Accuracy | AUC[0.8055] | | +| Speed | | | +| Total time | 35.559s | | +| Model for inference | | | + +# ModelZoo Homepage + [Link](https://gitee.com/mindspore/mindspore/tree/master/mindspore/model_zoo) diff --git a/example/deepfm_criteo/__init__.py b/example/deepfm_criteo/__init__.py new file mode 100644 index 0000000000..301ef9dcb7 --- /dev/null +++ b/example/deepfm_criteo/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the License); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# httpwww.apache.orglicensesLICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/example/deepfm_criteo/eval.py b/example/deepfm_criteo/eval.py new file mode 100644 index 0000000000..0452f73d23 --- /dev/null +++ b/example/deepfm_criteo/eval.py @@ -0,0 +1,66 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train_criteo.""" +import os +import sys +import time +import argparse + +from mindspore import context +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from src.deepfm import ModelBuilder, AUCMetric +from src.config import DataConfig, ModelConfig, TrainConfig +from src.dataset import create_dataset + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +parser = argparse.ArgumentParser(description='CTR Prediction') +parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') +parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') + +args_opt, _ = parser.parse_known_args() +device_id = int(os.getenv('DEVICE_ID')) +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) + + +def add_write(file_path, print_str): + with open(file_path, 'a+', encoding='utf-8') as file_out: + file_out.write(print_str + '\n') + + +if __name__ == '__main__': + data_config = DataConfig() + model_config = ModelConfig() + train_config = TrainConfig() + + ds_eval = create_dataset(args_opt.dataset_path, train_mode=False, + epochs=1, batch_size=train_config.batch_size) + model_builder = ModelBuilder(ModelConfig, TrainConfig) + train_net, eval_net = model_builder.get_train_eval_net() + train_net.set_train() + eval_net.set_train(False) + auc_metric = AUCMetric() + model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) + param_dict = load_checkpoint(args_opt.checkpoint_path) + load_param_into_net(eval_net, param_dict) + + start = time.time() + res = model.eval(ds_eval) + eval_time = time.time() - start + time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + out_str = f'{time_str} AUC: {list(res.values())[0]}, eval time: {eval_time}s.' + print(out_str) + add_write('./auc.log', str(out_str)) diff --git a/example/deepfm_criteo/scripts/run_distribute_train.sh b/example/deepfm_criteo/scripts/run_distribute_train.sh new file mode 100644 index 0000000000..a19c857936 --- /dev/null +++ b/example/deepfm_criteo/scripts/run_distribute_train.sh @@ -0,0 +1,44 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +echo "Please run the script as: " +echo "sh scripts/run_distribute_train.sh DEVICE_NUM DATASET_PATH MINDSPORE_HCCL_CONFIG_PAHT" +echo "for example: sh scripts/run_distribute_train.sh 8 /dataset_path /rank_table_8p.json" +echo "After running the script, the network runs in the background, The log will be generated in logx/output.log" + + +export RANK_SIZE=$1 +DATA_URL=$2 +export MINDSPORE_HCCL_CONFIG_PAHT=$3 + +for ((i=0; i env.log + python -u train.py \ + --dataset_path=$DATA_URL \ + --ckpt_path="checkpoint" \ + --eval_file_name='auc.log' \ + --loss_file_name='loss.log' \ + --do_eval=True > output.log 2>&1 & + cd ../ +done diff --git a/example/deepfm_criteo/scripts/run_eval.sh b/example/deepfm_criteo/scripts/run_eval.sh new file mode 100644 index 0000000000..aa5765da31 --- /dev/null +++ b/example/deepfm_criteo/scripts/run_eval.sh @@ -0,0 +1,32 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +echo "Please run the script as: " +echo "sh scripts/run_eval.sh DEVICE_ID DATASET_PATH CHECKPOINT_PATH" +echo "for example: sh scripts/run_eval.sh 0 /dataset_path /checkpoint_path" +echo "After running the script, the network runs in the background, The log will be generated in ms_log/eval_output.log" + +export DEVICE_ID=$1 +DATA_URL=$2 +CHECKPOINT_PATH=$3 + +mkdir -p ms_log +CUR_DIR=`pwd` +export GLOG_log_dir=${CUR_DIR}/ms_log +export GLOG_logtostderr=0 + +python -u eval.py \ + --dataset_path=$DATA_URL \ + --checkpoint_path=$CHECKPOINT_PATH > ms_log/eval_output.log 2>&1 & \ No newline at end of file diff --git a/example/deepfm_criteo/scripts/run_standalone_train.sh b/example/deepfm_criteo/scripts/run_standalone_train.sh new file mode 100644 index 0000000000..fa22b82d3d --- /dev/null +++ b/example/deepfm_criteo/scripts/run_standalone_train.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +echo "Please run the script as: " +echo "sh scripts/run_standalone_train.sh DEVICE_ID DATASET_PATH" +echo "for example: sh scripts/run_standalone_train.sh 0 /dataset_path" +echo "After running the script, the network runs in the background, The log will be generated in ms_log/output.log" + +export DEVICE_ID=$1 +DATA_URL=$2 + +mkdir -p ms_log +CUR_DIR=`pwd` +export GLOG_log_dir=${CUR_DIR}/ms_log +export GLOG_logtostderr=0 + +python -u train.py \ + --dataset_path=$DATA_URL \ + --ckpt_path="checkpoint" \ + --eval_file_name='auc.log' \ + --loss_file_name='loss.log' \ + --do_eval=True > ms_log/output.log 2>&1 & diff --git a/example/deepfm_criteo/src/__init__.py b/example/deepfm_criteo/src/__init__.py new file mode 100644 index 0000000000..301ef9dcb7 --- /dev/null +++ b/example/deepfm_criteo/src/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the License); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# httpwww.apache.orglicensesLICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/example/deepfm_criteo/src/callback.py b/example/deepfm_criteo/src/callback.py new file mode 100644 index 0000000000..fce7b29fa6 --- /dev/null +++ b/example/deepfm_criteo/src/callback.py @@ -0,0 +1,107 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the License); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# httpwww.apache.orglicensesLICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Defined callback for DeepFM. +""" +import time +from mindspore.train.callback import Callback + + +def add_write(file_path, out_str): + with open(file_path, 'a+', encoding='utf-8') as file_out: + file_out.write(out_str + '\n') + + +class EvalCallBack(Callback): + """ + Monitor the loss in training. + If the loss is NAN or INF terminating training. + Note + If per_print_times is 0 do not print loss. + """ + def __init__(self, model, eval_dataset, auc_metric, eval_file_path): + super(EvalCallBack, self).__init__() + self.model = model + self.eval_dataset = eval_dataset + self.aucMetric = auc_metric + self.aucMetric.clear() + self.eval_file_path = eval_file_path + + def epoch_end(self, run_context): + start_time = time.time() + out = self.model.eval(self.eval_dataset) + eval_time = int(time.time() - start_time) + time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + out_str = "{} EvalCallBack metric{}; eval_time{}s".format( + time_str, out.values(), eval_time) + print(out_str) + add_write(self.eval_file_path, out_str) + + +class LossCallBack(Callback): + """ + Monitor the loss in training. + If the loss is NAN or INF terminating training. + Note + If per_print_times is 0 do not print loss. + Args + loss_file_path (str) The file absolute path, to save as loss_file; + per_print_times (int) Print loss every times. Default 1. + """ + def __init__(self, loss_file_path, per_print_times=1): + super(LossCallBack, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0.") + self.loss_file_path = loss_file_path + self._per_print_times = per_print_times + + def step_end(self, run_context): + cb_params = run_context.original_args() + loss = cb_params.net_outputs.asnumpy() + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + cur_num = cb_params.cur_step_num + if self._per_print_times != 0 and cur_num % self._per_print_times == 0: + with open(self.loss_file_path, "a+") as loss_file: + time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + loss_file.write("{} epoch: {} step: {}, loss is {}\n".format( + time_str, cb_params.cur_epoch_num, cur_step_in_epoch, loss)) + print("epoch: {} step: {}, loss is {}\n".format( + cb_params.cur_epoch_num, cur_step_in_epoch, loss)) + + +class TimeMonitor(Callback): + """ + Time monitor for calculating cost of each epoch. + Args + data_size (int) step size of an epoch. + """ + def __init__(self, data_size): + super(TimeMonitor, self).__init__() + self.data_size = data_size + + def epoch_begin(self, run_context): + self.epoch_time = time.time() + + def epoch_end(self, run_context): + epoch_mseconds = (time.time() - self.epoch_time) * 1000 + per_step_mseconds = epoch_mseconds / self.data_size + print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True) + + def step_begin(self, run_context): + self.step_time = time.time() + + def step_end(self, run_context): + step_mseconds = (time.time() - self.step_time) * 1000 + print(f"step time {step_mseconds}", flush=True) diff --git a/example/deepfm_criteo/src/config.py b/example/deepfm_criteo/src/config.py new file mode 100644 index 0000000000..14a6daefb7 --- /dev/null +++ b/example/deepfm_criteo/src/config.py @@ -0,0 +1,62 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in train.py and eval.py +""" + + +class DataConfig: + """ + Define parameters of dataset. + """ + data_vocab_size = 184965 + train_num_of_parts = 21 + test_num_of_parts = 3 + batch_size = 1000 + data_field_size = 39 + # dataset format, 1: mindrecord, 2: tfrecord, 3: h5 + data_format = 2 + + +class ModelConfig: + """ + Define parameters of model. + """ + batch_size = DataConfig.batch_size + data_field_size = DataConfig.data_field_size + data_vocab_size = DataConfig.data_vocab_size + data_emb_dim = 80 + deep_layer_args = [[400, 400, 512], "relu"] + init_args = [-0.01, 0.01] + weight_bias_init = ['normal', 'normal'] + keep_prob = 0.9 + + +class TrainConfig: + """ + Define parameters of training. + """ + batch_size = DataConfig.batch_size + l2_coef = 1e-6 + learning_rate = 1e-5 + epsilon = 1e-8 + loss_scale = 1024.0 + train_epochs = 15 + save_checkpoint = True + ckpt_file_name_prefix = "deepfm" + save_checkpoint_steps = 1 + keep_checkpoint_max = 15 + eval_callback = True + loss_callback = True diff --git a/example/deepfm_criteo/src/dataset.py b/example/deepfm_criteo/src/dataset.py new file mode 100644 index 0000000000..4904715220 --- /dev/null +++ b/example/deepfm_criteo/src/dataset.py @@ -0,0 +1,299 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Create train or eval dataset. +""" +import os +import math +from enum import Enum + +import pandas as pd +import numpy as np +import mindspore.dataset.engine as de +import mindspore.common.dtype as mstype + +from .config import DataConfig + + +class DataType(Enum): + """ + Enumerate supported dataset format. + """ + MINDRECORD = 1 + TFRECORD = 2 + H5 = 3 + + +class H5Dataset(): + """ + Create dataset with H5 format. + + Args: + data_path (str): Dataset directory. + train_mode (bool): Whether dataset is used for train or eval (default=True). + train_num_of_parts (int): The number of train data file (default=21). + test_num_of_parts (int): The number of test data file (default=3). + """ + max_length = 39 + + def __init__(self, data_path, train_mode=True, + train_num_of_parts=DataConfig.train_num_of_parts, + test_num_of_parts=DataConfig.test_num_of_parts): + self._hdf_data_dir = data_path + self._is_training = train_mode + if self._is_training: + self._file_prefix = 'train' + self._num_of_parts = train_num_of_parts + else: + self._file_prefix = 'test' + self._num_of_parts = test_num_of_parts + self.data_size = self._bin_count(self._hdf_data_dir, self._file_prefix, self._num_of_parts) + print("data_size: {}".format(self.data_size)) + + def _bin_count(self, hdf_data_dir, file_prefix, num_of_parts): + size = 0 + for part in range(num_of_parts): + _y = pd.read_hdf(os.path.join(hdf_data_dir, f'{file_prefix}_output_part_{str(part)}.h5')) + size += _y.shape[0] + return size + + def _iterate_hdf_files_(self, num_of_parts=None, + shuffle_block=False): + """ + iterate among hdf files(blocks). when the whole data set is finished, the iterator restarts + from the beginning, thus the data stream will never stop + :param train_mode: True or false,false is eval_mode, + this file iterator will go through the train set + :param num_of_parts: number of files + :param shuffle_block: shuffle block files at every round + :return: input_hdf_file_name, output_hdf_file_name, finish_flag + """ + parts = np.arange(num_of_parts) + while True: + if shuffle_block: + for _ in range(int(shuffle_block)): + np.random.shuffle(parts) + for i, p in enumerate(parts): + yield os.path.join(self._hdf_data_dir, f'{self._file_prefix}_input_part_{str(p)}.h5'), \ + os.path.join(self._hdf_data_dir, f'{self._file_prefix}_output_part_{str(p)}.h5'), \ + i + 1 == len(parts) + + def _generator(self, X, y, batch_size, shuffle=True): + """ + should be accessed only in private + :param X: + :param y: + :param batch_size: + :param shuffle: + :return: + """ + number_of_batches = np.ceil(1. * X.shape[0] / batch_size) + counter = 0 + finished = False + sample_index = np.arange(X.shape[0]) + if shuffle: + for _ in range(int(shuffle)): + np.random.shuffle(sample_index) + assert X.shape[0] > 0 + while True: + batch_index = sample_index[batch_size * counter: batch_size * (counter + 1)] + X_batch = X[batch_index] + y_batch = y[batch_index] + counter += 1 + yield X_batch, y_batch, finished + if counter == number_of_batches: + counter = 0 + finished = True + + def batch_generator(self, batch_size=1000, + random_sample=False, shuffle_block=False): + """ + :param train_mode: True or false,false is eval_mode, + :param batch_size + :param num_of_parts: number of files + :param random_sample: if True, will shuffle + :param shuffle_block: shuffle file blocks at every round + :return: + """ + + for hdf_in, hdf_out, _ in self._iterate_hdf_files_(self._num_of_parts, + shuffle_block): + start = stop = None + X_all = pd.read_hdf(hdf_in, start=start, stop=stop).values + y_all = pd.read_hdf(hdf_out, start=start, stop=stop).values + data_gen = self._generator(X_all, y_all, batch_size, + shuffle=random_sample) + finished = False + + while not finished: + X, y, finished = data_gen.__next__() + X_id = X[:, 0:self.max_length] + X_va = X[:, self.max_length:] + yield np.array(X_id.astype(dtype=np.int32)), \ + np.array(X_va.astype(dtype=np.float32)), \ + np.array(y.astype(dtype=np.float32)) + + +def _get_h5_dataset(directory, train_mode=True, epochs=1, batch_size=1000): + """ + Get dataset with h5 format. + + Args: + directory (str): Dataset directory. + train_mode (bool): Whether dataset is use for train or eval (default=True). + epochs (int): Dataset epoch size (default=1). + batch_size (int): Dataset batch size (default=1000) + + Returns: + Dataset. + """ + data_para = {'batch_size': batch_size} + if train_mode: + data_para['random_sample'] = True + data_para['shuffle_block'] = True + + h5_dataset = H5Dataset(data_path=directory, train_mode=train_mode) + numbers_of_batch = math.ceil(h5_dataset.data_size / batch_size) + + def _iter_h5_data(): + train_eval_gen = h5_dataset.batch_generator(**data_para) + for _ in range(0, numbers_of_batch, 1): + yield train_eval_gen.__next__() + + ds = de.GeneratorDataset(_iter_h5_data, ["ids", "weights", "labels"]) + ds.set_dataset_size(numbers_of_batch) + ds = ds.repeat(epochs) + return ds + + +def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=1000, + line_per_sample=1000, rank_size=None, rank_id=None): + """ + Get dataset with mindrecord format. + + Args: + directory (str): Dataset directory. + train_mode (bool): Whether dataset is use for train or eval (default=True). + epochs (int): Dataset epoch size (default=1). + batch_size (int): Dataset batch size (default=1000). + line_per_sample (int): The number of sample per line (default=1000). + rank_size (int): The number of device, not necessary for single device (default=None). + rank_id (int): Id of device, not necessary for single device (default=None). + + Returns: + Dataset. + """ + file_prefix_name = 'train_input_part.mindrecord' if train_mode else 'test_input_part.mindrecord' + file_suffix_name = '00' if train_mode else '0' + shuffle = train_mode + + if rank_size is not None and rank_id is not None: + ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name), + columns_list=['feat_ids', 'feat_vals', 'label'], + num_shards=rank_size, shard_id=rank_id, shuffle=shuffle, + num_parallel_workers=8) + else: + ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name), + columns_list=['feat_ids', 'feat_vals', 'label'], + shuffle=shuffle, num_parallel_workers=8) + ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True) + ds = ds.map(operations=(lambda x, y, z: (np.array(x).flatten().reshape(batch_size, 39), + np.array(y).flatten().reshape(batch_size, 39), + np.array(z).flatten().reshape(batch_size, 1))), + input_columns=['feat_ids', 'feat_vals', 'label'], + columns_order=['feat_ids', 'feat_vals', 'label'], + num_parallel_workers=8) + ds = ds.repeat(epochs) + return ds + + +def _get_tf_dataset(directory, train_mode=True, epochs=1, batch_size=1000, + line_per_sample=1000, rank_size=None, rank_id=None): + """ + Get dataset with tfrecord format. + + Args: + directory (str): Dataset directory. + train_mode (bool): Whether dataset is use for train or eval (default=True). + epochs (int): Dataset epoch size (default=1). + batch_size (int): Dataset batch size (default=1000). + line_per_sample (int): The number of sample per line (default=1000). + rank_size (int): The number of device, not necessary for single device (default=None). + rank_id (int): Id of device, not necessary for single device (default=None). + + Returns: + Dataset. + """ + dataset_files = [] + file_prefixt_name = 'train' if train_mode else 'test' + shuffle = train_mode + for (dir_path, _, filenames) in os.walk(directory): + for filename in filenames: + if file_prefixt_name in filename and 'tfrecord' in filename: + dataset_files.append(os.path.join(dir_path, filename)) + schema = de.Schema() + schema.add_column('feat_ids', de_type=mstype.int32) + schema.add_column('feat_vals', de_type=mstype.float32) + schema.add_column('label', de_type=mstype.float32) + if rank_size is not None and rank_id is not None: + ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, + schema=schema, num_parallel_workers=8, + num_shards=rank_size, shard_id=rank_id, + shard_equal_rows=True) + else: + ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, + schema=schema, num_parallel_workers=8) + ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True) + ds = ds.map(operations=(lambda x, y, z: ( + np.array(x).flatten().reshape(batch_size, 39), + np.array(y).flatten().reshape(batch_size, 39), + np.array(z).flatten().reshape(batch_size, 1))), + input_columns=['feat_ids', 'feat_vals', 'label'], + columns_order=['feat_ids', 'feat_vals', 'label'], + num_parallel_workers=8) + ds = ds.repeat(epochs) + return ds + + +def create_dataset(directory, train_mode=True, epochs=1, batch_size=1000, + data_type=DataType.TFRECORD, line_per_sample=1000, + rank_size=None, rank_id=None): + """ + Get dataset. + + Args: + directory (str): Dataset directory. + train_mode (bool): Whether dataset is use for train or eval (default=True). + epochs (int): Dataset epoch size (default=1). + batch_size (int): Dataset batch size (default=1000). + data_type (DataType): The type of dataset which is one of H5, TFRECORE, MINDRECORD (default=TFRECORD). + line_per_sample (int): The number of sample per line (default=1000). + rank_size (int): The number of device, not necessary for single device (default=None). + rank_id (int): Id of device, not necessary for single device (default=None). + + Returns: + Dataset. + """ + if data_type == DataType.MINDRECORD: + return _get_mindrecord_dataset(directory, train_mode, epochs, + batch_size, line_per_sample, + rank_size, rank_id) + if data_type == DataType.TFRECORD: + return _get_tf_dataset(directory, train_mode, epochs, batch_size, + line_per_sample, rank_size=rank_size, rank_id=rank_id) + + if rank_size is not None and rank_size > 1: + raise ValueError('Please use mindrecord dataset.') + return _get_h5_dataset(directory, train_mode, epochs, batch_size) diff --git a/example/deepfm_criteo/src/deepfm.py b/example/deepfm_criteo/src/deepfm.py new file mode 100644 index 0000000000..0fbe3afa49 --- /dev/null +++ b/example/deepfm_criteo/src/deepfm.py @@ -0,0 +1,370 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test_training """ +import os + +import numpy as np +from sklearn.metrics import roc_auc_score +import mindspore.common.dtype as mstype +from mindspore.ops import functional as F +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.nn import Dropout +from mindspore.nn.optim import Adam +from mindspore.nn.metrics import Metric +from mindspore import nn, ParameterTuple, Parameter +from mindspore.common.initializer import Uniform, initializer, Normal +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig + +from .callback import EvalCallBack, LossCallBack + + +np_type = np.float32 +ms_type = mstype.float32 + + +class AUCMetric(Metric): + """AUC metric for DeepFM model.""" + def __init__(self): + super(AUCMetric, self).__init__() + self.pred_probs = [] + self.true_labels = [] + + def clear(self): + """Clear the internal evaluation result.""" + self.pred_probs = [] + self.true_labels = [] + + def update(self, *inputs): + batch_predict = inputs[1].asnumpy() + batch_label = inputs[2].asnumpy() + self.pred_probs.extend(batch_predict.flatten().tolist()) + self.true_labels.extend(batch_label.flatten().tolist()) + + def eval(self): + if len(self.true_labels) != len(self.pred_probs): + raise RuntimeError('true_labels.size() is not equal to pred_probs.size()') + auc = roc_auc_score(self.true_labels, self.pred_probs) + return auc + + +def init_method(method, shape, name, max_val=0.01): + """ + The method of init parameters. + + Args: + method (str): The method uses to initialize parameter. + shape (list): The shape of parameter. + name (str): The name of parameter. + max_val (float): Max value in parameter when uses 'random' or 'uniform' to initialize parameter. + + Returns: + Parameter. + """ + if method in ['random', 'uniform']: + params = Parameter(initializer(Uniform(max_val), shape, ms_type), name=name) + elif method == "one": + params = Parameter(initializer("ones", shape, ms_type), name=name) + elif method == 'zero': + params = Parameter(initializer("zeros", shape, ms_type), name=name) + elif method == "normal": + params = Parameter(initializer(Normal(max_val), shape, ms_type), name=name) + return params + + +def init_var_dict(init_args, values): + """ + Init parameter. + + Args: + init_args (list): Define max and min value of parameters. + values (list): Define name, shape and init method of parameters. + + Returns: + dict, a dict ot Parameter. + """ + var_map = {} + _, _max_val = init_args + for key, shape, init_flag in values: + if key not in var_map.keys(): + if init_flag in ['random', 'uniform']: + var_map[key] = Parameter(initializer(Uniform(_max_val), shape, ms_type), name=key) + elif init_flag == "one": + var_map[key] = Parameter(initializer("ones", shape, ms_type), name=key) + elif init_flag == "zero": + var_map[key] = Parameter(initializer("zeros", shape, ms_type), name=key) + elif init_flag == 'normal': + var_map[key] = Parameter(initializer(Normal(_max_val), shape, ms_type), name=key) + return var_map + + +class DenseLayer(nn.Cell): + """ + Dense Layer for Deep Layer of DeepFM Model; + Containing: activation, matmul, bias_add; + Args: + input_dim (int): the shape of weight at 0-aixs; + output_dim (int): the shape of weight at 1-aixs, and shape of bias + weight_bias_init (list): weight and bias init method, "random", "uniform", "one", "zero", "normal"; + act_str (str): activation function method, "relu", "sigmoid", "tanh"; + keep_prob (float): Dropout Layer keep_prob_rate; + scale_coef (float): input scale coefficient; + """ + def __init__(self, input_dim, output_dim, weight_bias_init, act_str, keep_prob=0.9, scale_coef=1.0): + super(DenseLayer, self).__init__() + weight_init, bias_init = weight_bias_init + self.weight = init_method(weight_init, [input_dim, output_dim], name="weight") + self.bias = init_method(bias_init, [output_dim], name="bias") + self.act_func = self._init_activation(act_str) + self.matmul = P.MatMul(transpose_b=False) + self.bias_add = P.BiasAdd() + self.cast = P.Cast() + self.dropout = Dropout(keep_prob=keep_prob) + self.mul = P.Mul() + self.realDiv = P.RealDiv() + self.scale_coef = scale_coef + + def _init_activation(self, act_str): + act_str = act_str.lower() + if act_str == "relu": + act_func = P.ReLU() + elif act_str == "sigmoid": + act_func = P.Sigmoid() + elif act_str == "tanh": + act_func = P.Tanh() + return act_func + + def construct(self, x): + x = self.act_func(x) + if self.training: + x = self.dropout(x) + x = self.mul(x, self.scale_coef) + x = self.cast(x, mstype.float16) + weight = self.cast(self.weight, mstype.float16) + wx = self.matmul(x, weight) + wx = self.cast(wx, mstype.float32) + wx = self.realDiv(wx, self.scale_coef) + output = self.bias_add(wx, self.bias) + return output + + +class DeepFMModel(nn.Cell): + """ + From paper: "DeepFM: A Factorization-Machine based Neural Network for CTR Prediction" + + Args: + batch_size (int): smaple_number of per step in training; (int, batch_size=128) + filed_size (int): input filed number, or called id_feature number; (int, filed_size=39) + vocab_size (int): id_feature vocab size, id dict size; (int, vocab_size=200000) + emb_dim (int): id embedding vector dim, id mapped to embedding vector; (int, emb_dim=100) + deep_layer_args (list): Deep Layer args, layer_dim_list, layer_activator; + (int, deep_layer_args=[[100, 100, 100], "relu"]) + init_args (list): init args for Parameter init; (list, init_args=[min, max, seeds]) + weight_bias_init (list): weight, bias init method for deep layers; + (list[str], weight_bias_init=['random', 'zero']) + keep_prob (float): if dropout_flag is True, keep_prob rate to keep connect; (float, keep_prob=0.8) + """ + def __init__(self, config): + super(DeepFMModel, self).__init__() + + self.batch_size = config.batch_size + self.field_size = config.data_field_size + self.vocab_size = config.data_vocab_size + self.emb_dim = config.data_emb_dim + self.deep_layer_dims_list, self.deep_layer_act = config.deep_layer_args + self.init_args = config.init_args + self.weight_bias_init = config.weight_bias_init + self.keep_prob = config.keep_prob + init_acts = [('W_l2', [self.vocab_size, 1], 'normal'), + ('V_l2', [self.vocab_size, self.emb_dim], 'normal'), + ('b', [1], 'normal')] + var_map = init_var_dict(self.init_args, init_acts) + self.fm_w = var_map["W_l2"] + self.fm_b = var_map["b"] + self.embedding_table = var_map["V_l2"] + # Deep Layers + self.deep_input_dims = self.field_size * self.emb_dim + 1 + self.all_dim_list = [self.deep_input_dims] + self.deep_layer_dims_list + [1] + self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1], + self.weight_bias_init, self.deep_layer_act, self.keep_prob) + self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2], + self.weight_bias_init, self.deep_layer_act, self.keep_prob) + self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3], + self.weight_bias_init, self.deep_layer_act, self.keep_prob) + self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4], + self.weight_bias_init, self.deep_layer_act, self.keep_prob) + # FM, linear Layers + self.Gatherv2 = P.GatherV2() + self.Mul = P.Mul() + self.ReduceSum = P.ReduceSum(keep_dims=False) + self.Reshape = P.Reshape() + self.Square = P.Square() + self.Shape = P.Shape() + self.Tile = P.Tile() + self.Concat = P.Concat(axis=1) + self.Cast = P.Cast() + + def construct(self, id_hldr, wt_hldr): + """ + Args: + id_hldr: batch ids; [bs, field_size] + wt_hldr: batch weights; [bs, field_size] + """ + + mask = self.Reshape(wt_hldr, (self.batch_size, self.field_size, 1)) + # Linear layer + fm_id_weight = self.Gatherv2(self.fm_w, id_hldr, 0) + wx = self.Mul(fm_id_weight, mask) + linear_out = self.ReduceSum(wx, 1) + # FM layer + fm_id_embs = self.Gatherv2(self.embedding_table, id_hldr, 0) + vx = self.Mul(fm_id_embs, mask) + v1 = self.ReduceSum(vx, 1) + v1 = self.Square(v1) + v2 = self.Square(vx) + v2 = self.ReduceSum(v2, 1) + fm_out = 0.5 * self.ReduceSum(v1 - v2, 1) + fm_out = self.Reshape(fm_out, (-1, 1)) + # Deep layer + b = self.Reshape(self.fm_b, (1, 1)) + b = self.Tile(b, (self.batch_size, 1)) + deep_in = self.Reshape(vx, (-1, self.field_size * self.emb_dim)) + deep_in = self.Concat((deep_in, b)) + deep_in = self.dense_layer_1(deep_in) + deep_in = self.dense_layer_2(deep_in) + deep_in = self.dense_layer_3(deep_in) + deep_out = self.dense_layer_4(deep_in) + out = linear_out + fm_out + deep_out + return out, fm_id_weight, fm_id_embs + + +class NetWithLossClass(nn.Cell): + """ + NetWithLossClass definition. + """ + def __init__(self, network, l2_coef=1e-6): + super(NetWithLossClass, self).__init__(auto_prefix=False) + self.loss = P.SigmoidCrossEntropyWithLogits() + self.network = network + self.l2_coef = l2_coef + self.Square = P.Square() + self.ReduceMean_false = P.ReduceMean(keep_dims=False) + self.ReduceSum_false = P.ReduceSum(keep_dims=False) + + def construct(self, batch_ids, batch_wts, label): + predict, fm_id_weight, fm_id_embs = self.network(batch_ids, batch_wts) + log_loss = self.loss(predict, label) + mean_log_loss = self.ReduceMean_false(log_loss) + l2_loss_w = self.ReduceSum_false(self.Square(fm_id_weight)) + l2_loss_v = self.ReduceSum_false(self.Square(fm_id_embs)) + l2_loss_all = self.l2_coef * (l2_loss_v + l2_loss_w) * 0.5 + loss = mean_log_loss + l2_loss_all + return loss + + +class TrainStepWrap(nn.Cell): + """ + TrainStepWrap definition + """ + def __init__(self, network, lr=5e-8, eps=1e-8, loss_scale=1000.0): + super(TrainStepWrap, self).__init__(auto_prefix=False) + self.network = network + self.network.set_train() + self.weights = ParameterTuple(network.trainable_params()) + self.optimizer = Adam(self.weights, learning_rate=lr, eps=eps, loss_scale=loss_scale) + self.hyper_map = C.HyperMap() + self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) + self.sens = loss_scale + + def construct(self, batch_ids, batch_wts, label): + weights = self.weights + loss = self.network(batch_ids, batch_wts, label) + sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) # + grads = self.grad(self.network, weights)(batch_ids, batch_wts, label, sens) + return F.depend(loss, self.optimizer(grads)) + + +class PredictWithSigmoid(nn.Cell): + """ + Eval model with sigmoid. + """ + def __init__(self, network): + super(PredictWithSigmoid, self).__init__(auto_prefix=False) + self.network = network + self.sigmoid = P.Sigmoid() + + def construct(self, batch_ids, batch_wts, labels): + logits, _, _, = self.network(batch_ids, batch_wts) + pred_probs = self.sigmoid(logits) + + return logits, pred_probs, labels + + +class ModelBuilder: + """ + Model builder for DeepFM. + + Args: + model_config (ModelConfig): Model configuration. + train_config (TrainConfig): Train configuration. + """ + def __init__(self, model_config, train_config): + self.model_config = model_config + self.train_config = train_config + + def get_callback_list(self, model=None, eval_dataset=None): + """ + Get callbacks which contains checkpoint callback, eval callback and loss callback. + + Args: + model (Cell): The network is added callback (default=None). + eval_dataset (Dataset): Dataset for eval (default=None). + """ + callback_list = [] + if self.train_config.save_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=self.train_config.save_checkpoint_steps, + keep_checkpoint_max=self.train_config.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint(prefix=self.train_config.ckpt_file_name_prefix, + directory=self.train_config.output_path, + config=config_ck) + callback_list.append(ckpt_cb) + if self.train_config.eval_callback: + if model is None: + raise RuntimeError("train_config.eval_callback is {}; get_callback_list() args model is {}".format( + self.train_config.eval_callback, model)) + if eval_dataset is None: + raise RuntimeError("train_config.eval_callback is {}; get_callback_list() " + "args eval_dataset is {}".format(self.train_config.eval_callback, eval_dataset)) + auc_metric = AUCMetric() + eval_callback = EvalCallBack(model, eval_dataset, auc_metric, + eval_file_path=os.path.join(self.train_config.output_path, + self.train_config.eval_file_name)) + callback_list.append(eval_callback) + if self.train_config.loss_callback: + loss_callback = LossCallBack(loss_file_path=os.path.join(self.train_config.output_path, + self.train_config.loss_file_name)) + callback_list.append(loss_callback) + if callback_list: + return callback_list + return None + + def get_train_eval_net(self): + deepfm_net = DeepFMModel(self.model_config) + loss_net = NetWithLossClass(deepfm_net, l2_coef=self.train_config.l2_coef) + train_net = TrainStepWrap(loss_net, lr=self.train_config.learning_rate, + eps=self.train_config.epsilon, + loss_scale=self.train_config.loss_scale) + eval_net = PredictWithSigmoid(deepfm_net) + return train_net, eval_net diff --git a/example/deepfm_criteo/train.py b/example/deepfm_criteo/train.py new file mode 100644 index 0000000000..228d04c0d3 --- /dev/null +++ b/example/deepfm_criteo/train.py @@ -0,0 +1,91 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train_criteo.""" +import os +import sys +import argparse + +from mindspore import context, ParallelMode +from mindspore.communication.management import init +from mindspore.train.model import Model +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor + +from src.deepfm import ModelBuilder, AUCMetric +from src.config import DataConfig, ModelConfig, TrainConfig +from src.dataset import create_dataset, DataType +from src.callback import EvalCallBack, LossCallBack + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +parser = argparse.ArgumentParser(description='CTR Prediction') +parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') +parser.add_argument('--ckpt_path', type=str, default=None, help='Checkpoint path') +parser.add_argument('--eval_file_name', type=str, default="./auc.log", help='eval file path') +parser.add_argument('--loss_file_name', type=str, default="./loss.log", help='loss file path') +parser.add_argument('--do_eval', type=bool, default=True, help='Do evaluation or not.') + +args_opt, _ = parser.parse_known_args() +device_id = int(os.getenv('DEVICE_ID')) +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) + + +if __name__ == '__main__': + data_config = DataConfig() + model_config = ModelConfig() + train_config = TrainConfig() + + rank_size = int(os.environ.get("RANK_SIZE", 1)) + if rank_size > 1: + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) + init() + rank_id = int(os.environ.get('RANK_ID')) + else: + rank_size = None + rank_id = None + + ds_train = create_dataset(args_opt.dataset_path, + train_mode=True, + epochs=train_config.train_epochs, + batch_size=train_config.batch_size, + data_type=DataType(data_config.data_format), + rank_size=rank_size, + rank_id=rank_id) + + model_builder = ModelBuilder(ModelConfig, TrainConfig) + train_net, eval_net = model_builder.get_train_eval_net() + auc_metric = AUCMetric() + model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) + + time_callback = TimeMonitor(data_size=ds_train.get_dataset_size()) + loss_callback = LossCallBack(loss_file_path=args_opt.loss_file_name) + callback_list = [time_callback, loss_callback] + + if train_config.save_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps, + keep_checkpoint_max=train_config.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix, + directory=args_opt.ckpt_path, + config=config_ck) + callback_list.append(ckpt_cb) + + if args_opt.do_eval: + ds_eval = create_dataset(args_opt.dataset_path, train_mode=False, + epochs=train_config.train_epochs, + batch_size=train_config.batch_size, + data_type=DataType(data_config.data_format)) + eval_callback = EvalCallBack(model, ds_eval, auc_metric, + eval_file_path=args_opt.eval_file_name) + callback_list.append(eval_callback) + model.train(train_config.train_epochs, ds_train, callbacks=callback_list)