!7599 new add ncf network.

Merge pull request !7599 from linqingke/ncf
This commit is contained in:
mindspore-ci-bot 2020-10-26 11:03:54 +08:00 committed by Gitee
commit 6406cd2a3c
18 changed files with 2219 additions and 0 deletions

View File

@ -0,0 +1,303 @@
# Contents
- [NCF Description](#NCF-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Features](#features)
- [Mixed Precision](#mixed-precision)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Distributed Training](#distributed-training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [How to use](#how-to-use)
- [Inference](#inference)
- [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
- [Transfer Learning](#transfer-learning)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [NCF Description](#contents)
NCF is a general framework for collaborative filtering of recommendations in which a neural network architecture is used to model user-item interactions. Unlike traditional models, NCF does not resort to Matrix Factorization (MF) with an inner product on latent features of users and items. It replaces the inner product with a multi-layer perceptron that can learn an arbitrary function from data.
[Paper](https://arxiv.org/abs/1708.05031): He X, Liao L, Zhang H, et al. Neural collaborative filtering[C]//Proceedings of the 26th international conference on world wide web. 2017: 173-182.
# [Model Architecture](#contents)
Two instantiations of NCF are Generalized Matrix Factorization (GMF) and Multi-Layer Perceptron (MLP). GMF applies a linear kernel to model the latent feature interactions, and and MLP uses a nonlinear kernel to learn the interaction function from data. NeuMF is a fused model of GMF and MLP to better model the complex user-item interactions, and unifies the strengths of linearity of MF and non-linearity of MLP for modeling the user-item latent structures. NeuMF allows GMF and MLP to learn separate embeddings, and combines the two models by concatenating their last hidden layer. [neumf_model.py](neumf_model.py) defines the architecture details.
# [Dataset](#contents)
The [MovieLens datasets](http://files.grouplens.org/datasets/movielens/) are used for model training and evaluation. Specifically, we use two datasets: **ml-1m** (short for MovieLens 1 million) and **ml-20m** (short for MovieLens 20 million).
### ml-1m
ml-1m dataset contains 1,000,209 anonymous ratings of approximately 3,706 movies made by 6,040 users who joined MovieLens in 2000. All ratings are contained in the file "ratings.dat" without header row, and are in the following format:
```
UserID::MovieID::Rating::Timestamp
```
- UserIDs range between 1 and 6040.
- MovieIDs range between 1 and 3952.
- Ratings are made on a 5-star scale (whole-star ratings only).
### ml-20m
ml-20m dataset contains 20,000,263 ratings of 26,744 movies by 138493 users. All ratings are contained in the file "ratings.csv". Each line of this file after the header row represents one rating of one movie by one user, and has the following format:
```
userId,movieId,rating,timestamp
```
- The lines within this file are ordered first by userId, then, within user, by movieId.
- Ratings are made on a 5-star scale, with half-star increments (0.5 stars - 5.0 stars).
In both datasets, the timestamp is represented in seconds since midnight Coordinated Universal Time (UTC) of January 1, 1970. Each user has at least 20 ratings.
# [Features](#contents)
## Mixed Precision
The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching reduce precision.
# [Environment Requirements](#contents)
- HardwareAscend/GPU
- Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
```python
#run data process
bash scripts/run_download_dataset.sh
# run training example
bash scripts/run_train.sh
# run distributed training example
sh scripts/run_train.sh rank_table.json
# run evaluation example
sh run_eval.sh
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```
├── ModelZoo_NCF_ME
├── README.md // descriptions about NCF
├── scripts
│ ├──run_train.sh // shell script for train
│ ├──run_distribute_train.sh // shell script for distribute train
│ ├──run_eval.sh // shell script for evaluation
│ ├──run_download_dataset.sh // shell script for dataget and process
│ ├──run_transfer_ckpt_to_air.sh // shell script for transfer model style
├── src
│ ├──dataset.py // creating dataset
│ ├──ncf.py // ncf architecture
│ ├──config.py // parameter configuration
│ ├──movielens.py // data download file
│ ├──callbacks.py // model loss and eval callback file
│ ├──constants.py // the constants of model
│ ├──export.py // export checkpoint files into geir/onnx
│ ├──metrics.py // the file for auc compute
│ ├──stat_utils.py // the file for data process functions
├── train.py // training script
├── eval.py // evaluation script
```
## [Script Parameters](#contents)
Parameters for both training and evaluation can be set in config.py.
- config for NCF, ml-1m dataset
```python
* `--data_path`: This should be set to the same directory given to the data_download data_dir argument.
* `--dataset`: The dataset name to be downloaded and preprocessed. By default, it is ml-1m.
* `--train_epochs`: Total train epochs.
* `--batch_size`: Training batch size.
* `--eval_batch_size`: Eval batch size.
* `--num_neg`: The Number of negative instances to pair with a positive instance.
* `--layers` The sizes of hidden layers for MLP.
* `--num_factors`The Embedding size of MF model.
* `--output_path`The location of the output file.
* `--eval_file_name` : Eval output file.
* `--loss_file_name` : Loss output file.
```
## [Training Process](#contents)
### Training
```python
bash scripts/run_train.sh
```
The python command above will run in the background, you can view the results through the file `train.log`. After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows:
```python
# grep "loss is " train.log
ds_train.size: 95
epoch: 1 step: 95, loss is 0.25074288
epoch: 2 step: 95, loss is 0.23324402
epoch: 3 step: 95, loss is 0.18286772
...
```
The model checkpoint will be saved in the current directory.
## [Evaluation Process](#contents)
### Evaluation
- evaluation on ml-1m dataset when running on Ascend
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "checkpoint/ncf-125_390.ckpt".
```python
sh scripts/run_eval.sh
```
The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
```python
# grep "accuracy: " eval.log
HR:0.6846,NDCG:0.410
```
# [Model Description](#contents)
## [Performance](#contents)
### Evaluation Performance
| Parameters | Ascend |
| -------------------------- | ------------------------------------------------------------ |
| Model Version | NCF |
| Resource | Ascend 910 CPU 2.60GHz56coresMemory314G |
| uploaded Date | 10/23/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | ml-1m |
| Training Parameters | epoch=25, steps=19418, batch_size = 256, lr=0.00382059 |
| Optimizer | GradOperation |
| Loss Function | Softmax Cross Entropy |
| outputs | probability |
| Speed | 1pc: 0.575 ms/step |
| Total time | 1pc: 5 mins |
### Inference Performance
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | NCF |
| Resource | Ascend 910 |
| Uploaded Date | 10/23/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | ml-1m |
| batch_size | 256 |
| outputs | probability |
| Accuracy | HR:0.6846,NDCG:0.410 |
## [How to use](#contents)
### Inference
If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/network_migration.html). Following the steps below, this is a simple example:
https://www.mindspore.cn/tutorial/zh-CN/master/use/multi_platform_inference.html
```
# Load unseen dataset for inference
dataset = dataset.create_dataset(cfg.data_path, 1, False)
# Define model
net = GoogleNet(num_classes=cfg.num_classes)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01,
cfg.momentum, weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean',
is_grad=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
# Load pre-trained model
param_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# Make predictions on the unseen dataset
acc = model.eval(dataset)
print("accuracy: ", acc)
```
### Continue Training on the Pretrained Model
```
# Load dataset
dataset = create_dataset(cfg.data_path, cfg.epoch_size)
batch_num = dataset.get_dataset_size()
# Define model
net = GoogleNet(num_classes=cfg.num_classes)
# Continue training if set pre_trained to be True
if cfg.pre_trained:
param_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(net, param_dict)
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size,
steps_per_epoch=batch_num)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
# Set callbacks
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5,
keep_checkpoint_max=cfg.keep_checkpoint_max)
time_cb = TimeMonitor(data_size=batch_num)
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./",
config=config_ck)
loss_cb = LossMonitor()
# Start training
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
print("train success")
```
# [Description of Random Situation](#contents)
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,91 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Using for eval the model checkpoint"""
import os
import argparse
from absl import logging
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore import context, Model
import src.constants as rconst
from src.dataset import create_dataset
from src.metrics import NCFMetric
from src.ncf import NCFModel, NetWithLossClass, TrainStepWrap, PredictWithSigmoid
from src.config import cfg
logging.set_verbosity(logging.INFO)
parser = argparse.ArgumentParser(description='NCF')
parser.add_argument("--data_path", type=str, default="./dataset/") # The location of the input data.
parser.add_argument("--dataset", type=str, default="ml-1m", choices=["ml-1m", "ml-20m"]) # Dataset to be trained and evaluated. ["ml-1m", "ml-20m"]
parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file.
parser.add_argument("--eval_file_name", type=str, default="eval.log") # Eval output file.
parser.add_argument("--checkpoint_file_path", type=str, default="./checkpoint/NCF-14_19418.ckpt") # The location of the checkpoint file.
args, _ = parser.parse_known_args()
def test_eval():
"""eval method"""
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
layers = cfg.layers
num_factors = cfg.num_factors
topk = rconst.TOP_K
num_eval_neg = rconst.NUM_EVAL_NEGATIVES
ds_eval, num_eval_users, num_eval_items = create_dataset(test_train=False, data_dir=args.data_path,
dataset=args.dataset, train_epochs=0,
eval_batch_size=cfg.eval_batch_size)
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
ncf_net = NCFModel(num_users=num_eval_users,
num_items=num_eval_items,
num_factors=num_factors,
model_layers=layers,
mf_regularization=0,
mlp_reg_layers=[0.0, 0.0, 0.0, 0.0],
mf_dim=16)
param_dict = load_checkpoint(args.checkpoint_file_path)
load_param_into_net(ncf_net, param_dict)
loss_net = NetWithLossClass(ncf_net)
train_net = TrainStepWrap(loss_net)
# train_net.set_train()
eval_net = PredictWithSigmoid(ncf_net, topk, num_eval_neg)
ncf_metric = NCFMetric()
model = Model(train_net, eval_network=eval_net, metrics={"ncf": ncf_metric})
ncf_metric.clear()
out = model.eval(ds_eval)
eval_file_path = os.path.join(args.output_path, args.eval_file_name)
eval_file = open(eval_file_path, "a+")
eval_file.write("EvalCallBack: HR = {}, NDCG = {}\n".format(out['ncf'][0], out['ncf'][1]))
eval_file.close()
print("EvalCallBack: HR = {}, NDCG = {}".format(out['ncf'][0], out['ncf'][1]))
if __name__ == '__main__':
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE,
device_target="Davinci",
save_graphs=True,
device_id=devid)
test_eval()

View File

@ -0,0 +1,47 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "Please run the script as: "
echo "sh scripts/run_distribute_train.sh DEVICE_NUM DATASET_PATH RANK_TABLE_FILE"
echo "for example: sh scripts/run_distribute_train.sh 8 /dataset_path /rank_table_8p.json"
current_exec_path=$(pwd)
echo ${current_exec_path}
export RANK_SIZE=$1
data_path=$2
export RANK_TABLE_FILE=$3
for((i=0;i<=RANK_SIZE;i++));
do
rm ${current_exec_path}/device_$i/ -rf
mkdir ${current_exec_path}/device_$i
cd ${current_exec_path}/device_$i || exit
export RANK_ID=$i
export DEVICE_ID=$i
python -u ${current_exec_path}/train.py \
--data_path $data_path \
--dataset 'ml-1m' \
--train_epochs 50 \
--output_path './output/' \
--eval_file_name 'eval.log' \
--loss_file_name 'loss.log' \
--checkpoint_path './checkpoint/' \
--device_target="Ascend" \
--device_id=$i \
--is_distributed=1 \
>log_$i.log 2>&1 &
done

View File

@ -0,0 +1,21 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "Please run the script as: "
echo "sh scripts/run_download_dataset.sh DATASET_PATH"
echo "for example: sh scripts/run_download_dataset.sh /dataset_path"
data_path=$1
python ./src/movielens.py --data_path $data_path --dataset 'ml-1m'

View File

@ -0,0 +1,22 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "Please run the script as: "
echo "sh scripts/run_eval.sh DATASET_PATH CKPT_FILE"
echo "for example: sh scripts/run_eval.sh /dataset_path /ncf.ckpt"
data_path=$1
ckpt_file=$2
python ./eval.py --data_path $data_path --dataset 'ml-1m' --eval_batch_size 160000 --output_path './output/' --eval_file_name 'eval.log' --checkpoint_file_path $ckpt_file

View File

@ -0,0 +1,22 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "Please run the script as: "
echo "sh scripts/run_train.sh DATASET_PATH CKPT_FILE"
echo "for example: sh scripts/run_train.sh /dataset_path /ncf.ckpt"
data_path=$1
ckpt_file=$2
python ./train.py --data_path $data_path --dataset 'ml-1m' --train_epochs 20 --batch_size 256 --output_path './output/' --loss_file_name 'loss.log' --checkpoint_path $ckpt_file

View File

@ -0,0 +1,22 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "Please run the script as: "
echo "sh scripts/run_transfer_ckpt_to_air.sh DATASET_PATH CKPT_FILE"
echo "for example: sh scripts/run_transfer_ckpt_to_air.sh /dataset_path /ncf.ckpt"
data_path=$1
ckpt_file=$2
python ./src/export.py --data_path $data_path --dataset 'ml-1m' --eval_batch_size 160000 --output_path './output/' --eval_file_name 'eval.log' --checkpoint_file_path $ckpt_file

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,39 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Callbacks file"""
from mindspore.train.callback import Callback
class EvalCallBack(Callback):
"""
Monitor the loss in evaluate.
"""
def __init__(self, model, eval_dataset, metric, eval_file_path="./eval.log"):
super(EvalCallBack, self).__init__()
self.model = model
self.eval_dataset = eval_dataset
self.metric = metric
self.metric.clear()
self.eval_file_path = eval_file_path
self.run_context = None
def epoch_end(self, run_context):
self.run_context = run_context
self.metric.clear()
out = self.model.eval(self.eval_dataset)
eval_file = open(self.eval_file_path, "a+")
eval_file.write("EvalCallBack: HR = {}, NDCG = {}\n".format(out['ncf'][0], out['ncf'][1]))
eval_file.close()
print("EvalCallBack: HR = {}, NDCG = {}".format(out['ncf'][0], out['ncf'][1]))

View File

@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in main.py
"""
from easydict import EasyDict as edict
cfg = edict({
'dataset': 'ml-1m', # Dataset to be trained and evaluated, choice: ["ml-1m", "ml-20m"]
'data_dir': '../dataset', # The location of the input data.
'train_epochs': 14, # The number of epochs used to train.
'batch_size': 256, # Batch size for training and evaluation
'eval_batch_size': 160000, # The batch size used for evaluation.
'num_neg': 4, # The Number of negative instances to pair with a positive instance.
'layers': [64, 32, 16], # The sizes of hidden layers for MLP
'num_factors': 16 # The Embedding size of MF model.
})

View File

@ -0,0 +1,78 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Central location for NCF specific values."""
import sys
import os
import numpy as np
import src.movielens as movielens
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# ==============================================================================
# == Main Thread Data Processing ===============================================
# ==============================================================================
# Keys for data shards
TRAIN_USER_KEY = "train_{}".format(movielens.USER_COLUMN)
TRAIN_ITEM_KEY = "train_{}".format(movielens.ITEM_COLUMN)
TRAIN_LABEL_KEY = "train_labels"
MASK_START_INDEX = "mask_start_index"
VALID_POINT_MASK = "valid_point_mask"
EVAL_USER_KEY = "eval_{}".format(movielens.USER_COLUMN)
EVAL_ITEM_KEY = "eval_{}".format(movielens.ITEM_COLUMN)
USER_MAP = "user_map"
ITEM_MAP = "item_map"
USER_DTYPE = np.int32
ITEM_DTYPE = np.int32
# In both datasets, each user has at least 20 ratings.
MIN_NUM_RATINGS = 20
# The number of negative examples attached with a positive example
# when performing evaluation.
NUM_EVAL_NEGATIVES = 99
# keys for evaluation metrics
TOP_K = 10 # Top-k list for evaluation
HR_KEY = "HR"
NDCG_KEY = "NDCG"
DUPLICATE_MASK = "duplicate_mask"
# Metric names
HR_METRIC_NAME = "HR_METRIC"
NDCG_METRIC_NAME = "NDCG_METRIC"
# Trying to load a cache created in py2 when running in py3 will cause an
# error due to differences in unicode handling.
RAW_CACHE_FILE = "raw_data_cache_py{}.pickle".format(sys.version_info[0])
CACHE_INVALIDATION_SEC = 3600 * 24
# ==============================================================================
# == Data Generation ===========================================================
# ==============================================================================
CYCLES_TO_BUFFER = 3 # The number of train cycles worth of data to "run ahead"
# of the main training loop.
# Number of batches to run per epoch when using synthetic data. At high batch
# sizes, we run for more batches than with real data, which is good since
# running more batches reduces noise when measuring the average batches/second.
SYNTHETIC_BATCHES_PER_EPOCH = 2000
# Only used when StreamingFilesDataset is used.
NUM_FILE_SHARDS = 16
TRAIN_FOLDER_TEMPLATE = "training_cycle_{}"
EVAL_FOLDER = "eval_data"
SHARD_TEMPLATE = "shard_{}.tfrecords"

View File

@ -0,0 +1,592 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset loading, creation and processing"""
import logging
import math
import os
import time
import timeit
import pickle
import numpy as np
import pandas as pd
from mindspore.dataset.engine import GeneratorDataset
import src.constants as rconst
import src.movielens as movielens
import src.stat_utils as stat_utils
DATASET_TO_NUM_USERS_AND_ITEMS = {
"ml-1m": (6040, 3706),
"ml-20m": (138493, 26744)
}
_EXPECTED_CACHE_KEYS = (
rconst.TRAIN_USER_KEY, rconst.TRAIN_ITEM_KEY, rconst.EVAL_USER_KEY,
rconst.EVAL_ITEM_KEY, rconst.USER_MAP, rconst.ITEM_MAP)
def load_data(data_dir, dataset):
"""
Load data in .csv format and output structured data.
This function reads in the raw CSV of positive items, and performs three
preprocessing transformations:
1) Filter out all users who have not rated at least a certain number
of items. (Typically 20 items)
2) Zero index the users and items such that the largest user_id is
`num_users - 1` and the largest item_id is `num_items - 1`
3) Sort the dataframe by user_id, with timestamp as a secondary sort key.
This allows the dataframe to be sliced by user in-place, and for the last
item to be selected simply by calling the `-1` index of a user's slice.
While all of these transformations are performed by Pandas (and are therefore
single-threaded), they only take ~2 minutes, and the overhead to apply a
MapReduce pattern to parallel process the dataset adds significant complexity
for no computational gain. For a larger dataset parallelizing this
preprocessing could yield speedups. (Also, this preprocessing step is only
performed once for an entire run.
"""
logging.info("Beginning loading data...")
raw_rating_path = os.path.join(data_dir, dataset, movielens.RATINGS_FILE)
cache_path = os.path.join(data_dir, dataset, rconst.RAW_CACHE_FILE)
valid_cache = os.path.exists(cache_path)
if valid_cache:
with open(cache_path, 'rb') as f:
cached_data = pickle.load(f)
for key in _EXPECTED_CACHE_KEYS:
if key not in cached_data:
valid_cache = False
if not valid_cache:
logging.info("Removing stale raw data cache file.")
os.remove(cache_path)
if valid_cache:
data = cached_data
else:
# process data and save to .csv
with open(raw_rating_path) as f:
df = pd.read_csv(f)
# Get the info of users who have more than 20 ratings on items
grouped = df.groupby(movielens.USER_COLUMN)
df = grouped.filter(lambda x: len(x) >= rconst.MIN_NUM_RATINGS)
original_users = df[movielens.USER_COLUMN].unique()
original_items = df[movielens.ITEM_COLUMN].unique()
# Map the ids of user and item to 0 based index for following processing
logging.info("Generating user_map and item_map...")
user_map = {user: index for index, user in enumerate(original_users)}
item_map = {item: index for index, item in enumerate(original_items)}
df[movielens.USER_COLUMN] = df[movielens.USER_COLUMN].apply(
lambda user: user_map[user])
df[movielens.ITEM_COLUMN] = df[movielens.ITEM_COLUMN].apply(
lambda item: item_map[item])
num_users = len(original_users)
num_items = len(original_items)
assert num_users <= np.iinfo(rconst.USER_DTYPE).max
assert num_items <= np.iinfo(rconst.ITEM_DTYPE).max
assert df[movielens.USER_COLUMN].max() == num_users - 1
assert df[movielens.ITEM_COLUMN].max() == num_items - 1
# This sort is used to shard the dataframe by user, and later to select
# the last item for a user to be used in validation.
logging.info("Sorting by user, timestamp...")
# This sort is equivalent to
# df.sort_values([movielens.USER_COLUMN, movielens.TIMESTAMP_COLUMN],
# inplace=True)
# except that the order of items with the same user and timestamp are
# sometimes different. For some reason, this sort results in a better
# hit-rate during evaluation, matching the performance of the MLPerf
# reference implementation.
df.sort_values(by=movielens.TIMESTAMP_COLUMN, inplace=True)
df.sort_values([movielens.USER_COLUMN, movielens.TIMESTAMP_COLUMN],
inplace=True, kind="mergesort")
# The dataframe does not reconstruct indices in the sort or filter steps.
df = df.reset_index()
grouped = df.groupby(movielens.USER_COLUMN, group_keys=False)
eval_df, train_df = grouped.tail(1), grouped.apply(lambda x: x.iloc[:-1])
data = {
rconst.TRAIN_USER_KEY:
train_df[movielens.USER_COLUMN].values.astype(rconst.USER_DTYPE),
rconst.TRAIN_ITEM_KEY:
train_df[movielens.ITEM_COLUMN].values.astype(rconst.ITEM_DTYPE),
rconst.EVAL_USER_KEY:
eval_df[movielens.USER_COLUMN].values.astype(rconst.USER_DTYPE),
rconst.EVAL_ITEM_KEY:
eval_df[movielens.ITEM_COLUMN].values.astype(rconst.ITEM_DTYPE),
rconst.USER_MAP: user_map,
rconst.ITEM_MAP: item_map,
"create_time": time.time(),
}
logging.info("Writing raw data cache.")
with open(cache_path, "wb") as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
num_users, num_items = DATASET_TO_NUM_USERS_AND_ITEMS[dataset]
if num_users != len(data[rconst.USER_MAP]):
raise ValueError("Expected to find {} users, but found {}".format(
num_users, len(data[rconst.USER_MAP])))
if num_items != len(data[rconst.ITEM_MAP]):
raise ValueError("Expected to find {} items, but found {}".format(
num_items, len(data[rconst.ITEM_MAP])))
return data, num_users, num_items
def construct_lookup_variables(train_pos_users, train_pos_items, num_users):
"""Lookup variables"""
index_bounds = None
sorted_train_pos_items = None
def index_segment(user):
lower, upper = index_bounds[user:user + 2]
items = sorted_train_pos_items[lower:upper]
negatives_since_last_positive = np.concatenate(
[items[0][np.newaxis], items[1:] - items[:-1] - 1])
return np.cumsum(negatives_since_last_positive)
start_time = timeit.default_timer()
inner_bounds = np.argwhere(train_pos_users[1:] -
train_pos_users[:-1])[:, 0] + 1
(upper_bound,) = train_pos_users.shape
index_bounds = np.array([0] + inner_bounds.tolist() + [upper_bound])
# Later logic will assume that the users are in sequential ascending order.
assert np.array_equal(train_pos_users[index_bounds[:-1]], np.arange(num_users))
sorted_train_pos_items = train_pos_items.copy()
for i in range(num_users):
lower, upper = index_bounds[i:i + 2]
sorted_train_pos_items[lower:upper].sort()
total_negatives = np.concatenate([
index_segment(i) for i in range(num_users)])
logging.info("Negative total vector built. Time: {:.1f} seconds".format(
timeit.default_timer() - start_time))
return total_negatives, index_bounds, sorted_train_pos_items
class NCFDataset:
"""
A dataset for NCF network.
"""
def __init__(self,
pos_users,
pos_items,
num_users,
num_items,
batch_size,
total_negatives,
index_bounds,
sorted_train_pos_items,
is_training=True):
self._pos_users = pos_users
self._pos_items = pos_items
self._num_users = num_users
self._num_items = num_items
self._batch_size = batch_size
self._total_negatives = total_negatives
self._index_bounds = index_bounds
self._sorted_train_pos_items = sorted_train_pos_items
self._is_training = is_training
if self._is_training:
self._train_pos_count = self._pos_users.shape[0]
else:
self._eval_users_per_batch = int(
batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
def lookup_negative_items(self, negative_users):
"""Lookup negative items"""
output = np.zeros(shape=negative_users.shape, dtype=rconst.ITEM_DTYPE) - 1
left_index = self._index_bounds[negative_users]
right_index = self._index_bounds[negative_users + 1] - 1
num_positives = right_index - left_index + 1
num_negatives = self._num_items - num_positives
neg_item_choice = stat_utils.very_slightly_biased_randint(num_negatives)
# Shortcuts:
# For points where the negative is greater than or equal to the tally before
# the last positive point there is no need to bisect. Instead the item id
# corresponding to the negative item choice is simply:
# last_postive_index + 1 + (neg_choice - last_negative_tally)
# Similarly, if the selection is less than the tally at the first positive
# then the item_id is simply the selection.
#
# Because MovieLens organizes popular movies into low integers (which is
# preserved through the preprocessing), the first shortcut is very
# efficient, allowing ~60% of samples to bypass the bisection. For the same
# reason, the second shortcut is rarely triggered (<0.02%) and is therefore
# not worth implementing.
use_shortcut = neg_item_choice >= self._total_negatives[right_index]
output[use_shortcut] = (
self._sorted_train_pos_items[right_index] + 1 +
(neg_item_choice - self._total_negatives[right_index])
)[use_shortcut]
if np.all(use_shortcut):
# The bisection code is ill-posed when there are no elements.
return output
not_use_shortcut = np.logical_not(use_shortcut)
left_index = left_index[not_use_shortcut]
right_index = right_index[not_use_shortcut]
neg_item_choice = neg_item_choice[not_use_shortcut]
num_loops = np.max(
np.ceil(np.log2(num_positives[not_use_shortcut])).astype(np.int32))
for _ in range(num_loops):
mid_index = (left_index + right_index) // 2
right_criteria = self._total_negatives[mid_index] > neg_item_choice
left_criteria = np.logical_not(right_criteria)
right_index[right_criteria] = mid_index[right_criteria]
left_index[left_criteria] = mid_index[left_criteria]
# Expected state after bisection pass:
# The right index is the smallest index whose tally is greater than the
# negative item choice index.
assert np.all((right_index - left_index) <= 1)
output[not_use_shortcut] = (
self._sorted_train_pos_items[right_index] - (self._total_negatives[right_index] - neg_item_choice)
)
assert np.all(output >= 0)
return output
def _get_train_item(self, index):
"""Get train item"""
(mask_start_index,) = index.shape
index_mod = np.mod(index, self._train_pos_count)
# get batch of users
users = self._pos_users[index_mod]
# get batch of items
negative_indices = np.greater_equal(index, self._train_pos_count)
negative_users = users[negative_indices]
negative_items = self.lookup_negative_items(negative_users=negative_users)
items = self._pos_items[index_mod]
items[negative_indices] = negative_items
# get batch of labels
labels = np.logical_not(negative_indices)
# pad last partial batch
pad_length = self._batch_size - index.shape[0]
if pad_length:
user_pad = np.arange(pad_length, dtype=users.dtype) % self._num_users
item_pad = np.arange(pad_length, dtype=items.dtype) % self._num_items
label_pad = np.zeros(shape=(pad_length,), dtype=labels.dtype)
users = np.concatenate([users, user_pad])
items = np.concatenate([items, item_pad])
labels = np.concatenate([labels, label_pad])
users = np.reshape(users, (self._batch_size, 1)) # (_batch_size, 1), int32
items = np.reshape(items, (self._batch_size, 1)) # (_batch_size, 1), int32
mask_start_index = np.array(mask_start_index, dtype=np.int32) # (_batch_size, 1), int32
valid_pt_mask = np.expand_dims(
np.less(np.arange(self._batch_size), mask_start_index), -1).astype(np.float32) # (_batch_size, 1), bool
labels = np.reshape(labels, (self._batch_size, 1)).astype(np.int32) # (_batch_size, 1), bool
return users, items, labels, valid_pt_mask
@staticmethod
def _assemble_eval_batch(users, positive_items, negative_items,
users_per_batch):
"""Construct duplicate_mask and structure data accordingly.
The positive items should be last so that they lose ties. However, they
should not be masked out if the true eval positive happens to be
selected as a negative. So instead, the positive is placed in the first
position, and then switched with the last element after the duplicate
mask has been computed.
Args:
users: An array of users in a batch. (should be identical along axis 1)
positive_items: An array (batch_size x 1) of positive item indices.
negative_items: An array of negative item indices.
users_per_batch: How many users should be in the batch. This is passed
as an argument so that ncf_test.py can use this method.
Returns:
User, item, and duplicate_mask arrays.
"""
items = np.concatenate([positive_items, negative_items], axis=1)
# We pad the users and items here so that the duplicate mask calculation
# will include padding. The metric function relies on all padded elements
# except the positive being marked as duplicate to mask out padded points.
if users.shape[0] < users_per_batch:
pad_rows = users_per_batch - users.shape[0]
padding = np.zeros(shape=(pad_rows, users.shape[1]), dtype=np.int32)
users = np.concatenate([users, padding.astype(users.dtype)], axis=0)
items = np.concatenate([items, padding.astype(items.dtype)], axis=0)
duplicate_mask = stat_utils.mask_duplicates(items, axis=1).astype(np.float32)
items[:, (0, -1)] = items[:, (-1, 0)]
duplicate_mask[:, (0, -1)] = duplicate_mask[:, (-1, 0)]
assert users.shape == items.shape == duplicate_mask.shape
return users, items, duplicate_mask
def _get_eval_item(self, index):
"""Get eval item"""
low_index, high_index = index
users = np.repeat(self._pos_users[low_index:high_index, np.newaxis],
1 + rconst.NUM_EVAL_NEGATIVES, axis=1)
positive_items = self._pos_items[low_index:high_index, np.newaxis]
negative_items = (self.lookup_negative_items(negative_users=users[:, :-1])
.reshape(-1, rconst.NUM_EVAL_NEGATIVES))
users, items, duplicate_mask = self._assemble_eval_batch(
users, positive_items, negative_items, self._eval_users_per_batch)
users = np.reshape(users.flatten(), (self._batch_size, 1)) # (self._batch_size, 1), int32
items = np.reshape(items.flatten(), (self._batch_size, 1)) # (self._batch_size, 1), int32
duplicate_mask = np.reshape(duplicate_mask.flatten(), (self._batch_size, 1)) # (self._batch_size, 1), bool
return users, items, duplicate_mask
def __getitem__(self, index):
"""
Get a batch of samples.
"""
if self._is_training:
return self._get_train_item(index)
return self._get_eval_item(index)
class RandomSampler:
"""
A random sampler for dataset.
"""
def __init__(self, pos_count, num_train_negatives, batch_size):
self.pos_count = pos_count
self._num_samples = (1 + num_train_negatives) * self.pos_count
self._batch_size = batch_size
self._num_batches = math.ceil(self._num_samples / self._batch_size)
def __iter__(self):
"""
Return indices of all batches within an epoch.
"""
indices = stat_utils.permutation((self._num_samples, stat_utils.random_int32()))
batch_indices = [indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._num_batches)]
return iter(batch_indices)
def __len__(self):
"""
Return length of the sampler, i.e., the number of batches for an epoch.
"""
return self._num_batches
class DistributedSamplerOfTrain:
"""
A distributed sampler for dataset.
"""
def __init__(self, pos_count, num_train_negatives, batch_size, rank_id, rank_size):
"""
Distributed sampler of training dataset.
"""
self._num_samples = (1 + num_train_negatives) * pos_count
self._rank_id = rank_id
self._rank_size = rank_size
self._batch_size = batch_size
self._batchs_per_rank = int(math.ceil(self._num_samples / self._batch_size / rank_size))
self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._batch_size))
self._total_num_samples = self._samples_per_rank * self._rank_size
def __iter__(self):
"""
Returns the data after each sampling.
"""
indices = stat_utils.permutation((self._num_samples, stat_utils.random_int32()))
indices = indices.tolist()
indices.extend(indices[:self._total_num_samples-len(indices)])
indices = indices[self._rank_id:self._total_num_samples:self._rank_size]
batch_indices = [indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._batchs_per_rank)]
return iter(np.array(batch_indices))
def __len__(self):
"""
Returns the length after each sampling.
"""
return self._batchs_per_rank
class SequenceSampler:
"""
A sequence sampler for dataset.
"""
def __init__(self, eval_batch_size, num_users):
self._eval_users_per_batch = int(
eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
self._eval_elements_in_epoch = num_users * (1 + rconst.NUM_EVAL_NEGATIVES)
self._eval_batches_per_epoch = self.count_batches(
self._eval_elements_in_epoch, eval_batch_size)
def __iter__(self):
indices = [(x * self._eval_users_per_batch, (x + 1) * self._eval_users_per_batch)
for x in range(self._eval_batches_per_epoch)]
return iter(indices)
@staticmethod
def count_batches(example_count, batch_size, batches_per_step=1):
"""Determine the number of batches, rounding up to fill all devices."""
x = (example_count + batch_size - 1) // batch_size
return (x + batches_per_step - 1) // batches_per_step * batches_per_step
def __len__(self):
"""
Return the length of the sampler, i,e, the number of batches in an epoch.
"""
return self._eval_batches_per_epoch
class DistributedSamplerOfEval:
"""
A distributed sampler for eval dataset.
"""
def __init__(self, eval_batch_size, num_users, rank_id, rank_size):
self._eval_users_per_batch = int(
eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
self._eval_elements_in_epoch = num_users * (1 + rconst.NUM_EVAL_NEGATIVES)
self._eval_batches_per_epoch = self.count_batches(
self._eval_elements_in_epoch, eval_batch_size)
self._rank_id = rank_id
self._rank_size = rank_size
self._eval_batch_size = eval_batch_size
self._batchs_per_rank = int(math.ceil(self._eval_batches_per_epoch / rank_size))
#self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._eval_batch_size))
#self._total_num_samples = self._samples_per_rank * self._rank_size
def __iter__(self):
indices = [(x * self._eval_users_per_batch, (x + self._rank_id + 1) * self._eval_users_per_batch)
for x in range(self._batchs_per_rank)]
return iter(np.array(indices))
@staticmethod
def count_batches(example_count, batch_size, batches_per_step=1):
"""Determine the number of batches, rounding up to fill all devices."""
x = (example_count + batch_size - 1) // batch_size
return (x + batches_per_step - 1) // batches_per_step * batches_per_step
def __len__(self):
return self._batchs_per_rank
def parse_eval_batch_size(eval_batch_size):
"""
Parse eval batch size.
"""
if eval_batch_size % (1 + rconst.NUM_EVAL_NEGATIVES):
raise ValueError("Eval batch size {} is not divisible by {}".format(
eval_batch_size, 1 + rconst.NUM_EVAL_NEGATIVES))
return eval_batch_size
def create_dataset(test_train=True, data_dir='./dataset/', dataset='ml-1m', train_epochs=14, batch_size=256,
eval_batch_size=160000, num_neg=4, rank_id=None, rank_size=None):
"""
Create NCF dataset.
"""
data, num_users, num_items = load_data(data_dir, dataset)
train_pos_users = data[rconst.TRAIN_USER_KEY]
train_pos_items = data[rconst.TRAIN_ITEM_KEY]
eval_pos_users = data[rconst.EVAL_USER_KEY]
eval_pos_items = data[rconst.EVAL_ITEM_KEY]
total_negatives, index_bounds, sorted_train_pos_items = \
construct_lookup_variables(train_pos_users, train_pos_items, num_users)
if test_train:
print(train_pos_users, train_pos_items, num_users, num_items, batch_size, total_negatives, index_bounds,
sorted_train_pos_items)
dataset = NCFDataset(train_pos_users, train_pos_items, num_users, num_items, batch_size, total_negatives,
index_bounds, sorted_train_pos_items)
sampler = RandomSampler(train_pos_users.shape[0], num_neg, batch_size)
if rank_id is not None and rank_size is not None:
sampler = DistributedSamplerOfTrain(train_pos_users.shape[0], num_neg, batch_size, rank_id, rank_size)
if dataset == 'ml-20m':
ds = GeneratorDataset(dataset,
column_names=[movielens.USER_COLUMN,
movielens.ITEM_COLUMN,
"labels",
rconst.VALID_POINT_MASK],
sampler=sampler, num_parallel_workers=32, python_multiprocessing=False)
else:
ds = GeneratorDataset(dataset,
column_names=[movielens.USER_COLUMN,
movielens.ITEM_COLUMN,
"labels",
rconst.VALID_POINT_MASK],
sampler=sampler)
else:
eval_batch_size = parse_eval_batch_size(eval_batch_size=eval_batch_size)
dataset = NCFDataset(eval_pos_users, eval_pos_items, num_users, num_items,
eval_batch_size, total_negatives, index_bounds,
sorted_train_pos_items, is_training=False)
sampler = SequenceSampler(eval_batch_size, num_users)
ds = GeneratorDataset(dataset,
column_names=[movielens.USER_COLUMN,
movielens.ITEM_COLUMN,
rconst.DUPLICATE_MASK],
sampler=sampler)
repeat_count = train_epochs if test_train else train_epochs + 1
ds = ds.repeat(repeat_count)
return ds, num_users, num_items

View File

@ -0,0 +1,112 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Export NCF air file."""
import os
import argparse
from absl import logging
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import Tensor, context, Model
import constants as rconst
from dataset import create_dataset
from metrics import NCFMetric
from ncf import NCFModel, NetWithLossClass, TrainStepWrap, PredictWithSigmoid
logging.set_verbosity(logging.INFO)
def argparse_init():
"""Argparse init method"""
parser = argparse.ArgumentParser(description='NCF')
parser.add_argument("--data_path", type=str, default="./dataset/") # The location of the input data.
parser.add_argument("--dataset", type=str, default="ml-1m", choices=["ml-1m", "ml-20m"]) # Dataset to be trained and evaluated. ["ml-1m", "ml-20m"]
parser.add_argument("--eval_batch_size", type=int, default=160000) # The batch size used for evaluation.
parser.add_argument("--layers", type=int, default=[64, 32, 16]) # The sizes of hidden layers for MLP
parser.add_argument("--num_factors", type=int, default=16) # The Embedding size of MF model.
parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file.
parser.add_argument("--eval_file_name", type=str, default="eval.log") # Eval output file.
parser.add_argument("--checkpoint_file_path", type=str, default="./checkpoint/NCF.ckpt") # The location of the checkpoint file.
return parser
def export_air_file():
""""Export file for eval"""
parser = argparse_init()
args, _ = parser.parse_known_args()
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
layers = args.layers
num_factors = args.num_factors
topk = rconst.TOP_K
num_eval_neg = rconst.NUM_EVAL_NEGATIVES
ds_eval, num_eval_users, num_eval_items = create_dataset(test_train=False, data_dir=args.data_path,
dataset=args.dataset, train_epochs=0,
eval_batch_size=args.eval_batch_size)
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
ncf_net = NCFModel(num_users=num_eval_users,
num_items=num_eval_items,
num_factors=num_factors,
model_layers=layers,
mf_regularization=0,
mlp_reg_layers=[0.0, 0.0, 0.0, 0.0],
mf_dim=16)
param_dict = load_checkpoint(args.checkpoint_file_path)
load_param_into_net(ncf_net, param_dict)
loss_net = NetWithLossClass(ncf_net)
train_net = TrainStepWrap(loss_net)
train_net.set_train()
eval_net = PredictWithSigmoid(ncf_net, topk, num_eval_neg)
ncf_metric = NCFMetric()
model = Model(train_net, eval_network=eval_net, metrics={"ncf": ncf_metric})
ncf_metric.clear()
out = model.eval(ds_eval)
eval_file_path = os.path.join(args.output_path, args.eval_file_name)
eval_file = open(eval_file_path, "a+")
eval_file.write("EvalCallBack: HR = {}, NDCG = {}\n".format(out['ncf'][0], out['ncf'][1]))
eval_file.close()
print("EvalCallBack: HR = {}, NDCG = {}".format(out['ncf'][0], out['ncf'][1]))
param_dict = load_checkpoint(args.checkpoint_file_path)
# load the parameter into net
load_param_into_net(eval_net, param_dict)
input_tensor_list = []
for data in ds_eval:
for j in data:
input_tensor_list.append(Tensor(j))
print(len(a))
break
print(input_tensor_list)
export(eval_net, *input_tensor_list, file_name='NCF.air', file_format='AIR')
if __name__ == '__main__':
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE,
device_target="Davinci",
save_graphs=True,
device_id=devid)
export_air_file()

View File

@ -0,0 +1,59 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""NCF metrics calculation"""
import numpy as np
from mindspore.nn.metrics import Metric
class NCFMetric(Metric):
"""NCF metrics method"""
def __init__(self):
super(NCFMetric, self).__init__()
self.hr = []
self.ndcg = []
self.weights = []
def clear(self):
"""Clear the internal evaluation result."""
self.hr = []
self.ndcg = []
self.weights = []
def hit(self, gt_item, pred_items):
if gt_item in pred_items:
return 1
return 0
def ndcg_function(self, gt_item, pred_items):
if gt_item in pred_items:
index = pred_items.index(gt_item)
return np.reciprocal(np.log2(index + 2))
return 0
def update(self, batch_indices, batch_items, metric_weights):
"""Update hr and ndcg"""
batch_indices = batch_indices.asnumpy() # (num_user, topk)
batch_items = batch_items.asnumpy() # (num_user, 100)
metric_weights = metric_weights.asnumpy() # (num_user,)
num_user = batch_items.shape[0]
for user in range(num_user):
if metric_weights[user]:
recommends = batch_items[user][batch_indices[user]].tolist()
items = batch_items[user].tolist()[-1]
self.hr.append(self.hit(items, recommends))
self.ndcg.append(self.ndcg_function(items, recommends))
def eval(self):
return np.mean(self.hr), np.mean(self.ndcg)

View File

@ -0,0 +1,287 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Download and extract the MovieLens dataset from GroupLens website.
Download the dataset, and perform basic preprocessing.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
import zipfile
import argparse
import six
from six.moves import urllib
import numpy as np
import pandas as pd
from absl import logging
import tensorflow as tf
ML_1M = "ml-1m"
ML_20M = "ml-20m"
DATASETS = [ML_1M, ML_20M]
RATINGS_FILE = "ratings.csv"
MOVIES_FILE = "movies.csv"
# URL to download dataset
_DATA_URL = "http://files.grouplens.org/datasets/movielens/"
GENRE_COLUMN = "genres"
ITEM_COLUMN = "item_id" # movies
RATING_COLUMN = "rating"
TIMESTAMP_COLUMN = "timestamp"
TITLE_COLUMN = "titles"
USER_COLUMN = "user_id"
GENRES = [
'Action', 'Adventure', 'Animation', "Children", 'Comedy', 'Crime',
'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', "IMAX", 'Musical',
'Mystery', 'Romance', 'Sci-Fi', 'Thriller', 'War', 'Western'
]
N_GENRE = len(GENRES)
RATING_COLUMNS = [USER_COLUMN, ITEM_COLUMN, RATING_COLUMN, TIMESTAMP_COLUMN]
MOVIE_COLUMNS = [ITEM_COLUMN, TITLE_COLUMN, GENRE_COLUMN]
# Note: Users are indexed [1, k], not [0, k-1]
NUM_USER_IDS = {
ML_1M: 6040,
ML_20M: 138493,
}
# Note: Movies are indexed [1, k], not [0, k-1]
# Both the 1m and 20m datasets use the same movie set.
NUM_ITEM_IDS = 3952
MAX_RATING = 5
NUM_RATINGS = {
ML_1M: 1000209,
ML_20M: 20000263
}
arg_parser = argparse.ArgumentParser(description='movielens dataset')
arg_parser.add_argument("--data_path", type=str, default="./dataset/")
arg_parser.add_argument("--dataset", type=str, default="ml-1m", choices=["ml-1m", "ml-20m"])
args, _ = arg_parser.parse_known_args()
def _download_and_clean(dataset, data_dir):
"""Download MovieLens dataset in a standard format.
This function downloads the specified MovieLens format and coerces it into a
standard format. The only difference between the ml-1m and ml-20m datasets
after this point (other than size, of course) is that the 1m dataset uses
whole number ratings while the 20m dataset allows half integer ratings.
"""
if dataset not in DATASETS:
raise ValueError("dataset {} is not in {{{}}}".format(
dataset, ",".join(DATASETS)))
data_subdir = os.path.join(data_dir, dataset)
expected_files = ["{}.zip".format(dataset), RATINGS_FILE, MOVIES_FILE]
tf.io.gfile.makedirs(data_subdir)
if set(expected_files).intersection(
tf.io.gfile.listdir(data_subdir)) == set(expected_files):
logging.info("Dataset {} has already been downloaded".format(dataset))
return
url = "{}{}.zip".format(_DATA_URL, dataset)
temp_dir = tempfile.mkdtemp()
try:
zip_path = os.path.join(temp_dir, "{}.zip".format(dataset))
zip_path, _ = urllib.request.urlretrieve(url, zip_path)
statinfo = os.stat(zip_path)
# A new line to clear the carriage return from download progress
# logging.info is not applicable here
print()
logging.info(
"Successfully downloaded {} {} bytes".format(
zip_path, statinfo.st_size))
zipfile.ZipFile(zip_path, "r").extractall(temp_dir)
if dataset == ML_1M:
_regularize_1m_dataset(temp_dir)
else:
_regularize_20m_dataset(temp_dir)
for fname in tf.io.gfile.listdir(temp_dir):
if not tf.io.gfile.exists(os.path.join(data_subdir, fname)):
tf.io.gfile.copy(os.path.join(temp_dir, fname),
os.path.join(data_subdir, fname))
else:
logging.info("Skipping copy of {}, as it already exists in the "
"destination folder.".format(fname))
finally:
tf.io.gfile.rmtree(temp_dir)
def _transform_csv(input_path, output_path, names, skip_first, separator=","):
"""Transform csv to a regularized format.
Args:
input_path: The path of the raw csv.
output_path: The path of the cleaned csv.
names: The csv column names.
skip_first: Boolean of whether to skip the first line of the raw csv.
separator: Character used to separate fields in the raw csv.
"""
if six.PY2:
names = [six.ensure_text(n, "utf-8") for n in names]
with tf.io.gfile.GFile(output_path, "wb") as f_out, \
tf.io.gfile.GFile(input_path, "rb") as f_in:
# Write column names to the csv.
f_out.write(",".join(names).encode("utf-8"))
f_out.write(b"\n")
for i, line in enumerate(f_in):
if i == 0 and skip_first:
continue # ignore existing labels in the csv
line = six.ensure_text(line, "utf-8", errors="ignore")
fields = line.split(separator)
if separator != ",":
fields = ['"{}"'.format(field) if "," in field else field
for field in fields]
f_out.write(",".join(fields).encode("utf-8"))
def _regularize_1m_dataset(temp_dir):
"""
ratings.dat
The file has no header row, and each line is in the following format:
UserID::MovieID::Rating::Timestamp
- UserIDs range from 1 and 6040
- MovieIDs range from 1 and 3952
- Ratings are made on a 5-star scale (whole-star ratings only)
- Timestamp is represented in seconds since midnight Coordinated Universal
Time (UTC) of January 1, 1970.
- Each user has at least 20 ratings
movies.dat
Each line has the following format:
MovieID::Title::Genres
- MovieIDs range from 1 and 3952
"""
working_dir = os.path.join(temp_dir, ML_1M)
_transform_csv(
input_path=os.path.join(working_dir, "ratings.dat"),
output_path=os.path.join(temp_dir, RATINGS_FILE),
names=RATING_COLUMNS, skip_first=False, separator="::")
_transform_csv(
input_path=os.path.join(working_dir, "movies.dat"),
output_path=os.path.join(temp_dir, MOVIES_FILE),
names=MOVIE_COLUMNS, skip_first=False, separator="::")
tf.io.gfile.rmtree(working_dir)
def _regularize_20m_dataset(temp_dir):
"""
ratings.csv
Each line of this file after the header row represents one rating of one
movie by one user, and has the following format:
userId,movieId,rating,timestamp
- The lines within this file are ordered first by userId, then, within user,
by movieId.
- Ratings are made on a 5-star scale, with half-star increments
(0.5 stars - 5.0 stars).
- Timestamps represent seconds since midnight Coordinated Universal Time
(UTC) of January 1, 1970.
- All the users had rated at least 20 movies.
movies.csv
Each line has the following format:
MovieID,Title,Genres
- MovieIDs range from 1 and 3952
"""
working_dir = os.path.join(temp_dir, ML_20M)
_transform_csv(
input_path=os.path.join(working_dir, "ratings.csv"),
output_path=os.path.join(temp_dir, RATINGS_FILE),
names=RATING_COLUMNS, skip_first=True, separator=",")
_transform_csv(
input_path=os.path.join(working_dir, "movies.csv"),
output_path=os.path.join(temp_dir, MOVIES_FILE),
names=MOVIE_COLUMNS, skip_first=True, separator=",")
tf.io.gfile.rmtree(working_dir)
def download(dataset, data_dir):
if dataset:
_download_and_clean(dataset, data_dir)
else:
_ = [_download_and_clean(d, data_dir) for d in DATASETS]
def ratings_csv_to_dataframe(data_dir, dataset):
with tf.io.gfile.GFile(os.path.join(data_dir, dataset, RATINGS_FILE)) as f:
return pd.read_csv(f, encoding="utf-8")
def csv_to_joint_dataframe(data_dir, dataset):
ratings = ratings_csv_to_dataframe(data_dir, dataset)
with tf.io.gfile.GFile(os.path.join(data_dir, dataset, MOVIES_FILE)) as f:
movies = pd.read_csv(f, encoding="utf-8")
df = ratings.merge(movies, on=ITEM_COLUMN)
df[RATING_COLUMN] = df[RATING_COLUMN].astype(np.float32)
return df
def integerize_genres(dataframe):
"""Replace genre string with a binary vector.
Args:
dataframe: a pandas dataframe of movie data.
Returns:
The transformed dataframe.
"""
def _map_fn(entry):
entry.replace("Children's", "Children") # naming difference.
movie_genres = entry.split("|")
output = np.zeros((len(GENRES),), dtype=np.int64)
for i, genre in enumerate(GENRES):
if genre in movie_genres:
output[i] = 1
return output
dataframe[GENRE_COLUMN] = dataframe[GENRE_COLUMN].apply(_map_fn)
return dataframe
if __name__ == "__main__":
download(args.dataset, args.data_path)

View File

@ -0,0 +1,290 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Neural Collaborative Filtering Model"""
from mindspore import nn
from mindspore import Tensor, Parameter, ParameterTuple
from mindspore._checkparam import Validator as validator
from mindspore.nn.layer.activation import get_activation
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.common.initializer import initializer
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
class DenseLayer(nn.Cell):
"""
Dense layer definition
"""
def __init__(self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
has_bias=True,
activation=None):
super(DenseLayer, self).__init__()
self.in_channels = validator.check_positive_int(in_channels)
self.out_channels = validator.check_positive_int(out_channels)
self.has_bias = validator.check_bool(has_bias)
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
weight_init.shape()[1] != in_channels:
raise ValueError("weight_init shape error")
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
self.matmul = P.MatMul(transpose_b=True)
self.bias_add = P.BiasAdd()
self.cast = P.Cast()
self.activation = get_activation(activation)
self.activation_flag = self.activation is not None
def construct(self, x):
"""
dense layer construct method
"""
x = self.cast(x, mstype.float16)
weight = self.cast(self.weight, mstype.float16)
bias = self.cast(self.bias, mstype.float16)
output = self.matmul(x, weight)
if self.has_bias:
output = self.bias_add(output, bias)
if self.activation_flag:
output = self.activation(output)
output = self.cast(output, mstype.float32)
return output
def extend_repr(self):
"""A pretty print for Dense layer."""
str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \
.format(self.in_channels, self.out_channels, self.weight, self.has_bias)
if self.has_bias:
str_info = str_info + ', bias={}'.format(self.bias)
if self.activation_flag:
str_info = str_info + ', activation={}'.format(self.activation)
return str_info
class NCFModel(nn.Cell):
"""
Class for Neural Collaborative Filtering Model from paper " Neural Collaborative Filtering".
"""
def __init__(self,
num_users,
num_items,
num_factors,
model_layers,
mf_regularization,
mlp_reg_layers,
mf_dim):
super(NCFModel, self).__init__()
self.data_path = ""
self.model_path = ""
self.num_users = num_users
self.num_items = num_items
self.num_factors = num_factors
self.model_layers = model_layers
self.mf_regularization = mf_regularization
self.mlp_reg_layers = mlp_reg_layers
self.mf_dim = mf_dim
self.num_layers = len(self.model_layers) # Number of layers in the MLP
if self.model_layers[0] % 2 != 0:
raise ValueError("The first layer size should be multiple of 2!")
# Initializer for embedding layers
self.embedding_initializer = "normal"
self.embedding_user = nn.Embedding(
self.num_users,
self.num_factors + self.model_layers[0] // 2,
embedding_table=self.embedding_initializer
)
self.embedding_item = nn.Embedding(
self.num_items,
self.num_factors + self.model_layers[0] // 2,
embedding_table=self.embedding_initializer
)
self.mlp_dense1 = DenseLayer(in_channels=self.model_layers[0],
out_channels=self.model_layers[1],
activation="relu")
self.mlp_dense2 = DenseLayer(in_channels=self.model_layers[1],
out_channels=self.model_layers[2],
activation="relu")
# Logit dense layer
self.logits_dense = DenseLayer(in_channels=self.model_layers[1],
out_channels=1,
weight_init="normal",
activation=None)
# ops definition
self.mul = P.Mul()
self.squeeze = P.Squeeze(axis=1)
self.concat = P.Concat(axis=1)
def construct(self, user_input, item_input):
"""
NCF construct method.
"""
# GMF part
# embedding_layers
embedding_user = self.embedding_user(user_input) # input: (256, 1) output: (256, 1, 16 + 32)
embedding_item = self.embedding_item(item_input) # input: (256, 1) output: (256, 1, 16 + 32)
mf_user_latent = self.squeeze(embedding_user)[:, :self.num_factors] # input: (256, 1, 16 + 32) output: (256, 16)
mf_item_latent = self.squeeze(embedding_item)[:, :self.num_factors] # input: (256, 1, 16 + 32) output: (256, 16)
# MLP part
mlp_user_latent = self.squeeze(embedding_user)[:, self.mf_dim:] # input: (256, 1, 16 + 32) output: (256, 32)
mlp_item_latent = self.squeeze(embedding_item)[:, self.mf_dim:] # input: (256, 1, 16 + 32) output: (256, 32)
# Element-wise multiply
mf_vector = self.mul(mf_user_latent, mf_item_latent) # input: (256, 16), (256, 16) output: (256, 16)
# Concatenation of two latent features
mlp_vector = self.concat((mlp_user_latent, mlp_item_latent)) # input: (256, 32), (256, 32) output: (256, 64)
# MLP dense layers
mlp_vector = self.mlp_dense1(mlp_vector) # input: (256, 64) output: (256, 32)
mlp_vector = self.mlp_dense2(mlp_vector) # input: (256, 32) output: (256, 16)
# # Concatenate GMF and MLP parts
predict_vector = self.concat((mf_vector, mlp_vector)) # input: (256, 16), (256, 16) output: (256, 32)
# Final prediction layer
logits = self.logits_dense(predict_vector) # input: (256, 32) output: (256, 1)
# Print model topology.
return logits
class NetWithLossClass(nn.Cell):
"""
NetWithLossClass definition
"""
def __init__(self, network):
super(NetWithLossClass, self).__init__(auto_prefix=False)
#self.loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
self.loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
self.network = network
self.reducesum = P.ReduceSum(keep_dims=False)
self.mul = P.Mul()
self.squeeze = P.Squeeze(axis=1)
self.zeroslike = P.ZerosLike()
self.concat = P.Concat(axis=1)
self.reciprocal = P.Reciprocal()
def construct(self, batch_users, batch_items, labels, valid_pt_mask):
predict = self.network(batch_users, batch_items)
predict = self.concat((self.zeroslike(predict), predict))
labels = self.squeeze(labels)
loss = self.loss(predict, labels)
loss = self.mul(loss, self.squeeze(valid_pt_mask))
mean_loss = self.mul(self.reducesum(loss), self.reciprocal(self.reducesum(valid_pt_mask)))
return mean_loss
class TrainStepWrap(nn.Cell):
"""
TrainStepWrap definition
"""
def __init__(self, network, sens=16384.0):
super(TrainStepWrap, self).__init__(auto_prefix=False)
self.network = network
self.network.set_train()
self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = nn.Adam(self.weights,
learning_rate=0.00382059,
beta1=0.9,
beta2=0.999,
eps=1e-8,
loss_scale=sens)
self.hyper_map = C.HyperMap()
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, mean, degree)
def construct(self, batch_users, batch_items, labels, valid_pt_mask):
weights = self.weights
loss = self.network(batch_users, batch_items, labels, valid_pt_mask)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) #
grads = self.grad(self.network, weights)(batch_users, batch_items, labels, valid_pt_mask, sens)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))
class PredictWithSigmoid(nn.Cell):
"""
Predict definition
"""
def __init__(self, network, k, num_eval_neg):
super(PredictWithSigmoid, self).__init__()
self.network = network
self.topk = P.TopK(sorted=True)
self.squeeze = P.Squeeze()
self.k = k
self.num_eval_neg = num_eval_neg
self.gather = P.GatherV2()
self.reshape = P.Reshape()
self.reducesum = P.ReduceSum(keep_dims=False)
self.notequal = P.NotEqual()
def construct(self, batch_users, batch_items, duplicated_masks):
predicts = self.network(batch_users, batch_items) # (bs, 1)
predicts = self.reshape(predicts, (-1, self.num_eval_neg + 1)) # (num_user, 100)
batch_items = self.reshape(batch_items, (-1, self.num_eval_neg + 1)) # (num_user, 100)
duplicated_masks = self.reshape(duplicated_masks, (-1, self.num_eval_neg + 1)) # (num_user, 100)
masks_sum = self.reducesum(duplicated_masks, 1)
metric_weights = self.notequal(masks_sum, self.num_eval_neg) # (num_user)
_, indices = self.topk(predicts, self.k) # (num_user, k)
return indices, batch_items, metric_weights

View File

@ -0,0 +1,90 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Statistics utility functions of NCF."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def random_int32():
return np.random.randint(low=0, high=np.iinfo(np.int32).max, dtype=np.int32)
def permutation(args):
"""Fork safe permutation function.
This function can be called within a multiprocessing worker and give
appropriately random results.
Args:
args: A size two tuple that will unpacked into the size of the permutation
and the random seed. This form is used because starmap is not universally
available.
returns:
A NumPy array containing a random permutation.
"""
x, seed = args
# If seed is None NumPy will seed randomly.
state = np.random.RandomState(seed=seed) # pylint: disable=no-member
output = np.arange(x, dtype=np.int32)
state.shuffle(output)
return output
def very_slightly_biased_randint(max_val_vector):
sample_dtype = np.uint64
out_dtype = max_val_vector.dtype
samples = np.random.randint(low=0, high=np.iinfo(sample_dtype).max,
size=max_val_vector.shape, dtype=sample_dtype)
return np.mod(samples, max_val_vector.astype(sample_dtype)).astype(out_dtype)
def mask_duplicates(x, axis=1): # type: (np.ndarray, int) -> np.ndarray
"""Identify duplicates from sampling with replacement.
Args:
x: A 2D NumPy array of samples
axis: The axis along which to de-dupe.
Returns:
A NumPy array with the same shape as x with one if an element appeared
previously along axis 1, else zero.
"""
if axis != 1:
raise NotImplementedError
x_sort_ind = np.argsort(x, axis=1, kind="mergesort")
sorted_x = x[np.arange(x.shape[0])[:, np.newaxis], x_sort_ind]
# compute the indices needed to map values back to their original position.
inv_x_sort_ind = np.argsort(x_sort_ind, axis=1, kind="mergesort")
# Compute the difference of adjacent sorted elements.
diffs = sorted_x[:, :-1] - sorted_x[:, 1:]
# We are only interested in whether an element is zero. Therefore left padding
# with ones to restore the original shape is sufficient.
diffs = np.concatenate(
[np.ones((diffs.shape[0], 1), dtype=diffs.dtype), diffs], axis=1)
# Duplicate values will have a difference of zero. By definition the first
# element is never a duplicate.
return np.where(diffs[np.arange(x.shape[0])[:, np.newaxis],
inv_x_sort_ind], 0, 1)

View File

@ -0,0 +1,105 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Training entry file"""
import os
import argparse
from absl import logging
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore import context, Model
from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from src.dataset import create_dataset
from src.ncf import NCFModel, NetWithLossClass, TrainStepWrap
from config import cfg
logging.set_verbosity(logging.INFO)
parser = argparse.ArgumentParser(description='NCF')
parser.add_argument("--data_path", type=str, default="./dataset/") # The location of the input data.
parser.add_argument("--dataset", type=str, default="ml-1m", choices=["ml-1m", "ml-20m"]) # Dataset to be trained and evaluated. ["ml-1m", "ml-20m"]
parser.add_argument("--train_epochs", type=int, default=14) # The number of epochs used to train.
parser.add_argument("--batch_size", type=int, default=256) # Batch size for training and evaluation
parser.add_argument("--num_neg", type=int, default=4) # The Number of negative instances to pair with a positive instance.
parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file.
parser.add_argument("--loss_file_name", type=str, default="loss.log") # Loss output file.
parser.add_argument("--checkpoint_path", type=str, default="./checkpoint/") # The location of the checkpoint file.
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument('--device_id', type=int, default=1, help='device id of GPU or Ascend. (Default: None)')
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
args = parser.parse_args()
def test_train():
"""train entry method"""
if args.is_distributed:
if args.device_target == "Ascend":
init()
context.set_context(device_id=args.device_id)
elif args.device_target == "GPU":
init()
args.rank = get_rank()
args.group_size = get_group_size()
device_num = args.group_size
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True, gradients_mean=True)
else:
context.set_context(device_id=args.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
layers = cfg.layers
num_factors = cfg.num_factors
epochs = args.train_epochs
ds_train, num_train_users, num_train_items = create_dataset(test_train=True, data_dir=args.data_path,
dataset=args.dataset, train_epochs=1,
batch_size=args.batch_size, num_neg=args.num_neg)
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
ncf_net = NCFModel(num_users=num_train_users,
num_items=num_train_items,
num_factors=num_factors,
model_layers=layers,
mf_regularization=0,
mlp_reg_layers=[0.0, 0.0, 0.0, 0.0],
mf_dim=16)
loss_net = NetWithLossClass(ncf_net)
train_net = TrainStepWrap(loss_net)
train_net.set_train()
model = Model(train_net)
callback = LossMonitor(per_print_times=ds_train.get_dataset_size())
ckpt_config = CheckpointConfig(save_checkpoint_steps=(4970845+args.batch_size-1)//(args.batch_size),
keep_checkpoint_max=100)
ckpoint_cb = ModelCheckpoint(prefix='NCF', directory=args.checkpoint_path, config=ckpt_config)
model.train(epochs,
ds_train,
callbacks=[TimeMonitor(ds_train.get_dataset_size()), callback, ckpoint_cb],
dataset_sink_mode=True)
if __name__ == '__main__':
test_train()