This commit is contained in:
zsw12138 2021-08-04 14:31:56 +00:00
parent 9726051f58
commit bc9f96c5cb
12 changed files with 1747 additions and 0 deletions

View File

@ -0,0 +1,220 @@
# 目录
<!-- TOC -->
- [目录](#目录)
- [Deep&Cross描述](#Deep&Cross描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [训练](#训练)
- [分布式训练](#分布式训练)
- [评估过程](#评估过程)
- [评估](#评估)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# Deep&Cross描述
Deep & Cross Network(DCN)是来自于 2017 年 google 和 Stanford 共同完成的一篇工作用于广告场景下的点击率预估CTR对比同样来自 google 的工作 Wide & DeepDCN 不需要特征工程来获得高阶的交叉特征,对比 FM 系列的模型DCN 拥有更高的计算效率并且能够提取到更高阶的交叉特征。
[论文](https://arxiv.org/pdf/1708.05123.pdf)
# 模型架构
DCN模型最开始是Embedding and stacking layer然后是并行的Cross Network和Deep Network最后是Combination Layer把Cross Network和Deep Network的结果组合得到输出。
# 数据集
使用的数据集:[1] Guo H 、Tang R和Ye Y等人使用的数据集。 DeepFM: A Factorization-Machine based Neural Network for CTR Prediction[J].2017.
# 环境要求
- 硬件GPU
- 使用GPU处理器来搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install)
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
1. 克隆代码。
```bash
git clone https://gitee.com/mindspore/mindspore.git
cd mindspore/model_zoo/official/recommend/deep_and_cross
```
2. 下载数据集。
> 请参考[1]获得下载链接。
```bash
mkdir -p data/origin_data && cd data/origin_data
wget DATA_LINK
tar -zxvf dac.tar.gz
```
3. 使用此脚本预处理数据。处理过程可能需要一小时生成的MindRecord数据存放在data/mindrecord路径下。
```bash
python src/preprocess_data.py --data_path=./data/ --dense_dim=13 --slot_dim=26 --threshold=100 --train_line_count=45840617 --skip_id_convert=0
```
4. 开始训练。
数据集准备就绪后即可在GPU上训练和评估模型。
GPU单卡训练命令如下
```bash
#单卡训练示例
python train.py --device_target="GPU" > output.train.log 2>&1 &
#或
sh scripts/run_train_gpu.sh
```
GPU 8卡训练命令如下
```bash
#8卡训练示例
sh scripts/run_train_multi_gpu.sh
```
5. 开始验证。
训练完毕后,按如下操作评估模型。
```bash
python eval.py --ckpt_path=CHECKPOINT_PATH
#或
sh scripts/run_eval.sh CHECKPOINT_PATH
```
# 脚本说明
## 脚本及样例代码
```bash
├── model_zoo
├── README.md // 所有模型相关说明
├── deep_and_cross
├── README.md // deep and cross相关说明
├── scripts
│ ├──run_train_gpu.sh // GPU处理器单卡训练shell脚本
│ ├──run_train_multi_gpu.sh // GPU处理器8卡训练shell脚本
│ ├──run_eval.sh // 评估的shell脚本
├── src
│ ├──dataset.py // 创建数据集
│ ├──deepandcross.py // deepandcross架构
│ ├──callback.py // 定义回调
│ ├──config.py // 参数配置
│ ├──metrics.py // 定义AUC
│ ├──preprocess_data.py // 预处理数据生成mindrecord文件
├── train.py // 训练脚本
├── eval.py // 评估脚本
```
## 脚本参数
在config.py中可以同时配置训练参数和评估参数。
- 配置GoogleNet和CIFAR-10数据集。
```python
self.device_target = "GPU" #设备选择
self.device_id = 0 #用于训练或评估数据集的设备ID
self.epochs = 10 #训练轮数
self.batch_size = 16000 #batch size大小
self.deep_layer_dim = [1024, 1024] #deep and cross deeplayer层大小
self.cross_layer_num = 6 #deep and cross crosslayer层数
self.eval_file_name = "eval.log" #验证结果输出文件
self.loss_file_name = "loss.log" #loss结果输出文件
self.ckpt_path = "./checkpoints/" #checkpoints输出目录
self.dataset_type = "mindrecord" #数据格式
self.is_distributed = 0 #是否分布式训练
```
更多配置细节请参考脚本`config.py`。
## 训练过程
### 训练
- GPU处理器环境运行
```bash
sh scripts/run_train_gpu.sh
```
上述bash命令将在后台运行您可以通过output.train.log文件查看结果。
训练结束后,您可在默认`./checkpoints/`脚本文件夹下找到检查点文件。
### 分布式训练
- GPU处理器环境运行
```bash
sh scripts/run_train_multi_gpu.sh
```
上述shell脚本将在后台运行分布训练。您可以通过output.multi_gpu.train.log文件查看结果。
## 评估过程
### 评估
- 在GPU处理器环境运行时评估CIFAR-10数据集
在运行以下命令之前,请检查用于评估的检查点路径。请将检查点路径设置为相对路径,例如“./checkpoints/deep_and_cross-10_2582.ckpt”。
```bash
python eval.py --ckpt_path=[CHECKPOINT_PATH] > eval.log 2>&1 &
```
上述python命令将在后台运行您可以通过eval.log文件查看结果。
或者,
```bash
sh scripts/run_eval.sh [CHECKPOINT_PATH]
```
上述python命令将在后台运行您可以通过output.eval.log文件查看结果。
# 模型描述
## 性能
### 评估性能
#### CRITEO数据集
| 参数 | GPU单卡 |
| -------------------------- | ----------------------------------------------------------- |
| 资源 | NV Tesla V100-32G |
| 上传日期 | 2021-06-30 |
| MindSpore版本 | 1.2.0 |
| 数据集 | CRITEO |
| 训练参数 | epoch=10, steps=2582, batch_size = 16000, lr=0.0001 |
| 优化器 | Adam |
| 损失函数 | Sigmoid交叉熵 |
| 输出 | 概率 |
| 损失 | 0.4388 |
| 速度 | 107-110毫秒/步 |
| 总时长 | 约2800秒 |
| 微调检查点 | 75M (.ckpt文件) |
| 推理AUC | 0.803786 |
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,82 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License Version 2.0(the "License");
# you may not use this file except in compliance with the License.
# you may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0#
#
# Unless required by applicable law or agreed to in writing software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ====================================================================================
"""Parse arguments"""
from mindspore import Model, context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.deep_and_cross import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, DeepCrossModel
from src.callbacks import EvalCallBack
from src.datasets import create_dataset, DataType
from src.metrics import AUCMetric
from src.config import DeepCrossConfig
def get_DCN_net(configure):
"""
Get network of deep&cross model.
"""
DCN_net = DeepCrossModel(configure)
loss_net = NetWithLossClass(DCN_net)
train_net = TrainStepWrap(loss_net)
eval_net = PredictWithSigmoid(DCN_net)
return train_net, eval_net
class ModelBuilder():
"""
Build the model.
"""
def __init__(self):
pass
def get_hook(self):
pass
def get_net(self, configure):
return get_DCN_net(configure)
def test_eval(configure):
"""
test_eval
"""
data_path = configure.data_path
batch_size = configure.batch_size
field_size = configure.field_size
if configure.dataset_type == "tfrecord":
dataset_type = DataType.TFRECORD
elif configure.dataset_type == "mindrecord":
dataset_type = DataType.MINDRECORD
else:
dataset_type = DataType.H5
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
batch_size=batch_size, data_type=dataset_type, target_column=field_size+1)
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
net_builder = ModelBuilder()
train_net, eval_net = net_builder.get_net(configure)
ckpt_path = configure.ckpt_path
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(eval_net, param_dict)
auc_metric = AUCMetric()
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
eval_callback = EvalCallBack(model, ds_eval, auc_metric, configure)
model.eval(ds_eval, callbacks=eval_callback)
if __name__ == "__main__":
config = DeepCrossConfig()
config.argparse_init()
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
test_eval(config)

View File

@ -0,0 +1,32 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_eval.sh ckpt_filename "
echo "for example: bash run_eval.sh 0_8-24_1012.ckpt"
echo "=============================================================================================================="
CKPT=$1
loc=$(pwd)
ckpt_path=$loc"/"$CKPT
echo "ckpt_path: "$ckpt_path
python eval.py --ckpt_path=$ckpt_path > output.eval.log 2>&1 &

View File

@ -0,0 +1,26 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_train_gpu.sh"
echo "for example: bash src/run_train_gpu.sh"
echo "=============================================================================================================="
python train.py \
--device_target="GPU" > output.train.log 2>&1 &

View File

@ -0,0 +1,27 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_train_multi_gpu.sh"
echo "for example: bash run_train_multi_gpu.sh"
echo "=============================================================================================================="
mpirun --allow-run-as-root -n 8 python train.py \
--device_target="GPU" --is_distributed=1 > output.multi_gpu.train.log 2>&1 &

View File

@ -0,0 +1,123 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""callbacks"""
import time
from mindspore.train.callback import Callback
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank
def add_write(file_path, out_str):
"""
add lines to the file
"""
with open(file_path, 'a+', encoding="utf-8") as file_out:
file_out.write(out_str + "\n")
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF, terminate the training.
Note:
If per_print_times is 0, do NOT print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, config=None, per_print_times=1):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("per_print_times must be in and >= 0.")
self._per_print_times = per_print_times
self.config = config
def step_end(self, run_context):
"""Monitor the loss in training."""
cb_params = run_context.original_args()
loss = cb_params.net_outputs.asnumpy()
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
cur_num = cb_params.cur_step_num
rank_id = 0
parallel_mode = context.get_auto_parallel_context("parallel_mode")
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL,
ParallelMode.DATA_PARALLEL):
rank_id = get_rank()
print("===loss===", rank_id, cb_params.cur_epoch_num, cur_step_in_epoch,
loss, flush=True)
# raise ValueError
if self._per_print_times != 0 and cur_num % self._per_print_times == 0 and self.config is not None:
loss_file = open(self.config.loss_file_name, "a+")
loss_file.write("epoch: %s, step: %s, loss: %s" %
(cb_params.cur_epoch_num, cur_step_in_epoch, loss))
loss_file.write("\n")
loss_file.close()
print("epoch: %s, step: %s, loss: %s" %
(cb_params.cur_epoch_num, cur_step_in_epoch, loss))
class EvalCallBack(Callback):
"""
Monitor the loss in evaluating.
If the loss is NAN or INF, terminate evaluating.
Note:
If per_print_times is 0, do NOT print loss.
Args:
print_per_step (int): Print loss every times. Default: 1.
"""
def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1, host_device_mix=False):
super(EvalCallBack, self).__init__()
if not isinstance(print_per_step, int) or print_per_step < 0:
raise ValueError("print_per_step must be int and >= 0.")
self.print_per_step = print_per_step
self.model = model
self.eval_dataset = eval_dataset
self.aucMetric = auc_metric
self.aucMetric.clear()
self.eval_file_name = config.eval_file_name
self.eval_values = []
self.host_device_mix = host_device_mix
def epoch_end(self, run_context):
"""
epoch end
"""
self.aucMetric.clear()
parallel_mode = context.get_auto_parallel_context("parallel_mode")
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
context.set_auto_parallel_context(strategy_ckpt_save_file="",
strategy_ckpt_load_file="./strategy_train.ckpt")
rank_id = 0
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL,
ParallelMode.DATA_PARALLEL):
rank_id = get_rank()
start_time = time.time()
out = self.model.eval(self.eval_dataset, dataset_sink_mode=(not self.host_device_mix))
end_time = time.time()
eval_time = int(end_time - start_time)
time_str = time.strftime("%Y-%m-%d %H:%M%S", time.localtime())
out_str = "{} == Rank: {} == EvalCallBack model.eval(): {}; eval_time: {}s".\
format(time_str, rank_id, out.values(), eval_time)
print(out_str)
self.eval_values = out.values()
add_write(self.eval_file_name, out_str)

View File

@ -0,0 +1,123 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""deep and cross config"""
import argparse
def argparse_init():
"""
argparse_init
"""
parser = argparse.ArgumentParser(description='DeepCross')
parser.add_argument("--device_target", type=str, default="GPU", choices=["Ascend", "GPU"],
help="device where the code will be implemented. (Default: GPU)")
parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--data_path", type=str, default="./data/mindrecord",
help="This should be set to the same directory given to the data_download's data_dir argument")
parser.add_argument("--epochs", type=int, default=10, help="Total train epochs")
parser.add_argument("--full_batch", type=int, default=0, help="Enable loading the full batch ")
parser.add_argument("--batch_size", type=int, default=16000, help="Training batch size.")
parser.add_argument("--eval_batch_size", type=int, default=16000, help="Eval batch size.")
parser.add_argument("--field_size", type=int, default=39, help="The number of features.")
parser.add_argument("--vocab_size", type=int, default=200000, help="The total features of dataset.")
parser.add_argument("--emb_dim", type=int, default=30, help="The dense embedding dimension of sparse feature.")
parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 1024],
help="The dimension of all deep layers.")
parser.add_argument("--cross_layer_num", type=int, default=6, help="Cross layer num")
parser.add_argument("--deep_layer_act", type=str, default='relu',
help="The activation function of all deep layers.")
parser.add_argument("--keep_prob", type=float, default=1.0, help="The keep rate in dropout layer.")
parser.add_argument("--dropout_flag", type=int, default=0, help="Enable dropout")
parser.add_argument("--output_path", type=str, default="./output/")
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/", help="The location of the checkpoint file.")
parser.add_argument("--eval_file_name", type=str, default="eval.log", help="Eval output file.")
parser.add_argument("--loss_file_name", type=str, default="loss.log", help="Loss output file.")
parser.add_argument("--host_device_mix", type=int, default=0, help="Enable host device mode or not")
parser.add_argument("--dataset_type", type=str, default="mindrecord", help="tfrecord/mindrecord/hd5")
parser.add_argument("--parameter_server", type=int, default=0, help="Open parameter server of not")
parser.add_argument("--is_distributed", type=int, default=0, help="0/1")
return parser
class DeepCrossConfig():
"""
DeepCrossConfig
"""
def __init__(self):
self.device_target = "GPU"
self.device_id = 0
self.data_path = "./test_raw_data/"
self.full_batch = False
self.epochs = 10
self.batch_size = 16000
self.eval_batch_size = 16000
self.field_size = 39
self.vocab_size = 200000
self.emb_dim = 27
self.deep_layer_dim = [1024, 1024]
self.cross_layer_num = 6
self.deep_layer_act = 'relu'
self.weight_bias_init = ['normal', 'normal']
self.emb_init = 'normal'
self.init_args = [-0.01, 0.01]
self.dropout_flag = False
self.keep_prob = 1.0
self.output_path = "./output"
self.eval_file_name = "eval.log"
self.loss_file_name = "loss.log"
self.ckpt_path = "./checkpoints/"
self.host_device_mix = 0
self.dataset_type = "mindrecord"
self.parameter_server = 0
self.is_distributed = 0
def argparse_init(self):
"""
argparse_init
"""
parser = argparse_init()
args, _ = parser.parse_known_args()
self.device_target = args.device_target
self.device_id = args.device_id
self.data_path = args.data_path
self.epochs = args.epochs
self.full_batch = bool(args.full_batch)
self.batch_size = args.batch_size
self.eval_batch_size = args.eval_batch_size
self.field_size = args.field_size
self.vocab_size = args.vocab_size
self.emb_dim = args.emb_dim
self.deep_layer_dim = args.deep_layer_dim
self.deep_layer_act = args.deep_layer_act
self.keep_prob = args.keep_prob
self.weight_bias_init = ['normal', 'normal']
self.emb_init = 'normal'
self.init_args = [-0.01, 0.01]
self.dropout_flag = bool(args.dropout_flag)
self.cross_layer_num = args.cross_layer_num
self.output_path = args.output_path
self.eval_file_name = args.eval_file_name
self.loss_file_name = args.loss_file_name
self.ckpt_path = args.ckpt_path
self.host_device_mix = args.host_device_mix
self.dataset_type = args.dataset_type
self.parameter_server = args.parameter_server
self.is_distributed = args.is_distributed

View File

@ -0,0 +1,358 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_dataset."""
import os
import math
from enum import Enum
import numpy as np
import pandas as pd
import mindspore.dataset.engine as de
import mindspore.common.dtype as mstype
class DataType(Enum):
"""
Enumerate supported dataset format.
"""
MINDRECORD = 1
TFRECORD = 2
H5 = 3
class H5Dataset():
"""
H5DataSet
"""
input_length = 39
def __init__(self, data_path, train_mode=True, train_num_of_parts=21,
test_num_of_parts=3):
self._hdf_data_dir = data_path
self._is_training = train_mode
if self._is_training:
self._file_prefix = 'train'
self._num_of_parts = train_num_of_parts
else:
self._file_prefix = 'test'
self._num_of_parts = test_num_of_parts
self.data_size = self._bin_count(self._hdf_data_dir, self._file_prefix,
self._num_of_parts)
print("data_size: {}".format(self.data_size))
def _bin_count(self, hdf_data_dir, file_prefix, num_of_parts):
size = 0
for part in range(num_of_parts):
_y = pd.read_hdf(os.path.join(hdf_data_dir,
file_prefix + '_output_part_' + str(
part) + '.h5'))
size += _y.shape[0]
return size
def _iterate_hdf_files_(self, num_of_parts=None,
shuffle_block=False):
"""
iterate among hdf files(blocks). when the whole data set is finished, the iterator restarts
from the beginning, thus the data stream will never stop
:param train_mode: True or false,false is eval_mode,
this file iterator will go through the train set
:param num_of_parts: number of files
:param shuffle_block: shuffle block files at every round
:return: input_hdf_file_name, output_hdf_file_name, finish_flag
"""
parts = np.arange(num_of_parts)
while True:
if shuffle_block:
for _ in range(int(shuffle_block)):
np.random.shuffle(parts)
for i, p in enumerate(parts):
yield os.path.join(self._hdf_data_dir,
self._file_prefix + '_input_part_' + str(
p) + '.h5'), \
os.path.join(self._hdf_data_dir,
self._file_prefix + '_output_part_' + str(
p) + '.h5'), i + 1 == len(parts)
def _generator(self, X, y, batch_size, shuffle=True):
"""
should be accessed only in private
:param X:
:param y:
:param batch_size:
:param shuffle:
:return:
"""
number_of_batches = np.ceil(1. * X.shape[0] / batch_size)
counter = 0
finished = False
sample_index = np.arange(X.shape[0])
if shuffle:
for _ in range(int(shuffle)):
np.random.shuffle(sample_index)
assert X.shape[0] > 0
while True:
batch_index = sample_index[
batch_size * counter: batch_size * (counter + 1)]
X_batch = X[batch_index]
y_batch = y[batch_index]
counter += 1
yield X_batch, y_batch, finished
if counter == number_of_batches:
counter = 0
finished = True
def batch_generator(self, batch_size=1000,
random_sample=False, shuffle_block=False):
"""
:param train_mode: True or false,false is eval_mode,
:param batch_size
:param num_of_parts: number of files
:param random_sample: if True, will shuffle
:param shuffle_block: shuffle file blocks at every round
:return:
"""
for hdf_in, hdf_out, _ in self._iterate_hdf_files_(self._num_of_parts,
shuffle_block):
start = stop = None
X_all = pd.read_hdf(hdf_in, start=start, stop=stop).values
y_all = pd.read_hdf(hdf_out, start=start, stop=stop).values
data_gen = self._generator(X_all, y_all, batch_size,
shuffle=random_sample)
finished = False
while not finished:
X, y, finished = data_gen.__next__()
X_id = X[:, 0:self.input_length]
X_va = X[:, self.input_length:]
yield np.array(X_id.astype(dtype=np.int32)), np.array(
X_va.astype(dtype=np.float32)), np.array(
y.astype(dtype=np.float32))
def _get_h5_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000):
"""
get_h5_dataset
"""
data_para = {
'batch_size': batch_size,
}
if train_mode:
data_para['random_sample'] = True
data_para['shuffle_block'] = True
h5_dataset = H5Dataset(data_path=data_dir, train_mode=train_mode)
numbers_of_batch = math.ceil(h5_dataset.data_size / batch_size)
def _iter_h5_data():
train_eval_gen = h5_dataset.batch_generator(**data_para)
for _ in range(0, numbers_of_batch, 1):
yield train_eval_gen.__next__()
ds = de.GeneratorDataset(_iter_h5_data(), ["ids", "weights", "labels"])
ds = ds.repeat(epochs)
return ds
def _padding_func(batch_size, manual_shape, target_column, field_size=39):
"""
get padding_func
"""
if manual_shape:
generate_concat_offset = [item[0]+item[1] for item in manual_shape]
part_size = int(target_column / len(generate_concat_offset))
filled_value = []
for i in range(field_size, target_column):
filled_value.append(generate_concat_offset[i//part_size]-1)
print("Filed Value:", filled_value)
def padding_func(x, y, z):
x = np.array(x).flatten().reshape(batch_size, field_size)
y = np.array(y).flatten().reshape(batch_size, field_size)
z = np.array(z).flatten().reshape(batch_size, 1)
x_id = np.ones((batch_size, target_column - field_size),
dtype=np.int32) * filled_value
x_id = np.concatenate([x, x_id.astype(dtype=np.int32)], axis=1)
mask = np.concatenate(
[y, np.zeros((batch_size, target_column-39), dtype=np.float32)], axis=1)
return (x_id, mask, z)
else:
def padding_func(x, y, z):
x = np.array(x).flatten().reshape(batch_size, field_size)
y = np.array(y).flatten().reshape(batch_size, field_size)
z = np.array(z).flatten().reshape(batch_size, 1)
return (x, y, z)
return padding_func
def _get_tf_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000,
line_per_sample=1000, rank_size=None, rank_id=None,
manual_shape=None, target_column=40):
"""
get_tf_dataset
"""
dataset_files = []
file_prefix_name = 'train' if train_mode else 'test'
shuffle = train_mode
for (dirpath, _, filenames) in os.walk(data_dir):
for filename in filenames:
if file_prefix_name in filename and "tfrecord" in filename:
dataset_files.append(os.path.join(dirpath, filename))
schema = de.Schema()
schema.add_column('feat_ids', de_type=mstype.int32)
schema.add_column('feat_vals', de_type=mstype.float32)
schema.add_column('label', de_type=mstype.float32)
if rank_size is not None and rank_id is not None:
ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, schema=schema, num_parallel_workers=8,
num_shards=rank_size, shard_id=rank_id, shard_equal_rows=True)
else:
ds = de.TFRecordDataset(dataset_files=dataset_files,
shuffle=shuffle, schema=schema, num_parallel_workers=8)
ds = ds.batch(int(batch_size / line_per_sample),
drop_remainder=True)
ds = ds.map(operations=_padding_func(batch_size, manual_shape, target_column),
input_columns=['feat_ids', 'feat_vals', 'label'],
column_order=['feat_ids', 'feat_vals', 'label'], num_parallel_workers=8)
ds = ds.repeat(epochs)
return ds
def _get_mindrecord_dataset(directory, train_mode=True, epochs=10, batch_size=16000,
line_per_sample=1000, rank_size=None, rank_id=None,
manual_shape=None, target_column=40):
"""
Get dataset with mindrecord format.
Args:
directory (str): Dataset directory.
train_mode (bool): Whether dataset is use for train or eval (default=True).
epochs (int): Dataset epoch size (default=1).
batch_size (int): Dataset batch size (default=1000).
line_per_sample (int): The number of sample per line (default=1000).
rank_size (int): The number of device, not necessary for single device (default=None).
rank_id (int): Id of device, not necessary for single device (default=None).
Returns:
Dataset.
"""
file_prefix_name = 'train_input_part.mindrecord' if train_mode else 'test_input_part.mindrecord'
file_suffix_name = '00' if train_mode else '0'
shuffle = train_mode
if rank_size is not None and rank_id is not None:
ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
columns_list=['feat_ids', 'feat_vals', 'label'],
num_shards=rank_size, shard_id=rank_id, shuffle=shuffle,
num_parallel_workers=8)
else:
ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
columns_list=['feat_ids', 'feat_vals', 'label'],
shuffle=shuffle, num_parallel_workers=8)
ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True)
ds = ds.map(_padding_func(batch_size, manual_shape, target_column, target_column-1),
input_columns=['feat_ids', 'feat_vals', 'label'],
column_order=['feat_ids', 'feat_vals', 'label'],
num_parallel_workers=8)
ds = ds.repeat(epochs)
return ds
def _get_vocab_size(target_column_number, worker_size, total_vocab_size, multiply=False, per_vocab_size=None):
"""
get_vocab_size
"""
# Only 39
inidival_vocabs = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 691, 540, 20855, 23639, 182, 15,
10091, 347, 4, 16366, 4494, 21293, 3103, 27, 6944, 22366, 11, 3267, 1610,
5, 21762, 14, 15, 15030, 61, 12220]
new_vocabs = inidival_vocabs + [1] * \
(target_column_number - len(inidival_vocabs))
part_size = int(target_column_number / worker_size)
# According to the workers, we merge some fields into the same part
new_vocab_size = []
for i in range(0, target_column_number, part_size):
new_vocab_size.append(sum(new_vocabs[i: i + part_size]))
index_offsets = [0]
# The gold feature numbers ared used to calculate the offset
features = [item for item in new_vocab_size]
# According to the per_vocab_size, maxize the vocab size
if per_vocab_size is not None:
new_vocab_size = [per_vocab_size] * worker_size
else:
# Expands the vocabulary of each field by the multiplier
if multiply is True:
cur_sum = sum(new_vocab_size)
k = total_vocab_size/cur_sum
new_vocab_size = [
math.ceil(int(item*k)/worker_size)*worker_size for item in new_vocab_size]
new_vocab_size = [(item // 8 + 1)*8 for item in new_vocab_size]
else:
if total_vocab_size > sum(new_vocab_size):
new_vocab_size[-1] = total_vocab_size - \
sum(new_vocab_size[:-1])
new_vocab_size = [item for item in new_vocab_size]
else:
raise ValueError(
"Please providede the correct vocab size, now is {}".format(total_vocab_size))
for i in range(worker_size-1):
off = index_offsets[i] + features[i]
index_offsets.append(off)
print("the offset: ", index_offsets)
manual_shape = tuple(
((new_vocab_size[i], index_offsets[i]) for i in range(worker_size)))
vocab_total = sum(new_vocab_size)
return manual_shape, vocab_total
def compute_manual_shape(config, worker_size):
target_column = (config.field_size // worker_size + 1) * worker_size
config.field_size = target_column
manual_shape, vocab_total = _get_vocab_size(target_column, worker_size, total_vocab_size=config.vocab_size,
per_vocab_size=None, multiply=False)
config.manual_shape = manual_shape
config.vocab_size = int(vocab_total)
def create_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000,
data_type=DataType.TFRECORD, line_per_sample=1000,
rank_size=None, rank_id=None, manual_shape=None, target_column=40):
"""
create_dataset
"""
if data_type == DataType.TFRECORD:
return _get_tf_dataset(data_dir, train_mode, epochs, batch_size,
line_per_sample, rank_size=rank_size, rank_id=rank_id,
manual_shape=manual_shape, target_column=target_column)
if data_type == DataType.MINDRECORD:
return _get_mindrecord_dataset(data_dir, train_mode, epochs, batch_size,
line_per_sample, rank_size=rank_size, rank_id=rank_id,
manual_shape=manual_shape, target_column=target_column)
if rank_size > 1:
raise RuntimeError("please use tfrecord dataset.")
return _get_h5_dataset(data_dir, train_mode, epochs, batch_size)

View File

@ -0,0 +1,306 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""deep and cross net"""
import numpy as np
from mindspore import nn
from mindspore import ParameterTuple
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.nn import Dropout
from mindspore.nn.optim import Adam
from mindspore.common.initializer import Uniform, initializer
from mindspore import Tensor
from mindspore.ops import Squeeze
from mindspore.common.parameter import Parameter
ms_type = mstype.float32
ds_type = mstype.int32
def init_method(method, shape, name, max_val=1.0):
'''
parameter init method
'''
if method in ['uniform']:
params = Parameter(initializer(
Uniform(max_val), shape, ms_type), name=name)
elif method == "one":
params = Parameter(initializer("ones", shape, ms_type), name=name)
elif method == 'zero':
params = Parameter(initializer("zeros", shape, ms_type), name=name)
elif method == "normal":
params = Parameter(initializer("normal", shape, ms_type), name=name)
return params
def normal_weight(shape, num_units):
norm = np.random.normal(0.0, num_units**-0.5, shape).astype(np.float32)
return Tensor(norm)
class DenseLayer(nn.Cell):
"""
Dense Layer for DCN Model;
Containing: activation, matmul, bias_add;
"""
def __init__(self, input_dim, output_dim, weight_bias_init, act_str,
keep_prob=0.5, use_activation=True, convert_dtype=True, drop_out=False):
super(DenseLayer, self).__init__()
weight_init, bias_init = weight_bias_init
self.weight = init_method(
weight_init, [input_dim, output_dim], name="weight")
self.bias = init_method(bias_init, [output_dim], name="bias")
self.act_func = self._init_activation(act_str)
self.matmul = P.MatMul(transpose_b=False)
self.bias_add = P.BiasAdd()
self.cast = P.Cast()
self.dropout = Dropout(keep_prob=keep_prob)
self.use_activation = use_activation
self.convert_dtype = convert_dtype
self.drop_out = drop_out
def _init_activation(self, act_str):
act_str = act_str.lower()
if act_str == "relu":
act_func = P.ReLU()
elif act_str == "sigmoid":
act_func = P.Sigmoid()
elif act_str == "tanh":
act_func = P.Tanh()
return act_func
def construct(self, x):
'''
Construct Dense layer
'''
if self.training and self.drop_out:
x = self.dropout(x)
if self.convert_dtype:
x = self.cast(x, mstype.float16)
weight = self.cast(self.weight, mstype.float16)
bias = self.cast(self.bias, mstype.float16)
wx = self.matmul(x, weight)
wx = self.bias_add(wx, bias)
if self.use_activation:
wx = self.act_func(wx)
wx = self.cast(wx, mstype.float32)
else:
wx = self.matmul(x, self.weight)
wx = self.bias_add(wx, self.bias)
if self.use_activation:
wx = self.act_func(wx)
return wx
class CrossLayer(nn.Cell):
"""
Cross Layer for DCN Model;
"""
def __init__(self, cross_raw_dim, cross_col_dim, weight_bias_init, convert_dtype=True):
super(CrossLayer, self).__init__()
weight_init, bias_init = weight_bias_init
self.cross_weight = init_method(weight_init, [cross_col_dim, 1], name="weight")
self.cross_bias = init_method(bias_init, [cross_col_dim, 1], name="bias")
self.matmul = nn.MatMul()
self.bias_add = P.BiasAdd()
self.tensor_add = P.TensorAdd()
self.reshape = P.Reshape()
self.cast = P.Cast()
self.expand_dims = P.ExpandDims()
self.convert_dtype = convert_dtype
self.squeeze = Squeeze(2)
def construct(self, inputs, x_0):
'''
Construct Cross layer
'''
x_0 = self.expand_dims(x_0, 2)
x_l = self.expand_dims(inputs, 2)
x_lw = C.tensor_dot(x_l, self.cross_weight, ((1,), (0,)))
dot = self.matmul(x_0, x_lw)
y_l = dot + self.cross_bias + x_l
y_l = self.squeeze(y_l)
return y_l
class EmbeddingLookup(nn.Cell):
"""
A embeddings lookup table with a fixed dictionary and size.
Args:
vocab_size (int): Size of the dictionary of embeddings.
embedding_size (int): The size of each embedding vector.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
"""
def __init__(self,
vocab_size,
embedding_size,
use_one_hot_embeddings=False,
initializer_range=0.02):
super(EmbeddingLookup, self).__init__()
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.use_one_hot_embeddings = use_one_hot_embeddings
self.embedding_table = Parameter(normal_weight([vocab_size, embedding_size], embedding_size))
self.expand = P.ExpandDims()
self.shape_flat = (-1,)
self.gather = P.Gather()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.array_mul = P.MatMul()
self.reshape = P.Reshape()
self.shape = P.Shape()
def construct(self, input_ids):
"""Get a embeddings lookup table with a fixed dictionary and size."""
input_shape = self.shape(input_ids)
flat_ids = self.reshape(input_ids, self.shape_flat)
if self.use_one_hot_embeddings:
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table)
else:
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
out_shape = input_shape + (self.embedding_size,)
output = self.reshape(output_for_reshape, out_shape)
return output, self.embedding_table
class DeepCrossModel(nn.Cell):
"""
Deep and Cross Model
"""
def __init__(self, config):
super(DeepCrossModel, self).__init__()
self.concat = P.Concat(axis=1)
self.reshape = P.Reshape()
self.deep_reshape = P.Reshape()
self.deep_mul = P.Mul()
self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim
self.batch_size = config.batch_size
self.field_size = config.field_size
self.input_size = self.field_size * self.emb_dim
self.deep_layer_dim = config.deep_layer_dim
self.deep_embeddinglookup = EmbeddingLookup(self.vocab_size, self.emb_dim)
self.cross_layer_num = config.cross_layer_num
self.keep_prob = config.keep_prob
self.cross_layer_1 = CrossLayer(self.batch_size,
self.input_size, weight_bias_init=['normal', 'normal'], convert_dtype=False)
self.cross_layer_2 = CrossLayer(self.batch_size,
self.input_size, weight_bias_init=['normal', 'normal'], convert_dtype=False)
self.cross_layer_3 = CrossLayer(self.batch_size,
self.input_size, weight_bias_init=['normal', 'normal'], convert_dtype=False)
self.cross_layer_4 = CrossLayer(self.batch_size,
self.input_size, weight_bias_init=['normal', 'normal'], convert_dtype=False)
self.cross_layer_5 = CrossLayer(self.batch_size,
self.input_size, weight_bias_init=['normal', 'normal'], convert_dtype=False)
self.cross_layer_6 = CrossLayer(self.batch_size,
self.input_size, weight_bias_init=['normal', 'normal'], convert_dtype=False)
self.dense_layer_1 = DenseLayer(self.input_size, self.deep_layer_dim[0],
weight_bias_init=['normal', 'normal'], act_str="relu",
keep_prob=self.keep_prob, convert_dtype=False, drop_out=False)
self.dense_layer_2 = DenseLayer(self.deep_layer_dim[0], self.deep_layer_dim[1],
weight_bias_init=['normal', 'normal'], act_str="relu",
keep_prob=self.keep_prob, convert_dtype=False, drop_out=False)
self.dense_layer_3 = DenseLayer(self.input_size + self.deep_layer_dim[1],
1, weight_bias_init=['normal', 'normal'], act_str="sigmoid",
keep_prob=self.keep_prob, use_activation=False,
convert_dtype=False, drop_out=False)
def construct(self, id_hldr, wt_hldr):
"""dcn construct"""
mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1))
deep_id_embs, _ = self.deep_embeddinglookup(id_hldr)
vx = self.deep_mul(deep_id_embs, mask)
input_x = self.deep_reshape(vx, (-1, self.field_size * self.emb_dim))
d_1 = self.dense_layer_1(input_x)
d_2 = self.dense_layer_2(d_1)
c_1 = self.cross_layer_1(input_x, input_x)
c_2 = self.cross_layer_2(c_1, input_x)
c_3 = self.cross_layer_3(c_2, input_x)
c_4 = self.cross_layer_4(c_3, input_x)
c_5 = self.cross_layer_5(c_4, input_x)
c_6 = self.cross_layer_6(c_5, input_x)
out = self.concat((d_2, c_6))
out = self.dense_layer_3(out)
return out
class NetWithLossClass(nn.Cell):
"""
NetWithLossClass definition.
"""
def __init__(self, network):
super(NetWithLossClass, self).__init__(auto_prefix=False)
self.loss = P.SigmoidCrossEntropyWithLogits()
self.network = network
self.ReduceMean = P.ReduceMean(keep_dims=False)
def construct(self, batch_ids, batch_wts, label):
predict = self.network(batch_ids, batch_wts)
log_loss = self.loss(predict, label)
mean_log_loss = self.ReduceMean(log_loss)
loss = mean_log_loss
return loss
class TrainStepWrap(nn.Cell):
"""
TrainStepWrap definition
"""
def __init__(self, network, lr=0.0001, eps=1e-8, loss_scale=1000.0):
super(TrainStepWrap, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.network.set_train()
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = Adam(self.weights, learning_rate=lr, eps=eps, loss_scale=loss_scale)
self.hyper_map = C.HyperMap()
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = loss_scale
def construct(self, batch_ids, batch_wts, label):
weights = self.weights
loss = self.network(batch_ids, batch_wts, label)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(batch_ids, batch_wts, label, sens)
return F.depend(loss, self.optimizer(grads))
class PredictWithSigmoid(nn.Cell):
"""
Eval model with sigmoid.
"""
def __init__(self, network):
super(PredictWithSigmoid, self).__init__(auto_prefix=False)
self.network = network
self.sigmoid = P.Sigmoid()
def construct(self, batch_ids, batch_wts, labels):
logits = self.network(batch_ids, batch_wts)
pred_probs = self.sigmoid(logits)
return logits, pred_probs, labels

View File

@ -0,0 +1,61 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Area under cure metric
"""
from sklearn.metrics import roc_auc_score
from mindspore import context
from mindspore.nn.metrics import Metric
from mindspore.communication.management import get_rank, get_group_size
class AUCMetric(Metric):
"""
Area under cure metric
"""
def __init__(self):
super(AUCMetric, self).__init__()
self.clear()
self.full_batch = context.get_auto_parallel_context("full_batch")
def clear(self):
"""Clear the internal evaluation result."""
self.true_labels = []
self.pred_probs = []
def update(self, *inputs): # inputs
"""Update list of predicts and labels."""
all_predict = inputs[1].asnumpy().flatten().tolist() # predict
all_label = inputs[2].asnumpy().flatten().tolist() # label
self.pred_probs.extend(all_predict)
if self.full_batch:
rank_id = get_rank()
group_size = get_group_size()
gap = len(all_label) // group_size
self.true_labels.extend(all_label[rank_id*gap: (rank_id+1)*gap])
else:
self.true_labels.extend(all_label)
def eval(self):
if len(self.true_labels) != len(self.pred_probs):
raise RuntimeError(
'true_labels.size is not equal to pred_probs.size()')
auc = roc_auc_score(self.true_labels, self.pred_probs)
print("====" * 20 + " auc_metric end")
print("====" * 20 + " auc: {}".format(auc))
return auc

View File

@ -0,0 +1,292 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Download raw data and preprocessed data."""
import os
import pickle
import collections
import argparse
import numpy as np
from mindspore.mindrecord import FileWriter
class StatsDict():
"""preprocessed data"""
def __init__(self, field_size, dense_dim, slot_dim, skip_id_convert):
self.field_size = field_size
self.dense_dim = dense_dim
self.slot_dim = slot_dim
self.skip_id_convert = bool(skip_id_convert)
self.val_cols = ["val_{}".format(i + 1) for i in range(self.dense_dim)]
self.cat_cols = ["cat_{}".format(i + 1) for i in range(self.slot_dim)]
self.val_min_dict = {col: 0 for col in self.val_cols}
self.val_max_dict = {col: 0 for col in self.val_cols}
self.cat_count_dict = {col: collections.defaultdict(int) for col in self.cat_cols}
self.oov_prefix = "OOV"
self.cat2id_dict = {}
self.cat2id_dict.update({col: i for i, col in enumerate(self.val_cols)})
self.cat2id_dict.update(
{self.oov_prefix + col: i + len(self.val_cols) for i, col in enumerate(self.cat_cols)})
def stats_vals(self, val_list):
"""Handling weights column"""
assert len(val_list) == len(self.val_cols)
def map_max_min(i, val):
key = self.val_cols[i]
if val != "":
if float(val) > self.val_max_dict[key]:
self.val_max_dict[key] = float(val)
if float(val) < self.val_min_dict[key]:
self.val_min_dict[key] = float(val)
for i, val in enumerate(val_list):
map_max_min(i, val)
def stats_cats(self, cat_list):
"""Handling cats column"""
assert len(cat_list) == len(self.cat_cols)
def map_cat_count(i, cat):
key = self.cat_cols[i]
self.cat_count_dict[key][cat] += 1
for i, cat in enumerate(cat_list):
map_cat_count(i, cat)
def save_dict(self, dict_path, prefix=""):
with open(os.path.join(dict_path, "{}val_max_dict.pkl".format(prefix)), "wb") as file_wrt:
pickle.dump(self.val_max_dict, file_wrt)
with open(os.path.join(dict_path, "{}val_min_dict.pkl".format(prefix)), "wb") as file_wrt:
pickle.dump(self.val_min_dict, file_wrt)
with open(os.path.join(dict_path, "{}cat_count_dict.pkl".format(prefix)), "wb") as file_wrt:
pickle.dump(self.cat_count_dict, file_wrt)
def load_dict(self, dict_path, prefix=""):
with open(os.path.join(dict_path, "{}val_max_dict.pkl".format(prefix)), "rb") as file_wrt:
self.val_max_dict = pickle.load(file_wrt)
with open(os.path.join(dict_path, "{}val_min_dict.pkl".format(prefix)), "rb") as file_wrt:
self.val_min_dict = pickle.load(file_wrt)
with open(os.path.join(dict_path, "{}cat_count_dict.pkl".format(prefix)), "rb") as file_wrt:
self.cat_count_dict = pickle.load(file_wrt)
print("val_max_dict.items()[:50]:{}".format(list(self.val_max_dict.items())))
print("val_min_dict.items()[:50]:{}".format(list(self.val_min_dict.items())))
def get_cat2id(self, threshold=100):
for key, cat_count_d in self.cat_count_dict.items():
new_cat_count_d = dict(filter(lambda x: x[1] > threshold, cat_count_d.items()))
for cat_str, _ in new_cat_count_d.items():
self.cat2id_dict[key + "_" + cat_str] = len(self.cat2id_dict)
print("cat2id_dict.size:{}".format(len(self.cat2id_dict)))
print("cat2id.dict.items()[:50]:{}".format(list(self.cat2id_dict.items())[:50]))
def map_cat2id(self, values, cats):
"""Cat to id"""
def minmax_scale_value(i, val):
max_v = float(self.val_max_dict["val_{}".format(i + 1)])
return float(val) * 1.0 / max_v
id_list = []
weight_list = []
for i, val in enumerate(values):
if val == "":
id_list.append(i)
weight_list.append(0)
else:
key = "val_{}".format(i + 1)
id_list.append(self.cat2id_dict[key])
weight_list.append(minmax_scale_value(i, float(val)))
for i, cat_str in enumerate(cats):
key = "cat_{}".format(i + 1) + "_" + cat_str
if key in self.cat2id_dict:
if self.skip_id_convert is True:
# For the synthetic data, if the generated id is between [0, max_vcoab], but the num examples is l
# ess than vocab_size/ slot_nums the id will still be converted to [0, real_vocab], where real_vocab
# the actually the vocab size, rather than the max_vocab. So a simple way to alleviate this
# problem is skip the id convert, regarding the synthetic data id as the final id.
id_list.append(cat_str)
else:
id_list.append(self.cat2id_dict[key])
else:
id_list.append(self.cat2id_dict[self.oov_prefix + "cat_{}".format(i + 1)])
weight_list.append(1.0)
return id_list, weight_list
def mkdir_path(file_path):
if not os.path.exists(file_path):
os.makedirs(file_path)
def statsdata(file_path, dict_output_path, recommendation_dataset_stats_dict, dense_dim=13, slot_dim=26):
"""Preprocess data and save data"""
with open(file_path, encoding="utf-8") as file_in:
errorline_list = []
count = 0
for line in file_in:
count += 1
line = line.strip("\n")
items = line.split("\t")
if len(items) != (dense_dim + slot_dim + 1):
errorline_list.append(count)
print("Found line length: {}, suppose to be {}, the line is {}".format(len(items),
dense_dim + slot_dim + 1, line))
continue
if count % 1000000 == 0:
print("Have handled {}w lines.".format(count // 10000))
values = items[1: dense_dim + 1]
cats = items[dense_dim + 1:]
assert len(values) == dense_dim, "values.size: {}".format(len(values))
assert len(cats) == slot_dim, "cats.size: {}".format(len(cats))
recommendation_dataset_stats_dict.stats_vals(values)
recommendation_dataset_stats_dict.stats_cats(cats)
recommendation_dataset_stats_dict.save_dict(dict_output_path)
def random_split_trans2mindrecord(input_file_path, output_file_path, recommendation_dataset_stats_dict,
part_rows=2000000, line_per_sample=1000, train_line_count=None,
test_size=0.1, seed=2020, dense_dim=13, slot_dim=26):
"""Random split data and save mindrecord"""
if train_line_count is None:
raise ValueError("Please provide training file line count")
test_size = int(train_line_count * test_size)
all_indices = [i for i in range(train_line_count)]
np.random.seed(seed)
np.random.shuffle(all_indices)
print("all_indices.size:{}".format(len(all_indices)))
test_indices_set = set(all_indices[:test_size])
print("test_indices_set.size:{}".format(len(test_indices_set)))
print("-----------------------" * 10 + "\n" * 2)
train_data_list = []
test_data_list = []
ids_list = []
wts_list = []
label_list = []
writer_train = FileWriter(os.path.join(output_file_path, "train_input_part.mindrecord"), 21)
writer_test = FileWriter(os.path.join(output_file_path, "test_input_part.mindrecord"), 3)
schema = {"label": {"type": "float32", "shape": [-1]}, "feat_vals": {"type": "float32", "shape": [-1]},
"feat_ids": {"type": "int32", "shape": [-1]}}
writer_train.add_schema(schema, "CRITEO_TRAIN")
writer_test.add_schema(schema, "CRITEO_TEST")
with open(input_file_path, encoding="utf-8") as file_in:
items_error_size_lineCount = []
count = 0
train_part_number = 0
test_part_number = 0
for i, line in enumerate(file_in):
count += 1
if count % 1000000 == 0:
print("Have handle {}w lines.".format(count // 10000))
line = line.strip("\n")
items = line.split("\t")
if len(items) != (1 + dense_dim + slot_dim):
items_error_size_lineCount.append(i)
continue
label = float(items[0])
values = items[1:1 + dense_dim]
cats = items[1 + dense_dim:]
assert len(values) == dense_dim, "values.size: {}".format(len(values))
assert len(cats) == slot_dim, "cats.size: {}".format(len(cats))
ids, wts = recommendation_dataset_stats_dict.map_cat2id(values, cats)
ids_list.extend(ids)
wts_list.extend(wts)
label_list.append(label)
if count % line_per_sample == 0:
if i not in test_indices_set:
train_data_list.append({"feat_ids": np.array(ids_list, dtype=np.int32),
"feat_vals": np.array(wts_list, dtype=np.float32),
"label": np.array(label_list, dtype=np.float32)
})
else:
test_data_list.append({"feat_ids": np.array(ids_list, dtype=np.int32),
"feat_vals": np.array(wts_list, dtype=np.float32),
"label": np.array(label_list, dtype=np.float32)
})
if train_data_list and len(train_data_list) % part_rows == 0:
writer_train.write_raw_data(train_data_list)
train_data_list.clear()
train_part_number += 1
if test_data_list and len(test_data_list) % part_rows == 0:
writer_test.write_raw_data(test_data_list)
test_data_list.clear()
test_part_number += 1
ids_list.clear()
wts_list.clear()
label_list.clear()
if train_data_list:
writer_train.write_raw_data(train_data_list)
if test_data_list:
writer_test.write_raw_data(test_data_list)
writer_train.commit()
writer_test.commit()
print("-------------" * 10)
print("items_error_size_lineCount.size(): {}.".format(len(items_error_size_lineCount)))
print("-------------" * 10)
np.save("items_error_size_lineCount.npy", items_error_size_lineCount)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Recommendation dataset")
parser.add_argument("--data_path", type=str, default="./data/", help='The path of the data file')
parser.add_argument("--dense_dim", type=int, default=13, help='The number of your continues fields')
parser.add_argument("--slot_dim", type=int, default=26,
help='The number of your sparse fields, it can also be called catelogy features.')
parser.add_argument("--threshold", type=int, default=100,
help='Word frequency below this will be regarded as OOV. It aims to reduce the vocab size')
parser.add_argument("--train_line_count", type=int, help='The number of examples in your dataset')
parser.add_argument("--skip_id_convert", type=int, default=0, choices=[0, 1],
help='Skip the id convert, regarding the original id as the final id.')
args, _ = parser.parse_known_args()
data_path = args.data_path
target_field_size = args.dense_dim + args.slot_dim
stats = StatsDict(field_size=target_field_size, dense_dim=args.dense_dim, slot_dim=args.slot_dim,
skip_id_convert=args.skip_id_convert)
data_file_path = data_path + "origin_data/train.txt"
stats_output_path = data_path + "stats_dict/"
mkdir_path(stats_output_path)
statsdata(data_file_path, stats_output_path, stats, dense_dim=args.dense_dim, slot_dim=args.slot_dim)
stats.load_dict(dict_path=stats_output_path, prefix="")
stats.get_cat2id(threshold=args.threshold)
in_file_path = data_path + "origin_data/train.txt"
output_path = data_path + "mindrecord/"
mkdir_path(output_path)
random_split_trans2mindrecord(in_file_path, output_path, stats, part_rows=2000000,
train_line_count=args.train_line_count, line_per_sample=1000,
test_size=0.1, seed=2020, dense_dim=args.dense_dim, slot_dim=args.slot_dim)

View File

@ -0,0 +1,97 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train dcn"""
from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.communication.management import init
from src.deep_and_cross import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, DeepCrossModel
from src.callbacks import LossCallBack
from src.datasets import create_dataset, DataType
from src.config import DeepCrossConfig
def get_DCN_net(configure):
"""
Get network of deep&cross model.
"""
DCN_net = DeepCrossModel(configure)
loss_net = NetWithLossClass(DCN_net)
train_net = TrainStepWrap(loss_net)
eval_net = PredictWithSigmoid(DCN_net)
return train_net, eval_net
class ModelBuilder():
"""
Build the model.
"""
def __init__(self):
pass
def get_hook(self):
pass
def get_net(self, configure):
return get_DCN_net(configure)
def test_train(configure):
"""
test_train
"""
if configure.is_distributed:
if configure.device_target == "Ascend":
context.set_context(device_id=configure.device_id, device_target=configure.device_target)
init()
elif configure.device_target == "GPU":
context.set_context(device_target=configure.device_target)
init("nccl")
context.set_context(mode=context.GRAPH_MODE)
else:
context.set_context(mode=context.GRAPH_MODE,
device_target=configure.device_target, device_id=configure.device_id)
data_path = configure.data_path
batch_size = configure.batch_size
field_size = configure.field_size
target_column = field_size + 1
print("field_size: {} ".format(field_size))
epochs = configure.epochs
if configure.dataset_type == "tfrecord":
dataset_type = DataType.TFRECORD
elif configure.dataset_type == "mindrecord":
dataset_type = DataType.MINDRECORD
else:
dataset_type = DataType.H5
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
batch_size=batch_size, data_type=dataset_type, target_column=target_column)
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
net_builder = ModelBuilder()
train_net, _ = net_builder.get_net(configure)
train_net.set_train()
model = Model(train_net)
callback = LossCallBack(config=configure)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(),
keep_checkpoint_max=5)
ckpoint_cb = ModelCheckpoint(prefix='deepcross_train', directory=configure.ckpt_path, config=ckptconfig)
model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), callback, ckpoint_cb])
if __name__ == "__main__":
config = DeepCrossConfig()
config.argparse_init()
test_train(config)