forked from mindspore-Ecosystem/mindspore
new add ncf network
This commit is contained in:
parent
ca36c7494a
commit
fe7767aa3c
|
@ -0,0 +1,303 @@
|
||||||
|
# Contents
|
||||||
|
|
||||||
|
- [NCF Description](#NCF-description)
|
||||||
|
- [Model Architecture](#model-architecture)
|
||||||
|
- [Dataset](#dataset)
|
||||||
|
- [Features](#features)
|
||||||
|
- [Mixed Precision](#mixed-precision)
|
||||||
|
- [Environment Requirements](#environment-requirements)
|
||||||
|
- [Quick Start](#quick-start)
|
||||||
|
- [Script Description](#script-description)
|
||||||
|
- [Script and Sample Code](#script-and-sample-code)
|
||||||
|
- [Script Parameters](#script-parameters)
|
||||||
|
- [Training Process](#training-process)
|
||||||
|
- [Training](#training)
|
||||||
|
- [Distributed Training](#distributed-training)
|
||||||
|
- [Evaluation Process](#evaluation-process)
|
||||||
|
- [Evaluation](#evaluation)
|
||||||
|
- [Model Description](#model-description)
|
||||||
|
- [Performance](#performance)
|
||||||
|
- [Evaluation Performance](#evaluation-performance)
|
||||||
|
- [Inference Performance](#evaluation-performance)
|
||||||
|
- [How to use](#how-to-use)
|
||||||
|
- [Inference](#inference)
|
||||||
|
- [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
|
||||||
|
- [Transfer Learning](#transfer-learning)
|
||||||
|
- [Description of Random Situation](#description-of-random-situation)
|
||||||
|
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||||
|
|
||||||
|
|
||||||
|
# [NCF Description](#contents)
|
||||||
|
|
||||||
|
NCF is a general framework for collaborative filtering of recommendations in which a neural network architecture is used to model user-item interactions. Unlike traditional models, NCF does not resort to Matrix Factorization (MF) with an inner product on latent features of users and items. It replaces the inner product with a multi-layer perceptron that can learn an arbitrary function from data.
|
||||||
|
|
||||||
|
[Paper](https://arxiv.org/abs/1708.05031): He X, Liao L, Zhang H, et al. Neural collaborative filtering[C]//Proceedings of the 26th international conference on world wide web. 2017: 173-182.
|
||||||
|
|
||||||
|
|
||||||
|
# [Model Architecture](#contents)
|
||||||
|
|
||||||
|
Two instantiations of NCF are Generalized Matrix Factorization (GMF) and Multi-Layer Perceptron (MLP). GMF applies a linear kernel to model the latent feature interactions, and and MLP uses a nonlinear kernel to learn the interaction function from data. NeuMF is a fused model of GMF and MLP to better model the complex user-item interactions, and unifies the strengths of linearity of MF and non-linearity of MLP for modeling the user-item latent structures. NeuMF allows GMF and MLP to learn separate embeddings, and combines the two models by concatenating their last hidden layer. [neumf_model.py](neumf_model.py) defines the architecture details.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# [Dataset](#contents)
|
||||||
|
|
||||||
|
The [MovieLens datasets](http://files.grouplens.org/datasets/movielens/) are used for model training and evaluation. Specifically, we use two datasets: **ml-1m** (short for MovieLens 1 million) and **ml-20m** (short for MovieLens 20 million).
|
||||||
|
|
||||||
|
### ml-1m
|
||||||
|
ml-1m dataset contains 1,000,209 anonymous ratings of approximately 3,706 movies made by 6,040 users who joined MovieLens in 2000. All ratings are contained in the file "ratings.dat" without header row, and are in the following format:
|
||||||
|
```
|
||||||
|
UserID::MovieID::Rating::Timestamp
|
||||||
|
```
|
||||||
|
- UserIDs range between 1 and 6040.
|
||||||
|
- MovieIDs range between 1 and 3952.
|
||||||
|
- Ratings are made on a 5-star scale (whole-star ratings only).
|
||||||
|
|
||||||
|
### ml-20m
|
||||||
|
ml-20m dataset contains 20,000,263 ratings of 26,744 movies by 138493 users. All ratings are contained in the file "ratings.csv". Each line of this file after the header row represents one rating of one movie by one user, and has the following format:
|
||||||
|
```
|
||||||
|
userId,movieId,rating,timestamp
|
||||||
|
```
|
||||||
|
- The lines within this file are ordered first by userId, then, within user, by movieId.
|
||||||
|
- Ratings are made on a 5-star scale, with half-star increments (0.5 stars - 5.0 stars).
|
||||||
|
|
||||||
|
In both datasets, the timestamp is represented in seconds since midnight Coordinated Universal Time (UTC) of January 1, 1970. Each user has at least 20 ratings.
|
||||||
|
|
||||||
|
# [Features](#contents)
|
||||||
|
|
||||||
|
## Mixed Precision
|
||||||
|
|
||||||
|
The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
|
||||||
|
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# [Environment Requirements](#contents)
|
||||||
|
|
||||||
|
- Hardware(Ascend/GPU)
|
||||||
|
- Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||||
|
- Framework
|
||||||
|
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||||
|
- For more information, please check the resources below:
|
||||||
|
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
|
||||||
|
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# [Quick Start](#contents)
|
||||||
|
|
||||||
|
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
#run data process
|
||||||
|
bash scripts/run_download_dataset.sh
|
||||||
|
|
||||||
|
# run training example
|
||||||
|
bash scripts/run_train.sh
|
||||||
|
|
||||||
|
# run distributed training example
|
||||||
|
sh scripts/run_train.sh rank_table.json
|
||||||
|
|
||||||
|
# run evaluation example
|
||||||
|
sh run_eval.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# [Script Description](#contents)
|
||||||
|
|
||||||
|
## [Script and Sample Code](#contents)
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
├── ModelZoo_NCF_ME
|
||||||
|
├── README.md // descriptions about NCF
|
||||||
|
├── scripts
|
||||||
|
│ ├──run_train.sh // shell script for train
|
||||||
|
│ ├──run_distribute_train.sh // shell script for distribute train
|
||||||
|
│ ├──run_eval.sh // shell script for evaluation
|
||||||
|
│ ├──run_download_dataset.sh // shell script for dataget and process
|
||||||
|
│ ├──run_transfer_ckpt_to_air.sh // shell script for transfer model style
|
||||||
|
├── src
|
||||||
|
│ ├──dataset.py // creating dataset
|
||||||
|
│ ├──ncf.py // ncf architecture
|
||||||
|
│ ├──config.py // parameter configuration
|
||||||
|
│ ├──movielens.py // data download file
|
||||||
|
│ ├──callbacks.py // model loss and eval callback file
|
||||||
|
│ ├──constants.py // the constants of model
|
||||||
|
│ ├──export.py // export checkpoint files into geir/onnx
|
||||||
|
│ ├──metrics.py // the file for auc compute
|
||||||
|
│ ├──stat_utils.py // the file for data process functions
|
||||||
|
├── train.py // training script
|
||||||
|
├── eval.py // evaluation script
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Script Parameters](#contents)
|
||||||
|
|
||||||
|
Parameters for both training and evaluation can be set in config.py.
|
||||||
|
|
||||||
|
- config for NCF, ml-1m dataset
|
||||||
|
|
||||||
|
```python
|
||||||
|
* `--data_path`: This should be set to the same directory given to the data_download data_dir argument.
|
||||||
|
* `--dataset`: The dataset name to be downloaded and preprocessed. By default, it is ml-1m.
|
||||||
|
* `--train_epochs`: Total train epochs.
|
||||||
|
* `--batch_size`: Training batch size.
|
||||||
|
* `--eval_batch_size`: Eval batch size.
|
||||||
|
* `--num_neg`: The Number of negative instances to pair with a positive instance.
|
||||||
|
* `--layers`: The sizes of hidden layers for MLP.
|
||||||
|
* `--num_factors`:The Embedding size of MF model.
|
||||||
|
* `--output_path`:The location of the output file.
|
||||||
|
* `--eval_file_name` : Eval output file.
|
||||||
|
* `--loss_file_name` : Loss output file.
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Training Process](#contents)
|
||||||
|
|
||||||
|
### Training
|
||||||
|
|
||||||
|
```python
|
||||||
|
bash scripts/run_train.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
The python command above will run in the background, you can view the results through the file `train.log`. After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# grep "loss is " train.log
|
||||||
|
ds_train.size: 95
|
||||||
|
epoch: 1 step: 95, loss is 0.25074288
|
||||||
|
epoch: 2 step: 95, loss is 0.23324402
|
||||||
|
epoch: 3 step: 95, loss is 0.18286772
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
The model checkpoint will be saved in the current directory.
|
||||||
|
|
||||||
|
## [Evaluation Process](#contents)
|
||||||
|
|
||||||
|
### Evaluation
|
||||||
|
|
||||||
|
- evaluation on ml-1m dataset when running on Ascend
|
||||||
|
|
||||||
|
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "checkpoint/ncf-125_390.ckpt".
|
||||||
|
|
||||||
|
```python
|
||||||
|
sh scripts/run_eval.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# grep "accuracy: " eval.log
|
||||||
|
HR:0.6846,NDCG:0.410
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# [Model Description](#contents)
|
||||||
|
## [Performance](#contents)
|
||||||
|
|
||||||
|
### Evaluation Performance
|
||||||
|
|
||||||
|
| Parameters | Ascend |
|
||||||
|
| -------------------------- | ------------------------------------------------------------ |
|
||||||
|
| Model Version | NCF |
|
||||||
|
| Resource | Ascend 910 ;CPU 2.60GHz,56cores;Memory,314G |
|
||||||
|
| uploaded Date | 10/23/2020 (month/day/year) |
|
||||||
|
| MindSpore Version | 1.0.0 |
|
||||||
|
| Dataset | ml-1m |
|
||||||
|
| Training Parameters | epoch=25, steps=19418, batch_size = 256, lr=0.00382059 |
|
||||||
|
| Optimizer | GradOperation |
|
||||||
|
| Loss Function | Softmax Cross Entropy |
|
||||||
|
| outputs | probability |
|
||||||
|
| Speed | 1pc: 0.575 ms/step |
|
||||||
|
| Total time | 1pc: 5 mins |
|
||||||
|
|
||||||
|
|
||||||
|
### Inference Performance
|
||||||
|
|
||||||
|
| Parameters | Ascend |
|
||||||
|
| ------------------- | --------------------------- |
|
||||||
|
| Model Version | NCF |
|
||||||
|
| Resource | Ascend 910 |
|
||||||
|
| Uploaded Date | 10/23/2020 (month/day/year) |
|
||||||
|
| MindSpore Version | 1.0.0 |
|
||||||
|
| Dataset | ml-1m |
|
||||||
|
| batch_size | 256 |
|
||||||
|
| outputs | probability |
|
||||||
|
| Accuracy | HR:0.6846,NDCG:0.410 |
|
||||||
|
|
||||||
|
## [How to use](#contents)
|
||||||
|
### Inference
|
||||||
|
|
||||||
|
If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/network_migration.html). Following the steps below, this is a simple example:
|
||||||
|
|
||||||
|
https://www.mindspore.cn/tutorial/zh-CN/master/use/multi_platform_inference.html
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
# Load unseen dataset for inference
|
||||||
|
dataset = dataset.create_dataset(cfg.data_path, 1, False)
|
||||||
|
|
||||||
|
# Define model
|
||||||
|
net = GoogleNet(num_classes=cfg.num_classes)
|
||||||
|
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01,
|
||||||
|
cfg.momentum, weight_decay=cfg.weight_decay)
|
||||||
|
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean',
|
||||||
|
is_grad=False)
|
||||||
|
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||||
|
|
||||||
|
# Load pre-trained model
|
||||||
|
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
net.set_train(False)
|
||||||
|
|
||||||
|
# Make predictions on the unseen dataset
|
||||||
|
acc = model.eval(dataset)
|
||||||
|
print("accuracy: ", acc)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Continue Training on the Pretrained Model
|
||||||
|
|
||||||
|
```
|
||||||
|
# Load dataset
|
||||||
|
dataset = create_dataset(cfg.data_path, cfg.epoch_size)
|
||||||
|
batch_num = dataset.get_dataset_size()
|
||||||
|
|
||||||
|
# Define model
|
||||||
|
net = GoogleNet(num_classes=cfg.num_classes)
|
||||||
|
# Continue training if set pre_trained to be True
|
||||||
|
if cfg.pre_trained:
|
||||||
|
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size,
|
||||||
|
steps_per_epoch=batch_num)
|
||||||
|
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
|
||||||
|
Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay)
|
||||||
|
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
|
||||||
|
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
|
||||||
|
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
|
||||||
|
|
||||||
|
# Set callbacks
|
||||||
|
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5,
|
||||||
|
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||||
|
time_cb = TimeMonitor(data_size=batch_num)
|
||||||
|
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./",
|
||||||
|
config=config_ck)
|
||||||
|
loss_cb = LossMonitor()
|
||||||
|
|
||||||
|
# Start training
|
||||||
|
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
|
||||||
|
print("train success")
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
# [Description of Random Situation](#contents)
|
||||||
|
|
||||||
|
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.
|
||||||
|
|
||||||
|
|
||||||
|
# [ModelZoo Homepage](#contents)
|
||||||
|
|
||||||
|
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,91 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Using for eval the model checkpoint"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from absl import logging
|
||||||
|
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore import context, Model
|
||||||
|
|
||||||
|
import src.constants as rconst
|
||||||
|
from src.dataset import create_dataset
|
||||||
|
from src.metrics import NCFMetric
|
||||||
|
from src.ncf import NCFModel, NetWithLossClass, TrainStepWrap, PredictWithSigmoid
|
||||||
|
|
||||||
|
from src.config import cfg
|
||||||
|
logging.set_verbosity(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='NCF')
|
||||||
|
parser.add_argument("--data_path", type=str, default="./dataset/") # The location of the input data.
|
||||||
|
parser.add_argument("--dataset", type=str, default="ml-1m", choices=["ml-1m", "ml-20m"]) # Dataset to be trained and evaluated. ["ml-1m", "ml-20m"]
|
||||||
|
parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file.
|
||||||
|
parser.add_argument("--eval_file_name", type=str, default="eval.log") # Eval output file.
|
||||||
|
parser.add_argument("--checkpoint_file_path", type=str, default="./checkpoint/NCF-14_19418.ckpt") # The location of the checkpoint file.
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
def test_eval():
|
||||||
|
"""eval method"""
|
||||||
|
if not os.path.exists(args.output_path):
|
||||||
|
os.makedirs(args.output_path)
|
||||||
|
|
||||||
|
layers = cfg.layers
|
||||||
|
num_factors = cfg.num_factors
|
||||||
|
topk = rconst.TOP_K
|
||||||
|
num_eval_neg = rconst.NUM_EVAL_NEGATIVES
|
||||||
|
|
||||||
|
ds_eval, num_eval_users, num_eval_items = create_dataset(test_train=False, data_dir=args.data_path,
|
||||||
|
dataset=args.dataset, train_epochs=0,
|
||||||
|
eval_batch_size=cfg.eval_batch_size)
|
||||||
|
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
|
||||||
|
|
||||||
|
ncf_net = NCFModel(num_users=num_eval_users,
|
||||||
|
num_items=num_eval_items,
|
||||||
|
num_factors=num_factors,
|
||||||
|
model_layers=layers,
|
||||||
|
mf_regularization=0,
|
||||||
|
mlp_reg_layers=[0.0, 0.0, 0.0, 0.0],
|
||||||
|
mf_dim=16)
|
||||||
|
param_dict = load_checkpoint(args.checkpoint_file_path)
|
||||||
|
load_param_into_net(ncf_net, param_dict)
|
||||||
|
|
||||||
|
loss_net = NetWithLossClass(ncf_net)
|
||||||
|
train_net = TrainStepWrap(loss_net)
|
||||||
|
# train_net.set_train()
|
||||||
|
eval_net = PredictWithSigmoid(ncf_net, topk, num_eval_neg)
|
||||||
|
|
||||||
|
ncf_metric = NCFMetric()
|
||||||
|
model = Model(train_net, eval_network=eval_net, metrics={"ncf": ncf_metric})
|
||||||
|
|
||||||
|
ncf_metric.clear()
|
||||||
|
out = model.eval(ds_eval)
|
||||||
|
|
||||||
|
eval_file_path = os.path.join(args.output_path, args.eval_file_name)
|
||||||
|
eval_file = open(eval_file_path, "a+")
|
||||||
|
eval_file.write("EvalCallBack: HR = {}, NDCG = {}\n".format(out['ncf'][0], out['ncf'][1]))
|
||||||
|
eval_file.close()
|
||||||
|
print("EvalCallBack: HR = {}, NDCG = {}".format(out['ncf'][0], out['ncf'][1]))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
devid = int(os.getenv('DEVICE_ID'))
|
||||||
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
|
device_target="Davinci",
|
||||||
|
save_graphs=True,
|
||||||
|
device_id=devid)
|
||||||
|
|
||||||
|
test_eval()
|
|
@ -0,0 +1,47 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
echo "Please run the script as: "
|
||||||
|
echo "sh scripts/run_distribute_train.sh DEVICE_NUM DATASET_PATH RANK_TABLE_FILE"
|
||||||
|
echo "for example: sh scripts/run_distribute_train.sh 8 /dataset_path /rank_table_8p.json"
|
||||||
|
|
||||||
|
current_exec_path=$(pwd)
|
||||||
|
echo ${current_exec_path}
|
||||||
|
|
||||||
|
export RANK_SIZE=$1
|
||||||
|
data_path=$2
|
||||||
|
export RANK_TABLE_FILE=$3
|
||||||
|
|
||||||
|
for((i=0;i<=RANK_SIZE;i++));
|
||||||
|
do
|
||||||
|
rm ${current_exec_path}/device_$i/ -rf
|
||||||
|
mkdir ${current_exec_path}/device_$i
|
||||||
|
cd ${current_exec_path}/device_$i || exit
|
||||||
|
export RANK_ID=$i
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
python -u ${current_exec_path}/train.py \
|
||||||
|
--data_path $data_path \
|
||||||
|
--dataset 'ml-1m' \
|
||||||
|
--train_epochs 50 \
|
||||||
|
--output_path './output/' \
|
||||||
|
--eval_file_name 'eval.log' \
|
||||||
|
--loss_file_name 'loss.log' \
|
||||||
|
--checkpoint_path './checkpoint/' \
|
||||||
|
--device_target="Ascend" \
|
||||||
|
--device_id=$i \
|
||||||
|
--is_distributed=1 \
|
||||||
|
>log_$i.log 2>&1 &
|
||||||
|
done
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
echo "Please run the script as: "
|
||||||
|
echo "sh scripts/run_download_dataset.sh DATASET_PATH"
|
||||||
|
echo "for example: sh scripts/run_download_dataset.sh /dataset_path"
|
||||||
|
|
||||||
|
data_path=$1
|
||||||
|
python ./src/movielens.py --data_path $data_path --dataset 'ml-1m'
|
|
@ -0,0 +1,22 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
echo "Please run the script as: "
|
||||||
|
echo "sh scripts/run_eval.sh DATASET_PATH CKPT_FILE"
|
||||||
|
echo "for example: sh scripts/run_eval.sh /dataset_path /ncf.ckpt"
|
||||||
|
|
||||||
|
data_path=$1
|
||||||
|
ckpt_file=$2
|
||||||
|
python ./eval.py --data_path $data_path --dataset 'ml-1m' --eval_batch_size 160000 --output_path './output/' --eval_file_name 'eval.log' --checkpoint_file_path $ckpt_file
|
|
@ -0,0 +1,22 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
echo "Please run the script as: "
|
||||||
|
echo "sh scripts/run_train.sh DATASET_PATH CKPT_FILE"
|
||||||
|
echo "for example: sh scripts/run_train.sh /dataset_path /ncf.ckpt"
|
||||||
|
|
||||||
|
data_path=$1
|
||||||
|
ckpt_file=$2
|
||||||
|
python ./train.py --data_path $data_path --dataset 'ml-1m' --train_epochs 20 --batch_size 256 --output_path './output/' --loss_file_name 'loss.log' --checkpoint_path $ckpt_file
|
|
@ -0,0 +1,22 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
echo "Please run the script as: "
|
||||||
|
echo "sh scripts/run_transfer_ckpt_to_air.sh DATASET_PATH CKPT_FILE"
|
||||||
|
echo "for example: sh scripts/run_transfer_ckpt_to_air.sh /dataset_path /ncf.ckpt"
|
||||||
|
|
||||||
|
data_path=$1
|
||||||
|
ckpt_file=$2
|
||||||
|
python ./src/export.py --data_path $data_path --dataset 'ml-1m' --eval_batch_size 160000 --output_path './output/' --eval_file_name 'eval.log' --checkpoint_file_path $ckpt_file
|
|
@ -0,0 +1 @@
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Callbacks file"""
|
||||||
|
from mindspore.train.callback import Callback
|
||||||
|
|
||||||
|
class EvalCallBack(Callback):
|
||||||
|
"""
|
||||||
|
Monitor the loss in evaluate.
|
||||||
|
"""
|
||||||
|
def __init__(self, model, eval_dataset, metric, eval_file_path="./eval.log"):
|
||||||
|
super(EvalCallBack, self).__init__()
|
||||||
|
self.model = model
|
||||||
|
self.eval_dataset = eval_dataset
|
||||||
|
self.metric = metric
|
||||||
|
self.metric.clear()
|
||||||
|
self.eval_file_path = eval_file_path
|
||||||
|
self.run_context = None
|
||||||
|
|
||||||
|
def epoch_end(self, run_context):
|
||||||
|
self.run_context = run_context
|
||||||
|
self.metric.clear()
|
||||||
|
out = self.model.eval(self.eval_dataset)
|
||||||
|
|
||||||
|
eval_file = open(self.eval_file_path, "a+")
|
||||||
|
eval_file.write("EvalCallBack: HR = {}, NDCG = {}\n".format(out['ncf'][0], out['ncf'][1]))
|
||||||
|
eval_file.close()
|
||||||
|
print("EvalCallBack: HR = {}, NDCG = {}".format(out['ncf'][0], out['ncf'][1]))
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
network config setting, will be used in main.py
|
||||||
|
"""
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
|
||||||
|
|
||||||
|
cfg = edict({
|
||||||
|
'dataset': 'ml-1m', # Dataset to be trained and evaluated, choice: ["ml-1m", "ml-20m"]
|
||||||
|
|
||||||
|
'data_dir': '../dataset', # The location of the input data.
|
||||||
|
|
||||||
|
'train_epochs': 14, # The number of epochs used to train.
|
||||||
|
|
||||||
|
'batch_size': 256, # Batch size for training and evaluation
|
||||||
|
|
||||||
|
'eval_batch_size': 160000, # The batch size used for evaluation.
|
||||||
|
|
||||||
|
'num_neg': 4, # The Number of negative instances to pair with a positive instance.
|
||||||
|
|
||||||
|
'layers': [64, 32, 16], # The sizes of hidden layers for MLP
|
||||||
|
|
||||||
|
'num_factors': 16 # The Embedding size of MF model.
|
||||||
|
|
||||||
|
})
|
|
@ -0,0 +1,78 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Central location for NCF specific values."""
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import src.movielens as movielens
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
# ==============================================================================
|
||||||
|
# == Main Thread Data Processing ===============================================
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
# Keys for data shards
|
||||||
|
TRAIN_USER_KEY = "train_{}".format(movielens.USER_COLUMN)
|
||||||
|
TRAIN_ITEM_KEY = "train_{}".format(movielens.ITEM_COLUMN)
|
||||||
|
TRAIN_LABEL_KEY = "train_labels"
|
||||||
|
MASK_START_INDEX = "mask_start_index"
|
||||||
|
VALID_POINT_MASK = "valid_point_mask"
|
||||||
|
EVAL_USER_KEY = "eval_{}".format(movielens.USER_COLUMN)
|
||||||
|
EVAL_ITEM_KEY = "eval_{}".format(movielens.ITEM_COLUMN)
|
||||||
|
|
||||||
|
USER_MAP = "user_map"
|
||||||
|
ITEM_MAP = "item_map"
|
||||||
|
|
||||||
|
USER_DTYPE = np.int32
|
||||||
|
ITEM_DTYPE = np.int32
|
||||||
|
|
||||||
|
# In both datasets, each user has at least 20 ratings.
|
||||||
|
MIN_NUM_RATINGS = 20
|
||||||
|
|
||||||
|
# The number of negative examples attached with a positive example
|
||||||
|
# when performing evaluation.
|
||||||
|
NUM_EVAL_NEGATIVES = 99
|
||||||
|
|
||||||
|
# keys for evaluation metrics
|
||||||
|
TOP_K = 10 # Top-k list for evaluation
|
||||||
|
HR_KEY = "HR"
|
||||||
|
NDCG_KEY = "NDCG"
|
||||||
|
DUPLICATE_MASK = "duplicate_mask"
|
||||||
|
|
||||||
|
# Metric names
|
||||||
|
HR_METRIC_NAME = "HR_METRIC"
|
||||||
|
NDCG_METRIC_NAME = "NDCG_METRIC"
|
||||||
|
|
||||||
|
# Trying to load a cache created in py2 when running in py3 will cause an
|
||||||
|
# error due to differences in unicode handling.
|
||||||
|
RAW_CACHE_FILE = "raw_data_cache_py{}.pickle".format(sys.version_info[0])
|
||||||
|
CACHE_INVALIDATION_SEC = 3600 * 24
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# == Data Generation ===========================================================
|
||||||
|
# ==============================================================================
|
||||||
|
CYCLES_TO_BUFFER = 3 # The number of train cycles worth of data to "run ahead"
|
||||||
|
# of the main training loop.
|
||||||
|
|
||||||
|
# Number of batches to run per epoch when using synthetic data. At high batch
|
||||||
|
# sizes, we run for more batches than with real data, which is good since
|
||||||
|
# running more batches reduces noise when measuring the average batches/second.
|
||||||
|
SYNTHETIC_BATCHES_PER_EPOCH = 2000
|
||||||
|
|
||||||
|
# Only used when StreamingFilesDataset is used.
|
||||||
|
NUM_FILE_SHARDS = 16
|
||||||
|
TRAIN_FOLDER_TEMPLATE = "training_cycle_{}"
|
||||||
|
EVAL_FOLDER = "eval_data"
|
||||||
|
SHARD_TEMPLATE = "shard_{}.tfrecords"
|
|
@ -0,0 +1,592 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Dataset loading, creation and processing"""
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import timeit
|
||||||
|
import pickle
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from mindspore.dataset.engine import GeneratorDataset
|
||||||
|
|
||||||
|
import src.constants as rconst
|
||||||
|
import src.movielens as movielens
|
||||||
|
import src.stat_utils as stat_utils
|
||||||
|
|
||||||
|
|
||||||
|
DATASET_TO_NUM_USERS_AND_ITEMS = {
|
||||||
|
"ml-1m": (6040, 3706),
|
||||||
|
"ml-20m": (138493, 26744)
|
||||||
|
}
|
||||||
|
|
||||||
|
_EXPECTED_CACHE_KEYS = (
|
||||||
|
rconst.TRAIN_USER_KEY, rconst.TRAIN_ITEM_KEY, rconst.EVAL_USER_KEY,
|
||||||
|
rconst.EVAL_ITEM_KEY, rconst.USER_MAP, rconst.ITEM_MAP)
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(data_dir, dataset):
|
||||||
|
"""
|
||||||
|
Load data in .csv format and output structured data.
|
||||||
|
|
||||||
|
This function reads in the raw CSV of positive items, and performs three
|
||||||
|
preprocessing transformations:
|
||||||
|
|
||||||
|
1) Filter out all users who have not rated at least a certain number
|
||||||
|
of items. (Typically 20 items)
|
||||||
|
|
||||||
|
2) Zero index the users and items such that the largest user_id is
|
||||||
|
`num_users - 1` and the largest item_id is `num_items - 1`
|
||||||
|
|
||||||
|
3) Sort the dataframe by user_id, with timestamp as a secondary sort key.
|
||||||
|
This allows the dataframe to be sliced by user in-place, and for the last
|
||||||
|
item to be selected simply by calling the `-1` index of a user's slice.
|
||||||
|
|
||||||
|
While all of these transformations are performed by Pandas (and are therefore
|
||||||
|
single-threaded), they only take ~2 minutes, and the overhead to apply a
|
||||||
|
MapReduce pattern to parallel process the dataset adds significant complexity
|
||||||
|
for no computational gain. For a larger dataset parallelizing this
|
||||||
|
preprocessing could yield speedups. (Also, this preprocessing step is only
|
||||||
|
performed once for an entire run.
|
||||||
|
"""
|
||||||
|
logging.info("Beginning loading data...")
|
||||||
|
|
||||||
|
raw_rating_path = os.path.join(data_dir, dataset, movielens.RATINGS_FILE)
|
||||||
|
cache_path = os.path.join(data_dir, dataset, rconst.RAW_CACHE_FILE)
|
||||||
|
|
||||||
|
valid_cache = os.path.exists(cache_path)
|
||||||
|
if valid_cache:
|
||||||
|
with open(cache_path, 'rb') as f:
|
||||||
|
cached_data = pickle.load(f)
|
||||||
|
|
||||||
|
for key in _EXPECTED_CACHE_KEYS:
|
||||||
|
if key not in cached_data:
|
||||||
|
valid_cache = False
|
||||||
|
|
||||||
|
if not valid_cache:
|
||||||
|
logging.info("Removing stale raw data cache file.")
|
||||||
|
os.remove(cache_path)
|
||||||
|
|
||||||
|
if valid_cache:
|
||||||
|
data = cached_data
|
||||||
|
else:
|
||||||
|
# process data and save to .csv
|
||||||
|
with open(raw_rating_path) as f:
|
||||||
|
df = pd.read_csv(f)
|
||||||
|
|
||||||
|
# Get the info of users who have more than 20 ratings on items
|
||||||
|
grouped = df.groupby(movielens.USER_COLUMN)
|
||||||
|
df = grouped.filter(lambda x: len(x) >= rconst.MIN_NUM_RATINGS)
|
||||||
|
|
||||||
|
original_users = df[movielens.USER_COLUMN].unique()
|
||||||
|
original_items = df[movielens.ITEM_COLUMN].unique()
|
||||||
|
|
||||||
|
# Map the ids of user and item to 0 based index for following processing
|
||||||
|
logging.info("Generating user_map and item_map...")
|
||||||
|
user_map = {user: index for index, user in enumerate(original_users)}
|
||||||
|
item_map = {item: index for index, item in enumerate(original_items)}
|
||||||
|
|
||||||
|
df[movielens.USER_COLUMN] = df[movielens.USER_COLUMN].apply(
|
||||||
|
lambda user: user_map[user])
|
||||||
|
df[movielens.ITEM_COLUMN] = df[movielens.ITEM_COLUMN].apply(
|
||||||
|
lambda item: item_map[item])
|
||||||
|
|
||||||
|
num_users = len(original_users)
|
||||||
|
num_items = len(original_items)
|
||||||
|
|
||||||
|
assert num_users <= np.iinfo(rconst.USER_DTYPE).max
|
||||||
|
assert num_items <= np.iinfo(rconst.ITEM_DTYPE).max
|
||||||
|
assert df[movielens.USER_COLUMN].max() == num_users - 1
|
||||||
|
assert df[movielens.ITEM_COLUMN].max() == num_items - 1
|
||||||
|
|
||||||
|
# This sort is used to shard the dataframe by user, and later to select
|
||||||
|
# the last item for a user to be used in validation.
|
||||||
|
logging.info("Sorting by user, timestamp...")
|
||||||
|
|
||||||
|
# This sort is equivalent to
|
||||||
|
# df.sort_values([movielens.USER_COLUMN, movielens.TIMESTAMP_COLUMN],
|
||||||
|
# inplace=True)
|
||||||
|
# except that the order of items with the same user and timestamp are
|
||||||
|
# sometimes different. For some reason, this sort results in a better
|
||||||
|
# hit-rate during evaluation, matching the performance of the MLPerf
|
||||||
|
# reference implementation.
|
||||||
|
df.sort_values(by=movielens.TIMESTAMP_COLUMN, inplace=True)
|
||||||
|
df.sort_values([movielens.USER_COLUMN, movielens.TIMESTAMP_COLUMN],
|
||||||
|
inplace=True, kind="mergesort")
|
||||||
|
|
||||||
|
# The dataframe does not reconstruct indices in the sort or filter steps.
|
||||||
|
df = df.reset_index()
|
||||||
|
|
||||||
|
grouped = df.groupby(movielens.USER_COLUMN, group_keys=False)
|
||||||
|
eval_df, train_df = grouped.tail(1), grouped.apply(lambda x: x.iloc[:-1])
|
||||||
|
|
||||||
|
data = {
|
||||||
|
rconst.TRAIN_USER_KEY:
|
||||||
|
train_df[movielens.USER_COLUMN].values.astype(rconst.USER_DTYPE),
|
||||||
|
rconst.TRAIN_ITEM_KEY:
|
||||||
|
train_df[movielens.ITEM_COLUMN].values.astype(rconst.ITEM_DTYPE),
|
||||||
|
rconst.EVAL_USER_KEY:
|
||||||
|
eval_df[movielens.USER_COLUMN].values.astype(rconst.USER_DTYPE),
|
||||||
|
rconst.EVAL_ITEM_KEY:
|
||||||
|
eval_df[movielens.ITEM_COLUMN].values.astype(rconst.ITEM_DTYPE),
|
||||||
|
rconst.USER_MAP: user_map,
|
||||||
|
rconst.ITEM_MAP: item_map,
|
||||||
|
"create_time": time.time(),
|
||||||
|
}
|
||||||
|
|
||||||
|
logging.info("Writing raw data cache.")
|
||||||
|
with open(cache_path, "wb") as f:
|
||||||
|
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
|
num_users, num_items = DATASET_TO_NUM_USERS_AND_ITEMS[dataset]
|
||||||
|
if num_users != len(data[rconst.USER_MAP]):
|
||||||
|
raise ValueError("Expected to find {} users, but found {}".format(
|
||||||
|
num_users, len(data[rconst.USER_MAP])))
|
||||||
|
if num_items != len(data[rconst.ITEM_MAP]):
|
||||||
|
raise ValueError("Expected to find {} items, but found {}".format(
|
||||||
|
num_items, len(data[rconst.ITEM_MAP])))
|
||||||
|
|
||||||
|
return data, num_users, num_items
|
||||||
|
|
||||||
|
|
||||||
|
def construct_lookup_variables(train_pos_users, train_pos_items, num_users):
|
||||||
|
"""Lookup variables"""
|
||||||
|
index_bounds = None
|
||||||
|
sorted_train_pos_items = None
|
||||||
|
|
||||||
|
def index_segment(user):
|
||||||
|
lower, upper = index_bounds[user:user + 2]
|
||||||
|
items = sorted_train_pos_items[lower:upper]
|
||||||
|
|
||||||
|
negatives_since_last_positive = np.concatenate(
|
||||||
|
[items[0][np.newaxis], items[1:] - items[:-1] - 1])
|
||||||
|
|
||||||
|
return np.cumsum(negatives_since_last_positive)
|
||||||
|
|
||||||
|
start_time = timeit.default_timer()
|
||||||
|
inner_bounds = np.argwhere(train_pos_users[1:] -
|
||||||
|
train_pos_users[:-1])[:, 0] + 1
|
||||||
|
(upper_bound,) = train_pos_users.shape
|
||||||
|
index_bounds = np.array([0] + inner_bounds.tolist() + [upper_bound])
|
||||||
|
|
||||||
|
# Later logic will assume that the users are in sequential ascending order.
|
||||||
|
assert np.array_equal(train_pos_users[index_bounds[:-1]], np.arange(num_users))
|
||||||
|
|
||||||
|
sorted_train_pos_items = train_pos_items.copy()
|
||||||
|
|
||||||
|
for i in range(num_users):
|
||||||
|
lower, upper = index_bounds[i:i + 2]
|
||||||
|
sorted_train_pos_items[lower:upper].sort()
|
||||||
|
|
||||||
|
total_negatives = np.concatenate([
|
||||||
|
index_segment(i) for i in range(num_users)])
|
||||||
|
|
||||||
|
logging.info("Negative total vector built. Time: {:.1f} seconds".format(
|
||||||
|
timeit.default_timer() - start_time))
|
||||||
|
|
||||||
|
return total_negatives, index_bounds, sorted_train_pos_items
|
||||||
|
|
||||||
|
|
||||||
|
class NCFDataset:
|
||||||
|
"""
|
||||||
|
A dataset for NCF network.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
pos_users,
|
||||||
|
pos_items,
|
||||||
|
num_users,
|
||||||
|
num_items,
|
||||||
|
batch_size,
|
||||||
|
total_negatives,
|
||||||
|
index_bounds,
|
||||||
|
sorted_train_pos_items,
|
||||||
|
is_training=True):
|
||||||
|
self._pos_users = pos_users
|
||||||
|
self._pos_items = pos_items
|
||||||
|
self._num_users = num_users
|
||||||
|
self._num_items = num_items
|
||||||
|
|
||||||
|
self._batch_size = batch_size
|
||||||
|
|
||||||
|
self._total_negatives = total_negatives
|
||||||
|
self._index_bounds = index_bounds
|
||||||
|
self._sorted_train_pos_items = sorted_train_pos_items
|
||||||
|
|
||||||
|
self._is_training = is_training
|
||||||
|
|
||||||
|
if self._is_training:
|
||||||
|
self._train_pos_count = self._pos_users.shape[0]
|
||||||
|
else:
|
||||||
|
self._eval_users_per_batch = int(
|
||||||
|
batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
|
||||||
|
|
||||||
|
def lookup_negative_items(self, negative_users):
|
||||||
|
"""Lookup negative items"""
|
||||||
|
output = np.zeros(shape=negative_users.shape, dtype=rconst.ITEM_DTYPE) - 1
|
||||||
|
|
||||||
|
left_index = self._index_bounds[negative_users]
|
||||||
|
right_index = self._index_bounds[negative_users + 1] - 1
|
||||||
|
|
||||||
|
num_positives = right_index - left_index + 1
|
||||||
|
num_negatives = self._num_items - num_positives
|
||||||
|
neg_item_choice = stat_utils.very_slightly_biased_randint(num_negatives)
|
||||||
|
|
||||||
|
# Shortcuts:
|
||||||
|
# For points where the negative is greater than or equal to the tally before
|
||||||
|
# the last positive point there is no need to bisect. Instead the item id
|
||||||
|
# corresponding to the negative item choice is simply:
|
||||||
|
# last_postive_index + 1 + (neg_choice - last_negative_tally)
|
||||||
|
# Similarly, if the selection is less than the tally at the first positive
|
||||||
|
# then the item_id is simply the selection.
|
||||||
|
#
|
||||||
|
# Because MovieLens organizes popular movies into low integers (which is
|
||||||
|
# preserved through the preprocessing), the first shortcut is very
|
||||||
|
# efficient, allowing ~60% of samples to bypass the bisection. For the same
|
||||||
|
# reason, the second shortcut is rarely triggered (<0.02%) and is therefore
|
||||||
|
# not worth implementing.
|
||||||
|
use_shortcut = neg_item_choice >= self._total_negatives[right_index]
|
||||||
|
output[use_shortcut] = (
|
||||||
|
self._sorted_train_pos_items[right_index] + 1 +
|
||||||
|
(neg_item_choice - self._total_negatives[right_index])
|
||||||
|
)[use_shortcut]
|
||||||
|
|
||||||
|
if np.all(use_shortcut):
|
||||||
|
# The bisection code is ill-posed when there are no elements.
|
||||||
|
return output
|
||||||
|
|
||||||
|
not_use_shortcut = np.logical_not(use_shortcut)
|
||||||
|
left_index = left_index[not_use_shortcut]
|
||||||
|
right_index = right_index[not_use_shortcut]
|
||||||
|
neg_item_choice = neg_item_choice[not_use_shortcut]
|
||||||
|
|
||||||
|
num_loops = np.max(
|
||||||
|
np.ceil(np.log2(num_positives[not_use_shortcut])).astype(np.int32))
|
||||||
|
|
||||||
|
for _ in range(num_loops):
|
||||||
|
mid_index = (left_index + right_index) // 2
|
||||||
|
right_criteria = self._total_negatives[mid_index] > neg_item_choice
|
||||||
|
left_criteria = np.logical_not(right_criteria)
|
||||||
|
|
||||||
|
right_index[right_criteria] = mid_index[right_criteria]
|
||||||
|
left_index[left_criteria] = mid_index[left_criteria]
|
||||||
|
|
||||||
|
# Expected state after bisection pass:
|
||||||
|
# The right index is the smallest index whose tally is greater than the
|
||||||
|
# negative item choice index.
|
||||||
|
|
||||||
|
assert np.all((right_index - left_index) <= 1)
|
||||||
|
|
||||||
|
output[not_use_shortcut] = (
|
||||||
|
self._sorted_train_pos_items[right_index] - (self._total_negatives[right_index] - neg_item_choice)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert np.all(output >= 0)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _get_train_item(self, index):
|
||||||
|
"""Get train item"""
|
||||||
|
(mask_start_index,) = index.shape
|
||||||
|
index_mod = np.mod(index, self._train_pos_count)
|
||||||
|
|
||||||
|
# get batch of users
|
||||||
|
users = self._pos_users[index_mod]
|
||||||
|
|
||||||
|
# get batch of items
|
||||||
|
negative_indices = np.greater_equal(index, self._train_pos_count)
|
||||||
|
negative_users = users[negative_indices]
|
||||||
|
negative_items = self.lookup_negative_items(negative_users=negative_users)
|
||||||
|
items = self._pos_items[index_mod]
|
||||||
|
items[negative_indices] = negative_items
|
||||||
|
|
||||||
|
# get batch of labels
|
||||||
|
labels = np.logical_not(negative_indices)
|
||||||
|
|
||||||
|
# pad last partial batch
|
||||||
|
pad_length = self._batch_size - index.shape[0]
|
||||||
|
if pad_length:
|
||||||
|
user_pad = np.arange(pad_length, dtype=users.dtype) % self._num_users
|
||||||
|
item_pad = np.arange(pad_length, dtype=items.dtype) % self._num_items
|
||||||
|
label_pad = np.zeros(shape=(pad_length,), dtype=labels.dtype)
|
||||||
|
users = np.concatenate([users, user_pad])
|
||||||
|
items = np.concatenate([items, item_pad])
|
||||||
|
labels = np.concatenate([labels, label_pad])
|
||||||
|
|
||||||
|
users = np.reshape(users, (self._batch_size, 1)) # (_batch_size, 1), int32
|
||||||
|
items = np.reshape(items, (self._batch_size, 1)) # (_batch_size, 1), int32
|
||||||
|
mask_start_index = np.array(mask_start_index, dtype=np.int32) # (_batch_size, 1), int32
|
||||||
|
valid_pt_mask = np.expand_dims(
|
||||||
|
np.less(np.arange(self._batch_size), mask_start_index), -1).astype(np.float32) # (_batch_size, 1), bool
|
||||||
|
labels = np.reshape(labels, (self._batch_size, 1)).astype(np.int32) # (_batch_size, 1), bool
|
||||||
|
|
||||||
|
return users, items, labels, valid_pt_mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _assemble_eval_batch(users, positive_items, negative_items,
|
||||||
|
users_per_batch):
|
||||||
|
"""Construct duplicate_mask and structure data accordingly.
|
||||||
|
|
||||||
|
The positive items should be last so that they lose ties. However, they
|
||||||
|
should not be masked out if the true eval positive happens to be
|
||||||
|
selected as a negative. So instead, the positive is placed in the first
|
||||||
|
position, and then switched with the last element after the duplicate
|
||||||
|
mask has been computed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
users: An array of users in a batch. (should be identical along axis 1)
|
||||||
|
positive_items: An array (batch_size x 1) of positive item indices.
|
||||||
|
negative_items: An array of negative item indices.
|
||||||
|
users_per_batch: How many users should be in the batch. This is passed
|
||||||
|
as an argument so that ncf_test.py can use this method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User, item, and duplicate_mask arrays.
|
||||||
|
"""
|
||||||
|
items = np.concatenate([positive_items, negative_items], axis=1)
|
||||||
|
|
||||||
|
# We pad the users and items here so that the duplicate mask calculation
|
||||||
|
# will include padding. The metric function relies on all padded elements
|
||||||
|
# except the positive being marked as duplicate to mask out padded points.
|
||||||
|
if users.shape[0] < users_per_batch:
|
||||||
|
pad_rows = users_per_batch - users.shape[0]
|
||||||
|
padding = np.zeros(shape=(pad_rows, users.shape[1]), dtype=np.int32)
|
||||||
|
users = np.concatenate([users, padding.astype(users.dtype)], axis=0)
|
||||||
|
items = np.concatenate([items, padding.astype(items.dtype)], axis=0)
|
||||||
|
|
||||||
|
duplicate_mask = stat_utils.mask_duplicates(items, axis=1).astype(np.float32)
|
||||||
|
|
||||||
|
items[:, (0, -1)] = items[:, (-1, 0)]
|
||||||
|
duplicate_mask[:, (0, -1)] = duplicate_mask[:, (-1, 0)]
|
||||||
|
|
||||||
|
assert users.shape == items.shape == duplicate_mask.shape
|
||||||
|
return users, items, duplicate_mask
|
||||||
|
|
||||||
|
def _get_eval_item(self, index):
|
||||||
|
"""Get eval item"""
|
||||||
|
low_index, high_index = index
|
||||||
|
users = np.repeat(self._pos_users[low_index:high_index, np.newaxis],
|
||||||
|
1 + rconst.NUM_EVAL_NEGATIVES, axis=1)
|
||||||
|
positive_items = self._pos_items[low_index:high_index, np.newaxis]
|
||||||
|
negative_items = (self.lookup_negative_items(negative_users=users[:, :-1])
|
||||||
|
.reshape(-1, rconst.NUM_EVAL_NEGATIVES))
|
||||||
|
|
||||||
|
users, items, duplicate_mask = self._assemble_eval_batch(
|
||||||
|
users, positive_items, negative_items, self._eval_users_per_batch)
|
||||||
|
|
||||||
|
users = np.reshape(users.flatten(), (self._batch_size, 1)) # (self._batch_size, 1), int32
|
||||||
|
items = np.reshape(items.flatten(), (self._batch_size, 1)) # (self._batch_size, 1), int32
|
||||||
|
duplicate_mask = np.reshape(duplicate_mask.flatten(), (self._batch_size, 1)) # (self._batch_size, 1), bool
|
||||||
|
|
||||||
|
return users, items, duplicate_mask
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
"""
|
||||||
|
Get a batch of samples.
|
||||||
|
"""
|
||||||
|
if self._is_training:
|
||||||
|
return self._get_train_item(index)
|
||||||
|
|
||||||
|
return self._get_eval_item(index)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomSampler:
|
||||||
|
"""
|
||||||
|
A random sampler for dataset.
|
||||||
|
"""
|
||||||
|
def __init__(self, pos_count, num_train_negatives, batch_size):
|
||||||
|
self.pos_count = pos_count
|
||||||
|
self._num_samples = (1 + num_train_negatives) * self.pos_count
|
||||||
|
self._batch_size = batch_size
|
||||||
|
self._num_batches = math.ceil(self._num_samples / self._batch_size)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
"""
|
||||||
|
Return indices of all batches within an epoch.
|
||||||
|
"""
|
||||||
|
indices = stat_utils.permutation((self._num_samples, stat_utils.random_int32()))
|
||||||
|
|
||||||
|
batch_indices = [indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._num_batches)]
|
||||||
|
return iter(batch_indices)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
Return length of the sampler, i.e., the number of batches for an epoch.
|
||||||
|
"""
|
||||||
|
return self._num_batches
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedSamplerOfTrain:
|
||||||
|
"""
|
||||||
|
A distributed sampler for dataset.
|
||||||
|
"""
|
||||||
|
def __init__(self, pos_count, num_train_negatives, batch_size, rank_id, rank_size):
|
||||||
|
"""
|
||||||
|
Distributed sampler of training dataset.
|
||||||
|
"""
|
||||||
|
self._num_samples = (1 + num_train_negatives) * pos_count
|
||||||
|
self._rank_id = rank_id
|
||||||
|
self._rank_size = rank_size
|
||||||
|
self._batch_size = batch_size
|
||||||
|
|
||||||
|
self._batchs_per_rank = int(math.ceil(self._num_samples / self._batch_size / rank_size))
|
||||||
|
self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._batch_size))
|
||||||
|
self._total_num_samples = self._samples_per_rank * self._rank_size
|
||||||
|
def __iter__(self):
|
||||||
|
"""
|
||||||
|
Returns the data after each sampling.
|
||||||
|
"""
|
||||||
|
indices = stat_utils.permutation((self._num_samples, stat_utils.random_int32()))
|
||||||
|
indices = indices.tolist()
|
||||||
|
indices.extend(indices[:self._total_num_samples-len(indices)])
|
||||||
|
indices = indices[self._rank_id:self._total_num_samples:self._rank_size]
|
||||||
|
batch_indices = [indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._batchs_per_rank)]
|
||||||
|
|
||||||
|
return iter(np.array(batch_indices))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
Returns the length after each sampling.
|
||||||
|
"""
|
||||||
|
return self._batchs_per_rank
|
||||||
|
|
||||||
|
class SequenceSampler:
|
||||||
|
"""
|
||||||
|
A sequence sampler for dataset.
|
||||||
|
"""
|
||||||
|
def __init__(self, eval_batch_size, num_users):
|
||||||
|
self._eval_users_per_batch = int(
|
||||||
|
eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
|
||||||
|
self._eval_elements_in_epoch = num_users * (1 + rconst.NUM_EVAL_NEGATIVES)
|
||||||
|
self._eval_batches_per_epoch = self.count_batches(
|
||||||
|
self._eval_elements_in_epoch, eval_batch_size)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
indices = [(x * self._eval_users_per_batch, (x + 1) * self._eval_users_per_batch)
|
||||||
|
for x in range(self._eval_batches_per_epoch)]
|
||||||
|
return iter(indices)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def count_batches(example_count, batch_size, batches_per_step=1):
|
||||||
|
"""Determine the number of batches, rounding up to fill all devices."""
|
||||||
|
x = (example_count + batch_size - 1) // batch_size
|
||||||
|
return (x + batches_per_step - 1) // batches_per_step * batches_per_step
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
Return the length of the sampler, i,e, the number of batches in an epoch.
|
||||||
|
"""
|
||||||
|
return self._eval_batches_per_epoch
|
||||||
|
|
||||||
|
class DistributedSamplerOfEval:
|
||||||
|
"""
|
||||||
|
A distributed sampler for eval dataset.
|
||||||
|
"""
|
||||||
|
def __init__(self, eval_batch_size, num_users, rank_id, rank_size):
|
||||||
|
self._eval_users_per_batch = int(
|
||||||
|
eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
|
||||||
|
self._eval_elements_in_epoch = num_users * (1 + rconst.NUM_EVAL_NEGATIVES)
|
||||||
|
self._eval_batches_per_epoch = self.count_batches(
|
||||||
|
self._eval_elements_in_epoch, eval_batch_size)
|
||||||
|
|
||||||
|
self._rank_id = rank_id
|
||||||
|
self._rank_size = rank_size
|
||||||
|
self._eval_batch_size = eval_batch_size
|
||||||
|
|
||||||
|
self._batchs_per_rank = int(math.ceil(self._eval_batches_per_epoch / rank_size))
|
||||||
|
#self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._eval_batch_size))
|
||||||
|
#self._total_num_samples = self._samples_per_rank * self._rank_size
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
indices = [(x * self._eval_users_per_batch, (x + self._rank_id + 1) * self._eval_users_per_batch)
|
||||||
|
for x in range(self._batchs_per_rank)]
|
||||||
|
|
||||||
|
return iter(np.array(indices))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def count_batches(example_count, batch_size, batches_per_step=1):
|
||||||
|
"""Determine the number of batches, rounding up to fill all devices."""
|
||||||
|
x = (example_count + batch_size - 1) // batch_size
|
||||||
|
return (x + batches_per_step - 1) // batches_per_step * batches_per_step
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self._batchs_per_rank
|
||||||
|
|
||||||
|
def parse_eval_batch_size(eval_batch_size):
|
||||||
|
"""
|
||||||
|
Parse eval batch size.
|
||||||
|
"""
|
||||||
|
if eval_batch_size % (1 + rconst.NUM_EVAL_NEGATIVES):
|
||||||
|
raise ValueError("Eval batch size {} is not divisible by {}".format(
|
||||||
|
eval_batch_size, 1 + rconst.NUM_EVAL_NEGATIVES))
|
||||||
|
return eval_batch_size
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(test_train=True, data_dir='./dataset/', dataset='ml-1m', train_epochs=14, batch_size=256,
|
||||||
|
eval_batch_size=160000, num_neg=4, rank_id=None, rank_size=None):
|
||||||
|
"""
|
||||||
|
Create NCF dataset.
|
||||||
|
"""
|
||||||
|
data, num_users, num_items = load_data(data_dir, dataset)
|
||||||
|
|
||||||
|
train_pos_users = data[rconst.TRAIN_USER_KEY]
|
||||||
|
train_pos_items = data[rconst.TRAIN_ITEM_KEY]
|
||||||
|
eval_pos_users = data[rconst.EVAL_USER_KEY]
|
||||||
|
eval_pos_items = data[rconst.EVAL_ITEM_KEY]
|
||||||
|
|
||||||
|
total_negatives, index_bounds, sorted_train_pos_items = \
|
||||||
|
construct_lookup_variables(train_pos_users, train_pos_items, num_users)
|
||||||
|
|
||||||
|
if test_train:
|
||||||
|
print(train_pos_users, train_pos_items, num_users, num_items, batch_size, total_negatives, index_bounds,
|
||||||
|
sorted_train_pos_items)
|
||||||
|
dataset = NCFDataset(train_pos_users, train_pos_items, num_users, num_items, batch_size, total_negatives,
|
||||||
|
index_bounds, sorted_train_pos_items)
|
||||||
|
sampler = RandomSampler(train_pos_users.shape[0], num_neg, batch_size)
|
||||||
|
if rank_id is not None and rank_size is not None:
|
||||||
|
sampler = DistributedSamplerOfTrain(train_pos_users.shape[0], num_neg, batch_size, rank_id, rank_size)
|
||||||
|
if dataset == 'ml-20m':
|
||||||
|
ds = GeneratorDataset(dataset,
|
||||||
|
column_names=[movielens.USER_COLUMN,
|
||||||
|
movielens.ITEM_COLUMN,
|
||||||
|
"labels",
|
||||||
|
rconst.VALID_POINT_MASK],
|
||||||
|
sampler=sampler, num_parallel_workers=32, python_multiprocessing=False)
|
||||||
|
else:
|
||||||
|
ds = GeneratorDataset(dataset,
|
||||||
|
column_names=[movielens.USER_COLUMN,
|
||||||
|
movielens.ITEM_COLUMN,
|
||||||
|
"labels",
|
||||||
|
rconst.VALID_POINT_MASK],
|
||||||
|
sampler=sampler)
|
||||||
|
|
||||||
|
else:
|
||||||
|
eval_batch_size = parse_eval_batch_size(eval_batch_size=eval_batch_size)
|
||||||
|
dataset = NCFDataset(eval_pos_users, eval_pos_items, num_users, num_items,
|
||||||
|
eval_batch_size, total_negatives, index_bounds,
|
||||||
|
sorted_train_pos_items, is_training=False)
|
||||||
|
sampler = SequenceSampler(eval_batch_size, num_users)
|
||||||
|
|
||||||
|
ds = GeneratorDataset(dataset,
|
||||||
|
column_names=[movielens.USER_COLUMN,
|
||||||
|
movielens.ITEM_COLUMN,
|
||||||
|
rconst.DUPLICATE_MASK],
|
||||||
|
sampler=sampler)
|
||||||
|
|
||||||
|
repeat_count = train_epochs if test_train else train_epochs + 1
|
||||||
|
ds = ds.repeat(repeat_count)
|
||||||
|
|
||||||
|
return ds, num_users, num_items
|
|
@ -0,0 +1,112 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Export NCF air file."""
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
from absl import logging
|
||||||
|
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||||
|
from mindspore import Tensor, context, Model
|
||||||
|
|
||||||
|
import constants as rconst
|
||||||
|
from dataset import create_dataset
|
||||||
|
from metrics import NCFMetric
|
||||||
|
from ncf import NCFModel, NetWithLossClass, TrainStepWrap, PredictWithSigmoid
|
||||||
|
|
||||||
|
|
||||||
|
logging.set_verbosity(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def argparse_init():
|
||||||
|
"""Argparse init method"""
|
||||||
|
parser = argparse.ArgumentParser(description='NCF')
|
||||||
|
|
||||||
|
parser.add_argument("--data_path", type=str, default="./dataset/") # The location of the input data.
|
||||||
|
parser.add_argument("--dataset", type=str, default="ml-1m", choices=["ml-1m", "ml-20m"]) # Dataset to be trained and evaluated. ["ml-1m", "ml-20m"]
|
||||||
|
parser.add_argument("--eval_batch_size", type=int, default=160000) # The batch size used for evaluation.
|
||||||
|
parser.add_argument("--layers", type=int, default=[64, 32, 16]) # The sizes of hidden layers for MLP
|
||||||
|
parser.add_argument("--num_factors", type=int, default=16) # The Embedding size of MF model.
|
||||||
|
parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file.
|
||||||
|
parser.add_argument("--eval_file_name", type=str, default="eval.log") # Eval output file.
|
||||||
|
parser.add_argument("--checkpoint_file_path", type=str, default="./checkpoint/NCF.ckpt") # The location of the checkpoint file.
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def export_air_file():
|
||||||
|
""""Export file for eval"""
|
||||||
|
parser = argparse_init()
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
if not os.path.exists(args.output_path):
|
||||||
|
os.makedirs(args.output_path)
|
||||||
|
|
||||||
|
layers = args.layers
|
||||||
|
num_factors = args.num_factors
|
||||||
|
topk = rconst.TOP_K
|
||||||
|
num_eval_neg = rconst.NUM_EVAL_NEGATIVES
|
||||||
|
|
||||||
|
ds_eval, num_eval_users, num_eval_items = create_dataset(test_train=False, data_dir=args.data_path,
|
||||||
|
dataset=args.dataset, train_epochs=0,
|
||||||
|
eval_batch_size=args.eval_batch_size)
|
||||||
|
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
|
||||||
|
|
||||||
|
ncf_net = NCFModel(num_users=num_eval_users,
|
||||||
|
num_items=num_eval_items,
|
||||||
|
num_factors=num_factors,
|
||||||
|
model_layers=layers,
|
||||||
|
mf_regularization=0,
|
||||||
|
mlp_reg_layers=[0.0, 0.0, 0.0, 0.0],
|
||||||
|
mf_dim=16)
|
||||||
|
param_dict = load_checkpoint(args.checkpoint_file_path)
|
||||||
|
load_param_into_net(ncf_net, param_dict)
|
||||||
|
|
||||||
|
loss_net = NetWithLossClass(ncf_net)
|
||||||
|
train_net = TrainStepWrap(loss_net)
|
||||||
|
train_net.set_train()
|
||||||
|
eval_net = PredictWithSigmoid(ncf_net, topk, num_eval_neg)
|
||||||
|
|
||||||
|
ncf_metric = NCFMetric()
|
||||||
|
model = Model(train_net, eval_network=eval_net, metrics={"ncf": ncf_metric})
|
||||||
|
|
||||||
|
ncf_metric.clear()
|
||||||
|
out = model.eval(ds_eval)
|
||||||
|
|
||||||
|
eval_file_path = os.path.join(args.output_path, args.eval_file_name)
|
||||||
|
eval_file = open(eval_file_path, "a+")
|
||||||
|
eval_file.write("EvalCallBack: HR = {}, NDCG = {}\n".format(out['ncf'][0], out['ncf'][1]))
|
||||||
|
eval_file.close()
|
||||||
|
print("EvalCallBack: HR = {}, NDCG = {}".format(out['ncf'][0], out['ncf'][1]))
|
||||||
|
|
||||||
|
param_dict = load_checkpoint(args.checkpoint_file_path)
|
||||||
|
# load the parameter into net
|
||||||
|
load_param_into_net(eval_net, param_dict)
|
||||||
|
|
||||||
|
input_tensor_list = []
|
||||||
|
for data in ds_eval:
|
||||||
|
for j in data:
|
||||||
|
input_tensor_list.append(Tensor(j))
|
||||||
|
print(len(a))
|
||||||
|
break
|
||||||
|
print(input_tensor_list)
|
||||||
|
export(eval_net, *input_tensor_list, file_name='NCF.air', file_format='AIR')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
devid = int(os.getenv('DEVICE_ID'))
|
||||||
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
|
device_target="Davinci",
|
||||||
|
save_graphs=True,
|
||||||
|
device_id=devid)
|
||||||
|
|
||||||
|
export_air_file()
|
|
@ -0,0 +1,59 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""NCF metrics calculation"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore.nn.metrics import Metric
|
||||||
|
|
||||||
|
class NCFMetric(Metric):
|
||||||
|
"""NCF metrics method"""
|
||||||
|
def __init__(self):
|
||||||
|
super(NCFMetric, self).__init__()
|
||||||
|
self.hr = []
|
||||||
|
self.ndcg = []
|
||||||
|
self.weights = []
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""Clear the internal evaluation result."""
|
||||||
|
self.hr = []
|
||||||
|
self.ndcg = []
|
||||||
|
self.weights = []
|
||||||
|
|
||||||
|
def hit(self, gt_item, pred_items):
|
||||||
|
if gt_item in pred_items:
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def ndcg_function(self, gt_item, pred_items):
|
||||||
|
if gt_item in pred_items:
|
||||||
|
index = pred_items.index(gt_item)
|
||||||
|
return np.reciprocal(np.log2(index + 2))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def update(self, batch_indices, batch_items, metric_weights):
|
||||||
|
"""Update hr and ndcg"""
|
||||||
|
batch_indices = batch_indices.asnumpy() # (num_user, topk)
|
||||||
|
batch_items = batch_items.asnumpy() # (num_user, 100)
|
||||||
|
metric_weights = metric_weights.asnumpy() # (num_user,)
|
||||||
|
num_user = batch_items.shape[0]
|
||||||
|
for user in range(num_user):
|
||||||
|
if metric_weights[user]:
|
||||||
|
recommends = batch_items[user][batch_indices[user]].tolist()
|
||||||
|
items = batch_items[user].tolist()[-1]
|
||||||
|
self.hr.append(self.hit(items, recommends))
|
||||||
|
self.ndcg.append(self.ndcg_function(items, recommends))
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
return np.mean(self.hr), np.mean(self.ndcg)
|
|
@ -0,0 +1,287 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Download and extract the MovieLens dataset from GroupLens website.
|
||||||
|
|
||||||
|
Download the dataset, and perform basic preprocessing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import zipfile
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import six
|
||||||
|
from six.moves import urllib
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from absl import logging
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
ML_1M = "ml-1m"
|
||||||
|
ML_20M = "ml-20m"
|
||||||
|
DATASETS = [ML_1M, ML_20M]
|
||||||
|
|
||||||
|
RATINGS_FILE = "ratings.csv"
|
||||||
|
MOVIES_FILE = "movies.csv"
|
||||||
|
|
||||||
|
# URL to download dataset
|
||||||
|
_DATA_URL = "http://files.grouplens.org/datasets/movielens/"
|
||||||
|
|
||||||
|
GENRE_COLUMN = "genres"
|
||||||
|
ITEM_COLUMN = "item_id" # movies
|
||||||
|
RATING_COLUMN = "rating"
|
||||||
|
TIMESTAMP_COLUMN = "timestamp"
|
||||||
|
TITLE_COLUMN = "titles"
|
||||||
|
USER_COLUMN = "user_id"
|
||||||
|
|
||||||
|
GENRES = [
|
||||||
|
'Action', 'Adventure', 'Animation', "Children", 'Comedy', 'Crime',
|
||||||
|
'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', "IMAX", 'Musical',
|
||||||
|
'Mystery', 'Romance', 'Sci-Fi', 'Thriller', 'War', 'Western'
|
||||||
|
]
|
||||||
|
N_GENRE = len(GENRES)
|
||||||
|
|
||||||
|
RATING_COLUMNS = [USER_COLUMN, ITEM_COLUMN, RATING_COLUMN, TIMESTAMP_COLUMN]
|
||||||
|
MOVIE_COLUMNS = [ITEM_COLUMN, TITLE_COLUMN, GENRE_COLUMN]
|
||||||
|
|
||||||
|
# Note: Users are indexed [1, k], not [0, k-1]
|
||||||
|
NUM_USER_IDS = {
|
||||||
|
ML_1M: 6040,
|
||||||
|
ML_20M: 138493,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Note: Movies are indexed [1, k], not [0, k-1]
|
||||||
|
# Both the 1m and 20m datasets use the same movie set.
|
||||||
|
NUM_ITEM_IDS = 3952
|
||||||
|
|
||||||
|
MAX_RATING = 5
|
||||||
|
|
||||||
|
NUM_RATINGS = {
|
||||||
|
ML_1M: 1000209,
|
||||||
|
ML_20M: 20000263
|
||||||
|
}
|
||||||
|
|
||||||
|
arg_parser = argparse.ArgumentParser(description='movielens dataset')
|
||||||
|
arg_parser.add_argument("--data_path", type=str, default="./dataset/")
|
||||||
|
arg_parser.add_argument("--dataset", type=str, default="ml-1m", choices=["ml-1m", "ml-20m"])
|
||||||
|
args, _ = arg_parser.parse_known_args()
|
||||||
|
|
||||||
|
|
||||||
|
def _download_and_clean(dataset, data_dir):
|
||||||
|
"""Download MovieLens dataset in a standard format.
|
||||||
|
|
||||||
|
This function downloads the specified MovieLens format and coerces it into a
|
||||||
|
standard format. The only difference between the ml-1m and ml-20m datasets
|
||||||
|
after this point (other than size, of course) is that the 1m dataset uses
|
||||||
|
whole number ratings while the 20m dataset allows half integer ratings.
|
||||||
|
"""
|
||||||
|
if dataset not in DATASETS:
|
||||||
|
raise ValueError("dataset {} is not in {{{}}}".format(
|
||||||
|
dataset, ",".join(DATASETS)))
|
||||||
|
|
||||||
|
data_subdir = os.path.join(data_dir, dataset)
|
||||||
|
|
||||||
|
expected_files = ["{}.zip".format(dataset), RATINGS_FILE, MOVIES_FILE]
|
||||||
|
|
||||||
|
tf.io.gfile.makedirs(data_subdir)
|
||||||
|
if set(expected_files).intersection(
|
||||||
|
tf.io.gfile.listdir(data_subdir)) == set(expected_files):
|
||||||
|
logging.info("Dataset {} has already been downloaded".format(dataset))
|
||||||
|
return
|
||||||
|
|
||||||
|
url = "{}{}.zip".format(_DATA_URL, dataset)
|
||||||
|
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
try:
|
||||||
|
zip_path = os.path.join(temp_dir, "{}.zip".format(dataset))
|
||||||
|
zip_path, _ = urllib.request.urlretrieve(url, zip_path)
|
||||||
|
statinfo = os.stat(zip_path)
|
||||||
|
# A new line to clear the carriage return from download progress
|
||||||
|
# logging.info is not applicable here
|
||||||
|
print()
|
||||||
|
logging.info(
|
||||||
|
"Successfully downloaded {} {} bytes".format(
|
||||||
|
zip_path, statinfo.st_size))
|
||||||
|
|
||||||
|
zipfile.ZipFile(zip_path, "r").extractall(temp_dir)
|
||||||
|
|
||||||
|
if dataset == ML_1M:
|
||||||
|
_regularize_1m_dataset(temp_dir)
|
||||||
|
else:
|
||||||
|
_regularize_20m_dataset(temp_dir)
|
||||||
|
|
||||||
|
for fname in tf.io.gfile.listdir(temp_dir):
|
||||||
|
if not tf.io.gfile.exists(os.path.join(data_subdir, fname)):
|
||||||
|
tf.io.gfile.copy(os.path.join(temp_dir, fname),
|
||||||
|
os.path.join(data_subdir, fname))
|
||||||
|
else:
|
||||||
|
logging.info("Skipping copy of {}, as it already exists in the "
|
||||||
|
"destination folder.".format(fname))
|
||||||
|
|
||||||
|
finally:
|
||||||
|
tf.io.gfile.rmtree(temp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def _transform_csv(input_path, output_path, names, skip_first, separator=","):
|
||||||
|
"""Transform csv to a regularized format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_path: The path of the raw csv.
|
||||||
|
output_path: The path of the cleaned csv.
|
||||||
|
names: The csv column names.
|
||||||
|
skip_first: Boolean of whether to skip the first line of the raw csv.
|
||||||
|
separator: Character used to separate fields in the raw csv.
|
||||||
|
"""
|
||||||
|
if six.PY2:
|
||||||
|
names = [six.ensure_text(n, "utf-8") for n in names]
|
||||||
|
|
||||||
|
with tf.io.gfile.GFile(output_path, "wb") as f_out, \
|
||||||
|
tf.io.gfile.GFile(input_path, "rb") as f_in:
|
||||||
|
|
||||||
|
# Write column names to the csv.
|
||||||
|
f_out.write(",".join(names).encode("utf-8"))
|
||||||
|
f_out.write(b"\n")
|
||||||
|
for i, line in enumerate(f_in):
|
||||||
|
if i == 0 and skip_first:
|
||||||
|
continue # ignore existing labels in the csv
|
||||||
|
|
||||||
|
line = six.ensure_text(line, "utf-8", errors="ignore")
|
||||||
|
fields = line.split(separator)
|
||||||
|
if separator != ",":
|
||||||
|
fields = ['"{}"'.format(field) if "," in field else field
|
||||||
|
for field in fields]
|
||||||
|
f_out.write(",".join(fields).encode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
def _regularize_1m_dataset(temp_dir):
|
||||||
|
"""
|
||||||
|
ratings.dat
|
||||||
|
The file has no header row, and each line is in the following format:
|
||||||
|
UserID::MovieID::Rating::Timestamp
|
||||||
|
- UserIDs range from 1 and 6040
|
||||||
|
- MovieIDs range from 1 and 3952
|
||||||
|
- Ratings are made on a 5-star scale (whole-star ratings only)
|
||||||
|
- Timestamp is represented in seconds since midnight Coordinated Universal
|
||||||
|
Time (UTC) of January 1, 1970.
|
||||||
|
- Each user has at least 20 ratings
|
||||||
|
|
||||||
|
movies.dat
|
||||||
|
Each line has the following format:
|
||||||
|
MovieID::Title::Genres
|
||||||
|
- MovieIDs range from 1 and 3952
|
||||||
|
"""
|
||||||
|
working_dir = os.path.join(temp_dir, ML_1M)
|
||||||
|
|
||||||
|
_transform_csv(
|
||||||
|
input_path=os.path.join(working_dir, "ratings.dat"),
|
||||||
|
output_path=os.path.join(temp_dir, RATINGS_FILE),
|
||||||
|
names=RATING_COLUMNS, skip_first=False, separator="::")
|
||||||
|
|
||||||
|
_transform_csv(
|
||||||
|
input_path=os.path.join(working_dir, "movies.dat"),
|
||||||
|
output_path=os.path.join(temp_dir, MOVIES_FILE),
|
||||||
|
names=MOVIE_COLUMNS, skip_first=False, separator="::")
|
||||||
|
|
||||||
|
tf.io.gfile.rmtree(working_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def _regularize_20m_dataset(temp_dir):
|
||||||
|
"""
|
||||||
|
ratings.csv
|
||||||
|
Each line of this file after the header row represents one rating of one
|
||||||
|
movie by one user, and has the following format:
|
||||||
|
userId,movieId,rating,timestamp
|
||||||
|
- The lines within this file are ordered first by userId, then, within user,
|
||||||
|
by movieId.
|
||||||
|
- Ratings are made on a 5-star scale, with half-star increments
|
||||||
|
(0.5 stars - 5.0 stars).
|
||||||
|
- Timestamps represent seconds since midnight Coordinated Universal Time
|
||||||
|
(UTC) of January 1, 1970.
|
||||||
|
- All the users had rated at least 20 movies.
|
||||||
|
|
||||||
|
movies.csv
|
||||||
|
Each line has the following format:
|
||||||
|
MovieID,Title,Genres
|
||||||
|
- MovieIDs range from 1 and 3952
|
||||||
|
"""
|
||||||
|
working_dir = os.path.join(temp_dir, ML_20M)
|
||||||
|
|
||||||
|
_transform_csv(
|
||||||
|
input_path=os.path.join(working_dir, "ratings.csv"),
|
||||||
|
output_path=os.path.join(temp_dir, RATINGS_FILE),
|
||||||
|
names=RATING_COLUMNS, skip_first=True, separator=",")
|
||||||
|
|
||||||
|
_transform_csv(
|
||||||
|
input_path=os.path.join(working_dir, "movies.csv"),
|
||||||
|
output_path=os.path.join(temp_dir, MOVIES_FILE),
|
||||||
|
names=MOVIE_COLUMNS, skip_first=True, separator=",")
|
||||||
|
|
||||||
|
tf.io.gfile.rmtree(working_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def download(dataset, data_dir):
|
||||||
|
if dataset:
|
||||||
|
_download_and_clean(dataset, data_dir)
|
||||||
|
else:
|
||||||
|
_ = [_download_and_clean(d, data_dir) for d in DATASETS]
|
||||||
|
|
||||||
|
|
||||||
|
def ratings_csv_to_dataframe(data_dir, dataset):
|
||||||
|
with tf.io.gfile.GFile(os.path.join(data_dir, dataset, RATINGS_FILE)) as f:
|
||||||
|
return pd.read_csv(f, encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def csv_to_joint_dataframe(data_dir, dataset):
|
||||||
|
ratings = ratings_csv_to_dataframe(data_dir, dataset)
|
||||||
|
|
||||||
|
with tf.io.gfile.GFile(os.path.join(data_dir, dataset, MOVIES_FILE)) as f:
|
||||||
|
movies = pd.read_csv(f, encoding="utf-8")
|
||||||
|
|
||||||
|
df = ratings.merge(movies, on=ITEM_COLUMN)
|
||||||
|
df[RATING_COLUMN] = df[RATING_COLUMN].astype(np.float32)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def integerize_genres(dataframe):
|
||||||
|
"""Replace genre string with a binary vector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataframe: a pandas dataframe of movie data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The transformed dataframe.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _map_fn(entry):
|
||||||
|
entry.replace("Children's", "Children") # naming difference.
|
||||||
|
movie_genres = entry.split("|")
|
||||||
|
output = np.zeros((len(GENRES),), dtype=np.int64)
|
||||||
|
for i, genre in enumerate(GENRES):
|
||||||
|
if genre in movie_genres:
|
||||||
|
output[i] = 1
|
||||||
|
return output
|
||||||
|
|
||||||
|
dataframe[GENRE_COLUMN] = dataframe[GENRE_COLUMN].apply(_map_fn)
|
||||||
|
|
||||||
|
return dataframe
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
download(args.dataset, args.data_path)
|
|
@ -0,0 +1,290 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Neural Collaborative Filtering Model"""
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore import Tensor, Parameter, ParameterTuple
|
||||||
|
from mindspore._checkparam import Validator as validator
|
||||||
|
from mindspore.nn.layer.activation import get_activation
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common.initializer import initializer
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
|
||||||
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||||
|
|
||||||
|
class DenseLayer(nn.Cell):
|
||||||
|
"""
|
||||||
|
Dense layer definition
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
weight_init='normal',
|
||||||
|
bias_init='zeros',
|
||||||
|
has_bias=True,
|
||||||
|
activation=None):
|
||||||
|
super(DenseLayer, self).__init__()
|
||||||
|
self.in_channels = validator.check_positive_int(in_channels)
|
||||||
|
self.out_channels = validator.check_positive_int(out_channels)
|
||||||
|
self.has_bias = validator.check_bool(has_bias)
|
||||||
|
|
||||||
|
if isinstance(weight_init, Tensor):
|
||||||
|
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
|
||||||
|
weight_init.shape()[1] != in_channels:
|
||||||
|
raise ValueError("weight_init shape error")
|
||||||
|
|
||||||
|
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
|
||||||
|
|
||||||
|
if self.has_bias:
|
||||||
|
if isinstance(bias_init, Tensor):
|
||||||
|
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
|
||||||
|
raise ValueError("bias_init shape error")
|
||||||
|
|
||||||
|
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
||||||
|
|
||||||
|
self.matmul = P.MatMul(transpose_b=True)
|
||||||
|
self.bias_add = P.BiasAdd()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
|
||||||
|
self.activation = get_activation(activation)
|
||||||
|
self.activation_flag = self.activation is not None
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""
|
||||||
|
dense layer construct method
|
||||||
|
"""
|
||||||
|
x = self.cast(x, mstype.float16)
|
||||||
|
weight = self.cast(self.weight, mstype.float16)
|
||||||
|
bias = self.cast(self.bias, mstype.float16)
|
||||||
|
|
||||||
|
output = self.matmul(x, weight)
|
||||||
|
if self.has_bias:
|
||||||
|
output = self.bias_add(output, bias)
|
||||||
|
if self.activation_flag:
|
||||||
|
output = self.activation(output)
|
||||||
|
output = self.cast(output, mstype.float32)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def extend_repr(self):
|
||||||
|
"""A pretty print for Dense layer."""
|
||||||
|
str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \
|
||||||
|
.format(self.in_channels, self.out_channels, self.weight, self.has_bias)
|
||||||
|
if self.has_bias:
|
||||||
|
str_info = str_info + ', bias={}'.format(self.bias)
|
||||||
|
|
||||||
|
if self.activation_flag:
|
||||||
|
str_info = str_info + ', activation={}'.format(self.activation)
|
||||||
|
|
||||||
|
return str_info
|
||||||
|
|
||||||
|
|
||||||
|
class NCFModel(nn.Cell):
|
||||||
|
"""
|
||||||
|
Class for Neural Collaborative Filtering Model from paper " Neural Collaborative Filtering".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_users,
|
||||||
|
num_items,
|
||||||
|
num_factors,
|
||||||
|
model_layers,
|
||||||
|
mf_regularization,
|
||||||
|
mlp_reg_layers,
|
||||||
|
mf_dim):
|
||||||
|
super(NCFModel, self).__init__()
|
||||||
|
|
||||||
|
self.data_path = ""
|
||||||
|
self.model_path = ""
|
||||||
|
|
||||||
|
self.num_users = num_users
|
||||||
|
self.num_items = num_items
|
||||||
|
self.num_factors = num_factors
|
||||||
|
self.model_layers = model_layers
|
||||||
|
|
||||||
|
self.mf_regularization = mf_regularization
|
||||||
|
self.mlp_reg_layers = mlp_reg_layers
|
||||||
|
|
||||||
|
self.mf_dim = mf_dim
|
||||||
|
|
||||||
|
self.num_layers = len(self.model_layers) # Number of layers in the MLP
|
||||||
|
|
||||||
|
if self.model_layers[0] % 2 != 0:
|
||||||
|
raise ValueError("The first layer size should be multiple of 2!")
|
||||||
|
|
||||||
|
# Initializer for embedding layers
|
||||||
|
self.embedding_initializer = "normal"
|
||||||
|
|
||||||
|
self.embedding_user = nn.Embedding(
|
||||||
|
self.num_users,
|
||||||
|
self.num_factors + self.model_layers[0] // 2,
|
||||||
|
embedding_table=self.embedding_initializer
|
||||||
|
)
|
||||||
|
self.embedding_item = nn.Embedding(
|
||||||
|
self.num_items,
|
||||||
|
self.num_factors + self.model_layers[0] // 2,
|
||||||
|
embedding_table=self.embedding_initializer
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mlp_dense1 = DenseLayer(in_channels=self.model_layers[0],
|
||||||
|
out_channels=self.model_layers[1],
|
||||||
|
activation="relu")
|
||||||
|
self.mlp_dense2 = DenseLayer(in_channels=self.model_layers[1],
|
||||||
|
out_channels=self.model_layers[2],
|
||||||
|
activation="relu")
|
||||||
|
|
||||||
|
# Logit dense layer
|
||||||
|
self.logits_dense = DenseLayer(in_channels=self.model_layers[1],
|
||||||
|
out_channels=1,
|
||||||
|
weight_init="normal",
|
||||||
|
activation=None)
|
||||||
|
|
||||||
|
# ops definition
|
||||||
|
self.mul = P.Mul()
|
||||||
|
self.squeeze = P.Squeeze(axis=1)
|
||||||
|
self.concat = P.Concat(axis=1)
|
||||||
|
|
||||||
|
def construct(self, user_input, item_input):
|
||||||
|
"""
|
||||||
|
NCF construct method.
|
||||||
|
"""
|
||||||
|
# GMF part
|
||||||
|
# embedding_layers
|
||||||
|
embedding_user = self.embedding_user(user_input) # input: (256, 1) output: (256, 1, 16 + 32)
|
||||||
|
embedding_item = self.embedding_item(item_input) # input: (256, 1) output: (256, 1, 16 + 32)
|
||||||
|
|
||||||
|
mf_user_latent = self.squeeze(embedding_user)[:, :self.num_factors] # input: (256, 1, 16 + 32) output: (256, 16)
|
||||||
|
mf_item_latent = self.squeeze(embedding_item)[:, :self.num_factors] # input: (256, 1, 16 + 32) output: (256, 16)
|
||||||
|
|
||||||
|
# MLP part
|
||||||
|
mlp_user_latent = self.squeeze(embedding_user)[:, self.mf_dim:] # input: (256, 1, 16 + 32) output: (256, 32)
|
||||||
|
mlp_item_latent = self.squeeze(embedding_item)[:, self.mf_dim:] # input: (256, 1, 16 + 32) output: (256, 32)
|
||||||
|
|
||||||
|
# Element-wise multiply
|
||||||
|
mf_vector = self.mul(mf_user_latent, mf_item_latent) # input: (256, 16), (256, 16) output: (256, 16)
|
||||||
|
|
||||||
|
# Concatenation of two latent features
|
||||||
|
mlp_vector = self.concat((mlp_user_latent, mlp_item_latent)) # input: (256, 32), (256, 32) output: (256, 64)
|
||||||
|
|
||||||
|
# MLP dense layers
|
||||||
|
mlp_vector = self.mlp_dense1(mlp_vector) # input: (256, 64) output: (256, 32)
|
||||||
|
mlp_vector = self.mlp_dense2(mlp_vector) # input: (256, 32) output: (256, 16)
|
||||||
|
|
||||||
|
# # Concatenate GMF and MLP parts
|
||||||
|
predict_vector = self.concat((mf_vector, mlp_vector)) # input: (256, 16), (256, 16) output: (256, 32)
|
||||||
|
|
||||||
|
# Final prediction layer
|
||||||
|
logits = self.logits_dense(predict_vector) # input: (256, 32) output: (256, 1)
|
||||||
|
|
||||||
|
# Print model topology.
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class NetWithLossClass(nn.Cell):
|
||||||
|
"""
|
||||||
|
NetWithLossClass definition
|
||||||
|
"""
|
||||||
|
def __init__(self, network):
|
||||||
|
super(NetWithLossClass, self).__init__(auto_prefix=False)
|
||||||
|
#self.loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||||
|
self.loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||||
|
self.network = network
|
||||||
|
self.reducesum = P.ReduceSum(keep_dims=False)
|
||||||
|
self.mul = P.Mul()
|
||||||
|
self.squeeze = P.Squeeze(axis=1)
|
||||||
|
self.zeroslike = P.ZerosLike()
|
||||||
|
self.concat = P.Concat(axis=1)
|
||||||
|
self.reciprocal = P.Reciprocal()
|
||||||
|
|
||||||
|
def construct(self, batch_users, batch_items, labels, valid_pt_mask):
|
||||||
|
predict = self.network(batch_users, batch_items)
|
||||||
|
predict = self.concat((self.zeroslike(predict), predict))
|
||||||
|
labels = self.squeeze(labels)
|
||||||
|
loss = self.loss(predict, labels)
|
||||||
|
loss = self.mul(loss, self.squeeze(valid_pt_mask))
|
||||||
|
mean_loss = self.mul(self.reducesum(loss), self.reciprocal(self.reducesum(valid_pt_mask)))
|
||||||
|
return mean_loss
|
||||||
|
|
||||||
|
|
||||||
|
class TrainStepWrap(nn.Cell):
|
||||||
|
"""
|
||||||
|
TrainStepWrap definition
|
||||||
|
"""
|
||||||
|
def __init__(self, network, sens=16384.0):
|
||||||
|
super(TrainStepWrap, self).__init__(auto_prefix=False)
|
||||||
|
self.network = network
|
||||||
|
self.network.set_train()
|
||||||
|
self.network.add_flags(defer_inline=True)
|
||||||
|
self.weights = ParameterTuple(network.trainable_params())
|
||||||
|
self.optimizer = nn.Adam(self.weights,
|
||||||
|
learning_rate=0.00382059,
|
||||||
|
beta1=0.9,
|
||||||
|
beta2=0.999,
|
||||||
|
eps=1e-8,
|
||||||
|
loss_scale=sens)
|
||||||
|
|
||||||
|
self.hyper_map = C.HyperMap()
|
||||||
|
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||||
|
self.sens = sens
|
||||||
|
|
||||||
|
self.reducer_flag = False
|
||||||
|
self.grad_reducer = None
|
||||||
|
parallel_mode = _get_parallel_mode()
|
||||||
|
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||||
|
self.reducer_flag = True
|
||||||
|
if self.reducer_flag:
|
||||||
|
mean = _get_gradients_mean()
|
||||||
|
degree = _get_device_num()
|
||||||
|
self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, mean, degree)
|
||||||
|
|
||||||
|
|
||||||
|
def construct(self, batch_users, batch_items, labels, valid_pt_mask):
|
||||||
|
weights = self.weights
|
||||||
|
loss = self.network(batch_users, batch_items, labels, valid_pt_mask)
|
||||||
|
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) #
|
||||||
|
grads = self.grad(self.network, weights)(batch_users, batch_items, labels, valid_pt_mask, sens)
|
||||||
|
if self.reducer_flag:
|
||||||
|
# apply grad reducer on grads
|
||||||
|
grads = self.grad_reducer(grads)
|
||||||
|
return F.depend(loss, self.optimizer(grads))
|
||||||
|
|
||||||
|
|
||||||
|
class PredictWithSigmoid(nn.Cell):
|
||||||
|
"""
|
||||||
|
Predict definition
|
||||||
|
"""
|
||||||
|
def __init__(self, network, k, num_eval_neg):
|
||||||
|
super(PredictWithSigmoid, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
self.topk = P.TopK(sorted=True)
|
||||||
|
self.squeeze = P.Squeeze()
|
||||||
|
self.k = k
|
||||||
|
self.num_eval_neg = num_eval_neg
|
||||||
|
self.gather = P.GatherV2()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.reducesum = P.ReduceSum(keep_dims=False)
|
||||||
|
self.notequal = P.NotEqual()
|
||||||
|
|
||||||
|
def construct(self, batch_users, batch_items, duplicated_masks):
|
||||||
|
predicts = self.network(batch_users, batch_items) # (bs, 1)
|
||||||
|
predicts = self.reshape(predicts, (-1, self.num_eval_neg + 1)) # (num_user, 100)
|
||||||
|
batch_items = self.reshape(batch_items, (-1, self.num_eval_neg + 1)) # (num_user, 100)
|
||||||
|
duplicated_masks = self.reshape(duplicated_masks, (-1, self.num_eval_neg + 1)) # (num_user, 100)
|
||||||
|
masks_sum = self.reducesum(duplicated_masks, 1)
|
||||||
|
metric_weights = self.notequal(masks_sum, self.num_eval_neg) # (num_user)
|
||||||
|
_, indices = self.topk(predicts, self.k) # (num_user, k)
|
||||||
|
|
||||||
|
return indices, batch_items, metric_weights
|
|
@ -0,0 +1,90 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Statistics utility functions of NCF."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def random_int32():
|
||||||
|
return np.random.randint(low=0, high=np.iinfo(np.int32).max, dtype=np.int32)
|
||||||
|
|
||||||
|
|
||||||
|
def permutation(args):
|
||||||
|
"""Fork safe permutation function.
|
||||||
|
|
||||||
|
This function can be called within a multiprocessing worker and give
|
||||||
|
appropriately random results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: A size two tuple that will unpacked into the size of the permutation
|
||||||
|
and the random seed. This form is used because starmap is not universally
|
||||||
|
available.
|
||||||
|
|
||||||
|
returns:
|
||||||
|
A NumPy array containing a random permutation.
|
||||||
|
"""
|
||||||
|
x, seed = args
|
||||||
|
|
||||||
|
# If seed is None NumPy will seed randomly.
|
||||||
|
state = np.random.RandomState(seed=seed) # pylint: disable=no-member
|
||||||
|
output = np.arange(x, dtype=np.int32)
|
||||||
|
state.shuffle(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def very_slightly_biased_randint(max_val_vector):
|
||||||
|
sample_dtype = np.uint64
|
||||||
|
out_dtype = max_val_vector.dtype
|
||||||
|
samples = np.random.randint(low=0, high=np.iinfo(sample_dtype).max,
|
||||||
|
size=max_val_vector.shape, dtype=sample_dtype)
|
||||||
|
return np.mod(samples, max_val_vector.astype(sample_dtype)).astype(out_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def mask_duplicates(x, axis=1): # type: (np.ndarray, int) -> np.ndarray
|
||||||
|
"""Identify duplicates from sampling with replacement.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: A 2D NumPy array of samples
|
||||||
|
axis: The axis along which to de-dupe.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A NumPy array with the same shape as x with one if an element appeared
|
||||||
|
previously along axis 1, else zero.
|
||||||
|
"""
|
||||||
|
if axis != 1:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
x_sort_ind = np.argsort(x, axis=1, kind="mergesort")
|
||||||
|
sorted_x = x[np.arange(x.shape[0])[:, np.newaxis], x_sort_ind]
|
||||||
|
|
||||||
|
# compute the indices needed to map values back to their original position.
|
||||||
|
inv_x_sort_ind = np.argsort(x_sort_ind, axis=1, kind="mergesort")
|
||||||
|
|
||||||
|
# Compute the difference of adjacent sorted elements.
|
||||||
|
diffs = sorted_x[:, :-1] - sorted_x[:, 1:]
|
||||||
|
|
||||||
|
# We are only interested in whether an element is zero. Therefore left padding
|
||||||
|
# with ones to restore the original shape is sufficient.
|
||||||
|
diffs = np.concatenate(
|
||||||
|
[np.ones((diffs.shape[0], 1), dtype=diffs.dtype), diffs], axis=1)
|
||||||
|
|
||||||
|
# Duplicate values will have a difference of zero. By definition the first
|
||||||
|
# element is never a duplicate.
|
||||||
|
return np.where(diffs[np.arange(x.shape[0])[:, np.newaxis],
|
||||||
|
inv_x_sort_ind], 0, 1)
|
|
@ -0,0 +1,105 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Training entry file"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from absl import logging
|
||||||
|
|
||||||
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||||
|
from mindspore import context, Model
|
||||||
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore.communication.management import get_rank, get_group_size, init
|
||||||
|
|
||||||
|
from src.dataset import create_dataset
|
||||||
|
from src.ncf import NCFModel, NetWithLossClass, TrainStepWrap
|
||||||
|
|
||||||
|
from config import cfg
|
||||||
|
|
||||||
|
logging.set_verbosity(logging.INFO)
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='NCF')
|
||||||
|
parser.add_argument("--data_path", type=str, default="./dataset/") # The location of the input data.
|
||||||
|
parser.add_argument("--dataset", type=str, default="ml-1m", choices=["ml-1m", "ml-20m"]) # Dataset to be trained and evaluated. ["ml-1m", "ml-20m"]
|
||||||
|
parser.add_argument("--train_epochs", type=int, default=14) # The number of epochs used to train.
|
||||||
|
parser.add_argument("--batch_size", type=int, default=256) # Batch size for training and evaluation
|
||||||
|
parser.add_argument("--num_neg", type=int, default=4) # The Number of negative instances to pair with a positive instance.
|
||||||
|
parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file.
|
||||||
|
parser.add_argument("--loss_file_name", type=str, default="loss.log") # Loss output file.
|
||||||
|
parser.add_argument("--checkpoint_path", type=str, default="./checkpoint/") # The location of the checkpoint file.
|
||||||
|
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||||
|
help='device where the code will be implemented. (Default: Ascend)')
|
||||||
|
parser.add_argument('--device_id', type=int, default=1, help='device id of GPU or Ascend. (Default: None)')
|
||||||
|
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
|
||||||
|
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
|
||||||
|
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
def test_train():
|
||||||
|
"""train entry method"""
|
||||||
|
if args.is_distributed:
|
||||||
|
if args.device_target == "Ascend":
|
||||||
|
init()
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
elif args.device_target == "GPU":
|
||||||
|
init()
|
||||||
|
|
||||||
|
args.rank = get_rank()
|
||||||
|
args.group_size = get_group_size()
|
||||||
|
device_num = args.group_size
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
|
parameter_broadcast=True, gradients_mean=True)
|
||||||
|
else:
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
|
||||||
|
if not os.path.exists(args.output_path):
|
||||||
|
os.makedirs(args.output_path)
|
||||||
|
|
||||||
|
layers = cfg.layers
|
||||||
|
num_factors = cfg.num_factors
|
||||||
|
epochs = args.train_epochs
|
||||||
|
|
||||||
|
ds_train, num_train_users, num_train_items = create_dataset(test_train=True, data_dir=args.data_path,
|
||||||
|
dataset=args.dataset, train_epochs=1,
|
||||||
|
batch_size=args.batch_size, num_neg=args.num_neg)
|
||||||
|
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
||||||
|
|
||||||
|
ncf_net = NCFModel(num_users=num_train_users,
|
||||||
|
num_items=num_train_items,
|
||||||
|
num_factors=num_factors,
|
||||||
|
model_layers=layers,
|
||||||
|
mf_regularization=0,
|
||||||
|
mlp_reg_layers=[0.0, 0.0, 0.0, 0.0],
|
||||||
|
mf_dim=16)
|
||||||
|
loss_net = NetWithLossClass(ncf_net)
|
||||||
|
train_net = TrainStepWrap(loss_net)
|
||||||
|
|
||||||
|
train_net.set_train()
|
||||||
|
|
||||||
|
model = Model(train_net)
|
||||||
|
callback = LossMonitor(per_print_times=ds_train.get_dataset_size())
|
||||||
|
ckpt_config = CheckpointConfig(save_checkpoint_steps=(4970845+args.batch_size-1)//(args.batch_size),
|
||||||
|
keep_checkpoint_max=100)
|
||||||
|
ckpoint_cb = ModelCheckpoint(prefix='NCF', directory=args.checkpoint_path, config=ckpt_config)
|
||||||
|
model.train(epochs,
|
||||||
|
ds_train,
|
||||||
|
callbacks=[TimeMonitor(ds_train.get_dataset_size()), callback, ckpoint_cb],
|
||||||
|
dataset_sink_mode=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_train()
|
Loading…
Reference in New Issue