!10285 Noah research recommender system model AutoDis

From: @chenbo116
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-22 18:51:44 +08:00 committed by Gitee
commit c381b74d4a
12 changed files with 1405 additions and 0 deletions

View File

@ -73,6 +73,8 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
- [CenterNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/centernet/README.md)
- [Natural Language Processing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp)
- [DS-CNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp/dscnn/README.md)
- [Recommender Systems](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/recommend)
- [AutoDis](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/recommend/autodis/README.md)
- [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio)
- [FCN-4](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/fcn-4/README.md)
- [High Performance Computing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc)

View File

@ -0,0 +1,230 @@
# Contents
- [AutoDis Description](#AutoDis-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [AutoDis Description](#contents)
The common methods for numerical feature embedding are Normalization and Discretization. The former shares a single embedding for intra-field features and the latter transforms the features into categorical form through various discretization approaches. However, the first approach surfers from low capacity and the second one limits performance as well because the discretization rule cannot be optimized with the ultimate goal of CTR model.
To fill the gap of representing numerical features, in this paper, we propose AutoDis, a framework that discretizes features in numerical fields automatically and is optimized with CTR models in an end-to-end manner. Specifically, we introduce a set of meta-embeddings for each numerical field to model the relationship among the intra-field features and propose an automatic differentiable discretization and aggregation approach to capture the correlations between the numerical features and meta-embeddings. AutoDis is a valid framework to work with various popular deep CTR models and is able to improve the recommendation performance significantly.
[Paper](https://arxiv.org/abs/2012.08986): Huifeng Guo*, Bo Chen*, Ruiming Tang, Zhenguo Li, Xiuqiang He. AutoDis: Automatic Discretization for Embedding Numerical Features in CTR Prediction
# [Model Architecture](#contents)
AutoDis leverages a set of meta-embeddings for each numerical field, which are shared among all the intra-field feature values. Meta-embeddings learn the relationship across different feature values in this field with a manageable number of embedding parameters. Utilizing meta-embedding is able to avoid explosive embedding parameters introduced by assigning each numerical feature with an independent embedding simply. Besides, the embedding of a numerical feature is designed as a differentiable aggregation over the shared meta-embeddings, so that the discretization of numerical features can be optimized with the ultimate goal of deep CTR models in an end-to-end manner.
# [Dataset](#contents)
- [1] A dataset [Criteo](https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz) used in Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li, Xiuqiang He. DeepFM: A Factorization-Machine based Neural Network for CTR Prediction[J]. 2017.
# [Environment Requirements](#contents)
- HardwareAscend/GPU
- Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
- runing on Ascend
```python
# run training example
python train.py \
--dataset_path='dataset/train' \
--ckpt_path='./checkpoint' \
--eval_file_name='auc.log' \
--loss_file_name='loss.log' \
--device_target='Ascend' \
--do_eval=True > ms_log/output.log 2>&1 &
# run evaluation example
python eval.py \
--dataset_path='dataset/test' \
--checkpoint_path='./checkpoint/autodis.ckpt' \
--device_target='Ascend' > ms_log/eval_output.log 2>&1 &
OR
sh scripts/run_eval.sh 0 Ascend /dataset_path /checkpoint_path/autodis.ckpt
```
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
Please follow the instructions in the link below:
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools>.
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```bash
.
└─autodis
├─README.md
├─mindspore_hub_conf.md # config for mindspore hub
├─scripts
├─run_standalone_train.sh # launch standalone training(1p) in Ascend or GPU
└─run_eval.sh # launch evaluating in Ascend or GPU
├─src
├─__init__.py # python init file
├─config.py # parameter configuration
├─callback.py # define callback function
├─autodis.py # AutoDis network
├─dataset.py # create dataset for AutoDis
├─eval.py # eval net
└─train.py # train net
```
## [Script Parameters](#contents)
Parameters for both training and evaluation can be set in config.py
- train parameters
```python
optional arguments:
-h, --help show this help message and exit
--dataset_path DATASET_PATH
Dataset path
--ckpt_path CKPT_PATH
Checkpoint path
--eval_file_name EVAL_FILE_NAME
Auc log file path. Default: "./auc.log"
--loss_file_name LOSS_FILE_NAME
Loss log file path. Default: "./loss.log"
--do_eval DO_EVAL Do evaluation or not. Default: True
--device_target DEVICE_TARGET
Ascend or GPU. Default: Ascend
```
- eval parameters
```bash
optional arguments:
-h, --help show this help message and exit
--checkpoint_path CHECKPOINT_PATH
Checkpoint file path
--dataset_path DATASET_PATH
Dataset path
--device_target DEVICE_TARGET
Ascend or GPU. Default: Ascend
```
## [Training Process](#contents)
### Training
- running on Ascend
```python
python train.py \
--dataset_path='dataset/train' \
--ckpt_path='./checkpoint' \
--eval_file_name='auc.log' \
--loss_file_name='loss.log' \
--device_target='Ascend' \
--do_eval=True > ms_log/output.log 2>&1 &
```
The python command above will run in the background, you can view the results through the file `ms_log/output.log`.
After training, you'll get some checkpoint files under `./checkpoint` folder by default. The loss value are saved in loss.log file.
```txt
2020-12-10 14:58:04 epoch: 1 step: 41257, loss is 0.44559600949287415
2020-12-10 15:06:59 epoch: 2 step: 41257, loss is 0.4370603561401367
...
```
The model checkpoint will be saved in the current directory.
## [Evaluation Process](#contents)
### Evaluation
- evaluation on dataset when running on Ascend
Before running the command below, please check the checkpoint path used for evaluation.
```python
python eval.py \
--dataset_path='dataset/test' \
--checkpoint_path='./checkpoint/autodis.ckpt' \
--device_target='Ascend' > ms_log/eval_output.log 2>&1 &
OR
sh scripts/run_eval.sh 0 Ascend /dataset_path /checkpoint_path/autodis.ckpt
```
The above python command will run in the background. You can view the results through the file "eval_output.log". The accuracy is saved in auc.log file.
```txt
{'result': {'AUC': 0.8109881454077731, 'eval_time': 27.72783327102661s}}
```
# [Model Description](#contents)
## [Performance](#contents)
### Evaluation Performance
| Parameters | Ascend |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | AutoDis |
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G |
| uploaded Date | 12/12/2020 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | [1] |
| Training Parameters | epoch=15, batch_size=1000, lr=1e-5 |
| Optimizer | Adam |
| Loss Function | Sigmoid Cross Entropy With Logits |
| outputs | Accuracy |
| Loss | 0.42 |
| Speed | 1pc: 8.16 ms/step; |
| Total time | 1pc: 90 mins; |
| Parameters (M) | 16.5 |
| Checkpoint for Fine tuning | 191M (.ckpt file) |
| Scripts | [AutoDis script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/recommend/autodis) |
### Inference Performance
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | AutoDis |
| Resource | Ascend 910 |
| Uploaded Date | 12/12/2020 (month/day/year) |
| MindSpore Version | 0.3.0-alpha |
| Dataset | [1] |
| batch_size | 1000 |
| outputs | accuracy |
| AUC | 1pc: 0.8112; |
| Model for inference | 191M (.ckpt file) |
# [Description of Random Situation](#contents)
We set the random seed before training in train.py.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,68 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval_criteo."""
import os
import sys
import time
import argparse
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.autodis import ModelBuilder, AUCMetric
from src.config import DataConfig, ModelConfig, TrainConfig
from src.dataset import create_dataset, DataType
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
parser = argparse.ArgumentParser(description='CTR Prediction')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default="Ascend", choices=["Ascend"],
help='Default: Ascend')
args_opt, _ = parser.parse_known_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
def add_write(file_path, print_str):
with open(file_path, 'a+', encoding='utf-8') as file_out:
file_out.write(print_str + '\n')
if __name__ == '__main__':
data_config = DataConfig()
model_config = ModelConfig()
train_config = TrainConfig()
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
epochs=1, batch_size=train_config.batch_size,
data_type=DataType(data_config.data_format))
model_builder = ModelBuilder(ModelConfig, TrainConfig)
train_net, eval_net = model_builder.get_train_eval_net()
train_net.set_train()
eval_net.set_train(False)
auc_metric = AUCMetric()
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(eval_net, param_dict)
start = time.time()
res = model.eval(ds_eval)
eval_time = time.time() - start
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
out_str = f'{time_str} AUC: {list(res.values())[0]}, eval time: {eval_time}s.'
print(out_str)
add_write('./auc.log', str(out_str))

View File

@ -0,0 +1,26 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.autodis import ModelBuilder
from src.config import ModelConfig, TrainConfig
def create_network(name, *args, **kwargs):
if name == 'autodis':
model_config = ModelConfig()
train_config = TrainConfig()
model_builder = ModelBuilder(model_config, train_config)
_, autodis_eval_net = model_builder.get_train_eval_net()
return autodis_eval_net
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,34 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "Please run the script as: "
echo "sh scripts/run_eval.sh DEVICE_ID DEVICE_TARGET DATASET_PATH CHECKPOINT_PATH"
echo "for example: sh scripts/run_eval.sh 0 GPU /dataset_path /checkpoint_path"
echo "After running the script, the network runs in the background, The log will be generated in ms_log/eval_output.log"
export DEVICE_ID=$1
DEVICE_TARGET=$2
DATA_URL=$3
CHECKPOINT_PATH=$4
mkdir -p ms_log
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python -u eval.py \
--dataset_path=$DATA_URL \
--checkpoint_path=$CHECKPOINT_PATH \
--device_target=$DEVICE_TARGET > ms_log/eval_output.log 2>&1 &

View File

@ -0,0 +1,46 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "Please run the script as: "
echo "sh scripts/run_standalone_train.sh DEVICE_ID/CUDA_VISIBLE_DEVICES DEVICE_TARGET DATASET_PATH"
echo "for example: sh scripts/run_standalone_train.sh 0 GPU /dataset_path"
echo "After running the script, the network runs in the background, The log will be generated in ms_log/output.log"
DEVICE_TARGET=$2
if [ "$DEVICE_TARGET" = "GPU" ]
then
export CUDA_VISIBLE_DEVICES=$1
fi
if [ "$DEVICE_TARGET" = "Ascend" ]
then
export DEVICE_ID=$1
fi
DATA_URL=$3
mkdir -p ms_log
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python -u train.py \
--dataset_path=$DATA_URL \
--ckpt_path="checkpoint" \
--eval_file_name='auc.log' \
--loss_file_name='loss.log' \
--device_target=$DEVICE_TARGET \
--do_eval=True > ms_log/output.log 2>&1 &

View File

@ -0,0 +1,410 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_training """
import os
import numpy as np
from sklearn.metrics import roc_auc_score
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.nn import Dropout
from mindspore.nn.optim import Adam
from mindspore.nn.metrics import Metric
from mindspore import nn, ParameterTuple, Parameter
from mindspore.common.initializer import Uniform, initializer, Normal
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from .callback import EvalCallBack, LossCallBack
np_type = np.float32
ms_type = mstype.float32
ms_type_16 = mstype.float16
class AUCMetric(Metric):
"""AUC metric for AutoDis model."""
def __init__(self):
super(AUCMetric, self).__init__()
self.pred_probs = []
self.true_labels = []
def clear(self):
"""Clear the internal evaluation result."""
self.pred_probs = []
self.true_labels = []
def update(self, *inputs):
batch_predict = inputs[1].asnumpy()
batch_label = inputs[2].asnumpy()
self.pred_probs.extend(batch_predict.flatten().tolist())
self.true_labels.extend(batch_label.flatten().tolist())
def eval(self):
if len(self.true_labels) != len(self.pred_probs):
raise RuntimeError('true_labels.size() is not equal to pred_probs.size()')
auc = roc_auc_score(self.true_labels, self.pred_probs)
return auc
def init_method(method, shape, name, max_val=0.01):
"""
The method of init parameters.
Args:
method (str): The method uses to initialize parameter.
shape (list): The shape of parameter.
name (str): The name of parameter.
max_val (float): Max value in parameter when uses 'random' or 'uniform' to initialize parameter.
Returns:
Parameter.
"""
if method in ['random', 'uniform']:
params = Parameter(initializer(Uniform(max_val), shape, ms_type), name=name)
elif method == "one":
params = Parameter(initializer("ones", shape, ms_type), name=name)
elif method == 'zero':
params = Parameter(initializer("zeros", shape, ms_type), name=name)
elif method == "normal":
params = Parameter(initializer(Normal(max_val), shape, ms_type), name=name)
return params
def init_var_dict(init_args, values):
"""
Init parameter.
Args:
init_args (list): Define max and min value of parameters.
values (list): Define name, shape and init method of parameters.
Returns:
dict, a dict ot Parameter.
"""
var_map = {}
_, _max_val = init_args
for key, shape, init_flag, data_type in values:
if key not in var_map.keys():
if init_flag in ['random', 'uniform']:
var_map[key] = Parameter(initializer(Uniform(_max_val), shape, data_type), name=key)
elif init_flag == "one":
var_map[key] = Parameter(initializer("ones", shape, data_type), name=key)
elif init_flag == "zero":
var_map[key] = Parameter(initializer("zeros", shape, data_type), name=key)
elif init_flag == 'normal':
var_map[key] = Parameter(initializer(Normal(_max_val), shape, data_type), name=key)
return var_map
class DenseLayer(nn.Cell):
"""
Dense Layer for Deep Layer of AutoDis Model;
Containing: activation, matmul, bias_add;
Args:
input_dim (int): the shape of weight at 0-aixs;
output_dim (int): the shape of weight at 1-aixs, and shape of bias
weight_bias_init (list): weight and bias init method, "random", "uniform", "one", "zero", "normal";
act_str (str): activation function method, "relu", "sigmoid", "tanh";
keep_prob (float): Dropout Layer keep_prob_rate;
scale_coef (float): input scale coefficient;
"""
def __init__(self, input_dim, output_dim, weight_bias_init, act_str, keep_prob=0.9, scale_coef=1.0):
super(DenseLayer, self).__init__()
weight_init, bias_init = weight_bias_init
self.weight = init_method(weight_init, [input_dim, output_dim], name="weight")
self.bias = init_method(bias_init, [output_dim], name="bias")
self.act_func = self._init_activation(act_str)
self.matmul = P.MatMul(transpose_b=False)
self.bias_add = P.BiasAdd()
self.cast = P.Cast()
self.dropout = Dropout(keep_prob=keep_prob)
self.mul = P.Mul()
self.realDiv = P.RealDiv()
self.scale_coef = scale_coef
def _init_activation(self, act_str):
act_str = act_str.lower()
if act_str == "relu":
act_func = P.ReLU()
elif act_str == "sigmoid":
act_func = P.Sigmoid()
elif act_str == "tanh":
act_func = P.Tanh()
return act_func
def construct(self, x):
"""Dense Layer for Deep Layer of AutoDis Model."""
x = self.act_func(x)
if self.training:
x = self.dropout(x)
x = self.mul(x, self.scale_coef)
x = self.cast(x, mstype.float16)
weight = self.cast(self.weight, mstype.float16)
wx = self.matmul(x, weight)
wx = self.cast(wx, mstype.float32)
wx = self.realDiv(wx, self.scale_coef)
output = self.bias_add(wx, self.bias)
return output
class AutoDisModel(nn.Cell):
"""
From paper: "AutoDis: Automatic Discretization for Embedding Numerical Features in CTR Prediction"
Args:
batch_size (int): smaple_number of per step in training; (int, batch_size=128)
filed_size (int): input filed number, or called id_feature number; (int, filed_size=39)
vocab_size (int): id_feature vocab size, id dict size; (int, vocab_size=200000)
emb_dim (int): id embedding vector dim, id mapped to embedding vector; (int, emb_dim=100)
deep_layer_args (list): Deep Layer args, layer_dim_list, layer_activator;
(int, deep_layer_args=[[100, 100, 100], "relu"])
init_args (list): init args for Parameter init; (list, init_args=[min, max, seeds])
weight_bias_init (list): weight, bias init method for deep layers;
(list[str], weight_bias_init=['random', 'zero'])
keep_prob (float): if dropout_flag is True, keep_prob rate to keep connect; (float, keep_prob=0.8)
"""
def __init__(self, config):
super(AutoDisModel, self).__init__()
self.batch_size = config.batch_size
self.field_size = config.data_field_size
self.vocab_size = config.data_vocab_size
self.emb_dim = config.data_emb_dim
self.deep_layer_dims_list, self.deep_layer_act = config.deep_layer_args
self.init_args = config.init_args
self.weight_bias_init = config.weight_bias_init
self.keep_prob = config.keep_prob
self.hash_size = config.hash_size
self.split_index = config.split_index
self.temperature = config.temperature
init_acts = [('W_l2', [self.vocab_size, 1], 'normal', ms_type),
('V_l2', [self.vocab_size, self.emb_dim], 'normal', ms_type),
('b', [1], 'normal', ms_type),
('logits', [self.split_index, self.hash_size], 'random', ms_type),
('autodis_embedding', [self.split_index, self.hash_size, self.emb_dim], 'random', ms_type_16)]
var_map = init_var_dict(self.init_args, init_acts)
self.fm_w = var_map["W_l2"]
self.fm_b = var_map["b"]
self.embedding_table = var_map["V_l2"]
self.logits = var_map["logits"]
self.autodis_embedding = var_map["autodis_embedding"]
# Deep Layers
self.deep_input_dims = self.field_size * self.emb_dim + 1
self.all_dim_list = [self.deep_input_dims] + self.deep_layer_dims_list + [1]
self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1],
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2],
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3],
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4],
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
# FM, linear Layers
self.Gatherv2 = P.GatherV2()
self.Mul = P.Mul()
self.ReduceSum = P.ReduceSum(keep_dims=False)
self.Reshape = P.Reshape()
self.Square = P.Square()
self.Shape = P.Shape()
self.Tile = P.Tile()
self.Concat = P.Concat(axis=1)
self.Cast = P.Cast()
# AutoDis
self.Slice = P.Slice()
self.BatchMatMul = P.BatchMatMul()
self.ExpandDims = P.ExpandDims()
self.Transpose = P.Transpose()
self.SoftMax = P.Softmax()
def construct(self, id_hldr, wt_hldr):
"""
Args:
id_hldr: batch ids; [bs, field_size]
wt_hldr: batch weights; [bs, field_size]
"""
con_wt_hldr = self.Slice(wt_hldr, (0, 0), (self.batch_size, self.split_index))
dis_id_hldr = self.Slice(id_hldr, (0, self.split_index), (self.batch_size, self.field_size - self.split_index))
dis_wt_hldr = self.Slice(wt_hldr, (0, self.split_index), (self.batch_size, self.field_size - self.split_index))
mask = self.Reshape(wt_hldr, (self.batch_size, self.field_size, 1))
# Linear layer
fm_id_weight = self.Gatherv2(self.fm_w, id_hldr, 0)
wx = self.Mul(fm_id_weight, mask)
linear_out = self.ReduceSum(wx, 1)
# FM layer
# AutoDis embeddding
con_wt_hldr = self.ExpandDims(con_wt_hldr, 2)
h_logits = self.ExpandDims(self.logits, 0)
con_wt_hldr = self.Transpose(con_wt_hldr, (1, 0, 2))
h_logits = self.Transpose(h_logits, (1, 0, 2))
logits_score = self.Mul(con_wt_hldr, h_logits)
logits_norm_score = self.SoftMax(logits_score / self.temperature)
logits_norm_score = self.Cast(logits_norm_score, mstype.float16)
autodis_emb = self.BatchMatMul(logits_norm_score, self.autodis_embedding)
autodis_emb = self.Transpose(autodis_emb, (1, 0, 2))
autodis_emb = self.Cast(autodis_emb, mstype.float32)
dis_fm_id_embs = self.Gatherv2(self.embedding_table, dis_id_hldr, 0)
dis_mask = self.Reshape(dis_wt_hldr, (self.batch_size, self.field_size - self.split_index, 1))
dis_emb = self.Mul(dis_fm_id_embs, dis_mask)
vx = self.Concat((autodis_emb, dis_emb))
v1 = self.ReduceSum(vx, 1)
v1 = self.Square(v1)
v2 = self.Square(vx)
v2 = self.ReduceSum(v2, 1)
fm_out = 0.5 * self.ReduceSum(v1 - v2, 1)
fm_out = self.Reshape(fm_out, (-1, 1))
# Deep layer
b = self.Reshape(self.fm_b, (1, 1))
b = self.Tile(b, (self.batch_size, 1))
deep_in = self.Reshape(vx, (-1, self.field_size * self.emb_dim))
deep_in = self.Concat((deep_in, b))
deep_in = self.dense_layer_1(deep_in)
deep_in = self.dense_layer_2(deep_in)
deep_in = self.dense_layer_3(deep_in)
deep_out = self.dense_layer_4(deep_in)
out = linear_out + fm_out + deep_out
return out, fm_id_weight, dis_fm_id_embs, self.logits, self.autodis_embedding
class NetWithLossClass(nn.Cell):
"""
NetWithLossClass definition.
"""
def __init__(self, network, l2_coef=1e-6):
super(NetWithLossClass, self).__init__(auto_prefix=False)
self.loss = P.SigmoidCrossEntropyWithLogits()
self.network = network
self.l2_coef = l2_coef
self.Square = P.Square()
self.ReduceMean_false = P.ReduceMean(keep_dims=False)
self.ReduceSum_false = P.ReduceSum(keep_dims=False)
def construct(self, batch_ids, batch_wts, label):
"""
Construct NetWithLossClass
"""
predict, fm_id_weight, fm_id_embs, logits, autodis_embedding = self.network(batch_ids, batch_wts)
log_loss = self.loss(predict, label)
mean_log_loss = self.ReduceMean_false(log_loss)
l2_loss_w = self.ReduceSum_false(self.Square(fm_id_weight))
l2_loss_v = self.ReduceSum_false(self.Square(fm_id_embs))
l2_loss_logits = self.ReduceSum_false(self.Square(logits))
l2_loss_autodis_embedding = self.ReduceSum_false(self.Square(autodis_embedding))
l2_loss_all = self.l2_coef * (l2_loss_v + l2_loss_w + l2_loss_logits + l2_loss_autodis_embedding) * 0.5
loss = mean_log_loss + l2_loss_all
return loss
class TrainStepWrap(nn.Cell):
"""
TrainStepWrap definition
"""
def __init__(self, network, lr=5e-8, eps=1e-8, loss_scale=1000.0):
super(TrainStepWrap, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.network.set_train()
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = Adam(self.weights, learning_rate=lr, eps=eps, loss_scale=loss_scale)
self.hyper_map = C.HyperMap()
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = loss_scale
def construct(self, batch_ids, batch_wts, label):
weights = self.weights
loss = self.network(batch_ids, batch_wts, label)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) #
grads = self.grad(self.network, weights)(batch_ids, batch_wts, label, sens)
return F.depend(loss, self.optimizer(grads))
class PredictWithSigmoid(nn.Cell):
"""
Eval model with sigmoid.
"""
def __init__(self, network):
super(PredictWithSigmoid, self).__init__(auto_prefix=False)
self.network = network
self.sigmoid = P.Sigmoid()
def construct(self, batch_ids, batch_wts, labels):
logits, _, _, = self.network(batch_ids, batch_wts)
pred_probs = self.sigmoid(logits)
return logits, pred_probs, labels
class ModelBuilder:
"""
Model builder for AutoDis.
Args:
model_config (ModelConfig): Model configuration.
train_config (TrainConfig): Train configuration.
"""
def __init__(self, model_config, train_config):
self.model_config = model_config
self.train_config = train_config
def get_callback_list(self, model=None, eval_dataset=None):
"""
Get callbacks which contains checkpoint callback, eval callback and loss callback.
Args:
model (Cell): The network is added callback (default=None).
eval_dataset (Dataset): Dataset for eval (default=None).
"""
callback_list = []
if self.train_config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=self.train_config.save_checkpoint_steps,
keep_checkpoint_max=self.train_config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix=self.train_config.ckpt_file_name_prefix,
directory=self.train_config.output_path,
config=config_ck)
callback_list.append(ckpt_cb)
if self.train_config.eval_callback:
if model is None:
raise RuntimeError("train_config.eval_callback is {}; get_callback_list() args model is {}".format(
self.train_config.eval_callback, model))
if eval_dataset is None:
raise RuntimeError("train_config.eval_callback is {}; get_callback_list() "
"args eval_dataset is {}".format(self.train_config.eval_callback, eval_dataset))
auc_metric = AUCMetric()
eval_callback = EvalCallBack(model, eval_dataset, auc_metric,
eval_file_path=os.path.join(self.train_config.output_path,
self.train_config.eval_file_name))
callback_list.append(eval_callback)
if self.train_config.loss_callback:
loss_callback = LossCallBack(loss_file_path=os.path.join(self.train_config.output_path,
self.train_config.loss_file_name))
callback_list.append(loss_callback)
if callback_list:
return callback_list
return None
def get_train_eval_net(self):
autodis_net = AutoDisModel(self.model_config)
loss_net = NetWithLossClass(autodis_net, l2_coef=self.train_config.l2_coef)
train_net = TrainStepWrap(loss_net, lr=self.train_config.learning_rate,
eps=self.train_config.epsilon,
loss_scale=self.train_config.loss_scale)
eval_net = PredictWithSigmoid(autodis_net)
return train_net, eval_net

