fat-deepffm commit
This commit is contained in:
parent
fc884e44b6
commit
9396f7122a
|
@ -0,0 +1,298 @@
|
|||
# Contents
|
||||
|
||||
- [FAT-DeepFFM Description](#deepfm-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Distributed Training](#distributed-training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Inference Process](#inference-process)
|
||||
- [Export MindIR](#export-mindir)
|
||||
- [Infer on Ascend310](#infer-on-ascend310)
|
||||
- [result](#result)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [Inference Performance](#evaluation-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [FAT-DeepFFM Description](#contents)
|
||||
|
||||
Click-through rate estimation is a very important part of computing advertising and recommendation systems. Meanwhile, CTR models often use some commonly used methods in other fields, such as computer vision and natural language processing. The most common one is the Attention mechanism. Use the Attention mechanism to pick out the most important features from the list and filter out the irrelevant ones. The attention mechanism is combined with CTR prediction model of deep learning.
|
||||
|
||||
[Paper](https://arxiv.org/abs/1905.06336): Junlin Zhang , Tongwen Huang , Zhiqi Zhang FAT-DeepFFM: Field Attentive Deep Field-aware Factorization Machine
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
Fat - DeepFFM consists of three parts. The FFM component is a factorization machine that is proposed to learn feature interactions for recommendation. The depth component is a feedforward neural network for learning higher-order feature interactions, and the attention part is the self-attention mechanism of features. The output of the initial feature from attention is then entered into the depth module. FAT-deepffm can simultaneously learn low-order and high-order feature interactions from the input original feature.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
- [1] A dataset used in Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li, Xiuqiang He. DeepFM: A Factorization-Machine based Neural Network for CTR Prediction[J]. 2017.
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend/GPU/CPU)
|
||||
- Prepare hardware environment with Ascend, GPU, or CPU processor.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Quick Start](#contents)
|
||||
|
||||
- Download the Dataset
|
||||
|
||||
> Please refer to [1] to obtain the download link
|
||||
|
||||
```bash
|
||||
mkdir -p data/ && cd data/
|
||||
wget DATA_LINK
|
||||
tar -zxvf dac.tar.gz
|
||||
```
|
||||
|
||||
- Use this script to preprocess the data. This may take about one hour and the generated mindrecord data is under data/mindrecord.
|
||||
|
||||
```bash
|
||||
python src/preprocess_data.py --data_path=./data/ --dense_dim=13 --slot_dim=26 --threshold=100 --train_line_count=45840617 --skip_id_convert=0
|
||||
```
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```shell
|
||||
# run training example
|
||||
python train.py \
|
||||
--dataset_path='data/mindrecord' \
|
||||
--ckpt_path='./checkpoint/Fat-DeepFFM' \
|
||||
--eval_file_name='./auc.log' \
|
||||
--loss_file_name='./loss.log' \
|
||||
--device_target='Ascend' \
|
||||
--do_eval=True > output.log 2>&1 &
|
||||
|
||||
# run distributed training example
|
||||
bash scripts/run_distribute_train.sh /dataset_path 8 scripts/hccl_8p.json False
|
||||
|
||||
# run evaluation example
|
||||
python eval.py \
|
||||
--dataset_path='dataset/mindrecord' \
|
||||
--ckpt_path='./checkpoint/Fat-DeepFFM.ckpt'\
|
||||
--device_target = 'Ascend'\
|
||||
--device_id=0 > eval_output.log 2>&1 &
|
||||
OR
|
||||
bash scripts/run_eval.sh 0 Ascend /dataset_path /ckpt_path
|
||||
```
|
||||
|
||||
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
|
||||
|
||||
Please follow the instructions in the link below:
|
||||
|
||||
[hccl tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
```path
|
||||
.
|
||||
└─Fat-deepffm
|
||||
├─README.md
|
||||
├─scripts
|
||||
├─run_alone_train.sh # launch standalone training(1p) in Ascend
|
||||
├─run_distribute_train.sh # launch distributed training(8p) in Ascend
|
||||
└─run_eval.sh # launch evaluating in Ascend
|
||||
├─src
|
||||
├─config.py # parameter configuration
|
||||
├─callback.py # define callback function
|
||||
├─fat-deepfm.py # fat-deepffm network
|
||||
├─lr_generator.py # generative learning rate
|
||||
├─metrics.py # verify the model
|
||||
├─dataset.py # create dataset for deepfm
|
||||
├─eval.py # eval net
|
||||
├─export.py # export net
|
||||
└─train.py # train net
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py
|
||||
|
||||
- train parameters
|
||||
|
||||
```help
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--dataset_path DATASET_PATH
|
||||
Dataset path
|
||||
--ckpt_path CKPT_PATH
|
||||
Checkpoint path
|
||||
--eval_file_name EVAL_FILE_NAME
|
||||
Auc log file path. Default: "./auc.log"
|
||||
--loss_file_name LOSS_FILE_NAME
|
||||
Loss log file path. Default: "./loss.log"
|
||||
--do_eval DO_EVAL Do evaluation or not. Default: True
|
||||
--device_target DEVICE_TARGET
|
||||
Ascend or GPU. Default: Ascend
|
||||
```
|
||||
|
||||
- eval parameters
|
||||
|
||||
```help
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--ckpt_path CHECKPOINT_PATH
|
||||
Checkpoint file path
|
||||
--dataset_path DATASET_PATH
|
||||
Dataset path
|
||||
--device_target DEVICE_TARGET
|
||||
Ascend or GPU. Default: Ascend
|
||||
```
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
### Training
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```shell
|
||||
python train.py \
|
||||
--dataset_path='/data/' \
|
||||
--ckpt_path='./checkpoint' \
|
||||
--eval_file_name='./auc.log' \
|
||||
--loss_file_name='./loss.log' \
|
||||
--device_target='Ascend' \
|
||||
--do_eval=True > output.log 2>&1 &
|
||||
```
|
||||
|
||||
The python command above will run in the background, you can view the results through the file `ms_log/output.log`.
|
||||
|
||||
After training, you'll get some checkpoint files under `./checkpoint` folder by default. The loss value are saved in loss.log file.
|
||||
|
||||
```log
|
||||
2021-06-19 21:59:10 epoch: 1 step: 5166, loss is 0.46262410283088684
|
||||
2021-06-19 22:12:13 epoch: 2 step: 5166, loss is 0.4792023301124573
|
||||
2021-06-19 22:21:03 epoch: 3 step: 5166, loss is 0.4666571617126465
|
||||
2021-06-19 22:29:54 epoch: 4 step: 5166, loss is 0.44029417634010315
|
||||
...
|
||||
```
|
||||
|
||||
The model checkpoint will be saved in the current directory.
|
||||
|
||||
### Distributed Training
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```shell
|
||||
bash scripts/run_distribute_train.sh /dataset_path 8 scripts/hccl_8p.json False
|
||||
```
|
||||
|
||||
The above shell script will run distribute training in the background. You can view the results through the file `log[X]/output.log`. The loss value are saved in loss.log file.
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Evaluation
|
||||
|
||||
- evaluation on dataset when running on Ascend
|
||||
|
||||
Before running the command below, please check the checkpoint path used for evaluation.
|
||||
|
||||
```shell
|
||||
python eval.py \
|
||||
--dataset_path=' /dataset_path' \
|
||||
--checkpoint_path='/ckpt_path' \
|
||||
--device_id=0 \
|
||||
--device_target='Ascend' > ms_log/eval_output.log 2>&1 &
|
||||
OR
|
||||
bash scripts/run_eval.sh 0 Ascend /dataset_path /ckpt_path
|
||||
```
|
||||
|
||||
The above python command will run in the background. You can view the results through the file "eval_output.log". The accuracy is saved in auc.log file.
|
||||
|
||||
```log
|
||||
{'AUC': 0.8091001899667086}
|
||||
```
|
||||
|
||||
## Inference Process
|
||||
|
||||
### [Export MindIR](#contents)
|
||||
|
||||
```shell
|
||||
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
|
||||
```
|
||||
|
||||
The ckpt_file parameter is required,
|
||||
`EXPORT_FORMAT` should be in ["AIR", "MINDIR"]
|
||||
|
||||
### Infer on Ascend310
|
||||
|
||||
Before performing inference, the mindir file must be exported by `export.py` script. We only provide an example of inference using MINDIR model.
|
||||
|
||||
```shell
|
||||
# Ascend310 inference
|
||||
bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [NEED_PREPROCESS] [DEVICE_ID]
|
||||
```
|
||||
|
||||
- `NEED_PREPROCESS` means weather need preprocess or not, it's value is 'y' or 'n'.
|
||||
- `DEVICE_ID` is optional, default value is 0.
|
||||
|
||||
### result
|
||||
|
||||
Inference result is saved in current path, you can find result like this in acc.log file.
|
||||
|
||||
```bash
|
||||
'AUC': 0.8091001899667086
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Training Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| Model Version | Fat-DeepFFM |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 |
|
||||
| uploaded Date | 09/15/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | Criteo |
|
||||
| Training Parameters | epoch=30, batch_size=1000, lr=1e-4 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | Sigmoid Cross Entropy With Logits |
|
||||
| outputs | AUC |
|
||||
| Loss | 0.45 |
|
||||
| Speed | 1pc: 8.16 ms/step; |
|
||||
| Total time | 1pc: 4 hours; |
|
||||
| Parameters (M) | 560.34 |
|
||||
| Checkpoint for Fine tuning | 87.65M (.ckpt file) |
|
||||
| Scripts | [deepfm script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/recommend/Fat-DeepFFM) |
|
||||
|
||||
### Inference Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | DeepFM |
|
||||
| Resource | Ascend 910; OS Euler2.8 |
|
||||
| Uploaded Date | 06/20/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | Criteo |
|
||||
| batch_size | 1000 |
|
||||
| outputs | AUC |
|
||||
| AUC | 1pc: 80.90%; |
|
||||
| Model for inference | 87.65M (.ckpt file) |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
We set the random seed before training in train.py.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
""" eval model"""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from src.config import ModelConfig
|
||||
from src.dataset import get_mindrecord_dataset
|
||||
from src.fat_deepffm import ModelBuilder
|
||||
from src.metrics import AUCMetric
|
||||
from mindspore import context, Model
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
parser = argparse.ArgumentParser(description='CTR Prediction')
|
||||
parser.add_argument('--dataset_path', type=str, default="/data/FM/mindrecord", help='Dataset path')
|
||||
parser.add_argument('--ckpt_path', type=str, default="/checkpoint/Fat-DeepFFM-24_5166.ckpt", help='Checkpoint path')
|
||||
parser.add_argument('--eval_file_name', type=str, default="./auc.log",
|
||||
help='Auc log file path. Default: "./auc.log"')
|
||||
parser.add_argument('--loss_file_name', type=str, default="./loss.log",
|
||||
help='Loss log file path. Default: "./loss.log"')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"),
|
||||
help="device target, support Ascend, GPU and CPU.")
|
||||
parser.add_argument('--device_id', type=int, default=0, choices=(0, 1, 2, 3, 4, 5, 6, 7),
|
||||
help="device target, support Ascend, GPU and CPU.")
|
||||
args = parser.parse_args()
|
||||
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||
print("rank_size", rank_size)
|
||||
|
||||
set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_config = ModelConfig()
|
||||
device_id = int(os.getenv('DEVICE_ID', default=args.device_id))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
|
||||
device_id=device_id)
|
||||
print("Load dataset...")
|
||||
train_net, test_net = ModelBuilder(model_config).get_train_eval_net()
|
||||
auc_metric = AUCMetric()
|
||||
model = Model(train_net, eval_network=test_net, metrics={"AUC": auc_metric})
|
||||
ds_test = get_mindrecord_dataset(args.dataset_path, train_mode=False)
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(train_net, param_dict)
|
||||
print("Training started...")
|
||||
res = model.eval(ds_test, dataset_sink_mode=False)
|
||||
out_str = f'AUC: {list(res.values())[0]}'
|
||||
print(res)
|
||||
print(out_str)
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""export ckpt to model"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import export, load_checkpoint
|
||||
from src.config import ModelConfig
|
||||
from src.fat_deepffm import ModelBuilder
|
||||
|
||||
parser = argparse.ArgumentParser(description="deepfm export")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1000, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="deepfm", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
args = parser.parse_args()
|
||||
set_seed(1)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = ModelConfig()
|
||||
|
||||
model_builder = ModelBuilder(config)
|
||||
_, network = model_builder.get_train_eval_net()
|
||||
network.set_train(False)
|
||||
|
||||
load_checkpoint(args.ckpt_file, net=network)
|
||||
|
||||
batch_ids = Tensor(np.zeros([config.batch_size, config.cats_dim]).astype(np.int32))
|
||||
batch_wts = Tensor(np.zeros([config.batch_size, config.dense_dim]).astype(np.float32))
|
||||
labels = Tensor(np.zeros([config.batch_size, 1]).astype(np.float32))
|
||||
|
||||
input_data = [batch_ids, batch_wts, labels]
|
||||
export(network, *input_data, file_name=args.file_name, file_format=args.file_format)
|
|
@ -0,0 +1,26 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
|
||||
echo "Please run the script as: "
|
||||
echo "sh scripts/run_distribute_train.sh DEVICE_NUM DATASET_PATH RANK_TABLE_FILE"
|
||||
echo "for example: bash scripts/run_alone_train.sh 0 data/mindrecord/ Ascend False"
|
||||
echo "After running the script, the network runs in the background, The log will be generated in logx/output.log"
|
||||
|
||||
export LANG="zh_CN.UTF-8"
|
||||
export DEVICE_ID=$1
|
||||
echo "start training"
|
||||
python train.py --dataset_path=$2 --device_target=$3 --do_eval=$4 >output.log 2>&1 &
|
||||
cd ../
|
|
@ -0,0 +1,35 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
|
||||
echo "Please run the script as: "
|
||||
echo "for example: bash scripts/run_distribute_train.sh /dataset_path 8 scripts/hccl_8p.json False"
|
||||
echo "After running the script, the network runs in the background, The log will be generated in logx/output.log"
|
||||
|
||||
export LANG="zh_CN.UTF-8"
|
||||
export RANK_SIZE=$2
|
||||
export RANK_TABLE_FILE=$3
|
||||
for ((i = 0; i < RANK_SIZE; i++)); do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
rm -rf Task$i
|
||||
mkdir ./Task$i
|
||||
cp *.py ./Task$i
|
||||
cp -r ./src ./Task$i
|
||||
cd ./Task$i || exit
|
||||
python train.py --dataset_path=$1 --do_eval=$4 >output.log 2>&1 &
|
||||
cd ../
|
||||
done
|
|
@ -0,0 +1,25 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
|
||||
echo "Please run the script as: "
|
||||
echo "for example: bash scripts/run_eval.sh 0 Ascend /dataset_path /ckpt_path"
|
||||
echo "After running the script, the network runs in the background, The log will be generated in logx/output.log"
|
||||
|
||||
export LANG="zh_CN.UTF-8"
|
||||
export DEVICE_ID=$1
|
||||
echo "start training"
|
||||
python eval.py --device_target=$2 --dataset_path=$3 --ckpt_path=$4 >eval_output.log 2>&1 &
|
||||
cd ../
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
""" Callback"""
|
||||
import time
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
|
||||
def add_write(file_path, out_str):
|
||||
""" Write info"""
|
||||
with open(file_path, 'a+', encoding='utf-8') as file_out:
|
||||
file_out.write(out_str + '\n')
|
||||
|
||||
|
||||
class AUCCallBack(Callback):
|
||||
""" AUCCallBack"""
|
||||
def __init__(self, model, eval_dataset, eval_file_path):
|
||||
self.model = model
|
||||
self.eval_dataset = eval_dataset
|
||||
self.eval_file_path = eval_file_path
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
start_time = time.time()
|
||||
out = self.model.eval(self.eval_dataset)
|
||||
eval_time = int(time.time() - start_time)
|
||||
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
out_str = "{} AUC:{}; eval_time{}s".format(
|
||||
time_str, out.values(), eval_time)
|
||||
print(out_str)
|
||||
add_write(self.eval_file_path, out_str)
|
||||
|
||||
|
||||
class LossCallback(Callback):
|
||||
""" LossCallBack"""
|
||||
def __init__(self, loss_file_path, per_print_times=1):
|
||||
super(LossCallback, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self.loss_file_path = loss_file_path
|
||||
|
||||
def step_end(self, run_context):
|
||||
""" run after step_end """
|
||||
cb_params = run_context.original_args()
|
||||
loss = cb_params.net_outputs
|
||||
|
||||
if isinstance(loss, (tuple, list)):
|
||||
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
|
||||
loss = loss[0]
|
||||
|
||||
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
|
||||
loss = np.mean(loss.asnumpy())
|
||||
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
|
||||
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
|
||||
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
|
||||
cb_params.cur_epoch_num, cur_step_in_epoch))
|
||||
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
|
||||
with open(self.loss_file_path, "a+") as loss_file:
|
||||
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
loss_file.write("{} epoch: {} step: {}, loss is {}\n".format(
|
||||
time_str, cb_params.cur_epoch_num, cur_step_in_epoch, loss))
|
||||
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)
|
||||
|
||||
|
||||
class TimeMonitor(Callback):
|
||||
""" TimeMonitor"""
|
||||
def __init__(self, data_size):
|
||||
super(TimeMonitor, self).__init__()
|
||||
self.data_size = data_size
|
||||
self.epoch_time = None
|
||||
self.step_time = None
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
self.epoch_time = time.time()
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
epoch_mseconds = (time.time() - self.epoch_time) * 1000
|
||||
per_step_mseconds = epoch_mseconds / self.data_size
|
||||
print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
|
||||
|
||||
def step_begin(self, run_context):
|
||||
self.step_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
step_mseconds = (time.time() - self.step_time) * 1000
|
||||
print(f"step time {step_mseconds}", flush=True)
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""config"""
|
||||
|
||||
class ModelConfig:
|
||||
"""model config"""
|
||||
vocb_size = 184965
|
||||
batch_size = 1000
|
||||
emb_dim = 8
|
||||
lr_end = 1e-4
|
||||
lr_init = 0
|
||||
epsilon = 1e-8
|
||||
loss_scale = 1
|
||||
epoch_size = 30
|
||||
steps_per_epoch = 5166
|
||||
repeat_size = 1
|
||||
weight_bias_init = ['normal', 'normal']
|
||||
deep_layer_args = [[1024, 512, 128, 32, 1], "relu"]
|
||||
att_layer_args = [676, "relu"]
|
||||
keep_prob = 0.6
|
||||
ckpt_path = "./data/"
|
||||
keep_checkpoint_max = 50
|
||||
cats_dim = 26
|
||||
dense_dim = 13
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""Get dataset"""
|
||||
import os
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=1000,
|
||||
rank_size=None, rank_id=None, line_per_sample=1000):
|
||||
"""Get Mindrecord dataset"""
|
||||
file_prefix_name = 'train_input_part.mindrecord' if train_mode else 'test_input_part.mindrecord'
|
||||
file_suffix_name = '00' if train_mode else '0'
|
||||
shuffle = train_mode
|
||||
if rank_size is not None and rank_id is not None:
|
||||
data_set = ds.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
|
||||
columns_list=['cats_vals', 'num_vals', 'label'],
|
||||
num_shards=rank_size, shard_id=rank_id, shuffle=shuffle,
|
||||
num_parallel_workers=8)
|
||||
else:
|
||||
data_set = ds.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
|
||||
columns_list=['cats_vals', 'num_vals', 'label'],
|
||||
shuffle=shuffle, num_parallel_workers=8)
|
||||
data_set = data_set.batch(int(batch_size / line_per_sample), drop_remainder=True)
|
||||
data_set = data_set.map(operations=(lambda x, y, z: (np.array(x).flatten().reshape(batch_size, 26),
|
||||
np.array(y).flatten().reshape(batch_size, 13),
|
||||
np.array(z).flatten().reshape(batch_size, 1))),
|
||||
input_columns=['cats_vals', 'num_vals', 'label'],
|
||||
column_order=['cats_vals', 'num_vals', 'label'],
|
||||
num_parallel_workers=8)
|
||||
data_set = data_set.repeat(epochs)
|
||||
return data_set
|
|
@ -0,0 +1,382 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
""" Fat-deepFFM"""
|
||||
import numpy as np
|
||||
from src.lr_generator import get_warmup_linear_lr
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
import mindspore.ops as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
from mindspore import Parameter, ParameterTuple
|
||||
from mindspore import Tensor
|
||||
|
||||
from mindspore import nn
|
||||
from mindspore.nn import Adam, DistributedGradReducer, Dropout
|
||||
from mindspore.nn.probability.distribution import Uniform
|
||||
from mindspore.context import ParallelMode
|
||||
|
||||
from mindspore.parallel._utils import _get_parallel_mode, _get_gradients_mean, _get_device_num
|
||||
|
||||
|
||||
def init_method(method, shape, name, max_val=1.0):
|
||||
""" Initialize weight"""
|
||||
params = None
|
||||
if method in ['uniform']:
|
||||
params = Parameter(initializer(Uniform(max_val), shape, mstype.float32), name=name)
|
||||
elif method == "one":
|
||||
params = Parameter(initializer("ones", shape, mstype.float32), name=name)
|
||||
elif method == 'zero':
|
||||
params = Parameter(initializer("zeros", shape, mstype.float32), name=name)
|
||||
elif method == "normal":
|
||||
params = Parameter(Tensor(np.random.normal(loc=0.0, scale=0.01, size=shape).astype(dtype=np.float32)),
|
||||
name=name)
|
||||
return params
|
||||
|
||||
|
||||
class DenseFeaturesLinear(nn.Cell):
|
||||
""" First order linear combination of dense features"""
|
||||
|
||||
def __init__(self, nume_dims=13, output_dim=1):
|
||||
super().__init__()
|
||||
self.dense = DenseLayer(nume_dims, output_dim, ['normal', 'normal'],
|
||||
"relu", use_dropout=False, use_act=False, use_bn=False)
|
||||
|
||||
def construct(self, x):
|
||||
res = self.dense(x)
|
||||
return res
|
||||
|
||||
|
||||
class DenseHighFeaturesLinear(nn.Cell):
|
||||
"""High-order linear combinations of dense features"""
|
||||
|
||||
def __init__(self, num_dims=13, output_dim=1):
|
||||
super().__init__()
|
||||
self.dense_d3_1 = DenseLayer(num_dims, 512, ['normal', 'normal'], "relu")
|
||||
self.dense_d3_2 = DenseLayer(512, 512, ['normal', 'normal'], "relu")
|
||||
self.dense_d3_3 = DenseLayer(512, output_dim, ['normal', 'normal'], "relu",
|
||||
use_dropout=False, use_act=False, use_bn=False)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.dense_d3_1(x)
|
||||
x = self.dense_d3_2(x)
|
||||
res = self.dense_d3_3(x)
|
||||
return res
|
||||
|
||||
|
||||
# 计算FFM一阶类别特征
|
||||
class SparseFeaturesLinear(nn.Cell):
|
||||
"""First-order linear combination of sparse features"""
|
||||
|
||||
def __init__(self, config, output_dim=1):
|
||||
super().__init__()
|
||||
self.weight = Parameter(Tensor(
|
||||
np.random.normal(loc=0.0, scale=0.01, size=[config.vocb_size, output_dim]).astype(dtype=np.float32)))
|
||||
self.reduceSum = P.ReduceSum(keep_dims=True)
|
||||
self.gather = P.Gather()
|
||||
self.squeeze = P.Squeeze(2)
|
||||
|
||||
def construct(self, x): # [b,26]
|
||||
res = self.gather(self.weight, x, 0)
|
||||
res = self.reduceSum(res, 1)
|
||||
res = self.squeeze(res)
|
||||
return res
|
||||
|
||||
|
||||
class SparseFeaturesFFMEmbedding(nn.Cell):
|
||||
"""The sparse features are dense"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.num_field = 26
|
||||
self.gather = P.Gather()
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.weights = []
|
||||
for _ in range(self.num_field):
|
||||
weight = Parameter(Tensor(
|
||||
np.random.normal(loc=0.0, scale=0.01, size=[config.vocb_size, config.emb_dim]).astype(
|
||||
dtype=np.float32)))
|
||||
self.weights.append(weight)
|
||||
|
||||
def construct(self, x):
|
||||
xs = ()
|
||||
for i in range(self.num_field):
|
||||
xs += (self.gather(self.weights[i], x, 0),)
|
||||
xs = self.concat(xs)
|
||||
return xs
|
||||
|
||||
|
||||
class DenseLayer(nn.Cell):
|
||||
"""Full connection layer templates"""
|
||||
|
||||
def __init__(self, input_dim, output_dim, weight_bias_init, act_str, keep_prob=0.9, convert_dtype=True,
|
||||
use_dropout=True, use_act=True, use_bn=True):
|
||||
super(DenseLayer, self).__init__()
|
||||
weight_init, bias_init = weight_bias_init
|
||||
self.weight = init_method(weight_init, [input_dim, output_dim], name="weight")
|
||||
self.bias = init_method(bias_init, [output_dim], name="bias")
|
||||
self.act_func = self._init_activation(act_str)
|
||||
self.matmul = P.MatMul(transpose_b=False)
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.cast = P.Cast()
|
||||
self.dropout = Dropout(keep_prob=keep_prob)
|
||||
self.mul = P.Mul()
|
||||
self.realDiv = P.RealDiv()
|
||||
self.convert_dtype = convert_dtype
|
||||
self.use_act = use_act
|
||||
self.use_dropout = use_dropout
|
||||
self.use_bn = use_bn
|
||||
self.bn = nn.BatchNorm1d(output_dim)
|
||||
|
||||
def _init_activation(self, act_str):
|
||||
act_func = None
|
||||
act_str = act_str.lower()
|
||||
if act_str == "relu":
|
||||
act_func = P.ReLU()
|
||||
elif act_str == "sigmoid":
|
||||
act_func = P.Sigmoid()
|
||||
elif act_str == "tanh":
|
||||
act_func = P.Tanh()
|
||||
return act_func
|
||||
|
||||
def construct(self, x):
|
||||
"""Construct function"""
|
||||
if self.convert_dtype:
|
||||
x = self.cast(x, mstype.float16)
|
||||
weight = self.cast(self.weight, mstype.float16)
|
||||
bias = self.cast(self.bias, mstype.float16)
|
||||
wx = self.matmul(x, weight)
|
||||
wx = self.bias_add(wx, bias)
|
||||
if self.use_bn:
|
||||
wx = self.bn(wx)
|
||||
if self.use_act:
|
||||
wx = self.act_func(wx)
|
||||
wx = self.cast(wx, mstype.float32)
|
||||
else:
|
||||
wx = self.matmul(x, self.weight)
|
||||
wx = self.bias_add(wx, self.bias)
|
||||
if self.use_bn:
|
||||
wx = self.bn(wx)
|
||||
if self.use_act:
|
||||
wx = self.act_func(wx)
|
||||
if self.use_dropout:
|
||||
wx = self.dropout(wx)
|
||||
return wx
|
||||
|
||||
|
||||
class AttentionFeaturelayer(nn.Cell):
|
||||
"""Attentional mechanism"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.cats_field = config.cats_dim
|
||||
self.weight_bias_init = config.weight_bias_init
|
||||
self.att_dim, self.att_layer_act = config.att_layer_args
|
||||
# attention部分
|
||||
self.att_conv = nn.Conv1d(in_channels=config.emb_dim, out_channels=1, kernel_size=1, stride=1)
|
||||
self.att_bn = nn.BatchNorm1d(676)
|
||||
self.att_re = nn.ReLU()
|
||||
self.dense_att_1 = DenseLayer(self.cats_field * self.cats_field, self.att_dim, self.weight_bias_init,
|
||||
self.att_layer_act)
|
||||
self.dense_att_2 = DenseLayer(self.att_dim, self.cats_field * self.cats_field, self.weight_bias_init,
|
||||
self.att_layer_act)
|
||||
self.transpose = P.Transpose()
|
||||
self.squeeze = P.Squeeze(axis=1)
|
||||
self.mul = P.Mul()
|
||||
self.addDim = P.ExpandDims()
|
||||
|
||||
def construct(self, x): # [b,676,8]
|
||||
tx = self.transpose(x, (0, 2, 1)) # 转换维度
|
||||
tx = self.att_conv(tx) # [b ,1, 676]
|
||||
att_xs = self.att_re(self.att_bn(self.squeeze(tx))) # (b,676)
|
||||
att_xs = self.dense_att_1(att_xs) # [b, 256]
|
||||
att_xs = self.dense_att_2(att_xs) # [b, 676]
|
||||
att_xs = self.addDim(att_xs, 2) # [b,676,1]
|
||||
out = self.mul(x, att_xs)
|
||||
return out
|
||||
|
||||
|
||||
# 计算FFM二阶类别特征
|
||||
class FieldAwareFactorizationMachine(nn.Cell):
|
||||
""" Sparse feature crossover"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.num_fields = 26
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.mul = P.Mul()
|
||||
self.stack = P.Stack(axis=0)
|
||||
self.cat = P.Concat(axis=0)
|
||||
self.sum = P.ReduceSum(keep_dims=True)
|
||||
self.squeeze = P.Squeeze(axis=2)
|
||||
self.squeeze1 = P.Squeeze(axis=1)
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
def construct(self, x): # [b,676,8]
|
||||
""" Sparse feature crossover """
|
||||
ix = ()
|
||||
for i in range(25):
|
||||
for j in range(i + 1, 26):
|
||||
m = 26 * j + i
|
||||
n = 26 * i + j
|
||||
ix += (self.squeeze1(self.mul(x[::, m:m + 1:1, ::], x[::, n:n + 1:1, ::])),)
|
||||
ix1 = self.stack(ix[:190]) # [190 b 8]
|
||||
ix2 = self.stack(ix[190:]) # [135 b 8]
|
||||
ix = self.cat((ix1, ix2)) # [325 b 8]
|
||||
ix = self.sum(ix, 2) # [325 b 1]
|
||||
ix = self.squeeze(ix) # [325 b]
|
||||
ix = self.transpose(ix, (1, 0)) # [b 325]
|
||||
ix = self.sum(ix, 1) # [b 1]
|
||||
return ix
|
||||
|
||||
|
||||
# 计算深度网络mlp
|
||||
class MultiLayerPerceptron(nn.Cell):
|
||||
"""Deep network layer"""
|
||||
|
||||
def __init__(self, config, input_dim):
|
||||
super().__init__()
|
||||
self.weight_bias_init = config.weight_bias_init
|
||||
self.att_dim, self.att_layer_act = config.deep_layer_args
|
||||
self.keep_prob = config.keep_prob
|
||||
self.flatten = nn.Flatten()
|
||||
self.d_dense = DenseLayer(config.dense_dim, input_dim, self.weight_bias_init,
|
||||
self.att_layer_act, self.keep_prob)
|
||||
self.dense1 = DenseLayer(input_dim, self.att_dim[0], self.weight_bias_init,
|
||||
self.att_layer_act, self.keep_prob)
|
||||
self.dense2 = DenseLayer(self.att_dim[0], self.att_dim[1], self.weight_bias_init,
|
||||
self.att_layer_act, self.keep_prob)
|
||||
self.dense3 = DenseLayer(self.att_dim[1], self.att_dim[2], self.weight_bias_init,
|
||||
self.att_layer_act, self.keep_prob)
|
||||
self.dense4 = DenseLayer(self.att_dim[2], self.att_dim[3], self.weight_bias_init,
|
||||
self.att_layer_act, self.keep_prob)
|
||||
self.dense5 = DenseLayer(self.att_dim[3], self.att_dim[4], self.weight_bias_init,
|
||||
self.att_layer_act, self.keep_prob, use_dropout=False, use_bn=False, use_act=False)
|
||||
|
||||
def construct(self, d, x):
|
||||
x = self.flatten(x) + self.d_dense(d)
|
||||
x = self.dense1(x)
|
||||
x = self.dense2(x)
|
||||
x = self.dense3(x)
|
||||
x = self.dense4(x)
|
||||
x = self.dense5(x)
|
||||
return x
|
||||
|
||||
|
||||
class Fat_DeepFFM(nn.Cell):
|
||||
""""The general model"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense_1st = DenseFeaturesLinear() # 一阶数值特征
|
||||
self.dense_high = DenseHighFeaturesLinear()
|
||||
self.sparse_1st = SparseFeaturesLinear(config) # 一阶类别特征
|
||||
self.FFMEmb = SparseFeaturesFFMEmbedding(config)
|
||||
self.attention = AttentionFeaturelayer(config)
|
||||
self.ffm = FieldAwareFactorizationMachine()
|
||||
self.embed_output_dim = 26 * 26 * config.emb_dim
|
||||
self.mlp = MultiLayerPerceptron(config, self.embed_output_dim)
|
||||
|
||||
def construct(self, cats_vals, num_vals):
|
||||
""""cats_vals:[b,13], num_vals:[b 26]"""
|
||||
X_dense, X_sparse = num_vals, cats_vals
|
||||
FFME = self.FFMEmb(X_sparse) # [b,676,8]
|
||||
dense_1st_res = self.dense_1st(X_dense) # [b,1]
|
||||
dense_1st__high_res = self.dense_high(X_dense)
|
||||
sparse_1st_res = self.sparse_1st(X_sparse) # [b,1]
|
||||
attention_res = self.attention(FFME) # [b,676,8]
|
||||
ffm_res = self.ffm(attention_res) # [b,1]
|
||||
mlp_res = self.mlp(X_dense, FFME) # # [b,1]
|
||||
res = dense_1st_res + dense_1st__high_res + sparse_1st_res + ffm_res + mlp_res # [b,1]
|
||||
return res
|
||||
|
||||
|
||||
class NetWithLossClass(nn.Cell):
|
||||
"""Get the model results"""
|
||||
|
||||
def __init__(self, network):
|
||||
super(NetWithLossClass, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.loss = P.SigmoidCrossEntropyWithLogits()
|
||||
|
||||
def construct(self, cats_vals, num_vals, label):
|
||||
predict = self.network(cats_vals, num_vals)
|
||||
loss = self.loss(predict, label)
|
||||
return loss
|
||||
|
||||
|
||||
class TrainStepWrap(nn.Cell):
|
||||
"""Reverse passing"""
|
||||
|
||||
def __init__(self, network, config):
|
||||
super(TrainStepWrap, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_train()
|
||||
self.lr = get_warmup_linear_lr(config.lr_init, config.lr_end, config.epoch_size * config.steps_per_epoch)
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = Adam(self.weights, learning_rate=self.lr, eps=config.epsilon,
|
||||
loss_scale=config.loss_scale)
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = config.loss_scale
|
||||
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = None
|
||||
parallel_mode = _get_parallel_mode()
|
||||
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, mean, degree)
|
||||
|
||||
def construct(self, cats_vals, num_vals, label):
|
||||
weights = self.weights
|
||||
loss = self.network(cats_vals, num_vals, label)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) #
|
||||
grads = self.grad(self.network, weights)(cats_vals, num_vals, label, sens)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
|
||||
class ModelBuilder:
|
||||
"""Get the model"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def get_train_eval_net(self):
|
||||
deepfm_net = Fat_DeepFFM(self.config)
|
||||
train_net = NetWithLossClass(deepfm_net)
|
||||
train_net = TrainStepWrap(train_net, self.config)
|
||||
test_net = PredictWithSigmoid(deepfm_net)
|
||||
return train_net, test_net
|
||||
|
||||
|
||||
class PredictWithSigmoid(nn.Cell):
|
||||
"""Model to predict"""
|
||||
|
||||
def __init__(self, network):
|
||||
super(PredictWithSigmoid, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.sigmoid = P.Sigmoid()
|
||||
|
||||
def construct(self, cats_vals, num_vals, label):
|
||||
logits = self.network(cats_vals, num_vals)
|
||||
pred_probs = self.sigmoid(logits)
|
||||
return logits, pred_probs, label
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""lr"""
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _generate_linear_lr(lr_init, lr_end, total_steps, warmup_steps, useWarmup=False):
|
||||
""" warmup lr"""
|
||||
lr_each_step = []
|
||||
if useWarmup:
|
||||
for i in range(0, total_steps):
|
||||
lrate = lr_init + (lr_end - lr_init) * i / warmup_steps
|
||||
if i >= warmup_steps:
|
||||
lrate = lr_end - (lr_end - lr_init) * (i - warmup_steps) / (total_steps - warmup_steps)
|
||||
lr_each_step.append(lrate)
|
||||
else:
|
||||
for i in range(total_steps):
|
||||
lrate = lr_end - (lr_end - lr_init) * i / total_steps
|
||||
lr_each_step.append(lrate)
|
||||
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def get_warmup_linear_lr(lr_init, lr_end, total_steps, warmup_steps=10):
|
||||
lr_each_step = _generate_linear_lr(lr_init, lr_end, total_steps, warmup_steps)
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
lr = get_warmup_linear_lr(0, 1e-4, 1000)
|
||||
print(lr.size)
|
||||
print(lr)
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""
|
||||
Area under cure metric
|
||||
"""
|
||||
from mindspore.nn.metrics import Metric
|
||||
from sklearn.metrics import roc_auc_score
|
||||
|
||||
|
||||
class AUCMetric(Metric):
|
||||
"""Area under cure metric"""
|
||||
|
||||
def __init__(self):
|
||||
super(AUCMetric, self).__init__()
|
||||
self.pred_probs = []
|
||||
self.true_labels = []
|
||||
|
||||
def clear(self):
|
||||
"""Clear the internal evaluation result."""
|
||||
self.pred_probs = []
|
||||
self.true_labels = []
|
||||
|
||||
def update(self, *inputs):
|
||||
"""Update list of predicts and labels."""
|
||||
batch_predict = inputs[1].asnumpy()
|
||||
batch_label = inputs[2].asnumpy()
|
||||
self.pred_probs.extend(batch_predict.flatten().tolist())
|
||||
self.true_labels.extend(batch_label.flatten().tolist())
|
||||
|
||||
def eval(self):
|
||||
if len(self.true_labels) != len(self.pred_probs):
|
||||
raise RuntimeError('true_labels.size() is not equal to pred_probs.size()')
|
||||
auc = roc_auc_score(self.true_labels, self.pred_probs)
|
||||
return auc
|
|
@ -0,0 +1,285 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""Download raw data and preprocessed data."""
|
||||
import argparse
|
||||
import collections
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from mindspore.dataset import context
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
|
||||
class StatsDict:
|
||||
"""preprocessed data"""
|
||||
|
||||
def __init__(self, field_size, dense_dim, slot_dim, skip_id_convert):
|
||||
self.field_size = field_size # 40
|
||||
self.dense_dim = dense_dim # 13
|
||||
self.slot_dim = slot_dim # 26
|
||||
self.skip_id_convert = bool(skip_id_convert)
|
||||
|
||||
self.val_cols = ["val_{}".format(i + 1) for i in range(self.dense_dim)]
|
||||
self.cat_cols = ["cat_{}".format(i + 1) for i in range(self.slot_dim)]
|
||||
|
||||
self.val_min_dict = {col: 0 for col in self.val_cols}
|
||||
self.val_max_dict = {col: 0 for col in self.val_cols}
|
||||
|
||||
self.cat_count_dict = {col: collections.defaultdict(int) for col in self.cat_cols}
|
||||
|
||||
self.oov_prefix = "OOV"
|
||||
|
||||
self.cat2id_dict = {}
|
||||
self.cat2id_dict.update({col: i for i, col in enumerate(self.val_cols)})
|
||||
self.cat2id_dict.update(
|
||||
{self.oov_prefix + col: i + len(self.val_cols) for i, col in enumerate(self.cat_cols)})
|
||||
|
||||
def stats_vals(self, val_list):
|
||||
"""Handling weights column"""
|
||||
assert len(val_list) == len(self.val_cols)
|
||||
|
||||
def map_max_min(i, val):
|
||||
key = self.val_cols[i]
|
||||
if val != "":
|
||||
if float(val) > self.val_max_dict[key]:
|
||||
self.val_max_dict[key] = float(val)
|
||||
if float(val) < self.val_min_dict[key]:
|
||||
self.val_min_dict[key] = float(val)
|
||||
|
||||
for i, val in enumerate(val_list):
|
||||
map_max_min(i, val)
|
||||
|
||||
def stats_cats(self, cat_list):
|
||||
"""Handling cats column"""
|
||||
|
||||
assert len(cat_list) == len(self.cat_cols)
|
||||
|
||||
def map_cat_count(i, cat):
|
||||
key = self.cat_cols[i]
|
||||
self.cat_count_dict[key][cat] += 1
|
||||
|
||||
for i, cat in enumerate(cat_list):
|
||||
map_cat_count(i, cat)
|
||||
|
||||
def save_dict(self, dict_path, prefix=""):
|
||||
with open(os.path.join(dict_path, "{}val_max_dict.pkl".format(prefix)), "wb") as file_wrt:
|
||||
pickle.dump(self.val_max_dict, file_wrt)
|
||||
with open(os.path.join(dict_path, "{}val_min_dict.pkl".format(prefix)), "wb") as file_wrt:
|
||||
pickle.dump(self.val_min_dict, file_wrt)
|
||||
with open(os.path.join(dict_path, "{}cat_count_dict.pkl".format(prefix)), "wb") as file_wrt:
|
||||
pickle.dump(self.cat_count_dict, file_wrt)
|
||||
|
||||
def load_dict(self, dict_path, prefix=""):
|
||||
with open(os.path.join(dict_path, "{}val_max_dict.pkl".format(prefix)), "rb") as file_wrt:
|
||||
self.val_max_dict = pickle.load(file_wrt)
|
||||
with open(os.path.join(dict_path, "{}val_min_dict.pkl".format(prefix)), "rb") as file_wrt:
|
||||
self.val_min_dict = pickle.load(file_wrt)
|
||||
with open(os.path.join(dict_path, "{}cat_count_dict.pkl".format(prefix)), "rb") as file_wrt:
|
||||
self.cat_count_dict = pickle.load(file_wrt)
|
||||
print("val_max_dict.items()[:50]:{}".format(list(self.val_max_dict.items())))
|
||||
print("val_min_dict.items()[:50]:{}".format(list(self.val_min_dict.items())))
|
||||
|
||||
def get_cat2id(self, threshold=100):
|
||||
for key, cat_count_d in self.cat_count_dict.items():
|
||||
new_cat_count_d = dict(filter(lambda x: x[1] > threshold, cat_count_d.items()))
|
||||
for cat_str, _ in new_cat_count_d.items():
|
||||
self.cat2id_dict[key + "_" + cat_str] = len(self.cat2id_dict)
|
||||
print("cat2id_dict.size:{}".format(len(self.cat2id_dict)))
|
||||
print("cat2id.dict.items()[:50]:{}".format(list(self.cat2id_dict.items())[:50]))
|
||||
|
||||
def map_cat2id(self, values, cats):
|
||||
"""Cat to id"""
|
||||
|
||||
def minmax_scale_value(i, val):
|
||||
max_v = float(self.val_max_dict["val_{}".format(i + 1)])
|
||||
return float(val) * 1.0 / max_v
|
||||
|
||||
dense_list = []
|
||||
spare_list = []
|
||||
for i, val in enumerate(values):
|
||||
if val == "":
|
||||
dense_list.append(0)
|
||||
else:
|
||||
dense_list.append(minmax_scale_value(i, float(val)))
|
||||
|
||||
for i, cat_str in enumerate(cats):
|
||||
key = "cat_{}".format(i + 1) + "_" + cat_str
|
||||
if key in self.cat2id_dict:
|
||||
spare_list.append(self.cat2id_dict[key])
|
||||
else:
|
||||
spare_list.append(self.cat2id_dict[self.oov_prefix + "cat_{}".format(i + 1)])
|
||||
return dense_list, spare_list
|
||||
|
||||
|
||||
def mkdir_path(file_path):
|
||||
if not os.path.exists(file_path):
|
||||
os.makedirs(file_path)
|
||||
|
||||
|
||||
def statsdata(file_path, dict_output_path, recommendation_dataset_stats_dict, dense_dim=13, slot_dim=26):
|
||||
"""Preprocess data and save data"""
|
||||
with open(file_path, encoding="utf-8") as file_in:
|
||||
errorline_list = []
|
||||
count = 0
|
||||
for line in file_in:
|
||||
count += 1
|
||||
line = line.strip("\n")
|
||||
items = line.split("\t")
|
||||
if len(items) != (dense_dim + slot_dim + 1):
|
||||
errorline_list.append(count)
|
||||
print("Found line length: {}, suppose to be {}, the line is {}".format(len(items),
|
||||
dense_dim + slot_dim + 1, line))
|
||||
continue
|
||||
if count % 1000000 == 0:
|
||||
print("Have handled {}w lines.".format(count // 10000))
|
||||
values = items[1: dense_dim + 1]
|
||||
cats = items[dense_dim + 1:]
|
||||
|
||||
assert len(values) == dense_dim, "values.size: {}".format(len(values))
|
||||
assert len(cats) == slot_dim, "cats.size: {}".format(len(cats))
|
||||
recommendation_dataset_stats_dict.stats_vals(values)
|
||||
recommendation_dataset_stats_dict.stats_cats(cats)
|
||||
recommendation_dataset_stats_dict.save_dict(dict_output_path)
|
||||
|
||||
|
||||
def random_split_trans2mindrecord(input_file_path, output_file_path, recommendation_dataset_stats_dict,
|
||||
part_rows=100000, line_per_sample=1000, train_line_count=None,
|
||||
test_size=0.1, seed=2020, dense_dim=13, slot_dim=26):
|
||||
"""Random split data and save mindrecord"""
|
||||
if train_line_count is None:
|
||||
raise ValueError("Please provide training file line count")
|
||||
test_size = int(train_line_count * test_size)
|
||||
all_indices = [i for i in range(train_line_count)]
|
||||
np.random.seed(seed)
|
||||
np.random.shuffle(all_indices)
|
||||
print("all_indices.size:{}".format(len(all_indices)))
|
||||
test_indices_set = set(all_indices[:test_size])
|
||||
print("test_indices_set.size:{}".format(len(test_indices_set)))
|
||||
print("-----------------------" * 10 + "\n" * 2)
|
||||
|
||||
train_data_list = []
|
||||
test_data_list = []
|
||||
cats_list = []
|
||||
dense_list = []
|
||||
label_list = []
|
||||
|
||||
writer_train = FileWriter(os.path.join(output_file_path, "train_input_part.mindrecord"), 1)
|
||||
writer_test = FileWriter(os.path.join(output_file_path, "test_input_part.mindrecord"), 1)
|
||||
|
||||
schema = {"label": {"type": "float32", "shape": [-1]}, "num_vals": {"type": "float32", "shape": [-1]},
|
||||
"cats_vals": {"type": "int32", "shape": [-1]}}
|
||||
writer_train.add_schema(schema, "CRITEO_TRAIN")
|
||||
writer_test.add_schema(schema, "CRITEO_TEST")
|
||||
|
||||
with open(input_file_path, encoding="utf-8") as file_in:
|
||||
items_error_size_lineCount = []
|
||||
count = 0
|
||||
train_part_number = 0
|
||||
test_part_number = 0
|
||||
for i, line in enumerate(file_in):
|
||||
count += 1
|
||||
if count % 1000000 == 0:
|
||||
print("Have handle {}w lines.".format(count // 10000))
|
||||
line = line.strip("\n")
|
||||
items = line.split("\t")
|
||||
if len(items) != (1 + dense_dim + slot_dim):
|
||||
items_error_size_lineCount.append(i)
|
||||
continue
|
||||
label = float(items[0])
|
||||
values = items[1:1 + dense_dim]
|
||||
cats = items[1 + dense_dim:]
|
||||
|
||||
assert len(values) == dense_dim, "values.size: {}".format(len(values))
|
||||
assert len(cats) == slot_dim, "cats.size: {}".format(len(cats))
|
||||
|
||||
dense, cats = recommendation_dataset_stats_dict.map_cat2id(values, cats)
|
||||
|
||||
dense_list.extend(dense)
|
||||
cats_list.extend(cats)
|
||||
label_list.append(label)
|
||||
|
||||
if count % line_per_sample == 0:
|
||||
if i not in test_indices_set:
|
||||
train_data_list.append({"cats_vals": np.array(cats_list, dtype=np.int32),
|
||||
"num_vals": np.array(dense_list, dtype=np.float32),
|
||||
"label": np.array(label_list, dtype=np.float32)
|
||||
})
|
||||
else:
|
||||
test_data_list.append({"cats_vals": np.array(cats_list, dtype=np.int32),
|
||||
"num_vals": np.array(dense_list, dtype=np.float32),
|
||||
"label": np.array(label_list, dtype=np.float32)
|
||||
})
|
||||
if train_data_list and len(train_data_list) % part_rows == 0:
|
||||
writer_train.write_raw_data(train_data_list)
|
||||
train_data_list.clear()
|
||||
train_part_number += 1
|
||||
|
||||
if test_data_list and len(test_data_list) % part_rows == 0:
|
||||
writer_test.write_raw_data(test_data_list)
|
||||
test_data_list.clear()
|
||||
test_part_number += 1
|
||||
|
||||
cats_list.clear()
|
||||
dense_list.clear()
|
||||
label_list.clear()
|
||||
|
||||
if train_data_list:
|
||||
writer_train.write_raw_data(train_data_list)
|
||||
if test_data_list:
|
||||
writer_test.write_raw_data(test_data_list)
|
||||
writer_train.commit()
|
||||
writer_test.commit()
|
||||
|
||||
print("-------------" * 10)
|
||||
print("items_error_size_lineCount.size(): {}.".format(len(items_error_size_lineCount)))
|
||||
print("-------------" * 10)
|
||||
np.save("items_error_size_lineCount.npy", items_error_size_lineCount)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Recommendation dataset")
|
||||
parser.add_argument("--data_path", type=str, default="./data/",
|
||||
help='The path of the data file')
|
||||
parser.add_argument("--dense_dim", type=int, default=13, help='The number of your continues fields')
|
||||
parser.add_argument("--slot_dim", type=int, default=26,
|
||||
help='The number of your sparse fields, it can also be called catelogy features.')
|
||||
parser.add_argument("--threshold", type=int, default=100,
|
||||
help='Word frequency below this will be regarded as OOV. It aims to reduce the vocab size')
|
||||
parser.add_argument("--train_line_count", type=int, default=45840617, help='The number of examples in your dataset')
|
||||
parser.add_argument("--skip_id_convert", type=int, default=0, choices=[0, 1],
|
||||
help='Skip the id convert, regarding the original id as the final id.')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
data_path = args.data_path
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend",
|
||||
device_id=device_id)
|
||||
target_field_size = args.dense_dim + args.slot_dim
|
||||
stats = StatsDict(field_size=target_field_size, dense_dim=args.dense_dim, slot_dim=args.slot_dim,
|
||||
skip_id_convert=args.skip_id_convert)
|
||||
data_file_path = data_path + "train.txt"
|
||||
stats_output_path = data_path + "stats_dict/"
|
||||
mkdir_path(stats_output_path)
|
||||
statsdata(data_file_path, stats_output_path, stats, dense_dim=args.dense_dim, slot_dim=args.slot_dim)
|
||||
|
||||
stats.load_dict(dict_path=stats_output_path, prefix="")
|
||||
stats.get_cat2id(threshold=args.threshold)
|
||||
|
||||
output_path = data_path + "mindrecord/"
|
||||
mkdir_path(output_path)
|
||||
random_split_trans2mindrecord(data_file_path, output_path, stats, part_rows=100000,
|
||||
train_line_count=45840617, line_per_sample=1000,
|
||||
test_size=0.1, seed=2020, dense_dim=13, slot_dim=26)
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""train_criteo."""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from src.callback import AUCCallBack
|
||||
from src.callback import TimeMonitor, LossCallback
|
||||
from src.config import ModelConfig
|
||||
from src.dataset import get_mindrecord_dataset
|
||||
from src.fat_deepffm import ModelBuilder
|
||||
from src.metrics import AUCMetric
|
||||
from mindspore import context, Model
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.communication.management import init, get_rank
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
||||
|
||||
parser = argparse.ArgumentParser(description='CTR Prediction')
|
||||
parser.add_argument('--dataset_path', type=str, default="./data/mindrecord", help='Dataset path')
|
||||
parser.add_argument('--ckpt_path', type=str, default="Fat-DeepFFM", help='Checkpoint path')
|
||||
parser.add_argument('--eval_file_name', type=str, default="./auc.log",
|
||||
help='Auc log file path. Default: "./auc.log"')
|
||||
parser.add_argument('--loss_file_name', type=str, default="./loss.log",
|
||||
help='Loss log file path. Default: "./loss.log"')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"),
|
||||
help="device target, support Ascend, GPU and CPU.")
|
||||
parser.add_argument('--do_eval', type=bool, default=False,
|
||||
help="Whether side training changes verification.")
|
||||
args = parser.parse_args()
|
||||
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||
print("rank_size", rank_size)
|
||||
|
||||
set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_config = ModelConfig()
|
||||
if rank_size > 1:
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
|
||||
device_id=device_id)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True,
|
||||
all_reduce_fusion_config=[9, 11])
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
else:
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
|
||||
device_id=device_id)
|
||||
rank_size = None
|
||||
rank_id = None
|
||||
print("load dataset...")
|
||||
ds_train = get_mindrecord_dataset(args.dataset_path, train_mode=True, epochs=1, batch_size=model_config.batch_size,
|
||||
rank_size=rank_size, rank_id=rank_id, line_per_sample=1000)
|
||||
train_net, test_net = ModelBuilder(model_config).get_train_eval_net()
|
||||
auc_metric = AUCMetric()
|
||||
model = Model(train_net, eval_network=test_net, metrics={"AUC": auc_metric})
|
||||
time_callback = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
loss_callback = LossCallback(args.loss_file_name)
|
||||
cb = [loss_callback, time_callback]
|
||||
if rank_size == 1 or device_id == 0:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size() * model_config.epoch_size,
|
||||
keep_checkpoint_max=model_config.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=args.ckpt_path, config=config_ck)
|
||||
cb += [ckpoint_cb]
|
||||
if args.do_eval and device_id == 0:
|
||||
ds_test = get_mindrecord_dataset(args.dataset_path, train_mode=False)
|
||||
eval_callback = AUCCallBack(model, ds_test, eval_file_path=args.eval_file_name)
|
||||
cb.append(eval_callback)
|
||||
print("Training started...")
|
||||
model.train(model_config.epoch_size, train_dataset=ds_train, callbacks=cb, dataset_sink_mode=True)
|
Loading…
Reference in New Issue