forked from mindspore-Ecosystem/mindspore
dcn
This commit is contained in:
parent
9726051f58
commit
bc9f96c5cb
|
@ -0,0 +1,220 @@
|
|||
# 目录
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [Deep&Cross描述](#Deep&Cross描述)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [训练](#训练)
|
||||
- [分布式训练](#分布式训练)
|
||||
- [评估过程](#评估过程)
|
||||
- [评估](#评估)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# Deep&Cross描述
|
||||
|
||||
Deep & Cross Network(DCN)是来自于 2017 年 google 和 Stanford 共同完成的一篇工作,用于广告场景下的点击率预估(CTR),对比同样来自 google 的工作 Wide & Deep,DCN 不需要特征工程来获得高阶的交叉特征,对比 FM 系列的模型,DCN 拥有更高的计算效率并且能够提取到更高阶的交叉特征。
|
||||
|
||||
[论文](https://arxiv.org/pdf/1708.05123.pdf)
|
||||
|
||||
# 模型架构
|
||||
|
||||
DCN模型最开始是Embedding and stacking layer,然后是并行的Cross Network和Deep Network,最后是Combination Layer把Cross Network和Deep Network的结果组合得到输出。
|
||||
|
||||
# 数据集
|
||||
|
||||
使用的数据集:[1] Guo H 、Tang R和Ye Y等人使用的数据集。 DeepFM: A Factorization-Machine based Neural Network for CTR Prediction[J].2017.
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件(GPU)
|
||||
- 使用GPU处理器来搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install)
|
||||
|
||||
# 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
1. 克隆代码。
|
||||
|
||||
```bash
|
||||
git clone https://gitee.com/mindspore/mindspore.git
|
||||
cd mindspore/model_zoo/official/recommend/deep_and_cross
|
||||
```
|
||||
|
||||
2. 下载数据集。
|
||||
|
||||
> 请参考[1]获得下载链接。
|
||||
|
||||
```bash
|
||||
mkdir -p data/origin_data && cd data/origin_data
|
||||
wget DATA_LINK
|
||||
tar -zxvf dac.tar.gz
|
||||
```
|
||||
|
||||
3. 使用此脚本预处理数据。处理过程可能需要一小时,生成的MindRecord数据存放在data/mindrecord路径下。
|
||||
|
||||
```bash
|
||||
python src/preprocess_data.py --data_path=./data/ --dense_dim=13 --slot_dim=26 --threshold=100 --train_line_count=45840617 --skip_id_convert=0
|
||||
```
|
||||
|
||||
4. 开始训练。
|
||||
|
||||
数据集准备就绪后,即可在GPU上训练和评估模型。
|
||||
|
||||
GPU单卡训练命令如下:
|
||||
|
||||
```bash
|
||||
#单卡训练示例
|
||||
python train.py --device_target="GPU" > output.train.log 2>&1 &
|
||||
#或
|
||||
sh scripts/run_train_gpu.sh
|
||||
```
|
||||
|
||||
GPU 8卡训练命令如下:
|
||||
|
||||
```bash
|
||||
#8卡训练示例
|
||||
sh scripts/run_train_multi_gpu.sh
|
||||
```
|
||||
|
||||
5. 开始验证。
|
||||
|
||||
训练完毕后,按如下操作评估模型。
|
||||
|
||||
```bash
|
||||
python eval.py --ckpt_path=CHECKPOINT_PATH
|
||||
#或
|
||||
sh scripts/run_eval.sh CHECKPOINT_PATH
|
||||
```
|
||||
|
||||
# 脚本说明
|
||||
|
||||
## 脚本及样例代码
|
||||
|
||||
```bash
|
||||
├── model_zoo
|
||||
├── README.md // 所有模型相关说明
|
||||
├── deep_and_cross
|
||||
├── README.md // deep and cross相关说明
|
||||
├── scripts
|
||||
│ ├──run_train_gpu.sh // GPU处理器单卡训练shell脚本
|
||||
│ ├──run_train_multi_gpu.sh // GPU处理器8卡训练shell脚本
|
||||
│ ├──run_eval.sh // 评估的shell脚本
|
||||
├── src
|
||||
│ ├──dataset.py // 创建数据集
|
||||
│ ├──deepandcross.py // deepandcross架构
|
||||
│ ├──callback.py // 定义回调
|
||||
│ ├──config.py // 参数配置
|
||||
│ ├──metrics.py // 定义AUC
|
||||
│ ├──preprocess_data.py // 预处理数据,生成mindrecord文件
|
||||
├── train.py // 训练脚本
|
||||
├── eval.py // 评估脚本
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
|
||||
在config.py中可以同时配置训练参数和评估参数。
|
||||
|
||||
- 配置GoogleNet和CIFAR-10数据集。
|
||||
|
||||
```python
|
||||
self.device_target = "GPU" #设备选择
|
||||
self.device_id = 0 #用于训练或评估数据集的设备ID
|
||||
self.epochs = 10 #训练轮数
|
||||
self.batch_size = 16000 #batch size大小
|
||||
self.deep_layer_dim = [1024, 1024] #deep and cross deeplayer层大小
|
||||
self.cross_layer_num = 6 #deep and cross crosslayer层数
|
||||
self.eval_file_name = "eval.log" #验证结果输出文件
|
||||
self.loss_file_name = "loss.log" #loss结果输出文件
|
||||
self.ckpt_path = "./checkpoints/" #checkpoints输出目录
|
||||
self.dataset_type = "mindrecord" #数据格式
|
||||
self.is_distributed = 0 #是否分布式训练
|
||||
|
||||
```
|
||||
|
||||
更多配置细节请参考脚本`config.py`。
|
||||
|
||||
## 训练过程
|
||||
|
||||
### 训练
|
||||
|
||||
- GPU处理器环境运行
|
||||
|
||||
```bash
|
||||
sh scripts/run_train_gpu.sh
|
||||
```
|
||||
|
||||
上述bash命令将在后台运行,您可以通过output.train.log文件查看结果。
|
||||
|
||||
训练结束后,您可在默认`./checkpoints/`脚本文件夹下找到检查点文件。
|
||||
|
||||
### 分布式训练
|
||||
|
||||
- GPU处理器环境运行
|
||||
|
||||
```bash
|
||||
sh scripts/run_train_multi_gpu.sh
|
||||
```
|
||||
|
||||
上述shell脚本将在后台运行分布训练。您可以通过output.multi_gpu.train.log文件查看结果。
|
||||
|
||||
## 评估过程
|
||||
|
||||
### 评估
|
||||
|
||||
- 在GPU处理器环境运行时评估CIFAR-10数据集
|
||||
|
||||
在运行以下命令之前,请检查用于评估的检查点路径。请将检查点路径设置为相对路径,例如“./checkpoints/deep_and_cross-10_2582.ckpt”。
|
||||
|
||||
```bash
|
||||
python eval.py --ckpt_path=[CHECKPOINT_PATH] > eval.log 2>&1 &
|
||||
```
|
||||
|
||||
上述python命令将在后台运行,您可以通过eval.log文件查看结果。
|
||||
|
||||
或者,
|
||||
|
||||
```bash
|
||||
sh scripts/run_eval.sh [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
上述python命令将在后台运行,您可以通过output.eval.log文件查看结果。
|
||||
|
||||
# 模型描述
|
||||
|
||||
## 性能
|
||||
|
||||
### 评估性能
|
||||
|
||||
#### CRITEO数据集
|
||||
|
||||
| 参数 | GPU单卡 |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| 资源 | NV Tesla V100-32G |
|
||||
| 上传日期 | 2021-06-30 |
|
||||
| MindSpore版本 | 1.2.0 |
|
||||
| 数据集 | CRITEO |
|
||||
| 训练参数 | epoch=10, steps=2582, batch_size = 16000, lr=0.0001 |
|
||||
| 优化器 | Adam |
|
||||
| 损失函数 | Sigmoid交叉熵 |
|
||||
| 输出 | 概率 |
|
||||
| 损失 | 0.4388 |
|
||||
| 速度 | 107-110毫秒/步 |
|
||||
| 总时长 | 约2800秒 |
|
||||
| 微调检查点 | 75M (.ckpt文件) |
|
||||
| 推理AUC | 0.803786 |
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
from mindspore import Model, context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.deep_and_cross import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, DeepCrossModel
|
||||
from src.callbacks import EvalCallBack
|
||||
from src.datasets import create_dataset, DataType
|
||||
from src.metrics import AUCMetric
|
||||
from src.config import DeepCrossConfig
|
||||
|
||||
|
||||
def get_DCN_net(configure):
|
||||
"""
|
||||
Get network of deep&cross model.
|
||||
"""
|
||||
DCN_net = DeepCrossModel(configure)
|
||||
|
||||
loss_net = NetWithLossClass(DCN_net)
|
||||
train_net = TrainStepWrap(loss_net)
|
||||
eval_net = PredictWithSigmoid(DCN_net)
|
||||
|
||||
return train_net, eval_net
|
||||
|
||||
class ModelBuilder():
|
||||
"""
|
||||
Build the model.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_hook(self):
|
||||
pass
|
||||
|
||||
def get_net(self, configure):
|
||||
return get_DCN_net(configure)
|
||||
|
||||
def test_eval(configure):
|
||||
"""
|
||||
test_eval
|
||||
"""
|
||||
data_path = configure.data_path
|
||||
batch_size = configure.batch_size
|
||||
field_size = configure.field_size
|
||||
if configure.dataset_type == "tfrecord":
|
||||
dataset_type = DataType.TFRECORD
|
||||
elif configure.dataset_type == "mindrecord":
|
||||
dataset_type = DataType.MINDRECORD
|
||||
else:
|
||||
dataset_type = DataType.H5
|
||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
|
||||
batch_size=batch_size, data_type=dataset_type, target_column=field_size+1)
|
||||
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
|
||||
net_builder = ModelBuilder()
|
||||
train_net, eval_net = net_builder.get_net(configure)
|
||||
ckpt_path = configure.ckpt_path
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
load_param_into_net(eval_net, param_dict)
|
||||
auc_metric = AUCMetric()
|
||||
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
|
||||
eval_callback = EvalCallBack(model, ds_eval, auc_metric, configure)
|
||||
model.eval(ds_eval, callbacks=eval_callback)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = DeepCrossConfig()
|
||||
config.argparse_init()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
test_eval(config)
|
|
@ -0,0 +1,32 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_eval.sh ckpt_filename "
|
||||
echo "for example: bash run_eval.sh 0_8-24_1012.ckpt"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
|
||||
CKPT=$1
|
||||
|
||||
loc=$(pwd)
|
||||
|
||||
ckpt_path=$loc"/"$CKPT
|
||||
echo "ckpt_path: "$ckpt_path
|
||||
python eval.py --ckpt_path=$ckpt_path > output.eval.log 2>&1 &
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_train_gpu.sh"
|
||||
echo "for example: bash src/run_train_gpu.sh"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
|
||||
python train.py \
|
||||
--device_target="GPU" > output.train.log 2>&1 &
|
|
@ -0,0 +1,27 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_train_multi_gpu.sh"
|
||||
echo "for example: bash run_train_multi_gpu.sh"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
|
||||
|
||||
mpirun --allow-run-as-root -n 8 python train.py \
|
||||
--device_target="GPU" --is_distributed=1 > output.multi_gpu.train.log 2>&1 &
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""callbacks"""
|
||||
|
||||
|
||||
import time
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore import context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import get_rank
|
||||
|
||||
def add_write(file_path, out_str):
|
||||
"""
|
||||
add lines to the file
|
||||
"""
|
||||
with open(file_path, 'a+', encoding="utf-8") as file_out:
|
||||
file_out.write(out_str + "\n")
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
|
||||
If the loss is NAN or INF, terminate the training.
|
||||
|
||||
Note:
|
||||
If per_print_times is 0, do NOT print loss.
|
||||
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
def __init__(self, config=None, per_print_times=1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("per_print_times must be in and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self.config = config
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""Monitor the loss in training."""
|
||||
cb_params = run_context.original_args()
|
||||
loss = cb_params.net_outputs.asnumpy()
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
cur_num = cb_params.cur_step_num
|
||||
rank_id = 0
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL,
|
||||
ParallelMode.DATA_PARALLEL):
|
||||
rank_id = get_rank()
|
||||
|
||||
print("===loss===", rank_id, cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||
loss, flush=True)
|
||||
# raise ValueError
|
||||
if self._per_print_times != 0 and cur_num % self._per_print_times == 0 and self.config is not None:
|
||||
loss_file = open(self.config.loss_file_name, "a+")
|
||||
loss_file.write("epoch: %s, step: %s, loss: %s" %
|
||||
(cb_params.cur_epoch_num, cur_step_in_epoch, loss))
|
||||
loss_file.write("\n")
|
||||
loss_file.close()
|
||||
print("epoch: %s, step: %s, loss: %s" %
|
||||
(cb_params.cur_epoch_num, cur_step_in_epoch, loss))
|
||||
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in evaluating.
|
||||
|
||||
If the loss is NAN or INF, terminate evaluating.
|
||||
|
||||
Note:
|
||||
If per_print_times is 0, do NOT print loss.
|
||||
|
||||
Args:
|
||||
print_per_step (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1, host_device_mix=False):
|
||||
super(EvalCallBack, self).__init__()
|
||||
if not isinstance(print_per_step, int) or print_per_step < 0:
|
||||
raise ValueError("print_per_step must be int and >= 0.")
|
||||
self.print_per_step = print_per_step
|
||||
self.model = model
|
||||
self.eval_dataset = eval_dataset
|
||||
self.aucMetric = auc_metric
|
||||
self.aucMetric.clear()
|
||||
self.eval_file_name = config.eval_file_name
|
||||
self.eval_values = []
|
||||
self.host_device_mix = host_device_mix
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""
|
||||
epoch end
|
||||
"""
|
||||
self.aucMetric.clear()
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
context.set_auto_parallel_context(strategy_ckpt_save_file="",
|
||||
strategy_ckpt_load_file="./strategy_train.ckpt")
|
||||
rank_id = 0
|
||||
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL,
|
||||
ParallelMode.DATA_PARALLEL):
|
||||
rank_id = get_rank()
|
||||
start_time = time.time()
|
||||
out = self.model.eval(self.eval_dataset, dataset_sink_mode=(not self.host_device_mix))
|
||||
end_time = time.time()
|
||||
eval_time = int(end_time - start_time)
|
||||
|
||||
time_str = time.strftime("%Y-%m-%d %H:%M%S", time.localtime())
|
||||
out_str = "{} == Rank: {} == EvalCallBack model.eval(): {}; eval_time: {}s".\
|
||||
format(time_str, rank_id, out.values(), eval_time)
|
||||
print(out_str)
|
||||
self.eval_values = out.values()
|
||||
add_write(self.eval_file_name, out_str)
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""deep and cross config"""
|
||||
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
def argparse_init():
|
||||
"""
|
||||
argparse_init
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='DeepCross')
|
||||
parser.add_argument("--device_target", type=str, default="GPU", choices=["Ascend", "GPU"],
|
||||
help="device where the code will be implemented. (Default: GPU)")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||
parser.add_argument("--data_path", type=str, default="./data/mindrecord",
|
||||
help="This should be set to the same directory given to the data_download's data_dir argument")
|
||||
parser.add_argument("--epochs", type=int, default=10, help="Total train epochs")
|
||||
parser.add_argument("--full_batch", type=int, default=0, help="Enable loading the full batch ")
|
||||
parser.add_argument("--batch_size", type=int, default=16000, help="Training batch size.")
|
||||
parser.add_argument("--eval_batch_size", type=int, default=16000, help="Eval batch size.")
|
||||
parser.add_argument("--field_size", type=int, default=39, help="The number of features.")
|
||||
parser.add_argument("--vocab_size", type=int, default=200000, help="The total features of dataset.")
|
||||
parser.add_argument("--emb_dim", type=int, default=30, help="The dense embedding dimension of sparse feature.")
|
||||
parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 1024],
|
||||
help="The dimension of all deep layers.")
|
||||
parser.add_argument("--cross_layer_num", type=int, default=6, help="Cross layer num")
|
||||
parser.add_argument("--deep_layer_act", type=str, default='relu',
|
||||
help="The activation function of all deep layers.")
|
||||
parser.add_argument("--keep_prob", type=float, default=1.0, help="The keep rate in dropout layer.")
|
||||
parser.add_argument("--dropout_flag", type=int, default=0, help="Enable dropout")
|
||||
parser.add_argument("--output_path", type=str, default="./output/")
|
||||
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/", help="The location of the checkpoint file.")
|
||||
parser.add_argument("--eval_file_name", type=str, default="eval.log", help="Eval output file.")
|
||||
parser.add_argument("--loss_file_name", type=str, default="loss.log", help="Loss output file.")
|
||||
parser.add_argument("--host_device_mix", type=int, default=0, help="Enable host device mode or not")
|
||||
parser.add_argument("--dataset_type", type=str, default="mindrecord", help="tfrecord/mindrecord/hd5")
|
||||
parser.add_argument("--parameter_server", type=int, default=0, help="Open parameter server of not")
|
||||
parser.add_argument("--is_distributed", type=int, default=0, help="0/1")
|
||||
return parser
|
||||
|
||||
|
||||
class DeepCrossConfig():
|
||||
"""
|
||||
DeepCrossConfig
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.device_target = "GPU"
|
||||
self.device_id = 0
|
||||
self.data_path = "./test_raw_data/"
|
||||
self.full_batch = False
|
||||
self.epochs = 10
|
||||
self.batch_size = 16000
|
||||
self.eval_batch_size = 16000
|
||||
self.field_size = 39
|
||||
self.vocab_size = 200000
|
||||
self.emb_dim = 27
|
||||
self.deep_layer_dim = [1024, 1024]
|
||||
self.cross_layer_num = 6
|
||||
self.deep_layer_act = 'relu'
|
||||
self.weight_bias_init = ['normal', 'normal']
|
||||
self.emb_init = 'normal'
|
||||
self.init_args = [-0.01, 0.01]
|
||||
self.dropout_flag = False
|
||||
self.keep_prob = 1.0
|
||||
|
||||
self.output_path = "./output"
|
||||
self.eval_file_name = "eval.log"
|
||||
self.loss_file_name = "loss.log"
|
||||
self.ckpt_path = "./checkpoints/"
|
||||
self.host_device_mix = 0
|
||||
self.dataset_type = "mindrecord"
|
||||
self.parameter_server = 0
|
||||
self.is_distributed = 0
|
||||
|
||||
|
||||
def argparse_init(self):
|
||||
"""
|
||||
argparse_init
|
||||
"""
|
||||
parser = argparse_init()
|
||||
args, _ = parser.parse_known_args()
|
||||
self.device_target = args.device_target
|
||||
self.device_id = args.device_id
|
||||
self.data_path = args.data_path
|
||||
self.epochs = args.epochs
|
||||
self.full_batch = bool(args.full_batch)
|
||||
self.batch_size = args.batch_size
|
||||
self.eval_batch_size = args.eval_batch_size
|
||||
self.field_size = args.field_size
|
||||
self.vocab_size = args.vocab_size
|
||||
self.emb_dim = args.emb_dim
|
||||
self.deep_layer_dim = args.deep_layer_dim
|
||||
self.deep_layer_act = args.deep_layer_act
|
||||
self.keep_prob = args.keep_prob
|
||||
self.weight_bias_init = ['normal', 'normal']
|
||||
self.emb_init = 'normal'
|
||||
self.init_args = [-0.01, 0.01]
|
||||
self.dropout_flag = bool(args.dropout_flag)
|
||||
self.cross_layer_num = args.cross_layer_num
|
||||
|
||||
self.output_path = args.output_path
|
||||
self.eval_file_name = args.eval_file_name
|
||||
self.loss_file_name = args.loss_file_name
|
||||
self.ckpt_path = args.ckpt_path
|
||||
self.host_device_mix = args.host_device_mix
|
||||
self.dataset_type = args.dataset_type
|
||||
self.parameter_server = args.parameter_server
|
||||
self.is_distributed = args.is_distributed
|
|
@ -0,0 +1,358 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train_dataset."""
|
||||
|
||||
|
||||
import os
|
||||
import math
|
||||
from enum import Enum
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
|
||||
class DataType(Enum):
|
||||
"""
|
||||
Enumerate supported dataset format.
|
||||
"""
|
||||
MINDRECORD = 1
|
||||
TFRECORD = 2
|
||||
H5 = 3
|
||||
|
||||
|
||||
class H5Dataset():
|
||||
"""
|
||||
H5DataSet
|
||||
"""
|
||||
input_length = 39
|
||||
|
||||
def __init__(self, data_path, train_mode=True, train_num_of_parts=21,
|
||||
test_num_of_parts=3):
|
||||
self._hdf_data_dir = data_path
|
||||
self._is_training = train_mode
|
||||
|
||||
if self._is_training:
|
||||
self._file_prefix = 'train'
|
||||
self._num_of_parts = train_num_of_parts
|
||||
else:
|
||||
self._file_prefix = 'test'
|
||||
self._num_of_parts = test_num_of_parts
|
||||
|
||||
self.data_size = self._bin_count(self._hdf_data_dir, self._file_prefix,
|
||||
self._num_of_parts)
|
||||
print("data_size: {}".format(self.data_size))
|
||||
|
||||
def _bin_count(self, hdf_data_dir, file_prefix, num_of_parts):
|
||||
size = 0
|
||||
for part in range(num_of_parts):
|
||||
_y = pd.read_hdf(os.path.join(hdf_data_dir,
|
||||
file_prefix + '_output_part_' + str(
|
||||
part) + '.h5'))
|
||||
size += _y.shape[0]
|
||||
return size
|
||||
|
||||
def _iterate_hdf_files_(self, num_of_parts=None,
|
||||
shuffle_block=False):
|
||||
"""
|
||||
iterate among hdf files(blocks). when the whole data set is finished, the iterator restarts
|
||||
from the beginning, thus the data stream will never stop
|
||||
:param train_mode: True or false,false is eval_mode,
|
||||
this file iterator will go through the train set
|
||||
:param num_of_parts: number of files
|
||||
:param shuffle_block: shuffle block files at every round
|
||||
:return: input_hdf_file_name, output_hdf_file_name, finish_flag
|
||||
"""
|
||||
parts = np.arange(num_of_parts)
|
||||
while True:
|
||||
if shuffle_block:
|
||||
for _ in range(int(shuffle_block)):
|
||||
np.random.shuffle(parts)
|
||||
for i, p in enumerate(parts):
|
||||
yield os.path.join(self._hdf_data_dir,
|
||||
self._file_prefix + '_input_part_' + str(
|
||||
p) + '.h5'), \
|
||||
os.path.join(self._hdf_data_dir,
|
||||
self._file_prefix + '_output_part_' + str(
|
||||
p) + '.h5'), i + 1 == len(parts)
|
||||
|
||||
def _generator(self, X, y, batch_size, shuffle=True):
|
||||
"""
|
||||
should be accessed only in private
|
||||
:param X:
|
||||
:param y:
|
||||
:param batch_size:
|
||||
:param shuffle:
|
||||
:return:
|
||||
"""
|
||||
number_of_batches = np.ceil(1. * X.shape[0] / batch_size)
|
||||
counter = 0
|
||||
finished = False
|
||||
sample_index = np.arange(X.shape[0])
|
||||
if shuffle:
|
||||
for _ in range(int(shuffle)):
|
||||
np.random.shuffle(sample_index)
|
||||
assert X.shape[0] > 0
|
||||
while True:
|
||||
batch_index = sample_index[
|
||||
batch_size * counter: batch_size * (counter + 1)]
|
||||
X_batch = X[batch_index]
|
||||
y_batch = y[batch_index]
|
||||
counter += 1
|
||||
yield X_batch, y_batch, finished
|
||||
if counter == number_of_batches:
|
||||
counter = 0
|
||||
finished = True
|
||||
|
||||
def batch_generator(self, batch_size=1000,
|
||||
random_sample=False, shuffle_block=False):
|
||||
"""
|
||||
:param train_mode: True or false,false is eval_mode,
|
||||
:param batch_size
|
||||
:param num_of_parts: number of files
|
||||
:param random_sample: if True, will shuffle
|
||||
:param shuffle_block: shuffle file blocks at every round
|
||||
:return:
|
||||
"""
|
||||
|
||||
for hdf_in, hdf_out, _ in self._iterate_hdf_files_(self._num_of_parts,
|
||||
shuffle_block):
|
||||
start = stop = None
|
||||
X_all = pd.read_hdf(hdf_in, start=start, stop=stop).values
|
||||
y_all = pd.read_hdf(hdf_out, start=start, stop=stop).values
|
||||
data_gen = self._generator(X_all, y_all, batch_size,
|
||||
shuffle=random_sample)
|
||||
finished = False
|
||||
|
||||
while not finished:
|
||||
X, y, finished = data_gen.__next__()
|
||||
X_id = X[:, 0:self.input_length]
|
||||
X_va = X[:, self.input_length:]
|
||||
yield np.array(X_id.astype(dtype=np.int32)), np.array(
|
||||
X_va.astype(dtype=np.float32)), np.array(
|
||||
y.astype(dtype=np.float32))
|
||||
|
||||
|
||||
def _get_h5_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000):
|
||||
"""
|
||||
get_h5_dataset
|
||||
"""
|
||||
data_para = {
|
||||
'batch_size': batch_size,
|
||||
}
|
||||
if train_mode:
|
||||
data_para['random_sample'] = True
|
||||
data_para['shuffle_block'] = True
|
||||
|
||||
h5_dataset = H5Dataset(data_path=data_dir, train_mode=train_mode)
|
||||
numbers_of_batch = math.ceil(h5_dataset.data_size / batch_size)
|
||||
|
||||
def _iter_h5_data():
|
||||
train_eval_gen = h5_dataset.batch_generator(**data_para)
|
||||
for _ in range(0, numbers_of_batch, 1):
|
||||
yield train_eval_gen.__next__()
|
||||
|
||||
ds = de.GeneratorDataset(_iter_h5_data(), ["ids", "weights", "labels"])
|
||||
ds = ds.repeat(epochs)
|
||||
return ds
|
||||
|
||||
|
||||
def _padding_func(batch_size, manual_shape, target_column, field_size=39):
|
||||
"""
|
||||
get padding_func
|
||||
"""
|
||||
if manual_shape:
|
||||
generate_concat_offset = [item[0]+item[1] for item in manual_shape]
|
||||
part_size = int(target_column / len(generate_concat_offset))
|
||||
filled_value = []
|
||||
for i in range(field_size, target_column):
|
||||
filled_value.append(generate_concat_offset[i//part_size]-1)
|
||||
print("Filed Value:", filled_value)
|
||||
|
||||
def padding_func(x, y, z):
|
||||
x = np.array(x).flatten().reshape(batch_size, field_size)
|
||||
y = np.array(y).flatten().reshape(batch_size, field_size)
|
||||
z = np.array(z).flatten().reshape(batch_size, 1)
|
||||
|
||||
x_id = np.ones((batch_size, target_column - field_size),
|
||||
dtype=np.int32) * filled_value
|
||||
x_id = np.concatenate([x, x_id.astype(dtype=np.int32)], axis=1)
|
||||
mask = np.concatenate(
|
||||
[y, np.zeros((batch_size, target_column-39), dtype=np.float32)], axis=1)
|
||||
return (x_id, mask, z)
|
||||
else:
|
||||
def padding_func(x, y, z):
|
||||
x = np.array(x).flatten().reshape(batch_size, field_size)
|
||||
y = np.array(y).flatten().reshape(batch_size, field_size)
|
||||
z = np.array(z).flatten().reshape(batch_size, 1)
|
||||
return (x, y, z)
|
||||
return padding_func
|
||||
|
||||
|
||||
def _get_tf_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000,
|
||||
line_per_sample=1000, rank_size=None, rank_id=None,
|
||||
manual_shape=None, target_column=40):
|
||||
"""
|
||||
get_tf_dataset
|
||||
"""
|
||||
dataset_files = []
|
||||
file_prefix_name = 'train' if train_mode else 'test'
|
||||
shuffle = train_mode
|
||||
for (dirpath, _, filenames) in os.walk(data_dir):
|
||||
for filename in filenames:
|
||||
if file_prefix_name in filename and "tfrecord" in filename:
|
||||
dataset_files.append(os.path.join(dirpath, filename))
|
||||
schema = de.Schema()
|
||||
schema.add_column('feat_ids', de_type=mstype.int32)
|
||||
schema.add_column('feat_vals', de_type=mstype.float32)
|
||||
schema.add_column('label', de_type=mstype.float32)
|
||||
if rank_size is not None and rank_id is not None:
|
||||
ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, schema=schema, num_parallel_workers=8,
|
||||
num_shards=rank_size, shard_id=rank_id, shard_equal_rows=True)
|
||||
else:
|
||||
ds = de.TFRecordDataset(dataset_files=dataset_files,
|
||||
shuffle=shuffle, schema=schema, num_parallel_workers=8)
|
||||
ds = ds.batch(int(batch_size / line_per_sample),
|
||||
drop_remainder=True)
|
||||
|
||||
ds = ds.map(operations=_padding_func(batch_size, manual_shape, target_column),
|
||||
input_columns=['feat_ids', 'feat_vals', 'label'],
|
||||
column_order=['feat_ids', 'feat_vals', 'label'], num_parallel_workers=8)
|
||||
ds = ds.repeat(epochs)
|
||||
return ds
|
||||
|
||||
|
||||
def _get_mindrecord_dataset(directory, train_mode=True, epochs=10, batch_size=16000,
|
||||
line_per_sample=1000, rank_size=None, rank_id=None,
|
||||
manual_shape=None, target_column=40):
|
||||
"""
|
||||
Get dataset with mindrecord format.
|
||||
|
||||
Args:
|
||||
directory (str): Dataset directory.
|
||||
train_mode (bool): Whether dataset is use for train or eval (default=True).
|
||||
epochs (int): Dataset epoch size (default=1).
|
||||
batch_size (int): Dataset batch size (default=1000).
|
||||
line_per_sample (int): The number of sample per line (default=1000).
|
||||
rank_size (int): The number of device, not necessary for single device (default=None).
|
||||
rank_id (int): Id of device, not necessary for single device (default=None).
|
||||
|
||||
Returns:
|
||||
Dataset.
|
||||
"""
|
||||
file_prefix_name = 'train_input_part.mindrecord' if train_mode else 'test_input_part.mindrecord'
|
||||
file_suffix_name = '00' if train_mode else '0'
|
||||
shuffle = train_mode
|
||||
|
||||
if rank_size is not None and rank_id is not None:
|
||||
ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
|
||||
columns_list=['feat_ids', 'feat_vals', 'label'],
|
||||
num_shards=rank_size, shard_id=rank_id, shuffle=shuffle,
|
||||
num_parallel_workers=8)
|
||||
else:
|
||||
ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
|
||||
columns_list=['feat_ids', 'feat_vals', 'label'],
|
||||
shuffle=shuffle, num_parallel_workers=8)
|
||||
ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True)
|
||||
ds = ds.map(_padding_func(batch_size, manual_shape, target_column, target_column-1),
|
||||
input_columns=['feat_ids', 'feat_vals', 'label'],
|
||||
column_order=['feat_ids', 'feat_vals', 'label'],
|
||||
num_parallel_workers=8)
|
||||
ds = ds.repeat(epochs)
|
||||
return ds
|
||||
|
||||
|
||||
def _get_vocab_size(target_column_number, worker_size, total_vocab_size, multiply=False, per_vocab_size=None):
|
||||
"""
|
||||
get_vocab_size
|
||||
"""
|
||||
# Only 39
|
||||
inidival_vocabs = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 691, 540, 20855, 23639, 182, 15,
|
||||
10091, 347, 4, 16366, 4494, 21293, 3103, 27, 6944, 22366, 11, 3267, 1610,
|
||||
5, 21762, 14, 15, 15030, 61, 12220]
|
||||
|
||||
new_vocabs = inidival_vocabs + [1] * \
|
||||
(target_column_number - len(inidival_vocabs))
|
||||
part_size = int(target_column_number / worker_size)
|
||||
|
||||
# According to the workers, we merge some fields into the same part
|
||||
new_vocab_size = []
|
||||
for i in range(0, target_column_number, part_size):
|
||||
new_vocab_size.append(sum(new_vocabs[i: i + part_size]))
|
||||
|
||||
index_offsets = [0]
|
||||
|
||||
# The gold feature numbers ared used to calculate the offset
|
||||
features = [item for item in new_vocab_size]
|
||||
|
||||
# According to the per_vocab_size, maxize the vocab size
|
||||
if per_vocab_size is not None:
|
||||
new_vocab_size = [per_vocab_size] * worker_size
|
||||
else:
|
||||
# Expands the vocabulary of each field by the multiplier
|
||||
if multiply is True:
|
||||
cur_sum = sum(new_vocab_size)
|
||||
k = total_vocab_size/cur_sum
|
||||
new_vocab_size = [
|
||||
math.ceil(int(item*k)/worker_size)*worker_size for item in new_vocab_size]
|
||||
new_vocab_size = [(item // 8 + 1)*8 for item in new_vocab_size]
|
||||
|
||||
else:
|
||||
if total_vocab_size > sum(new_vocab_size):
|
||||
new_vocab_size[-1] = total_vocab_size - \
|
||||
sum(new_vocab_size[:-1])
|
||||
new_vocab_size = [item for item in new_vocab_size]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Please providede the correct vocab size, now is {}".format(total_vocab_size))
|
||||
|
||||
for i in range(worker_size-1):
|
||||
off = index_offsets[i] + features[i]
|
||||
index_offsets.append(off)
|
||||
|
||||
print("the offset: ", index_offsets)
|
||||
manual_shape = tuple(
|
||||
((new_vocab_size[i], index_offsets[i]) for i in range(worker_size)))
|
||||
vocab_total = sum(new_vocab_size)
|
||||
return manual_shape, vocab_total
|
||||
|
||||
|
||||
def compute_manual_shape(config, worker_size):
|
||||
target_column = (config.field_size // worker_size + 1) * worker_size
|
||||
config.field_size = target_column
|
||||
manual_shape, vocab_total = _get_vocab_size(target_column, worker_size, total_vocab_size=config.vocab_size,
|
||||
per_vocab_size=None, multiply=False)
|
||||
config.manual_shape = manual_shape
|
||||
config.vocab_size = int(vocab_total)
|
||||
|
||||
|
||||
def create_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000,
|
||||
data_type=DataType.TFRECORD, line_per_sample=1000,
|
||||
rank_size=None, rank_id=None, manual_shape=None, target_column=40):
|
||||
"""
|
||||
create_dataset
|
||||
"""
|
||||
if data_type == DataType.TFRECORD:
|
||||
return _get_tf_dataset(data_dir, train_mode, epochs, batch_size,
|
||||
line_per_sample, rank_size=rank_size, rank_id=rank_id,
|
||||
manual_shape=manual_shape, target_column=target_column)
|
||||
if data_type == DataType.MINDRECORD:
|
||||
return _get_mindrecord_dataset(data_dir, train_mode, epochs, batch_size,
|
||||
line_per_sample, rank_size=rank_size, rank_id=rank_id,
|
||||
manual_shape=manual_shape, target_column=target_column)
|
||||
|
||||
if rank_size > 1:
|
||||
raise RuntimeError("please use tfrecord dataset.")
|
||||
return _get_h5_dataset(data_dir, train_mode, epochs, batch_size)
|
|
@ -0,0 +1,306 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""deep and cross net"""
|
||||
|
||||
|
||||
import numpy as np
|
||||
from mindspore import nn
|
||||
from mindspore import ParameterTuple
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn import Dropout
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.common.initializer import Uniform, initializer
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import Squeeze
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
|
||||
ms_type = mstype.float32
|
||||
ds_type = mstype.int32
|
||||
|
||||
|
||||
def init_method(method, shape, name, max_val=1.0):
|
||||
'''
|
||||
parameter init method
|
||||
'''
|
||||
if method in ['uniform']:
|
||||
params = Parameter(initializer(
|
||||
Uniform(max_val), shape, ms_type), name=name)
|
||||
elif method == "one":
|
||||
params = Parameter(initializer("ones", shape, ms_type), name=name)
|
||||
elif method == 'zero':
|
||||
params = Parameter(initializer("zeros", shape, ms_type), name=name)
|
||||
elif method == "normal":
|
||||
params = Parameter(initializer("normal", shape, ms_type), name=name)
|
||||
return params
|
||||
|
||||
|
||||
def normal_weight(shape, num_units):
|
||||
norm = np.random.normal(0.0, num_units**-0.5, shape).astype(np.float32)
|
||||
return Tensor(norm)
|
||||
|
||||
|
||||
|
||||
class DenseLayer(nn.Cell):
|
||||
"""
|
||||
Dense Layer for DCN Model;
|
||||
Containing: activation, matmul, bias_add;
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, output_dim, weight_bias_init, act_str,
|
||||
keep_prob=0.5, use_activation=True, convert_dtype=True, drop_out=False):
|
||||
super(DenseLayer, self).__init__()
|
||||
weight_init, bias_init = weight_bias_init
|
||||
self.weight = init_method(
|
||||
weight_init, [input_dim, output_dim], name="weight")
|
||||
self.bias = init_method(bias_init, [output_dim], name="bias")
|
||||
self.act_func = self._init_activation(act_str)
|
||||
self.matmul = P.MatMul(transpose_b=False)
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.cast = P.Cast()
|
||||
self.dropout = Dropout(keep_prob=keep_prob)
|
||||
self.use_activation = use_activation
|
||||
self.convert_dtype = convert_dtype
|
||||
self.drop_out = drop_out
|
||||
|
||||
def _init_activation(self, act_str):
|
||||
act_str = act_str.lower()
|
||||
if act_str == "relu":
|
||||
act_func = P.ReLU()
|
||||
elif act_str == "sigmoid":
|
||||
act_func = P.Sigmoid()
|
||||
elif act_str == "tanh":
|
||||
act_func = P.Tanh()
|
||||
return act_func
|
||||
|
||||
def construct(self, x):
|
||||
'''
|
||||
Construct Dense layer
|
||||
'''
|
||||
if self.training and self.drop_out:
|
||||
x = self.dropout(x)
|
||||
if self.convert_dtype:
|
||||
x = self.cast(x, mstype.float16)
|
||||
weight = self.cast(self.weight, mstype.float16)
|
||||
bias = self.cast(self.bias, mstype.float16)
|
||||
wx = self.matmul(x, weight)
|
||||
wx = self.bias_add(wx, bias)
|
||||
if self.use_activation:
|
||||
wx = self.act_func(wx)
|
||||
wx = self.cast(wx, mstype.float32)
|
||||
else:
|
||||
wx = self.matmul(x, self.weight)
|
||||
wx = self.bias_add(wx, self.bias)
|
||||
if self.use_activation:
|
||||
wx = self.act_func(wx)
|
||||
return wx
|
||||
|
||||
|
||||
class CrossLayer(nn.Cell):
|
||||
"""
|
||||
Cross Layer for DCN Model;
|
||||
"""
|
||||
def __init__(self, cross_raw_dim, cross_col_dim, weight_bias_init, convert_dtype=True):
|
||||
super(CrossLayer, self).__init__()
|
||||
weight_init, bias_init = weight_bias_init
|
||||
self.cross_weight = init_method(weight_init, [cross_col_dim, 1], name="weight")
|
||||
self.cross_bias = init_method(bias_init, [cross_col_dim, 1], name="bias")
|
||||
self.matmul = nn.MatMul()
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.tensor_add = P.TensorAdd()
|
||||
self.reshape = P.Reshape()
|
||||
self.cast = P.Cast()
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.convert_dtype = convert_dtype
|
||||
self.squeeze = Squeeze(2)
|
||||
|
||||
def construct(self, inputs, x_0):
|
||||
'''
|
||||
Construct Cross layer
|
||||
'''
|
||||
x_0 = self.expand_dims(x_0, 2)
|
||||
x_l = self.expand_dims(inputs, 2)
|
||||
x_lw = C.tensor_dot(x_l, self.cross_weight, ((1,), (0,)))
|
||||
dot = self.matmul(x_0, x_lw)
|
||||
y_l = dot + self.cross_bias + x_l
|
||||
y_l = self.squeeze(y_l)
|
||||
return y_l
|
||||
|
||||
|
||||
class EmbeddingLookup(nn.Cell):
|
||||
"""
|
||||
A embeddings lookup table with a fixed dictionary and size.
|
||||
|
||||
Args:
|
||||
vocab_size (int): Size of the dictionary of embeddings.
|
||||
embedding_size (int): The size of each embedding vector.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
"""
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
embedding_size,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02):
|
||||
super(EmbeddingLookup, self).__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_size = embedding_size
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.embedding_table = Parameter(normal_weight([vocab_size, embedding_size], embedding_size))
|
||||
self.expand = P.ExpandDims()
|
||||
self.shape_flat = (-1,)
|
||||
self.gather = P.Gather()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.array_mul = P.MatMul()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, input_ids):
|
||||
"""Get a embeddings lookup table with a fixed dictionary and size."""
|
||||
input_shape = self.shape(input_ids)
|
||||
|
||||
flat_ids = self.reshape(input_ids, self.shape_flat)
|
||||
if self.use_one_hot_embeddings:
|
||||
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
|
||||
output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table)
|
||||
else:
|
||||
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
|
||||
|
||||
out_shape = input_shape + (self.embedding_size,)
|
||||
output = self.reshape(output_for_reshape, out_shape)
|
||||
return output, self.embedding_table
|
||||
|
||||
|
||||
class DeepCrossModel(nn.Cell):
|
||||
"""
|
||||
Deep and Cross Model
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(DeepCrossModel, self).__init__()
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.reshape = P.Reshape()
|
||||
self.deep_reshape = P.Reshape()
|
||||
self.deep_mul = P.Mul()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.emb_dim = config.emb_dim
|
||||
self.batch_size = config.batch_size
|
||||
self.field_size = config.field_size
|
||||
self.input_size = self.field_size * self.emb_dim
|
||||
self.deep_layer_dim = config.deep_layer_dim
|
||||
self.deep_embeddinglookup = EmbeddingLookup(self.vocab_size, self.emb_dim)
|
||||
self.cross_layer_num = config.cross_layer_num
|
||||
self.keep_prob = config.keep_prob
|
||||
|
||||
self.cross_layer_1 = CrossLayer(self.batch_size,
|
||||
self.input_size, weight_bias_init=['normal', 'normal'], convert_dtype=False)
|
||||
self.cross_layer_2 = CrossLayer(self.batch_size,
|
||||
self.input_size, weight_bias_init=['normal', 'normal'], convert_dtype=False)
|
||||
self.cross_layer_3 = CrossLayer(self.batch_size,
|
||||
self.input_size, weight_bias_init=['normal', 'normal'], convert_dtype=False)
|
||||
self.cross_layer_4 = CrossLayer(self.batch_size,
|
||||
self.input_size, weight_bias_init=['normal', 'normal'], convert_dtype=False)
|
||||
self.cross_layer_5 = CrossLayer(self.batch_size,
|
||||
self.input_size, weight_bias_init=['normal', 'normal'], convert_dtype=False)
|
||||
self.cross_layer_6 = CrossLayer(self.batch_size,
|
||||
self.input_size, weight_bias_init=['normal', 'normal'], convert_dtype=False)
|
||||
|
||||
self.dense_layer_1 = DenseLayer(self.input_size, self.deep_layer_dim[0],
|
||||
weight_bias_init=['normal', 'normal'], act_str="relu",
|
||||
keep_prob=self.keep_prob, convert_dtype=False, drop_out=False)
|
||||
self.dense_layer_2 = DenseLayer(self.deep_layer_dim[0], self.deep_layer_dim[1],
|
||||
weight_bias_init=['normal', 'normal'], act_str="relu",
|
||||
keep_prob=self.keep_prob, convert_dtype=False, drop_out=False)
|
||||
self.dense_layer_3 = DenseLayer(self.input_size + self.deep_layer_dim[1],
|
||||
1, weight_bias_init=['normal', 'normal'], act_str="sigmoid",
|
||||
keep_prob=self.keep_prob, use_activation=False,
|
||||
convert_dtype=False, drop_out=False)
|
||||
|
||||
def construct(self, id_hldr, wt_hldr):
|
||||
"""dcn construct"""
|
||||
mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1))
|
||||
deep_id_embs, _ = self.deep_embeddinglookup(id_hldr)
|
||||
vx = self.deep_mul(deep_id_embs, mask)
|
||||
input_x = self.deep_reshape(vx, (-1, self.field_size * self.emb_dim))
|
||||
d_1 = self.dense_layer_1(input_x)
|
||||
d_2 = self.dense_layer_2(d_1)
|
||||
c_1 = self.cross_layer_1(input_x, input_x)
|
||||
c_2 = self.cross_layer_2(c_1, input_x)
|
||||
c_3 = self.cross_layer_3(c_2, input_x)
|
||||
c_4 = self.cross_layer_4(c_3, input_x)
|
||||
c_5 = self.cross_layer_5(c_4, input_x)
|
||||
c_6 = self.cross_layer_6(c_5, input_x)
|
||||
out = self.concat((d_2, c_6))
|
||||
out = self.dense_layer_3(out)
|
||||
return out
|
||||
|
||||
|
||||
class NetWithLossClass(nn.Cell):
|
||||
"""
|
||||
NetWithLossClass definition.
|
||||
"""
|
||||
def __init__(self, network):
|
||||
super(NetWithLossClass, self).__init__(auto_prefix=False)
|
||||
self.loss = P.SigmoidCrossEntropyWithLogits()
|
||||
self.network = network
|
||||
self.ReduceMean = P.ReduceMean(keep_dims=False)
|
||||
|
||||
def construct(self, batch_ids, batch_wts, label):
|
||||
predict = self.network(batch_ids, batch_wts)
|
||||
log_loss = self.loss(predict, label)
|
||||
mean_log_loss = self.ReduceMean(log_loss)
|
||||
loss = mean_log_loss
|
||||
return loss
|
||||
|
||||
|
||||
class TrainStepWrap(nn.Cell):
|
||||
"""
|
||||
TrainStepWrap definition
|
||||
"""
|
||||
def __init__(self, network, lr=0.0001, eps=1e-8, loss_scale=1000.0):
|
||||
super(TrainStepWrap, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.network.set_train()
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = Adam(self.weights, learning_rate=lr, eps=eps, loss_scale=loss_scale)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = loss_scale
|
||||
|
||||
def construct(self, batch_ids, batch_wts, label):
|
||||
weights = self.weights
|
||||
loss = self.network(batch_ids, batch_wts, label)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(batch_ids, batch_wts, label, sens)
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
class PredictWithSigmoid(nn.Cell):
|
||||
"""
|
||||
Eval model with sigmoid.
|
||||
"""
|
||||
def __init__(self, network):
|
||||
super(PredictWithSigmoid, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.sigmoid = P.Sigmoid()
|
||||
|
||||
def construct(self, batch_ids, batch_wts, labels):
|
||||
logits = self.network(batch_ids, batch_wts)
|
||||
pred_probs = self.sigmoid(logits)
|
||||
return logits, pred_probs, labels
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
Area under cure metric
|
||||
"""
|
||||
|
||||
from sklearn.metrics import roc_auc_score
|
||||
from mindspore import context
|
||||
from mindspore.nn.metrics import Metric
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
|
||||
class AUCMetric(Metric):
|
||||
"""
|
||||
Area under cure metric
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(AUCMetric, self).__init__()
|
||||
self.clear()
|
||||
self.full_batch = context.get_auto_parallel_context("full_batch")
|
||||
|
||||
def clear(self):
|
||||
"""Clear the internal evaluation result."""
|
||||
self.true_labels = []
|
||||
self.pred_probs = []
|
||||
|
||||
def update(self, *inputs): # inputs
|
||||
"""Update list of predicts and labels."""
|
||||
all_predict = inputs[1].asnumpy().flatten().tolist() # predict
|
||||
all_label = inputs[2].asnumpy().flatten().tolist() # label
|
||||
self.pred_probs.extend(all_predict)
|
||||
if self.full_batch:
|
||||
rank_id = get_rank()
|
||||
group_size = get_group_size()
|
||||
gap = len(all_label) // group_size
|
||||
self.true_labels.extend(all_label[rank_id*gap: (rank_id+1)*gap])
|
||||
else:
|
||||
self.true_labels.extend(all_label)
|
||||
|
||||
def eval(self):
|
||||
if len(self.true_labels) != len(self.pred_probs):
|
||||
raise RuntimeError(
|
||||
'true_labels.size is not equal to pred_probs.size()')
|
||||
|
||||
auc = roc_auc_score(self.true_labels, self.pred_probs)
|
||||
print("====" * 20 + " auc_metric end")
|
||||
print("====" * 20 + " auc: {}".format(auc))
|
||||
return auc
|
|
@ -0,0 +1,292 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Download raw data and preprocessed data."""
|
||||
import os
|
||||
import pickle
|
||||
import collections
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
|
||||
class StatsDict():
|
||||
"""preprocessed data"""
|
||||
|
||||
def __init__(self, field_size, dense_dim, slot_dim, skip_id_convert):
|
||||
self.field_size = field_size
|
||||
self.dense_dim = dense_dim
|
||||
self.slot_dim = slot_dim
|
||||
self.skip_id_convert = bool(skip_id_convert)
|
||||
|
||||
self.val_cols = ["val_{}".format(i + 1) for i in range(self.dense_dim)]
|
||||
self.cat_cols = ["cat_{}".format(i + 1) for i in range(self.slot_dim)]
|
||||
|
||||
self.val_min_dict = {col: 0 for col in self.val_cols}
|
||||
self.val_max_dict = {col: 0 for col in self.val_cols}
|
||||
|
||||
self.cat_count_dict = {col: collections.defaultdict(int) for col in self.cat_cols}
|
||||
|
||||
self.oov_prefix = "OOV"
|
||||
|
||||
self.cat2id_dict = {}
|
||||
self.cat2id_dict.update({col: i for i, col in enumerate(self.val_cols)})
|
||||
self.cat2id_dict.update(
|
||||
{self.oov_prefix + col: i + len(self.val_cols) for i, col in enumerate(self.cat_cols)})
|
||||
|
||||
def stats_vals(self, val_list):
|
||||
"""Handling weights column"""
|
||||
assert len(val_list) == len(self.val_cols)
|
||||
|
||||
def map_max_min(i, val):
|
||||
key = self.val_cols[i]
|
||||
if val != "":
|
||||
if float(val) > self.val_max_dict[key]:
|
||||
self.val_max_dict[key] = float(val)
|
||||
if float(val) < self.val_min_dict[key]:
|
||||
self.val_min_dict[key] = float(val)
|
||||
|
||||
for i, val in enumerate(val_list):
|
||||
map_max_min(i, val)
|
||||
|
||||
def stats_cats(self, cat_list):
|
||||
"""Handling cats column"""
|
||||
|
||||
assert len(cat_list) == len(self.cat_cols)
|
||||
|
||||
def map_cat_count(i, cat):
|
||||
key = self.cat_cols[i]
|
||||
self.cat_count_dict[key][cat] += 1
|
||||
|
||||
for i, cat in enumerate(cat_list):
|
||||
map_cat_count(i, cat)
|
||||
|
||||
def save_dict(self, dict_path, prefix=""):
|
||||
with open(os.path.join(dict_path, "{}val_max_dict.pkl".format(prefix)), "wb") as file_wrt:
|
||||
pickle.dump(self.val_max_dict, file_wrt)
|
||||
with open(os.path.join(dict_path, "{}val_min_dict.pkl".format(prefix)), "wb") as file_wrt:
|
||||
pickle.dump(self.val_min_dict, file_wrt)
|
||||
with open(os.path.join(dict_path, "{}cat_count_dict.pkl".format(prefix)), "wb") as file_wrt:
|
||||
pickle.dump(self.cat_count_dict, file_wrt)
|
||||
|
||||
def load_dict(self, dict_path, prefix=""):
|
||||
with open(os.path.join(dict_path, "{}val_max_dict.pkl".format(prefix)), "rb") as file_wrt:
|
||||
self.val_max_dict = pickle.load(file_wrt)
|
||||
with open(os.path.join(dict_path, "{}val_min_dict.pkl".format(prefix)), "rb") as file_wrt:
|
||||
self.val_min_dict = pickle.load(file_wrt)
|
||||
with open(os.path.join(dict_path, "{}cat_count_dict.pkl".format(prefix)), "rb") as file_wrt:
|
||||
self.cat_count_dict = pickle.load(file_wrt)
|
||||
print("val_max_dict.items()[:50]:{}".format(list(self.val_max_dict.items())))
|
||||
print("val_min_dict.items()[:50]:{}".format(list(self.val_min_dict.items())))
|
||||
|
||||
def get_cat2id(self, threshold=100):
|
||||
for key, cat_count_d in self.cat_count_dict.items():
|
||||
new_cat_count_d = dict(filter(lambda x: x[1] > threshold, cat_count_d.items()))
|
||||
for cat_str, _ in new_cat_count_d.items():
|
||||
self.cat2id_dict[key + "_" + cat_str] = len(self.cat2id_dict)
|
||||
print("cat2id_dict.size:{}".format(len(self.cat2id_dict)))
|
||||
print("cat2id.dict.items()[:50]:{}".format(list(self.cat2id_dict.items())[:50]))
|
||||
|
||||
def map_cat2id(self, values, cats):
|
||||
"""Cat to id"""
|
||||
|
||||
def minmax_scale_value(i, val):
|
||||
max_v = float(self.val_max_dict["val_{}".format(i + 1)])
|
||||
return float(val) * 1.0 / max_v
|
||||
|
||||
id_list = []
|
||||
weight_list = []
|
||||
for i, val in enumerate(values):
|
||||
if val == "":
|
||||
id_list.append(i)
|
||||
weight_list.append(0)
|
||||
else:
|
||||
key = "val_{}".format(i + 1)
|
||||
id_list.append(self.cat2id_dict[key])
|
||||
weight_list.append(minmax_scale_value(i, float(val)))
|
||||
|
||||
for i, cat_str in enumerate(cats):
|
||||
key = "cat_{}".format(i + 1) + "_" + cat_str
|
||||
if key in self.cat2id_dict:
|
||||
if self.skip_id_convert is True:
|
||||
# For the synthetic data, if the generated id is between [0, max_vcoab], but the num examples is l
|
||||
# ess than vocab_size/ slot_nums the id will still be converted to [0, real_vocab], where real_vocab
|
||||
# the actually the vocab size, rather than the max_vocab. So a simple way to alleviate this
|
||||
# problem is skip the id convert, regarding the synthetic data id as the final id.
|
||||
id_list.append(cat_str)
|
||||
else:
|
||||
id_list.append(self.cat2id_dict[key])
|
||||
else:
|
||||
id_list.append(self.cat2id_dict[self.oov_prefix + "cat_{}".format(i + 1)])
|
||||
weight_list.append(1.0)
|
||||
return id_list, weight_list
|
||||
|
||||
|
||||
def mkdir_path(file_path):
|
||||
if not os.path.exists(file_path):
|
||||
os.makedirs(file_path)
|
||||
|
||||
|
||||
def statsdata(file_path, dict_output_path, recommendation_dataset_stats_dict, dense_dim=13, slot_dim=26):
|
||||
"""Preprocess data and save data"""
|
||||
with open(file_path, encoding="utf-8") as file_in:
|
||||
errorline_list = []
|
||||
count = 0
|
||||
for line in file_in:
|
||||
count += 1
|
||||
line = line.strip("\n")
|
||||
items = line.split("\t")
|
||||
if len(items) != (dense_dim + slot_dim + 1):
|
||||
errorline_list.append(count)
|
||||
print("Found line length: {}, suppose to be {}, the line is {}".format(len(items),
|
||||
dense_dim + slot_dim + 1, line))
|
||||
continue
|
||||
if count % 1000000 == 0:
|
||||
print("Have handled {}w lines.".format(count // 10000))
|
||||
values = items[1: dense_dim + 1]
|
||||
cats = items[dense_dim + 1:]
|
||||
|
||||
assert len(values) == dense_dim, "values.size: {}".format(len(values))
|
||||
assert len(cats) == slot_dim, "cats.size: {}".format(len(cats))
|
||||
recommendation_dataset_stats_dict.stats_vals(values)
|
||||
recommendation_dataset_stats_dict.stats_cats(cats)
|
||||
recommendation_dataset_stats_dict.save_dict(dict_output_path)
|
||||
|
||||
|
||||
def random_split_trans2mindrecord(input_file_path, output_file_path, recommendation_dataset_stats_dict,
|
||||
part_rows=2000000, line_per_sample=1000, train_line_count=None,
|
||||
test_size=0.1, seed=2020, dense_dim=13, slot_dim=26):
|
||||
"""Random split data and save mindrecord"""
|
||||
if train_line_count is None:
|
||||
raise ValueError("Please provide training file line count")
|
||||
test_size = int(train_line_count * test_size)
|
||||
all_indices = [i for i in range(train_line_count)]
|
||||
np.random.seed(seed)
|
||||
np.random.shuffle(all_indices)
|
||||
print("all_indices.size:{}".format(len(all_indices)))
|
||||
test_indices_set = set(all_indices[:test_size])
|
||||
print("test_indices_set.size:{}".format(len(test_indices_set)))
|
||||
print("-----------------------" * 10 + "\n" * 2)
|
||||
|
||||
train_data_list = []
|
||||
test_data_list = []
|
||||
ids_list = []
|
||||
wts_list = []
|
||||
label_list = []
|
||||
|
||||
writer_train = FileWriter(os.path.join(output_file_path, "train_input_part.mindrecord"), 21)
|
||||
writer_test = FileWriter(os.path.join(output_file_path, "test_input_part.mindrecord"), 3)
|
||||
|
||||
schema = {"label": {"type": "float32", "shape": [-1]}, "feat_vals": {"type": "float32", "shape": [-1]},
|
||||
"feat_ids": {"type": "int32", "shape": [-1]}}
|
||||
writer_train.add_schema(schema, "CRITEO_TRAIN")
|
||||
writer_test.add_schema(schema, "CRITEO_TEST")
|
||||
|
||||
with open(input_file_path, encoding="utf-8") as file_in:
|
||||
items_error_size_lineCount = []
|
||||
count = 0
|
||||
train_part_number = 0
|
||||
test_part_number = 0
|
||||
for i, line in enumerate(file_in):
|
||||
count += 1
|
||||
if count % 1000000 == 0:
|
||||
print("Have handle {}w lines.".format(count // 10000))
|
||||
line = line.strip("\n")
|
||||
items = line.split("\t")
|
||||
if len(items) != (1 + dense_dim + slot_dim):
|
||||
items_error_size_lineCount.append(i)
|
||||
continue
|
||||
label = float(items[0])
|
||||
values = items[1:1 + dense_dim]
|
||||
cats = items[1 + dense_dim:]
|
||||
|
||||
assert len(values) == dense_dim, "values.size: {}".format(len(values))
|
||||
assert len(cats) == slot_dim, "cats.size: {}".format(len(cats))
|
||||
|
||||
ids, wts = recommendation_dataset_stats_dict.map_cat2id(values, cats)
|
||||
|
||||
ids_list.extend(ids)
|
||||
wts_list.extend(wts)
|
||||
label_list.append(label)
|
||||
|
||||
if count % line_per_sample == 0:
|
||||
if i not in test_indices_set:
|
||||
train_data_list.append({"feat_ids": np.array(ids_list, dtype=np.int32),
|
||||
"feat_vals": np.array(wts_list, dtype=np.float32),
|
||||
"label": np.array(label_list, dtype=np.float32)
|
||||
})
|
||||
else:
|
||||
test_data_list.append({"feat_ids": np.array(ids_list, dtype=np.int32),
|
||||
"feat_vals": np.array(wts_list, dtype=np.float32),
|
||||
"label": np.array(label_list, dtype=np.float32)
|
||||
})
|
||||
if train_data_list and len(train_data_list) % part_rows == 0:
|
||||
writer_train.write_raw_data(train_data_list)
|
||||
train_data_list.clear()
|
||||
train_part_number += 1
|
||||
|
||||
if test_data_list and len(test_data_list) % part_rows == 0:
|
||||
writer_test.write_raw_data(test_data_list)
|
||||
test_data_list.clear()
|
||||
test_part_number += 1
|
||||
|
||||
ids_list.clear()
|
||||
wts_list.clear()
|
||||
label_list.clear()
|
||||
|
||||
if train_data_list:
|
||||
writer_train.write_raw_data(train_data_list)
|
||||
if test_data_list:
|
||||
writer_test.write_raw_data(test_data_list)
|
||||
writer_train.commit()
|
||||
writer_test.commit()
|
||||
|
||||
print("-------------" * 10)
|
||||
print("items_error_size_lineCount.size(): {}.".format(len(items_error_size_lineCount)))
|
||||
print("-------------" * 10)
|
||||
np.save("items_error_size_lineCount.npy", items_error_size_lineCount)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Recommendation dataset")
|
||||
parser.add_argument("--data_path", type=str, default="./data/", help='The path of the data file')
|
||||
parser.add_argument("--dense_dim", type=int, default=13, help='The number of your continues fields')
|
||||
parser.add_argument("--slot_dim", type=int, default=26,
|
||||
help='The number of your sparse fields, it can also be called catelogy features.')
|
||||
parser.add_argument("--threshold", type=int, default=100,
|
||||
help='Word frequency below this will be regarded as OOV. It aims to reduce the vocab size')
|
||||
parser.add_argument("--train_line_count", type=int, help='The number of examples in your dataset')
|
||||
parser.add_argument("--skip_id_convert", type=int, default=0, choices=[0, 1],
|
||||
help='Skip the id convert, regarding the original id as the final id.')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
data_path = args.data_path
|
||||
|
||||
target_field_size = args.dense_dim + args.slot_dim
|
||||
stats = StatsDict(field_size=target_field_size, dense_dim=args.dense_dim, slot_dim=args.slot_dim,
|
||||
skip_id_convert=args.skip_id_convert)
|
||||
data_file_path = data_path + "origin_data/train.txt"
|
||||
stats_output_path = data_path + "stats_dict/"
|
||||
mkdir_path(stats_output_path)
|
||||
statsdata(data_file_path, stats_output_path, stats, dense_dim=args.dense_dim, slot_dim=args.slot_dim)
|
||||
|
||||
stats.load_dict(dict_path=stats_output_path, prefix="")
|
||||
stats.get_cat2id(threshold=args.threshold)
|
||||
|
||||
in_file_path = data_path + "origin_data/train.txt"
|
||||
output_path = data_path + "mindrecord/"
|
||||
mkdir_path(output_path)
|
||||
random_split_trans2mindrecord(in_file_path, output_path, stats, part_rows=2000000,
|
||||
train_line_count=args.train_line_count, line_per_sample=1000,
|
||||
test_size=0.1, seed=2020, dense_dim=args.dense_dim, slot_dim=args.slot_dim)
|
|
@ -0,0 +1,97 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train dcn"""
|
||||
|
||||
|
||||
from mindspore import Model, context
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
from mindspore.communication.management import init
|
||||
from src.deep_and_cross import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, DeepCrossModel
|
||||
from src.callbacks import LossCallBack
|
||||
from src.datasets import create_dataset, DataType
|
||||
from src.config import DeepCrossConfig
|
||||
|
||||
|
||||
def get_DCN_net(configure):
|
||||
"""
|
||||
Get network of deep&cross model.
|
||||
"""
|
||||
DCN_net = DeepCrossModel(configure)
|
||||
loss_net = NetWithLossClass(DCN_net)
|
||||
train_net = TrainStepWrap(loss_net)
|
||||
eval_net = PredictWithSigmoid(DCN_net)
|
||||
return train_net, eval_net
|
||||
|
||||
|
||||
class ModelBuilder():
|
||||
"""
|
||||
Build the model.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_hook(self):
|
||||
pass
|
||||
|
||||
def get_net(self, configure):
|
||||
return get_DCN_net(configure)
|
||||
|
||||
|
||||
def test_train(configure):
|
||||
"""
|
||||
test_train
|
||||
"""
|
||||
if configure.is_distributed:
|
||||
if configure.device_target == "Ascend":
|
||||
context.set_context(device_id=configure.device_id, device_target=configure.device_target)
|
||||
init()
|
||||
elif configure.device_target == "GPU":
|
||||
context.set_context(device_target=configure.device_target)
|
||||
init("nccl")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=configure.device_target, device_id=configure.device_id)
|
||||
|
||||
data_path = configure.data_path
|
||||
batch_size = configure.batch_size
|
||||
field_size = configure.field_size
|
||||
target_column = field_size + 1
|
||||
print("field_size: {} ".format(field_size))
|
||||
epochs = configure.epochs
|
||||
if configure.dataset_type == "tfrecord":
|
||||
dataset_type = DataType.TFRECORD
|
||||
elif configure.dataset_type == "mindrecord":
|
||||
dataset_type = DataType.MINDRECORD
|
||||
else:
|
||||
dataset_type = DataType.H5
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
|
||||
batch_size=batch_size, data_type=dataset_type, target_column=target_column)
|
||||
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
||||
net_builder = ModelBuilder()
|
||||
train_net, _ = net_builder.get_net(configure)
|
||||
train_net.set_train()
|
||||
model = Model(train_net)
|
||||
callback = LossCallBack(config=configure)
|
||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(),
|
||||
keep_checkpoint_max=5)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='deepcross_train', directory=configure.ckpt_path, config=ckptconfig)
|
||||
model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), callback, ckpoint_cb])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = DeepCrossConfig()
|
||||
config.argparse_init()
|
||||
test_train(config)
|
Loading…
Reference in New Issue