View File

@ -0,0 +1,108 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Defined callback for DeepFM.
"""
import time
from mindspore.train.callback import Callback
def add_write(file_path, out_str):
with open(file_path, 'a+', encoding='utf-8') as file_out:
file_out.write(out_str + '\n')
class EvalCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note
If per_print_times is 0 do not print loss.
"""
def __init__(self, model, eval_dataset, auc_metric, eval_file_path):
super(EvalCallBack, self).__init__()
self.model = model
self.eval_dataset = eval_dataset
self.aucMetric = auc_metric
self.aucMetric.clear()
self.eval_file_path = eval_file_path
def epoch_end(self, run_context):
start_time = time.time()
out = self.model.eval(self.eval_dataset)
eval_time = int(time.time() - start_time)
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
out_str = "{} EvalCallBack metric{}; eval_time{}s".format(
time_str, out.values(), eval_time)
print(out_str)
add_write(self.eval_file_path, out_str)
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note
If per_print_times is 0 do not print loss.
Args
loss_file_path (str) The file absolute path, to save as loss_file;
per_print_times (int) Print loss every times. Default 1.
"""
def __init__(self, loss_file_path, per_print_times=1):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self.loss_file_path = loss_file_path
self._per_print_times = per_print_times
def step_end(self, run_context):
"""Monitor the loss in training."""
cb_params = run_context.original_args()
loss = cb_params.net_outputs.asnumpy()
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
cur_num = cb_params.cur_step_num
if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
with open(self.loss_file_path, "a+") as loss_file:
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
loss_file.write("{} epoch: {} step: {}, loss is {}\n".format(
time_str, cb_params.cur_epoch_num, cur_step_in_epoch, loss))
print("epoch: {} step: {}, loss is {}\n".format(
cb_params.cur_epoch_num, cur_step_in_epoch, loss))
class TimeMonitor(Callback):
"""
Time monitor for calculating cost of each epoch.
Args
data_size (int) step size of an epoch.
"""
def __init__(self, data_size):
super(TimeMonitor, self).__init__()
self.data_size = data_size
def epoch_begin(self, run_context):
self.epoch_time = time.time()
def epoch_end(self, run_context):
epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / self.data_size
print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
def step_begin(self, run_context):
self.step_time = time.time()
def step_end(self, run_context):
step_mseconds = (time.time() - self.step_time) * 1000
print(f"step time {step_mseconds}", flush=True)

View File

@ -0,0 +1,64 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
class DataConfig:
"""
Define parameters of dataset.
"""
data_vocab_size = 184965
train_num_of_parts = 21
test_num_of_parts = 3
batch_size = 1000
data_field_size = 39
# dataset format, 1: mindrecord, 2: tfrecord, 3: h5
data_format = 2
class ModelConfig:
"""
Define parameters of model.
"""
batch_size = DataConfig.batch_size
data_field_size = DataConfig.data_field_size
data_vocab_size = DataConfig.data_vocab_size
data_emb_dim = 80
deep_layer_args = [[400, 400, 512], "relu"]
init_args = [-0.01, 0.01]
weight_bias_init = ['normal', 'normal']
keep_prob = 0.9
split_index = 13
hash_size = 20
temperature = 1e-5
class TrainConfig:
"""
Define parameters of training.
"""
batch_size = DataConfig.batch_size
l2_coef = 1e-6
learning_rate = 1e-5
epsilon = 1e-8
loss_scale = 1024.0
train_epochs = 15
save_checkpoint = True
ckpt_file_name_prefix = "autodis"
save_checkpoint_steps = 1
keep_checkpoint_max = 15
eval_callback = True
loss_callback = True

View File

@ -0,0 +1,298 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Create train or eval dataset.
"""
import os
import math
from enum import Enum
import numpy as np
import pandas as pd
import mindspore.dataset.engine as de
import mindspore.common.dtype as mstype
from .config import DataConfig
class DataType(Enum):
"""
Enumerate supported dataset format.
"""
MINDRECORD = 1
TFRECORD = 2
H5 = 3
class H5Dataset():
"""
Create dataset with H5 format.
Args:
data_path (str): Dataset directory.
train_mode (bool): Whether dataset is used for train or eval (default=True).
train_num_of_parts (int): The number of train data file (default=21).
test_num_of_parts (int): The number of test data file (default=3).
"""
max_length = 39
def __init__(self, data_path, train_mode=True,
train_num_of_parts=DataConfig.train_num_of_parts,
test_num_of_parts=DataConfig.test_num_of_parts):
self._hdf_data_dir = data_path
self._is_training = train_mode
if self._is_training:
self._file_prefix = 'train'
self._num_of_parts = train_num_of_parts
else:
self._file_prefix = 'test'
self._num_of_parts = test_num_of_parts
self.data_size = self._bin_count(self._hdf_data_dir, self._file_prefix, self._num_of_parts)
print("data_size: {}".format(self.data_size))
def _bin_count(self, hdf_data_dir, file_prefix, num_of_parts):
size = 0
for part in range(num_of_parts):
_y = pd.read_hdf(os.path.join(hdf_data_dir, f'{file_prefix}_output_part_{str(part)}.h5'))
size += _y.shape[0]
return size
def _iterate_hdf_files_(self, num_of_parts=None,
shuffle_block=False):
"""
iterate among hdf files(blocks). when the whole data set is finished, the iterator restarts
from the beginning, thus the data stream will never stop
:param train_mode: True or false,false is eval_mode,
this file iterator will go through the train set
:param num_of_parts: number of files
:param shuffle_block: shuffle block files at every round
:return: input_hdf_file_name, output_hdf_file_name, finish_flag
"""
parts = np.arange(num_of_parts)
while True:
if shuffle_block:
for _ in range(int(shuffle_block)):
np.random.shuffle(parts)
for i, p in enumerate(parts):
yield os.path.join(self._hdf_data_dir, f'{self._file_prefix}_input_part_{str(p)}.h5'), \
os.path.join(self._hdf_data_dir, f'{self._file_prefix}_output_part_{str(p)}.h5'), \
i + 1 == len(parts)
def _generator(self, X, y, batch_size, shuffle=True):
"""
should be accessed only in private
:param X:
:param y:
:param batch_size:
:param shuffle:
:return:
"""
number_of_batches = np.ceil(1. * X.shape[0] / batch_size)
counter = 0
finished = False
sample_index = np.arange(X.shape[0])
if shuffle:
for _ in range(int(shuffle)):
np.random.shuffle(sample_index)
assert X.shape[0] > 0
while True:
batch_index = sample_index[batch_size * counter: batch_size * (counter + 1)]
X_batch = X[batch_index]
y_batch = y[batch_index]
counter += 1
yield X_batch, y_batch, finished
if counter == number_of_batches:
counter = 0
finished = True
def batch_generator(self, batch_size=1000,
random_sample=False, shuffle_block=False):
"""
:param train_mode: True or false,false is eval_mode,
:param batch_size
:param num_of_parts: number of files
:param random_sample: if True, will shuffle
:param shuffle_block: shuffle file blocks at every round
:return:
"""
for hdf_in, hdf_out, _ in self._iterate_hdf_files_(self._num_of_parts,
shuffle_block):
start = stop = None
X_all = pd.read_hdf(hdf_in, start=start, stop=stop).values
y_all = pd.read_hdf(hdf_out, start=start, stop=stop).values
data_gen = self._generator(X_all, y_all, batch_size,
shuffle=random_sample)
finished = False
while not finished:
X, y, finished = data_gen.__next__()
X_id = X[:, 0:self.max_length]
X_va = X[:, self.max_length:]
yield np.array(X_id.astype(dtype=np.int32)), \
np.array(X_va.astype(dtype=np.float32)), \
np.array(y.astype(dtype=np.float32))
def _get_h5_dataset(directory, train_mode=True, epochs=1, batch_size=1000):
"""
Get dataset with h5 format.
Args:
directory (str): Dataset directory.
train_mode (bool): Whether dataset is use for train or eval (default=True).
epochs (int): Dataset epoch size (default=1).
batch_size (int): Dataset batch size (default=1000)
Returns:
Dataset.
"""
data_para = {'batch_size': batch_size}
if train_mode:
data_para['random_sample'] = True
data_para['shuffle_block'] = True
h5_dataset = H5Dataset(data_path=directory, train_mode=train_mode)
numbers_of_batch = math.ceil(h5_dataset.data_size / batch_size)
def _iter_h5_data():
train_eval_gen = h5_dataset.batch_generator(**data_para)
for _ in range(0, numbers_of_batch, 1):
yield train_eval_gen.__next__()
ds = de.GeneratorDataset(_iter_h5_data, ["ids", "weights", "labels"])
ds = ds.repeat(epochs)
return ds
def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=1000,
line_per_sample=1000, rank_size=None, rank_id=None):
"""
Get dataset with mindrecord format.
Args:
directory (str): Dataset directory.
train_mode (bool): Whether dataset is use for train or eval (default=True).
epochs (int): Dataset epoch size (default=1).
batch_size (int): Dataset batch size (default=1000).
line_per_sample (int): The number of sample per line (default=1000).
rank_size (int): The number of device, not necessary for single device (default=None).
rank_id (int): Id of device, not necessary for single device (default=None).
Returns:
Dataset.
"""
file_prefix_name = 'train_input_part.mindrecord' if train_mode else 'test_input_part.mindrecord'
file_suffix_name = '00' if train_mode else '0'
shuffle = train_mode
if rank_size is not None and rank_id is not None:
ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
columns_list=['feat_ids', 'feat_vals', 'label'],
num_shards=rank_size, shard_id=rank_id, shuffle=shuffle,
num_parallel_workers=8)
else:
ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
columns_list=['feat_ids', 'feat_vals', 'label'],
shuffle=shuffle, num_parallel_workers=8)
ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True)
ds = ds.map(operations=(lambda x, y, z: (np.array(x).flatten().reshape(batch_size, 39),
np.array(y).flatten().reshape(batch_size, 39),
np.array(z).flatten().reshape(batch_size, 1))),
input_columns=['feat_ids', 'feat_vals', 'label'],
column_order=['feat_ids', 'feat_vals', 'label'],
num_parallel_workers=8)
ds = ds.repeat(epochs)
return ds
def _get_tf_dataset(directory, train_mode=True, epochs=1, batch_size=1000,
line_per_sample=1000, rank_size=None, rank_id=None):
"""
Get dataset with tfrecord format.
Args:
directory (str): Dataset directory.
train_mode (bool): Whether dataset is use for train or eval (default=True).
epochs (int): Dataset epoch size (default=1).
batch_size (int): Dataset batch size (default=1000).
line_per_sample (int): The number of sample per line (default=1000).
rank_size (int): The number of device, not necessary for single device (default=None).
rank_id (int): Id of device, not necessary for single device (default=None).
Returns:
Dataset.
"""
dataset_files = []
file_prefixt_name = 'train' if train_mode else 'test'
shuffle = train_mode
for (dir_path, _, filenames) in os.walk(directory):
for filename in filenames:
if file_prefixt_name in filename and 'tfrecord' in filename:
dataset_files.append(os.path.join(dir_path, filename))
schema = de.Schema()
schema.add_column('feat_ids', de_type=mstype.int32)
schema.add_column('feat_vals', de_type=mstype.float32)
schema.add_column('label', de_type=mstype.float32)
if rank_size is not None and rank_id is not None:
ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle,
schema=schema, num_parallel_workers=8,
num_shards=rank_size, shard_id=rank_id,
shard_equal_rows=True)
else:
ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle,
schema=schema, num_parallel_workers=8)
ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True)
ds = ds.map(operations=(lambda x, y, z: (
np.array(x).flatten().reshape(batch_size, 39),
np.array(y).flatten().reshape(batch_size, 39),
np.array(z).flatten().reshape(batch_size, 1))),
input_columns=['feat_ids', 'feat_vals', 'label'],
column_order=['feat_ids', 'feat_vals', 'label'],
num_parallel_workers=8)
ds = ds.repeat(epochs)
return ds
def create_dataset(directory, train_mode=True, epochs=1, batch_size=1000,
data_type=DataType.TFRECORD, line_per_sample=1000,
rank_size=None, rank_id=None):
"""
Get dataset.
Args:
directory (str): Dataset directory.
train_mode (bool): Whether dataset is use for train or eval (default=True).
epochs (int): Dataset epoch size (default=1).
batch_size (int): Dataset batch size (default=1000).
data_type (DataType): The type of dataset which is one of H5, TFRECORE, MINDRECORD (default=TFRECORD).
line_per_sample (int): The number of sample per line (default=1000).
rank_size (int): The number of device, not necessary for single device (default=None).
rank_id (int): Id of device, not necessary for single device (default=None).
Returns:
Dataset.
"""
if data_type == DataType.MINDRECORD:
return _get_mindrecord_dataset(directory, train_mode, epochs,
batch_size, line_per_sample,
rank_size, rank_id)
if data_type == DataType.TFRECORD:
return _get_tf_dataset(directory, train_mode, epochs, batch_size,
line_per_sample, rank_size=rank_size, rank_id=rank_id)
if rank_size is not None and rank_size > 1:
raise ValueError('Please use mindrecord dataset.')
return _get_h5_dataset(directory, train_mode, epochs, batch_size)

