forked from mindspore-Ecosystem/mindspore
add DeepFM
This commit is contained in:
parent
ded9608f6d
commit
92c1b2bd31
|
@ -0,0 +1,132 @@
|
|||
# DeepFM Description
|
||||
|
||||
This is an example of training DeepFM with Criteo dataset in MindSpore.
|
||||
|
||||
[Paper](https://arxiv.org/pdf/1703.04247.pdf) Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li, Xiuqiang He
|
||||
|
||||
|
||||
# Model architecture
|
||||
|
||||
The overall network architecture of DeepFM is show below:
|
||||
|
||||
[Link](https://arxiv.org/pdf/1703.04247.pdf)
|
||||
|
||||
|
||||
# Requirements
|
||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||
- Download the criteo dataset for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). Convert the dataset to TFRecord format and move the files to a specified path.
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
|
||||
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
|
||||
|
||||
# Script description
|
||||
|
||||
## Script and sample code
|
||||
|
||||
```python
|
||||
├── deepfm
|
||||
├── README.md
|
||||
├── scripts
|
||||
│ ├──run_train.sh
|
||||
│ ├──run_eval.sh
|
||||
├── src
|
||||
│ ├──config.py
|
||||
│ ├──dataset.py
|
||||
│ ├──callback.py
|
||||
│ ├──deepfm.py
|
||||
├── train.py
|
||||
├── eval.py
|
||||
```
|
||||
|
||||
## Training process
|
||||
|
||||
### Usage
|
||||
|
||||
- sh run_train.sh [DEVICE_NUM] [DATASET_PATH] [MINDSPORE_HCCL_CONFIG_PAHT]
|
||||
- python train.py --dataset_path [DATASET_PATH]
|
||||
|
||||
### Launch
|
||||
|
||||
```
|
||||
# distribute training example
|
||||
sh scripts/run_distribute_train.sh 8 /opt/dataset/criteo /opt/mindspore_hccl_file.json
|
||||
# standalone training example
|
||||
sh scripts/run_standalone_train.sh 0 /opt/dataset/criteo
|
||||
or
|
||||
python train.py --dataset_path /opt/dataset/criteo > output.log 2>&1 &
|
||||
```
|
||||
|
||||
### Result
|
||||
|
||||
Training result will be stored in the example path.
|
||||
Checkpoints will be stored at `./checkpoint` by default,
|
||||
and training log will be redirected to `./output.log` by default,
|
||||
and loss log will be redirected to `./loss.log` by default,
|
||||
and eval log will be redirected to `./auc.log` by default.
|
||||
|
||||
|
||||
## Eval process
|
||||
|
||||
### Usage
|
||||
|
||||
- sh run_eval.sh [DEVICE_ID] [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
|
||||
### Launch
|
||||
|
||||
```
|
||||
# infer example
|
||||
sh scripts/run_eval.sh 0 ~/criteo/eval/ ~/train/deepfm-15_41257.ckpt
|
||||
```
|
||||
|
||||
> checkpoint can be produced in training process.
|
||||
|
||||
### Result
|
||||
|
||||
Inference result will be stored in the example path, you can find result like the followings in `auc.log`.
|
||||
|
||||
```
|
||||
2020-05-27 20:51:35 AUC: 0.80577889065281, eval time: 35.55999s.
|
||||
```
|
||||
|
||||
# Model description
|
||||
|
||||
## Performance
|
||||
|
||||
### Training Performance
|
||||
|
||||
| Parameters | DeepFM |
|
||||
| -------------------------- | ------------------------------------------------------|
|
||||
| Model Version | |
|
||||
| Resource | Ascend 910, cpu:2.60GHz 96cores, memory:1.5T |
|
||||
| uploaded Date | 05/27/2020 |
|
||||
| MindSpore Version | 0.2.0 |
|
||||
| Dataset | Criteo |
|
||||
| Training Parameters | src/config.py |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | SoftmaxCrossEntropyWithLogits |
|
||||
| outputs | |
|
||||
| Loss | 0.4234 |
|
||||
| Accuracy | AUC[0.8055] |
|
||||
| Total time | 91 min |
|
||||
| Params (M) | |
|
||||
| Checkpoint for Fine tuning | |
|
||||
| Model for inference | |
|
||||
|
||||
#### Inference Performance
|
||||
|
||||
| Parameters | | |
|
||||
| -------------------------- | ----------------------------- | ------------------------- |
|
||||
| Model Version | | |
|
||||
| Resource | Ascend 910 | Ascend 310 |
|
||||
| uploaded Date | 05/27/2020 | 05/27/2020 |
|
||||
| MindSpore Version | 0.2.0 | 0.2.0 |
|
||||
| Dataset | Criteo | |
|
||||
| batch_size | 1000 | |
|
||||
| outputs | | |
|
||||
| Accuracy | AUC[0.8055] | |
|
||||
| Speed | | |
|
||||
| Total time | 35.559s | |
|
||||
| Model for inference | | |
|
||||
|
||||
# ModelZoo Homepage
|
||||
[Link](https://gitee.com/mindspore/mindspore/tree/master/mindspore/model_zoo)
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the License);
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# httpwww.apache.orglicensesLICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an AS IS BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train_criteo."""
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.deepfm import ModelBuilder, AUCMetric
|
||||
from src.config import DataConfig, ModelConfig, TrainConfig
|
||||
from src.dataset import create_dataset
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
parser = argparse.ArgumentParser(description='CTR Prediction')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
|
||||
args_opt, _ = parser.parse_known_args()
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
|
||||
|
||||
|
||||
def add_write(file_path, print_str):
|
||||
with open(file_path, 'a+', encoding='utf-8') as file_out:
|
||||
file_out.write(print_str + '\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_config = DataConfig()
|
||||
model_config = ModelConfig()
|
||||
train_config = TrainConfig()
|
||||
|
||||
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
|
||||
epochs=1, batch_size=train_config.batch_size)
|
||||
model_builder = ModelBuilder(ModelConfig, TrainConfig)
|
||||
train_net, eval_net = model_builder.get_train_eval_net()
|
||||
train_net.set_train()
|
||||
eval_net.set_train(False)
|
||||
auc_metric = AUCMetric()
|
||||
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(eval_net, param_dict)
|
||||
|
||||
start = time.time()
|
||||
res = model.eval(ds_eval)
|
||||
eval_time = time.time() - start
|
||||
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
out_str = f'{time_str} AUC: {list(res.values())[0]}, eval time: {eval_time}s.'
|
||||
print(out_str)
|
||||
add_write('./auc.log', str(out_str))
|
|
@ -0,0 +1,44 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "Please run the script as: "
|
||||
echo "sh scripts/run_distribute_train.sh DEVICE_NUM DATASET_PATH MINDSPORE_HCCL_CONFIG_PAHT"
|
||||
echo "for example: sh scripts/run_distribute_train.sh 8 /dataset_path /rank_table_8p.json"
|
||||
echo "After running the script, the network runs in the background, The log will be generated in logx/output.log"
|
||||
|
||||
|
||||
export RANK_SIZE=$1
|
||||
DATA_URL=$2
|
||||
export MINDSPORE_HCCL_CONFIG_PAHT=$3
|
||||
|
||||
for ((i=0; i<RANK_SIZE;i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf log$i
|
||||
mkdir ./log$i
|
||||
cp *.py ./log$i
|
||||
cp -r src ./log$i
|
||||
cd ./log$i || exit
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python -u train.py \
|
||||
--dataset_path=$DATA_URL \
|
||||
--ckpt_path="checkpoint" \
|
||||
--eval_file_name='auc.log' \
|
||||
--loss_file_name='loss.log' \
|
||||
--do_eval=True > output.log 2>&1 &
|
||||
cd ../
|
||||
done
|
|
@ -0,0 +1,32 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "Please run the script as: "
|
||||
echo "sh scripts/run_eval.sh DEVICE_ID DATASET_PATH CHECKPOINT_PATH"
|
||||
echo "for example: sh scripts/run_eval.sh 0 /dataset_path /checkpoint_path"
|
||||
echo "After running the script, the network runs in the background, The log will be generated in ms_log/eval_output.log"
|
||||
|
||||
export DEVICE_ID=$1
|
||||
DATA_URL=$2
|
||||
CHECKPOINT_PATH=$3
|
||||
|
||||
mkdir -p ms_log
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
|
||||
python -u eval.py \
|
||||
--dataset_path=$DATA_URL \
|
||||
--checkpoint_path=$CHECKPOINT_PATH > ms_log/eval_output.log 2>&1 &
|
|
@ -0,0 +1,34 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "Please run the script as: "
|
||||
echo "sh scripts/run_standalone_train.sh DEVICE_ID DATASET_PATH"
|
||||
echo "for example: sh scripts/run_standalone_train.sh 0 /dataset_path"
|
||||
echo "After running the script, the network runs in the background, The log will be generated in ms_log/output.log"
|
||||
|
||||
export DEVICE_ID=$1
|
||||
DATA_URL=$2
|
||||
|
||||
mkdir -p ms_log
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
|
||||
python -u train.py \
|
||||
--dataset_path=$DATA_URL \
|
||||
--ckpt_path="checkpoint" \
|
||||
--eval_file_name='auc.log' \
|
||||
--loss_file_name='loss.log' \
|
||||
--do_eval=True > ms_log/output.log 2>&1 &
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the License);
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# httpwww.apache.orglicensesLICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an AS IS BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the License);
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# httpwww.apache.orglicensesLICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an AS IS BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Defined callback for DeepFM.
|
||||
"""
|
||||
import time
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
|
||||
def add_write(file_path, out_str):
|
||||
with open(file_path, 'a+', encoding='utf-8') as file_out:
|
||||
file_out.write(out_str + '\n')
|
||||
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss is NAN or INF terminating training.
|
||||
Note
|
||||
If per_print_times is 0 do not print loss.
|
||||
"""
|
||||
def __init__(self, model, eval_dataset, auc_metric, eval_file_path):
|
||||
super(EvalCallBack, self).__init__()
|
||||
self.model = model
|
||||
self.eval_dataset = eval_dataset
|
||||
self.aucMetric = auc_metric
|
||||
self.aucMetric.clear()
|
||||
self.eval_file_path = eval_file_path
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
start_time = time.time()
|
||||
out = self.model.eval(self.eval_dataset)
|
||||
eval_time = int(time.time() - start_time)
|
||||
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
out_str = "{} EvalCallBack metric{}; eval_time{}s".format(
|
||||
time_str, out.values(), eval_time)
|
||||
print(out_str)
|
||||
add_write(self.eval_file_path, out_str)
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss is NAN or INF terminating training.
|
||||
Note
|
||||
If per_print_times is 0 do not print loss.
|
||||
Args
|
||||
loss_file_path (str) The file absolute path, to save as loss_file;
|
||||
per_print_times (int) Print loss every times. Default 1.
|
||||
"""
|
||||
def __init__(self, loss_file_path, per_print_times=1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self.loss_file_path = loss_file_path
|
||||
self._per_print_times = per_print_times
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
loss = cb_params.net_outputs.asnumpy()
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
cur_num = cb_params.cur_step_num
|
||||
if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
|
||||
with open(self.loss_file_path, "a+") as loss_file:
|
||||
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
loss_file.write("{} epoch: {} step: {}, loss is {}\n".format(
|
||||
time_str, cb_params.cur_epoch_num, cur_step_in_epoch, loss))
|
||||
print("epoch: {} step: {}, loss is {}\n".format(
|
||||
cb_params.cur_epoch_num, cur_step_in_epoch, loss))
|
||||
|
||||
|
||||
class TimeMonitor(Callback):
|
||||
"""
|
||||
Time monitor for calculating cost of each epoch.
|
||||
Args
|
||||
data_size (int) step size of an epoch.
|
||||
"""
|
||||
def __init__(self, data_size):
|
||||
super(TimeMonitor, self).__init__()
|
||||
self.data_size = data_size
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
self.epoch_time = time.time()
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
epoch_mseconds = (time.time() - self.epoch_time) * 1000
|
||||
per_step_mseconds = epoch_mseconds / self.data_size
|
||||
print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
|
||||
|
||||
def step_begin(self, run_context):
|
||||
self.step_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
step_mseconds = (time.time() - self.step_time) * 1000
|
||||
print(f"step time {step_mseconds}", flush=True)
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
|
||||
|
||||
class DataConfig:
|
||||
"""
|
||||
Define parameters of dataset.
|
||||
"""
|
||||
data_vocab_size = 184965
|
||||
train_num_of_parts = 21
|
||||
test_num_of_parts = 3
|
||||
batch_size = 1000
|
||||
data_field_size = 39
|
||||
# dataset format, 1: mindrecord, 2: tfrecord, 3: h5
|
||||
data_format = 2
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
"""
|
||||
Define parameters of model.
|
||||
"""
|
||||
batch_size = DataConfig.batch_size
|
||||
data_field_size = DataConfig.data_field_size
|
||||
data_vocab_size = DataConfig.data_vocab_size
|
||||
data_emb_dim = 80
|
||||
deep_layer_args = [[400, 400, 512], "relu"]
|
||||
init_args = [-0.01, 0.01]
|
||||
weight_bias_init = ['normal', 'normal']
|
||||
keep_prob = 0.9
|
||||
|
||||
|
||||
class TrainConfig:
|
||||
"""
|
||||
Define parameters of training.
|
||||
"""
|
||||
batch_size = DataConfig.batch_size
|
||||
l2_coef = 1e-6
|
||||
learning_rate = 1e-5
|
||||
epsilon = 1e-8
|
||||
loss_scale = 1024.0
|
||||
train_epochs = 15
|
||||
save_checkpoint = True
|
||||
ckpt_file_name_prefix = "deepfm"
|
||||
save_checkpoint_steps = 1
|
||||
keep_checkpoint_max = 15
|
||||
eval_callback = True
|
||||
loss_callback = True
|
|
@ -0,0 +1,299 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Create train or eval dataset.
|
||||
"""
|
||||
import os
|
||||
import math
|
||||
from enum import Enum
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from .config import DataConfig
|
||||
|
||||
|
||||
class DataType(Enum):
|
||||
"""
|
||||
Enumerate supported dataset format.
|
||||
"""
|
||||
MINDRECORD = 1
|
||||
TFRECORD = 2
|
||||
H5 = 3
|
||||
|
||||
|
||||
class H5Dataset():
|
||||
"""
|
||||
Create dataset with H5 format.
|
||||
|
||||
Args:
|
||||
data_path (str): Dataset directory.
|
||||
train_mode (bool): Whether dataset is used for train or eval (default=True).
|
||||
train_num_of_parts (int): The number of train data file (default=21).
|
||||
test_num_of_parts (int): The number of test data file (default=3).
|
||||
"""
|
||||
max_length = 39
|
||||
|
||||
def __init__(self, data_path, train_mode=True,
|
||||
train_num_of_parts=DataConfig.train_num_of_parts,
|
||||
test_num_of_parts=DataConfig.test_num_of_parts):
|
||||
self._hdf_data_dir = data_path
|
||||
self._is_training = train_mode
|
||||
if self._is_training:
|
||||
self._file_prefix = 'train'
|
||||
self._num_of_parts = train_num_of_parts
|
||||
else:
|
||||
self._file_prefix = 'test'
|
||||
self._num_of_parts = test_num_of_parts
|
||||
self.data_size = self._bin_count(self._hdf_data_dir, self._file_prefix, self._num_of_parts)
|
||||
print("data_size: {}".format(self.data_size))
|
||||
|
||||
def _bin_count(self, hdf_data_dir, file_prefix, num_of_parts):
|
||||
size = 0
|
||||
for part in range(num_of_parts):
|
||||
_y = pd.read_hdf(os.path.join(hdf_data_dir, f'{file_prefix}_output_part_{str(part)}.h5'))
|
||||
size += _y.shape[0]
|
||||
return size
|
||||
|
||||
def _iterate_hdf_files_(self, num_of_parts=None,
|
||||
shuffle_block=False):
|
||||
"""
|
||||
iterate among hdf files(blocks). when the whole data set is finished, the iterator restarts
|
||||
from the beginning, thus the data stream will never stop
|
||||
:param train_mode: True or false,false is eval_mode,
|
||||
this file iterator will go through the train set
|
||||
:param num_of_parts: number of files
|
||||
:param shuffle_block: shuffle block files at every round
|
||||
:return: input_hdf_file_name, output_hdf_file_name, finish_flag
|
||||
"""
|
||||
parts = np.arange(num_of_parts)
|
||||
while True:
|
||||
if shuffle_block:
|
||||
for _ in range(int(shuffle_block)):
|
||||
np.random.shuffle(parts)
|
||||
for i, p in enumerate(parts):
|
||||
yield os.path.join(self._hdf_data_dir, f'{self._file_prefix}_input_part_{str(p)}.h5'), \
|
||||
os.path.join(self._hdf_data_dir, f'{self._file_prefix}_output_part_{str(p)}.h5'), \
|
||||
i + 1 == len(parts)
|
||||
|
||||
def _generator(self, X, y, batch_size, shuffle=True):
|
||||
"""
|
||||
should be accessed only in private
|
||||
:param X:
|
||||
:param y:
|
||||
:param batch_size:
|
||||
:param shuffle:
|
||||
:return:
|
||||
"""
|
||||
number_of_batches = np.ceil(1. * X.shape[0] / batch_size)
|
||||
counter = 0
|
||||
finished = False
|
||||
sample_index = np.arange(X.shape[0])
|
||||
if shuffle:
|
||||
for _ in range(int(shuffle)):
|
||||
np.random.shuffle(sample_index)
|
||||
assert X.shape[0] > 0
|
||||
while True:
|
||||
batch_index = sample_index[batch_size * counter: batch_size * (counter + 1)]
|
||||
X_batch = X[batch_index]
|
||||
y_batch = y[batch_index]
|
||||
counter += 1
|
||||
yield X_batch, y_batch, finished
|
||||
if counter == number_of_batches:
|
||||
counter = 0
|
||||
finished = True
|
||||
|
||||
def batch_generator(self, batch_size=1000,
|
||||
random_sample=False, shuffle_block=False):
|
||||
"""
|
||||
:param train_mode: True or false,false is eval_mode,
|
||||
:param batch_size
|
||||
:param num_of_parts: number of files
|
||||
:param random_sample: if True, will shuffle
|
||||
:param shuffle_block: shuffle file blocks at every round
|
||||
:return:
|
||||
"""
|
||||
|
||||
for hdf_in, hdf_out, _ in self._iterate_hdf_files_(self._num_of_parts,
|
||||
shuffle_block):
|
||||
start = stop = None
|
||||
X_all = pd.read_hdf(hdf_in, start=start, stop=stop).values
|
||||
y_all = pd.read_hdf(hdf_out, start=start, stop=stop).values
|
||||
data_gen = self._generator(X_all, y_all, batch_size,
|
||||
shuffle=random_sample)
|
||||
finished = False
|
||||
|
||||
while not finished:
|
||||
X, y, finished = data_gen.__next__()
|
||||
X_id = X[:, 0:self.max_length]
|
||||
X_va = X[:, self.max_length:]
|
||||
yield np.array(X_id.astype(dtype=np.int32)), \
|
||||
np.array(X_va.astype(dtype=np.float32)), \
|
||||
np.array(y.astype(dtype=np.float32))
|
||||
|
||||
|
||||
def _get_h5_dataset(directory, train_mode=True, epochs=1, batch_size=1000):
|
||||
"""
|
||||
Get dataset with h5 format.
|
||||
|
||||
Args:
|
||||
directory (str): Dataset directory.
|
||||
train_mode (bool): Whether dataset is use for train or eval (default=True).
|
||||
epochs (int): Dataset epoch size (default=1).
|
||||
batch_size (int): Dataset batch size (default=1000)
|
||||
|
||||
Returns:
|
||||
Dataset.
|
||||
"""
|
||||
data_para = {'batch_size': batch_size}
|
||||
if train_mode:
|
||||
data_para['random_sample'] = True
|
||||
data_para['shuffle_block'] = True
|
||||
|
||||
h5_dataset = H5Dataset(data_path=directory, train_mode=train_mode)
|
||||
numbers_of_batch = math.ceil(h5_dataset.data_size / batch_size)
|
||||
|
||||
def _iter_h5_data():
|
||||
train_eval_gen = h5_dataset.batch_generator(**data_para)
|
||||
for _ in range(0, numbers_of_batch, 1):
|
||||
yield train_eval_gen.__next__()
|
||||
|
||||
ds = de.GeneratorDataset(_iter_h5_data, ["ids", "weights", "labels"])
|
||||
ds.set_dataset_size(numbers_of_batch)
|
||||
ds = ds.repeat(epochs)
|
||||
return ds
|
||||
|
||||
|
||||
def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=1000,
|
||||
line_per_sample=1000, rank_size=None, rank_id=None):
|
||||
"""
|
||||
Get dataset with mindrecord format.
|
||||
|
||||
Args:
|
||||
directory (str): Dataset directory.
|
||||
train_mode (bool): Whether dataset is use for train or eval (default=True).
|
||||
epochs (int): Dataset epoch size (default=1).
|
||||
batch_size (int): Dataset batch size (default=1000).
|
||||
line_per_sample (int): The number of sample per line (default=1000).
|
||||
rank_size (int): The number of device, not necessary for single device (default=None).
|
||||
rank_id (int): Id of device, not necessary for single device (default=None).
|
||||
|
||||
Returns:
|
||||
Dataset.
|
||||
"""
|
||||
file_prefix_name = 'train_input_part.mindrecord' if train_mode else 'test_input_part.mindrecord'
|
||||
file_suffix_name = '00' if train_mode else '0'
|
||||
shuffle = train_mode
|
||||
|
||||
if rank_size is not None and rank_id is not None:
|
||||
ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
|
||||
columns_list=['feat_ids', 'feat_vals', 'label'],
|
||||
num_shards=rank_size, shard_id=rank_id, shuffle=shuffle,
|
||||
num_parallel_workers=8)
|
||||
else:
|
||||
ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
|
||||
columns_list=['feat_ids', 'feat_vals', 'label'],
|
||||
shuffle=shuffle, num_parallel_workers=8)
|
||||
ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True)
|
||||
ds = ds.map(operations=(lambda x, y, z: (np.array(x).flatten().reshape(batch_size, 39),
|
||||
np.array(y).flatten().reshape(batch_size, 39),
|
||||
np.array(z).flatten().reshape(batch_size, 1))),
|
||||
input_columns=['feat_ids', 'feat_vals', 'label'],
|
||||
columns_order=['feat_ids', 'feat_vals', 'label'],
|
||||
num_parallel_workers=8)
|
||||
ds = ds.repeat(epochs)
|
||||
return ds
|
||||
|
||||
|
||||
def _get_tf_dataset(directory, train_mode=True, epochs=1, batch_size=1000,
|
||||
line_per_sample=1000, rank_size=None, rank_id=None):
|
||||
"""
|
||||
Get dataset with tfrecord format.
|
||||
|
||||
Args:
|
||||
directory (str): Dataset directory.
|
||||
train_mode (bool): Whether dataset is use for train or eval (default=True).
|
||||
epochs (int): Dataset epoch size (default=1).
|
||||
batch_size (int): Dataset batch size (default=1000).
|
||||
line_per_sample (int): The number of sample per line (default=1000).
|
||||
rank_size (int): The number of device, not necessary for single device (default=None).
|
||||
rank_id (int): Id of device, not necessary for single device (default=None).
|
||||
|
||||
Returns:
|
||||
Dataset.
|
||||
"""
|
||||
dataset_files = []
|
||||
file_prefixt_name = 'train' if train_mode else 'test'
|
||||
shuffle = train_mode
|
||||
for (dir_path, _, filenames) in os.walk(directory):
|
||||
for filename in filenames:
|
||||
if file_prefixt_name in filename and 'tfrecord' in filename:
|
||||
dataset_files.append(os.path.join(dir_path, filename))
|
||||
schema = de.Schema()
|
||||
schema.add_column('feat_ids', de_type=mstype.int32)
|
||||
schema.add_column('feat_vals', de_type=mstype.float32)
|
||||
schema.add_column('label', de_type=mstype.float32)
|
||||
if rank_size is not None and rank_id is not None:
|
||||
ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle,
|
||||
schema=schema, num_parallel_workers=8,
|
||||
num_shards=rank_size, shard_id=rank_id,
|
||||
shard_equal_rows=True)
|
||||
else:
|
||||
ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle,
|
||||
schema=schema, num_parallel_workers=8)
|
||||
ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True)
|
||||
ds = ds.map(operations=(lambda x, y, z: (
|
||||
np.array(x).flatten().reshape(batch_size, 39),
|
||||
np.array(y).flatten().reshape(batch_size, 39),
|
||||
np.array(z).flatten().reshape(batch_size, 1))),
|
||||
input_columns=['feat_ids', 'feat_vals', 'label'],
|
||||
columns_order=['feat_ids', 'feat_vals', 'label'],
|
||||
num_parallel_workers=8)
|
||||
ds = ds.repeat(epochs)
|
||||
return ds
|
||||
|
||||
|
||||
def create_dataset(directory, train_mode=True, epochs=1, batch_size=1000,
|
||||
data_type=DataType.TFRECORD, line_per_sample=1000,
|
||||
rank_size=None, rank_id=None):
|
||||
"""
|
||||
Get dataset.
|
||||
|
||||
Args:
|
||||
directory (str): Dataset directory.
|
||||
train_mode (bool): Whether dataset is use for train or eval (default=True).
|
||||
epochs (int): Dataset epoch size (default=1).
|
||||
batch_size (int): Dataset batch size (default=1000).
|
||||
data_type (DataType): The type of dataset which is one of H5, TFRECORE, MINDRECORD (default=TFRECORD).
|
||||
line_per_sample (int): The number of sample per line (default=1000).
|
||||
rank_size (int): The number of device, not necessary for single device (default=None).
|
||||
rank_id (int): Id of device, not necessary for single device (default=None).
|
||||
|
||||
Returns:
|
||||
Dataset.
|
||||
"""
|
||||
if data_type == DataType.MINDRECORD:
|
||||
return _get_mindrecord_dataset(directory, train_mode, epochs,
|
||||
batch_size, line_per_sample,
|
||||
rank_size, rank_id)
|
||||
if data_type == DataType.TFRECORD:
|
||||
return _get_tf_dataset(directory, train_mode, epochs, batch_size,
|
||||
line_per_sample, rank_size=rank_size, rank_id=rank_id)
|
||||
|
||||
if rank_size is not None and rank_size > 1:
|
||||
raise ValueError('Please use mindrecord dataset.')
|
||||
return _get_h5_dataset(directory, train_mode, epochs, batch_size)
|
|
@ -0,0 +1,370 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_training """
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from sklearn.metrics import roc_auc_score
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn import Dropout
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.nn.metrics import Metric
|
||||
from mindspore import nn, ParameterTuple, Parameter
|
||||
from mindspore.common.initializer import Uniform, initializer, Normal
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
|
||||
from .callback import EvalCallBack, LossCallBack
|
||||
|
||||
|
||||
np_type = np.float32
|
||||
ms_type = mstype.float32
|
||||
|
||||
|
||||
class AUCMetric(Metric):
|
||||
"""AUC metric for DeepFM model."""
|
||||
def __init__(self):
|
||||
super(AUCMetric, self).__init__()
|
||||
self.pred_probs = []
|
||||
self.true_labels = []
|
||||
|
||||
def clear(self):
|
||||
"""Clear the internal evaluation result."""
|
||||
self.pred_probs = []
|
||||
self.true_labels = []
|
||||
|
||||
def update(self, *inputs):
|
||||
batch_predict = inputs[1].asnumpy()
|
||||
batch_label = inputs[2].asnumpy()
|
||||
self.pred_probs.extend(batch_predict.flatten().tolist())
|
||||
self.true_labels.extend(batch_label.flatten().tolist())
|
||||
|
||||
def eval(self):
|
||||
if len(self.true_labels) != len(self.pred_probs):
|
||||
raise RuntimeError('true_labels.size() is not equal to pred_probs.size()')
|
||||
auc = roc_auc_score(self.true_labels, self.pred_probs)
|
||||
return auc
|
||||
|
||||
|
||||
def init_method(method, shape, name, max_val=0.01):
|
||||
"""
|
||||
The method of init parameters.
|
||||
|
||||
Args:
|
||||
method (str): The method uses to initialize parameter.
|
||||
shape (list): The shape of parameter.
|
||||
name (str): The name of parameter.
|
||||
max_val (float): Max value in parameter when uses 'random' or 'uniform' to initialize parameter.
|
||||
|
||||
Returns:
|
||||
Parameter.
|
||||
"""
|
||||
if method in ['random', 'uniform']:
|
||||
params = Parameter(initializer(Uniform(max_val), shape, ms_type), name=name)
|
||||
elif method == "one":
|
||||
params = Parameter(initializer("ones", shape, ms_type), name=name)
|
||||
elif method == 'zero':
|
||||
params = Parameter(initializer("zeros", shape, ms_type), name=name)
|
||||
elif method == "normal":
|
||||
params = Parameter(initializer(Normal(max_val), shape, ms_type), name=name)
|
||||
return params
|
||||
|
||||
|
||||
def init_var_dict(init_args, values):
|
||||
"""
|
||||
Init parameter.
|
||||
|
||||
Args:
|
||||
init_args (list): Define max and min value of parameters.
|
||||
values (list): Define name, shape and init method of parameters.
|
||||
|
||||
Returns:
|
||||
dict, a dict ot Parameter.
|
||||
"""
|
||||
var_map = {}
|
||||
_, _max_val = init_args
|
||||
for key, shape, init_flag in values:
|
||||
if key not in var_map.keys():
|
||||
if init_flag in ['random', 'uniform']:
|
||||
var_map[key] = Parameter(initializer(Uniform(_max_val), shape, ms_type), name=key)
|
||||
elif init_flag == "one":
|
||||
var_map[key] = Parameter(initializer("ones", shape, ms_type), name=key)
|
||||
elif init_flag == "zero":
|
||||
var_map[key] = Parameter(initializer("zeros", shape, ms_type), name=key)
|
||||
elif init_flag == 'normal':
|
||||
var_map[key] = Parameter(initializer(Normal(_max_val), shape, ms_type), name=key)
|
||||
return var_map
|
||||
|
||||
|
||||
class DenseLayer(nn.Cell):
|
||||
"""
|
||||
Dense Layer for Deep Layer of DeepFM Model;
|
||||
Containing: activation, matmul, bias_add;
|
||||
Args:
|
||||
input_dim (int): the shape of weight at 0-aixs;
|
||||
output_dim (int): the shape of weight at 1-aixs, and shape of bias
|
||||
weight_bias_init (list): weight and bias init method, "random", "uniform", "one", "zero", "normal";
|
||||
act_str (str): activation function method, "relu", "sigmoid", "tanh";
|
||||
keep_prob (float): Dropout Layer keep_prob_rate;
|
||||
scale_coef (float): input scale coefficient;
|
||||
"""
|
||||
def __init__(self, input_dim, output_dim, weight_bias_init, act_str, keep_prob=0.9, scale_coef=1.0):
|
||||
super(DenseLayer, self).__init__()
|
||||
weight_init, bias_init = weight_bias_init
|
||||
self.weight = init_method(weight_init, [input_dim, output_dim], name="weight")
|
||||
self.bias = init_method(bias_init, [output_dim], name="bias")
|
||||
self.act_func = self._init_activation(act_str)
|
||||
self.matmul = P.MatMul(transpose_b=False)
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.cast = P.Cast()
|
||||
self.dropout = Dropout(keep_prob=keep_prob)
|
||||
self.mul = P.Mul()
|
||||
self.realDiv = P.RealDiv()
|
||||
self.scale_coef = scale_coef
|
||||
|
||||
def _init_activation(self, act_str):
|
||||
act_str = act_str.lower()
|
||||
if act_str == "relu":
|
||||
act_func = P.ReLU()
|
||||
elif act_str == "sigmoid":
|
||||
act_func = P.Sigmoid()
|
||||
elif act_str == "tanh":
|
||||
act_func = P.Tanh()
|
||||
return act_func
|
||||
|
||||
def construct(self, x):
|
||||
x = self.act_func(x)
|
||||
if self.training:
|
||||
x = self.dropout(x)
|
||||
x = self.mul(x, self.scale_coef)
|
||||
x = self.cast(x, mstype.float16)
|
||||
weight = self.cast(self.weight, mstype.float16)
|
||||
wx = self.matmul(x, weight)
|
||||
wx = self.cast(wx, mstype.float32)
|
||||
wx = self.realDiv(wx, self.scale_coef)
|
||||
output = self.bias_add(wx, self.bias)
|
||||
return output
|
||||
|
||||
|
||||
class DeepFMModel(nn.Cell):
|
||||
"""
|
||||
From paper: "DeepFM: A Factorization-Machine based Neural Network for CTR Prediction"
|
||||
|
||||
Args:
|
||||
batch_size (int): smaple_number of per step in training; (int, batch_size=128)
|
||||
filed_size (int): input filed number, or called id_feature number; (int, filed_size=39)
|
||||
vocab_size (int): id_feature vocab size, id dict size; (int, vocab_size=200000)
|
||||
emb_dim (int): id embedding vector dim, id mapped to embedding vector; (int, emb_dim=100)
|
||||
deep_layer_args (list): Deep Layer args, layer_dim_list, layer_activator;
|
||||
(int, deep_layer_args=[[100, 100, 100], "relu"])
|
||||
init_args (list): init args for Parameter init; (list, init_args=[min, max, seeds])
|
||||
weight_bias_init (list): weight, bias init method for deep layers;
|
||||
(list[str], weight_bias_init=['random', 'zero'])
|
||||
keep_prob (float): if dropout_flag is True, keep_prob rate to keep connect; (float, keep_prob=0.8)
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(DeepFMModel, self).__init__()
|
||||
|
||||
self.batch_size = config.batch_size
|
||||
self.field_size = config.data_field_size
|
||||
self.vocab_size = config.data_vocab_size
|
||||
self.emb_dim = config.data_emb_dim
|
||||
self.deep_layer_dims_list, self.deep_layer_act = config.deep_layer_args
|
||||
self.init_args = config.init_args
|
||||
self.weight_bias_init = config.weight_bias_init
|
||||
self.keep_prob = config.keep_prob
|
||||
init_acts = [('W_l2', [self.vocab_size, 1], 'normal'),
|
||||
('V_l2', [self.vocab_size, self.emb_dim], 'normal'),
|
||||
('b', [1], 'normal')]
|
||||
var_map = init_var_dict(self.init_args, init_acts)
|
||||
self.fm_w = var_map["W_l2"]
|
||||
self.fm_b = var_map["b"]
|
||||
self.embedding_table = var_map["V_l2"]
|
||||
# Deep Layers
|
||||
self.deep_input_dims = self.field_size * self.emb_dim + 1
|
||||
self.all_dim_list = [self.deep_input_dims] + self.deep_layer_dims_list + [1]
|
||||
self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1],
|
||||
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
|
||||
self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2],
|
||||
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
|
||||
self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3],
|
||||
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
|
||||
self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4],
|
||||
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
|
||||
# FM, linear Layers
|
||||
self.Gatherv2 = P.GatherV2()
|
||||
self.Mul = P.Mul()
|
||||
self.ReduceSum = P.ReduceSum(keep_dims=False)
|
||||
self.Reshape = P.Reshape()
|
||||
self.Square = P.Square()
|
||||
self.Shape = P.Shape()
|
||||
self.Tile = P.Tile()
|
||||
self.Concat = P.Concat(axis=1)
|
||||
self.Cast = P.Cast()
|
||||
|
||||
def construct(self, id_hldr, wt_hldr):
|
||||
"""
|
||||
Args:
|
||||
id_hldr: batch ids; [bs, field_size]
|
||||
wt_hldr: batch weights; [bs, field_size]
|
||||
"""
|
||||
|
||||
mask = self.Reshape(wt_hldr, (self.batch_size, self.field_size, 1))
|
||||
# Linear layer
|
||||
fm_id_weight = self.Gatherv2(self.fm_w, id_hldr, 0)
|
||||
wx = self.Mul(fm_id_weight, mask)
|
||||
linear_out = self.ReduceSum(wx, 1)
|
||||
# FM layer
|
||||
fm_id_embs = self.Gatherv2(self.embedding_table, id_hldr, 0)
|
||||
vx = self.Mul(fm_id_embs, mask)
|
||||
v1 = self.ReduceSum(vx, 1)
|
||||
v1 = self.Square(v1)
|
||||
v2 = self.Square(vx)
|
||||
v2 = self.ReduceSum(v2, 1)
|
||||
fm_out = 0.5 * self.ReduceSum(v1 - v2, 1)
|
||||
fm_out = self.Reshape(fm_out, (-1, 1))
|
||||
# Deep layer
|
||||
b = self.Reshape(self.fm_b, (1, 1))
|
||||
b = self.Tile(b, (self.batch_size, 1))
|
||||
deep_in = self.Reshape(vx, (-1, self.field_size * self.emb_dim))
|
||||
deep_in = self.Concat((deep_in, b))
|
||||
deep_in = self.dense_layer_1(deep_in)
|
||||
deep_in = self.dense_layer_2(deep_in)
|
||||
deep_in = self.dense_layer_3(deep_in)
|
||||
deep_out = self.dense_layer_4(deep_in)
|
||||
out = linear_out + fm_out + deep_out
|
||||
return out, fm_id_weight, fm_id_embs
|
||||
|
||||
|
||||
class NetWithLossClass(nn.Cell):
|
||||
"""
|
||||
NetWithLossClass definition.
|
||||
"""
|
||||
def __init__(self, network, l2_coef=1e-6):
|
||||
super(NetWithLossClass, self).__init__(auto_prefix=False)
|
||||
self.loss = P.SigmoidCrossEntropyWithLogits()
|
||||
self.network = network
|
||||
self.l2_coef = l2_coef
|
||||
self.Square = P.Square()
|
||||
self.ReduceMean_false = P.ReduceMean(keep_dims=False)
|
||||
self.ReduceSum_false = P.ReduceSum(keep_dims=False)
|
||||
|
||||
def construct(self, batch_ids, batch_wts, label):
|
||||
predict, fm_id_weight, fm_id_embs = self.network(batch_ids, batch_wts)
|
||||
log_loss = self.loss(predict, label)
|
||||
mean_log_loss = self.ReduceMean_false(log_loss)
|
||||
l2_loss_w = self.ReduceSum_false(self.Square(fm_id_weight))
|
||||
l2_loss_v = self.ReduceSum_false(self.Square(fm_id_embs))
|
||||
l2_loss_all = self.l2_coef * (l2_loss_v + l2_loss_w) * 0.5
|
||||
loss = mean_log_loss + l2_loss_all
|
||||
return loss
|
||||
|
||||
|
||||
class TrainStepWrap(nn.Cell):
|
||||
"""
|
||||
TrainStepWrap definition
|
||||
"""
|
||||
def __init__(self, network, lr=5e-8, eps=1e-8, loss_scale=1000.0):
|
||||
super(TrainStepWrap, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_train()
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = Adam(self.weights, learning_rate=lr, eps=eps, loss_scale=loss_scale)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||
self.sens = loss_scale
|
||||
|
||||
def construct(self, batch_ids, batch_wts, label):
|
||||
weights = self.weights
|
||||
loss = self.network(batch_ids, batch_wts, label)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) #
|
||||
grads = self.grad(self.network, weights)(batch_ids, batch_wts, label, sens)
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
|
||||
class PredictWithSigmoid(nn.Cell):
|
||||
"""
|
||||
Eval model with sigmoid.
|
||||
"""
|
||||
def __init__(self, network):
|
||||
super(PredictWithSigmoid, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.sigmoid = P.Sigmoid()
|
||||
|
||||
def construct(self, batch_ids, batch_wts, labels):
|
||||
logits, _, _, = self.network(batch_ids, batch_wts)
|
||||
pred_probs = self.sigmoid(logits)
|
||||
|
||||
return logits, pred_probs, labels
|
||||
|
||||
|
||||
class ModelBuilder:
|
||||
"""
|
||||
Model builder for DeepFM.
|
||||
|
||||
Args:
|
||||
model_config (ModelConfig): Model configuration.
|
||||
train_config (TrainConfig): Train configuration.
|
||||
"""
|
||||
def __init__(self, model_config, train_config):
|
||||
self.model_config = model_config
|
||||
self.train_config = train_config
|
||||
|
||||
def get_callback_list(self, model=None, eval_dataset=None):
|
||||
"""
|
||||
Get callbacks which contains checkpoint callback, eval callback and loss callback.
|
||||
|
||||
Args:
|
||||
model (Cell): The network is added callback (default=None).
|
||||
eval_dataset (Dataset): Dataset for eval (default=None).
|
||||
"""
|
||||
callback_list = []
|
||||
if self.train_config.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=self.train_config.save_checkpoint_steps,
|
||||
keep_checkpoint_max=self.train_config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix=self.train_config.ckpt_file_name_prefix,
|
||||
directory=self.train_config.output_path,
|
||||
config=config_ck)
|
||||
callback_list.append(ckpt_cb)
|
||||
if self.train_config.eval_callback:
|
||||
if model is None:
|
||||
raise RuntimeError("train_config.eval_callback is {}; get_callback_list() args model is {}".format(
|
||||
self.train_config.eval_callback, model))
|
||||
if eval_dataset is None:
|
||||
raise RuntimeError("train_config.eval_callback is {}; get_callback_list() "
|
||||
"args eval_dataset is {}".format(self.train_config.eval_callback, eval_dataset))
|
||||
auc_metric = AUCMetric()
|
||||
eval_callback = EvalCallBack(model, eval_dataset, auc_metric,
|
||||
eval_file_path=os.path.join(self.train_config.output_path,
|
||||
self.train_config.eval_file_name))
|
||||
callback_list.append(eval_callback)
|
||||
if self.train_config.loss_callback:
|
||||
loss_callback = LossCallBack(loss_file_path=os.path.join(self.train_config.output_path,
|
||||
self.train_config.loss_file_name))
|
||||
callback_list.append(loss_callback)
|
||||
if callback_list:
|
||||
return callback_list
|
||||
return None
|
||||
|
||||
def get_train_eval_net(self):
|
||||
deepfm_net = DeepFMModel(self.model_config)
|
||||
loss_net = NetWithLossClass(deepfm_net, l2_coef=self.train_config.l2_coef)
|
||||
train_net = TrainStepWrap(loss_net, lr=self.train_config.learning_rate,
|
||||
eps=self.train_config.epsilon,
|
||||
loss_scale=self.train_config.loss_scale)
|
||||
eval_net = PredictWithSigmoid(deepfm_net)
|
||||
return train_net, eval_net
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train_criteo."""
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
from mindspore import context, ParallelMode
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
|
||||
from src.deepfm import ModelBuilder, AUCMetric
|
||||
from src.config import DataConfig, ModelConfig, TrainConfig
|
||||
from src.dataset import create_dataset, DataType
|
||||
from src.callback import EvalCallBack, LossCallBack
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
parser = argparse.ArgumentParser(description='CTR Prediction')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--ckpt_path', type=str, default=None, help='Checkpoint path')
|
||||
parser.add_argument('--eval_file_name', type=str, default="./auc.log", help='eval file path')
|
||||
parser.add_argument('--loss_file_name', type=str, default="./loss.log", help='loss file path')
|
||||
parser.add_argument('--do_eval', type=bool, default=True, help='Do evaluation or not.')
|
||||
|
||||
args_opt, _ = parser.parse_known_args()
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_config = DataConfig()
|
||||
model_config = ModelConfig()
|
||||
train_config = TrainConfig()
|
||||
|
||||
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||
if rank_size > 1:
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True)
|
||||
init()
|
||||
rank_id = int(os.environ.get('RANK_ID'))
|
||||
else:
|
||||
rank_size = None
|
||||
rank_id = None
|
||||
|
||||
ds_train = create_dataset(args_opt.dataset_path,
|
||||
train_mode=True,
|
||||
epochs=train_config.train_epochs,
|
||||
batch_size=train_config.batch_size,
|
||||
data_type=DataType(data_config.data_format),
|
||||
rank_size=rank_size,
|
||||
rank_id=rank_id)
|
||||
|
||||
model_builder = ModelBuilder(ModelConfig, TrainConfig)
|
||||
train_net, eval_net = model_builder.get_train_eval_net()
|
||||
auc_metric = AUCMetric()
|
||||
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
|
||||
|
||||
time_callback = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
loss_callback = LossCallBack(loss_file_path=args_opt.loss_file_name)
|
||||
callback_list = [time_callback, loss_callback]
|
||||
|
||||
if train_config.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps,
|
||||
keep_checkpoint_max=train_config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix,
|
||||
directory=args_opt.ckpt_path,
|
||||
config=config_ck)
|
||||
callback_list.append(ckpt_cb)
|
||||
|
||||
if args_opt.do_eval:
|
||||
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
|
||||
epochs=train_config.train_epochs,
|
||||
batch_size=train_config.batch_size,
|
||||
data_type=DataType(data_config.data_format))
|
||||
eval_callback = EvalCallBack(model, ds_eval, auc_metric,
|
||||
eval_file_path=args_opt.eval_file_name)
|
||||
callback_list.append(eval_callback)
|
||||
model.train(train_config.train_epochs, ds_train, callbacks=callback_list)
|
Loading…
Reference in New Issue