View File

@ -0,0 +1,119 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_criteo."""
import os
import sys
import argparse
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank
from mindspore.train.model import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.common import set_seed
#from mindspore.profiler import Profiler
from src.autodis import ModelBuilder, AUCMetric
from src.config import DataConfig, ModelConfig, TrainConfig
from src.dataset import create_dataset, DataType
from src.callback import EvalCallBack, LossCallBack
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
parser = argparse.ArgumentParser(description='CTR Prediction')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--ckpt_path', type=str, default=None, help='Checkpoint path')
parser.add_argument('--eval_file_name', type=str, default="./auc.log",
help='Auc log file path. Default: "./auc.log"')
parser.add_argument('--loss_file_name', type=str, default="./loss.log",
help='Loss log file path. Default: "./loss.log"')
parser.add_argument('--do_eval', type=str, default='True', choices=["True", "False"],
help='Do evaluation or not, only support "True" or "False". Default: "True"')
parser.add_argument('--device_target', type=str, default="Ascend", choices=["Ascend"],
help='Default: Ascend')
args_opt, _ = parser.parse_known_args()
args_opt.do_eval = args_opt.do_eval == 'True'
rank_size = int(os.environ.get("RANK_SIZE", 1))
set_seed(1)
if __name__ == '__main__':
data_config = DataConfig()
model_config = ModelConfig()
train_config = TrainConfig()
if rank_size > 1:
if args_opt.device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
init()
rank_id = int(os.environ.get('RANK_ID'))
else:
print("Unsupported device_target ", args_opt.device_target)
exit()
else:
if args_opt.device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
else:
print("Unsupported device_target ", args_opt.device_target)
exit()
rank_size = None
rank_id = None
# Init Profiler
#profiler = Profiler(output_path='./data', is_detail=True, is_show_op_path=False, subgraph='all')
ds_train = create_dataset(args_opt.dataset_path,
train_mode=True,
epochs=1,
batch_size=train_config.batch_size,
data_type=DataType(data_config.data_format),
rank_size=rank_size,
rank_id=rank_id)
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
steps_size = ds_train.get_dataset_size()
model_builder = ModelBuilder(ModelConfig, TrainConfig)
train_net, eval_net = model_builder.get_train_eval_net()
auc_metric = AUCMetric()
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
time_callback = TimeMonitor(data_size=ds_train.get_dataset_size())
loss_callback = LossCallBack(loss_file_path=args_opt.loss_file_name)
callback_list = [time_callback, loss_callback]
if train_config.save_checkpoint:
if rank_size:
train_config.ckpt_file_name_prefix = train_config.ckpt_file_name_prefix + str(get_rank())
args_opt.ckpt_path = os.path.join(args_opt.ckpt_path, 'ckpt_' + str(get_rank()) + '/')
config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps,
keep_checkpoint_max=train_config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix,
directory=args_opt.ckpt_path,
config=config_ck)
callback_list.append(ckpt_cb)
if args_opt.do_eval:
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
epochs=1,
batch_size=train_config.batch_size,
data_type=DataType(data_config.data_format))
eval_callback = EvalCallBack(model, ds_eval, auc_metric,
eval_file_path=args_opt.eval_file_name)
callback_list.append(eval_callback)
model.train(train_config.train_epochs, ds_train, callbacks=callback_list)