forked from OSSInnovation/mindspore
!5788 [ModelZoo]Add bgcf to model zoo
Merge pull request !5788 from zhanke/bgcf
This commit is contained in:
commit
c65f32b27f
|
@ -0,0 +1,187 @@
|
||||||
|
<!--TOC -->
|
||||||
|
|
||||||
|
- [Bayesian Graph Collaborative Filtering](#bayesian-graph-collaborative-filtering)
|
||||||
|
- [Model Architecture](#model-architecture)
|
||||||
|
- [Dataset](#dataset)
|
||||||
|
- [Features](#features)
|
||||||
|
- [Mixed Precision](#mixed-precision)
|
||||||
|
- [Environment Requirements](#environment-requirements)
|
||||||
|
- [Quick Start](#quick-start)
|
||||||
|
- [Script Description](#script-description)
|
||||||
|
- [Script and Sample Code](#script-and-sample-code)
|
||||||
|
- [Script Parameters](#script-parameters)
|
||||||
|
- [Training Process](#training-process)
|
||||||
|
- [Training](#training)
|
||||||
|
- [Model Description](#model-description)
|
||||||
|
- [Performance](#performance)
|
||||||
|
- [Description of random situation](#description-of-random-situation)
|
||||||
|
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||||
|
<!--TOC -->
|
||||||
|
# [Bayesian Graph Collaborative Filtering](#contents)
|
||||||
|
|
||||||
|
Bayesian Graph Collaborative Filtering(BGCF) was proposed in 2020 by Sun J, Guo W, Zhang D et al. By naturally incorporating the
|
||||||
|
uncertainty in the user-item interaction graph shows excellent performance on Amazon recommendation dataset.This is an example of
|
||||||
|
training of BGCF with Amazon-Beauty dataset in MindSpore. More importantly, this is the first open source version for BGCF.
|
||||||
|
|
||||||
|
[Paper](https://dl.acm.org/doi/pdf/10.1145/3394486.3403254): Sun J, Guo W, Zhang D, et al. A Framework for Recommending Accurate and Diverse Items Using Bayesian Graph Convolutional Neural Networks[C]//Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. 2020: 2030-2039.
|
||||||
|
|
||||||
|
# [Model Architecture](#contents)
|
||||||
|
|
||||||
|
Specially, BGCF contains two main modules. The first is sampling, which produce sample graphs based in node copying. Another module
|
||||||
|
aggregate the neighbors sampling from nodes consisting of mean aggregator and attention aggregator.
|
||||||
|
|
||||||
|
# [Dataset](#contents)
|
||||||
|
- Dataset size:
|
||||||
|
Statistics of dataset used are summarized as below:
|
||||||
|
|
||||||
|
| | Amazon-Beauty |
|
||||||
|
| ------------------ | -----------------------:|
|
||||||
|
| Task | Recommendation |
|
||||||
|
| # User | 7068 (1 graph) |
|
||||||
|
| # Item | 3570 |
|
||||||
|
| # Interaction | 79506 |
|
||||||
|
| # Training Data | 60818 |
|
||||||
|
| # Test Data | 18688 |
|
||||||
|
| # Density | 0.315% |
|
||||||
|
|
||||||
|
- Data Preparation
|
||||||
|
- Place the dataset to any path you want, the folder should include files as follows(we use Amazon-Beauty dataset as an example)"
|
||||||
|
```
|
||||||
|
.
|
||||||
|
└─data
|
||||||
|
├─ratings_Beauty.csv
|
||||||
|
```
|
||||||
|
|
||||||
|
- Generate dataset in mindrecord format for Amazon-Beauty.
|
||||||
|
```builddoutcfg
|
||||||
|
cd ./scripts
|
||||||
|
# SRC_PATH is the dataset file path you download.
|
||||||
|
sh run_process_data_ascend.sh [SRC_PATH]
|
||||||
|
```
|
||||||
|
|
||||||
|
- Launch
|
||||||
|
```
|
||||||
|
# Generate dataset in mindrecord format for Amazon-Beauty.
|
||||||
|
sh ./run_process_data_ascend.sh ./data
|
||||||
|
|
||||||
|
# [Features](#contents)
|
||||||
|
|
||||||
|
## Mixed Precision
|
||||||
|
|
||||||
|
To ultilize the strong computation power of Ascend chip, and accelerate the training process, the mixed training method is used. MindSpore is able to cope with FP32 inputs and FP16 operators. In BGCF example, the model is set to FP16 mode except for the loss calculation part.
|
||||||
|
|
||||||
|
# [Environment Requirements](#contents)
|
||||||
|
|
||||||
|
- Hardward (Ascend)
|
||||||
|
- Framework
|
||||||
|
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||||
|
- For more information, please check the resources below:
|
||||||
|
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
|
||||||
|
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
|
||||||
|
|
||||||
|
# [Quick Start](#contents)
|
||||||
|
|
||||||
|
After installing MindSpore via the official website and Dataset is correctly generated, you can start training and evaluation as follows.
|
||||||
|
|
||||||
|
- running on Ascend
|
||||||
|
|
||||||
|
```
|
||||||
|
# run training example with Amazon-Beauty dataset
|
||||||
|
sh run_train_ascend.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
# [Script Description](#contents)
|
||||||
|
|
||||||
|
## [Script and Sample Code](#contents)
|
||||||
|
|
||||||
|
```shell
|
||||||
|
.
|
||||||
|
└─bgcf
|
||||||
|
├─README.md
|
||||||
|
├─scripts
|
||||||
|
| ├─run_process_data_ascend.sh # Generate dataset in mindrecord format
|
||||||
|
| └─run_train_ascend.sh # Launch training
|
||||||
|
|
|
||||||
|
├─src
|
||||||
|
| ├─bgcf.py # BGCF model
|
||||||
|
| ├─callback.py # Callback function
|
||||||
|
| ├─config.py # Training configurations
|
||||||
|
| ├─dataset.py # Data preprocessing
|
||||||
|
| ├─metrics.py # Recommendation metrics
|
||||||
|
| └─utils.py # Utils for training bgcf
|
||||||
|
|
|
||||||
|
└─train.py # Train net
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Script Parameters](#contents)
|
||||||
|
|
||||||
|
Parameters for both training and evaluation can be set in config.py.
|
||||||
|
|
||||||
|
- config for BGCF dataset
|
||||||
|
|
||||||
|
```python
|
||||||
|
"learning_rate": 0.001, # Learning rate
|
||||||
|
"num_epochs": 600, # Epoch sizes for training
|
||||||
|
"num_neg": 10, # Negative sampling rate
|
||||||
|
"raw_neighs": 40, # Num of sampling neighbors in raw graph
|
||||||
|
"gnew_neighs": 20, # Num of sampling neighbors in sample graph
|
||||||
|
"input_dim": 64, # User and item embedding dimension
|
||||||
|
"l2_coeff": 0.03 # l2 coefficient
|
||||||
|
"neighbor_dropout": [0.0, 0.2, 0.3]# Dropout ratio for different aggregation layer
|
||||||
|
"num_graphs":5 # Num of sample graph
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Training Process](#contents)
|
||||||
|
|
||||||
|
### Training
|
||||||
|
|
||||||
|
- running on Ascend
|
||||||
|
```python
|
||||||
|
sh run_train_ascend.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
Training result will be stored in the scripts path, whose folder name begins with "train". You can find the result like the
|
||||||
|
followings in log.
|
||||||
|
|
||||||
|
```python
|
||||||
|
Epoch 001 iter 12 loss 34696.242
|
||||||
|
Epoch 002 iter 12 loss 34275.508
|
||||||
|
Epoch 003 iter 12 loss 30620.635
|
||||||
|
Epoch 004 iter 12 loss 21628.908
|
||||||
|
|
||||||
|
...
|
||||||
|
Epoch 597 iter 12 loss 3662.3152
|
||||||
|
Epoch 598 iter 12 loss 3640.7612
|
||||||
|
Epoch 599 iter 12 loss 3654.9087
|
||||||
|
Epoch 600 iter 12 loss 3632.4585
|
||||||
|
epoch:600, recall_@10:0.10393, recall_@20:0.15669, ndcg_@10:0.07564, ndcg_@20:0.09343,
|
||||||
|
sedp_@10:0.01936, sedp_@20:0.01544, nov_@10:7.58599, nov_@20:7.79782
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
# [Model Description](#contents)
|
||||||
|
## [Performance](#contents)
|
||||||
|
|
||||||
|
| Parameter | BGCF |
|
||||||
|
| ------------------------------------ | ----------------------------------------- |
|
||||||
|
| Resource | Ascend 910 |
|
||||||
|
| uploaded Date | 09/04/2020(month/day/year) |
|
||||||
|
| MindSpore Version | 1.0 |
|
||||||
|
| Dataset | Amazon-Beauty |
|
||||||
|
| Training Parameter | epoch=600 |
|
||||||
|
| Optimizer | Adam |
|
||||||
|
| Loss Function | BPR loss |
|
||||||
|
| Recall@20 | 0.1534 |
|
||||||
|
| NDCG@20 | 0.0912 |
|
||||||
|
| Total time | 30min |
|
||||||
|
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf |
|
||||||
|
|
||||||
|
# [Description of random situation](#contents)
|
||||||
|
|
||||||
|
BGCF model contains lots of dropout operations, if you want to disable dropout, set the neighbor_dropout to [0.0, 0.0, 0.0] in src/config.py.
|
||||||
|
|
||||||
|
# [ModelZoo Homepage](#contents)
|
||||||
|
|
||||||
|
Please check the official [homepage](http://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 1 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_process_data_ascend.sh [SRC_PATH] "
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
SRC_PATH=$(get_real_path $1)
|
||||||
|
echo $SRC_PATH
|
||||||
|
|
||||||
|
|
||||||
|
if [ ! -d data_mr ]; then
|
||||||
|
mkdir data_mr
|
||||||
|
else
|
||||||
|
echo data_mr exist
|
||||||
|
fi
|
||||||
|
MINDRECORD_PATH=`pwd`/data_mr
|
||||||
|
|
||||||
|
rm -rf ${MINDRECORD_PATH:?}/*
|
||||||
|
INTER_FILE_DIR=$MINDRECORD_PATH/InterFile
|
||||||
|
mkdir -p $INTER_FILE_DIR
|
||||||
|
|
||||||
|
cd ../../../../utils/graph_to_mindrecord || exit
|
||||||
|
|
||||||
|
echo "Start to converting data."
|
||||||
|
python amazon_beauty/converting_data.py --src_path $SRC_PATH --out_path $INTER_FILE_DIR
|
||||||
|
|
||||||
|
echo "Start to generate train_mr."
|
||||||
|
python writer.py --mindrecord_script amazon_beauty \
|
||||||
|
--mindrecord_file "$MINDRECORD_PATH/train_mr" \
|
||||||
|
--mindrecord_partitions 1 \
|
||||||
|
--mindrecord_header_size_by_bit 18 \
|
||||||
|
--mindrecord_page_size_by_bit 20 \
|
||||||
|
--graph_api_args "$INTER_FILE_DIR/user.csv:$INTER_FILE_DIR/item.csv:$INTER_FILE_DIR/rating_train.csv"
|
||||||
|
|
||||||
|
echo "Start to generate test_mr."
|
||||||
|
python writer.py --mindrecord_script amazon_beauty \
|
||||||
|
--mindrecord_file "$MINDRECORD_PATH/test_mr" \
|
||||||
|
--mindrecord_partitions 1 \
|
||||||
|
--mindrecord_header_size_by_bit 18 \
|
||||||
|
--mindrecord_page_size_by_bit 20 \
|
||||||
|
--graph_api_args "$INTER_FILE_DIR/user.csv:$INTER_FILE_DIR/item.csv:$INTER_FILE_DIR/rating_test.csv"
|
||||||
|
|
||||||
|
for id in {0..4}
|
||||||
|
do
|
||||||
|
echo "Start to generate sampled${id}_mr."
|
||||||
|
python writer.py --mindrecord_script amazon_beauty \
|
||||||
|
--mindrecord_file "${MINDRECORD_PATH}/sampled${id}_mr" \
|
||||||
|
--mindrecord_partitions 1 \
|
||||||
|
--mindrecord_header_size_by_bit 18 \
|
||||||
|
--mindrecord_page_size_by_bit 20 \
|
||||||
|
--graph_api_args "$INTER_FILE_DIR/user.csv:$INTER_FILE_DIR/item.csv:$INTER_FILE_DIR/rating_sampled${id}.csv"
|
||||||
|
done
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
ulimit -u unlimited
|
||||||
|
export DEVICE_NUM=1
|
||||||
|
export RANK_SIZE=$DEVICE_NUM
|
||||||
|
export DEVICE_ID=0
|
||||||
|
export RANK_ID=0
|
||||||
|
|
||||||
|
if [ -d "train" ];
|
||||||
|
then
|
||||||
|
rm -rf ./train
|
||||||
|
fi
|
||||||
|
mkdir ./train
|
||||||
|
cp ../*.py ./train
|
||||||
|
cp *.sh ./train
|
||||||
|
cp -r ../src ./train
|
||||||
|
cd ./train || exit
|
||||||
|
mkdir ./ckpts
|
||||||
|
env > env.log
|
||||||
|
echo "start training for device $DEVICE_ID"
|
||||||
|
|
||||||
|
python train.py --datapath=../data_mr &> log &
|
||||||
|
|
||||||
|
cd ..
|
|
@ -0,0 +1,263 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Architecture"""
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Parameter
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.common.initializer import initializer
|
||||||
|
|
||||||
|
|
||||||
|
class MeanConv(nn.Cell):
|
||||||
|
"""
|
||||||
|
BGCF mean aggregate layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_in_dim (int): The input feature dimension.
|
||||||
|
feature_out_dim (int): The output feature dimension.
|
||||||
|
activation (str): Activation function applied to the output of the layer, eg. 'relu'. Default: 'tanh'.
|
||||||
|
dropout (float): Dropout ratio for the dropout layer. Default: 0.2.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- self_feature (Tensor) - Tensor of shape :math:`(batch_size, feature_dim)`.
|
||||||
|
- neigh_feature (Tensor) - Tensor of shape :math:`(batch_size, neighbour_num, feature_dim)`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, output tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
name,
|
||||||
|
feature_in_dim,
|
||||||
|
feature_out_dim,
|
||||||
|
activation,
|
||||||
|
dropout=0.2):
|
||||||
|
super(MeanConv, self).__init__()
|
||||||
|
|
||||||
|
self.out_weight = Parameter(
|
||||||
|
initializer("XavierUniform", [feature_in_dim * 2, feature_out_dim], dtype=mstype.float32),
|
||||||
|
name=name + 'out_weight')
|
||||||
|
|
||||||
|
if activation == "tanh":
|
||||||
|
self.act = P.Tanh()
|
||||||
|
elif activation == "relu":
|
||||||
|
self.act = P.ReLU()
|
||||||
|
else:
|
||||||
|
raise ValueError("activation should be tanh or relu")
|
||||||
|
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
self.concat = P.Concat(axis=1)
|
||||||
|
self.reduce_mean = P.ReduceMean(keep_dims=False)
|
||||||
|
self.dropout = nn.Dropout(keep_prob=1 - dropout)
|
||||||
|
|
||||||
|
def construct(self, self_feature, neigh_feature):
|
||||||
|
neigh_matrix = self.reduce_mean(neigh_feature, 1)
|
||||||
|
neigh_matrix = self.dropout(neigh_matrix)
|
||||||
|
|
||||||
|
output = self.concat((self_feature, neigh_matrix))
|
||||||
|
output = self.act(self.matmul(output, self.out_weight))
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class AttenConv(nn.Cell):
|
||||||
|
"""
|
||||||
|
BGCF attention aggregate layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_in_dim (int): The input feature dimension.
|
||||||
|
feature_out_dim (int): The output feature dimension.
|
||||||
|
dropout (float): Dropout ratio for the dropout layer. Default: 0.2.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- self_feature (Tensor) - Tensor of shape :math:`(batch_size, feature_dim)`.
|
||||||
|
- neigh_feature (Tensor) - Tensor of shape :math:`(batch_size, neighbour_num, feature_dim)`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, output tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
name,
|
||||||
|
feature_in_dim,
|
||||||
|
feature_out_dim,
|
||||||
|
dropout=0.2):
|
||||||
|
super(AttenConv, self).__init__()
|
||||||
|
|
||||||
|
self.out_weight = Parameter(
|
||||||
|
initializer("XavierUniform", [feature_in_dim * 2, feature_out_dim], dtype=mstype.float32),
|
||||||
|
name=name + 'out_weight')
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.squeeze = P.Squeeze(1)
|
||||||
|
self.concat = P.Concat(axis=1)
|
||||||
|
self.expanddims = P.ExpandDims()
|
||||||
|
self.softmax = P.Softmax(axis=-1)
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
self.matmul_3 = P.BatchMatMul()
|
||||||
|
self.matmul_t = P.BatchMatMul(transpose_b=True)
|
||||||
|
self.dropout = nn.Dropout(keep_prob=1 - dropout)
|
||||||
|
|
||||||
|
def construct(self, self_feature, neigh_feature):
|
||||||
|
"""Attention aggregation"""
|
||||||
|
query = self.expanddims(self_feature, 1)
|
||||||
|
neigh_matrix = self.dropout(neigh_feature)
|
||||||
|
|
||||||
|
score = self.matmul_t(query, neigh_matrix)
|
||||||
|
score = self.softmax(score)
|
||||||
|
atten_agg = self.matmul_3(score, neigh_matrix)
|
||||||
|
atten_agg = self.squeeze(atten_agg)
|
||||||
|
|
||||||
|
output = self.matmul(self.concat((atten_agg, self_feature)), self.out_weight)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class BGCF(nn.Cell):
|
||||||
|
"""
|
||||||
|
BGCF architecture.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_argv (list[int]): A list of the dataset argv.
|
||||||
|
architect_argv (list[int]): A list of the model layer argv.
|
||||||
|
activation (str): Activation function applied to the output of the layer, eg. 'relu'. Default: 'tanh'.
|
||||||
|
neigh_drop_rate (list[float]): A list of the dropout ratio.
|
||||||
|
num_user (int): The num of user.
|
||||||
|
num_item (int): The num of item.
|
||||||
|
input_dim (int): The feature dim.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dataset_argv,
|
||||||
|
architect_argv,
|
||||||
|
activation,
|
||||||
|
neigh_drop_rate,
|
||||||
|
num_user,
|
||||||
|
num_item,
|
||||||
|
input_dim):
|
||||||
|
super(BGCF, self).__init__()
|
||||||
|
|
||||||
|
self.user_embeddings = Parameter(initializer("XavierUniform", [num_user, input_dim], dtype=mstype.float32),
|
||||||
|
name='user_embed')
|
||||||
|
self.item_embeddings = Parameter(initializer("XavierUniform", [num_item, input_dim], dtype=mstype.float32),
|
||||||
|
name='item_embed')
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.tanh = P.Tanh()
|
||||||
|
self.shape = P.Shape()
|
||||||
|
self.split = P.Split(0, 2)
|
||||||
|
self.gather = P.GatherV2()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.concat_0 = P.Concat(0)
|
||||||
|
self.concat_1 = P.Concat(1)
|
||||||
|
|
||||||
|
(self.input_dim, self.num_user, self.num_item) = dataset_argv
|
||||||
|
self.layer_dim = architect_argv
|
||||||
|
|
||||||
|
self.gnew_agg_mean = MeanConv('gnew_agg_mean', self.input_dim, self.layer_dim,
|
||||||
|
activation=activation, dropout=neigh_drop_rate[1])
|
||||||
|
self.gnew_agg_mean.to_float(mstype.float16)
|
||||||
|
|
||||||
|
self.gnew_agg_user = AttenConv('gnew_agg_att_user', self.input_dim,
|
||||||
|
self.layer_dim, dropout=neigh_drop_rate[2])
|
||||||
|
self.gnew_agg_user.to_float(mstype.float16)
|
||||||
|
|
||||||
|
self.gnew_agg_item = AttenConv('gnew_agg_att_item', self.input_dim,
|
||||||
|
self.layer_dim, dropout=neigh_drop_rate[2])
|
||||||
|
self.gnew_agg_item.to_float(mstype.float16)
|
||||||
|
|
||||||
|
self.user_feature_dim = self.input_dim
|
||||||
|
self.item_feature_dim = self.input_dim
|
||||||
|
|
||||||
|
self.final_weight = Parameter(
|
||||||
|
initializer("XavierUniform", [self.input_dim * 3, self.input_dim * 3], dtype=mstype.float32),
|
||||||
|
name='final_weight')
|
||||||
|
|
||||||
|
self.raw_agg_funcs_user = MeanConv('raw_agg_user', self.input_dim, self.layer_dim,
|
||||||
|
activation=activation, dropout=neigh_drop_rate[0])
|
||||||
|
self.raw_agg_funcs_user.to_float(mstype.float16)
|
||||||
|
|
||||||
|
self.raw_agg_funcs_item = MeanConv('raw_agg_item', self.input_dim, self.layer_dim,
|
||||||
|
activation=activation, dropout=neigh_drop_rate[0])
|
||||||
|
self.raw_agg_funcs_item.to_float(mstype.float16)
|
||||||
|
|
||||||
|
def construct(self,
|
||||||
|
u_id,
|
||||||
|
pos_item_id,
|
||||||
|
neg_item_id,
|
||||||
|
pos_users,
|
||||||
|
pos_items,
|
||||||
|
u_group_nodes,
|
||||||
|
u_neighs,
|
||||||
|
u_gnew_neighs,
|
||||||
|
i_group_nodes,
|
||||||
|
i_neighs,
|
||||||
|
i_gnew_neighs,
|
||||||
|
neg_group_nodes,
|
||||||
|
neg_neighs,
|
||||||
|
neg_gnew_neighs,
|
||||||
|
neg_item_num):
|
||||||
|
"""Aggregate user and item embeddings"""
|
||||||
|
all_user_embed = self.gather(self.user_embeddings, self.concat_0((u_id, pos_users)), 0)
|
||||||
|
|
||||||
|
u_self_matrix_at_layers = self.gather(self.user_embeddings, u_group_nodes, 0)
|
||||||
|
u_neigh_matrix_at_layers = self.gather(self.item_embeddings, u_neighs, 0)
|
||||||
|
|
||||||
|
u_output_mean = self.raw_agg_funcs_user(u_self_matrix_at_layers, u_neigh_matrix_at_layers)
|
||||||
|
|
||||||
|
u_gnew_neighs_matrix = self.gather(self.item_embeddings, u_gnew_neighs, 0)
|
||||||
|
u_output_from_gnew_mean = self.gnew_agg_mean(u_self_matrix_at_layers, u_gnew_neighs_matrix)
|
||||||
|
|
||||||
|
u_output_from_gnew_att = self.gnew_agg_user(u_self_matrix_at_layers,
|
||||||
|
self.concat_1((u_neigh_matrix_at_layers, u_gnew_neighs_matrix)))
|
||||||
|
|
||||||
|
u_output = self.concat_1((u_output_mean, u_output_from_gnew_mean, u_output_from_gnew_att))
|
||||||
|
all_user_rep = self.tanh(u_output)
|
||||||
|
|
||||||
|
all_pos_item_embed = self.gather(self.item_embeddings, self.concat_0((pos_item_id, pos_items)), 0)
|
||||||
|
|
||||||
|
i_self_matrix_at_layers = self.gather(self.item_embeddings, i_group_nodes, 0)
|
||||||
|
i_neigh_matrix_at_layers = self.gather(self.user_embeddings, i_neighs, 0)
|
||||||
|
|
||||||
|
i_output_mean = self.raw_agg_funcs_item(i_self_matrix_at_layers, i_neigh_matrix_at_layers)
|
||||||
|
|
||||||
|
i_gnew_neighs_matrix = self.gather(self.user_embeddings, i_gnew_neighs, 0)
|
||||||
|
i_output_from_gnew_mean = self.gnew_agg_mean(i_self_matrix_at_layers, i_gnew_neighs_matrix)
|
||||||
|
|
||||||
|
i_output_from_gnew_att = self.gnew_agg_item(i_self_matrix_at_layers,
|
||||||
|
self.concat_1((i_neigh_matrix_at_layers, i_gnew_neighs_matrix)))
|
||||||
|
|
||||||
|
i_output = self.concat_1((i_output_mean, i_output_from_gnew_mean, i_output_from_gnew_att))
|
||||||
|
all_pos_item_rep = self.tanh(i_output)
|
||||||
|
|
||||||
|
neg_item_embed = self.gather(self.item_embeddings, neg_item_id, 0)
|
||||||
|
|
||||||
|
neg_self_matrix_at_layers = self.gather(self.item_embeddings, neg_group_nodes, 0)
|
||||||
|
neg_neigh_matrix_at_layers = self.gather(self.user_embeddings, neg_neighs, 0)
|
||||||
|
|
||||||
|
neg_output_mean = self.raw_agg_funcs_item(neg_self_matrix_at_layers, neg_neigh_matrix_at_layers)
|
||||||
|
|
||||||
|
neg_gnew_neighs_matrix = self.gather(self.user_embeddings, neg_gnew_neighs, 0)
|
||||||
|
neg_output_from_gnew_mean = self.gnew_agg_mean(neg_self_matrix_at_layers, neg_gnew_neighs_matrix)
|
||||||
|
|
||||||
|
neg_output_from_gnew_att = self.gnew_agg_item(neg_self_matrix_at_layers,
|
||||||
|
self.concat_1(
|
||||||
|
(neg_neigh_matrix_at_layers, neg_gnew_neighs_matrix)))
|
||||||
|
|
||||||
|
neg_output = self.concat_1((neg_output_mean, neg_output_from_gnew_mean, neg_output_from_gnew_att))
|
||||||
|
neg_output = self.tanh(neg_output)
|
||||||
|
|
||||||
|
neg_output_shape = self.shape(neg_output)
|
||||||
|
neg_item_rep = self.reshape(neg_output,
|
||||||
|
(self.shape(neg_item_embed)[0], neg_item_num, neg_output_shape[-1]))
|
||||||
|
|
||||||
|
return all_user_embed, all_user_rep, all_pos_item_embed, all_pos_item_rep, neg_item_embed, neg_item_rep
|
|
@ -0,0 +1,374 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
callback
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.common.parameter import ParameterTuple
|
||||||
|
|
||||||
|
from src.utils import convert_item_id
|
||||||
|
|
||||||
|
|
||||||
|
def TestBGCF(forward_net, num_user, num_item, input_dim, test_graph_dataset):
|
||||||
|
"""BGCF test wrapper"""
|
||||||
|
user_reps = np.zeros([num_user, input_dim * 3])
|
||||||
|
item_reps = np.zeros([num_item, input_dim * 3])
|
||||||
|
|
||||||
|
for _ in range(50):
|
||||||
|
test_graph_dataset.random_select_sampled_graph()
|
||||||
|
u_test_neighs, u_test_gnew_neighs = test_graph_dataset.get_user_sapmled_neighbor()
|
||||||
|
i_test_neighs, i_test_gnew_neighs = test_graph_dataset.get_item_sampled_neighbor()
|
||||||
|
|
||||||
|
u_test_neighs = Tensor(convert_item_id(u_test_neighs, num_user), mstype.int32)
|
||||||
|
u_test_gnew_neighs = Tensor(convert_item_id(u_test_gnew_neighs, num_user), mstype.int32)
|
||||||
|
i_test_neighs = Tensor(i_test_neighs, mstype.int32)
|
||||||
|
i_test_gnew_neighs = Tensor(i_test_gnew_neighs, mstype.int32)
|
||||||
|
|
||||||
|
users = Tensor(np.arange(num_user).reshape(-1,), mstype.int32)
|
||||||
|
items = Tensor(np.arange(num_item).reshape(-1,), mstype.int32)
|
||||||
|
neg_items = Tensor(np.arange(num_item).reshape(-1, 1), mstype.int32)
|
||||||
|
|
||||||
|
user_rep, item_rep = forward_net(users,
|
||||||
|
items,
|
||||||
|
neg_items,
|
||||||
|
u_test_neighs,
|
||||||
|
u_test_gnew_neighs,
|
||||||
|
i_test_neighs,
|
||||||
|
i_test_gnew_neighs)
|
||||||
|
|
||||||
|
user_reps += user_rep.asnumpy()
|
||||||
|
item_reps += item_rep.asnumpy()
|
||||||
|
|
||||||
|
user_reps /= 50
|
||||||
|
item_reps /= 50
|
||||||
|
return user_reps, item_reps
|
||||||
|
|
||||||
|
|
||||||
|
class ForwardBGCF(nn.Cell):
|
||||||
|
"""Calculate the forward output"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
network):
|
||||||
|
super(ForwardBGCF, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def construct(self, users, items, neg_items, u_neighs, u_gnew_neighs, i_neighs, i_gnew_neighs):
|
||||||
|
"""Calculate the user and item representation"""
|
||||||
|
_, user_rep, _, item_rep, _, _, = self.network(users,
|
||||||
|
items,
|
||||||
|
neg_items,
|
||||||
|
users,
|
||||||
|
items,
|
||||||
|
users,
|
||||||
|
u_neighs,
|
||||||
|
u_gnew_neighs,
|
||||||
|
items,
|
||||||
|
i_neighs,
|
||||||
|
i_gnew_neighs,
|
||||||
|
items,
|
||||||
|
i_neighs,
|
||||||
|
i_gnew_neighs,
|
||||||
|
1)
|
||||||
|
return user_rep, item_rep
|
||||||
|
|
||||||
|
|
||||||
|
class BGCFLoss(nn.Cell):
|
||||||
|
"""BGCF loss with user and item embedding"""
|
||||||
|
|
||||||
|
def __init__(self, neg_item_num, l2_embed, dist_reg):
|
||||||
|
super(BGCFLoss, self).__init__()
|
||||||
|
|
||||||
|
self.neg_item_num = neg_item_num
|
||||||
|
self.l2_embed = l2_embed
|
||||||
|
self.dist_reg = dist_reg
|
||||||
|
|
||||||
|
self.log = P.Log()
|
||||||
|
self.pow = P.Pow()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.tile = P.Tile()
|
||||||
|
self.shape = P.Shape()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.concat = P.Concat(1)
|
||||||
|
self.concat2 = P.Concat(2)
|
||||||
|
self.split = P.Split(0, 2)
|
||||||
|
self.reduce_sum = P.ReduceSum()
|
||||||
|
self.expand_dims = P.ExpandDims()
|
||||||
|
self.multiply = P.Mul()
|
||||||
|
self.matmul = P.BatchMatMul()
|
||||||
|
self.squeeze = P.Squeeze(1)
|
||||||
|
self.transpose = P.Transpose()
|
||||||
|
self.l2_loss = P.L2Loss()
|
||||||
|
self.sigmoid = P.Sigmoid()
|
||||||
|
|
||||||
|
def construct(self, all_user_embed, all_user_rep, all_pos_item_embed,
|
||||||
|
all_pos_item_rep, neg_item_embed, neg_item_rep):
|
||||||
|
"""Calculate loss"""
|
||||||
|
all_user_embed = self.cast(all_user_embed, mstype.float16)
|
||||||
|
all_user_rep = self.concat((all_user_rep, all_user_embed))
|
||||||
|
|
||||||
|
user_rep, pos_user_rep = self.split(all_user_rep)
|
||||||
|
user_embed, pos_user_embed = self.split(all_user_embed)
|
||||||
|
|
||||||
|
user_user_distance = self.reduce_sum(self.pow(user_rep - pos_user_rep, 2)) \
|
||||||
|
+ self.reduce_sum(self.pow(user_embed - pos_user_embed, 2))
|
||||||
|
user_user_distance = self.cast(user_user_distance, mstype.float32)
|
||||||
|
|
||||||
|
user_rep = self.expand_dims(user_rep, 1)
|
||||||
|
|
||||||
|
all_pos_item_embed = self.cast(all_pos_item_embed, mstype.float16)
|
||||||
|
all_pos_item_rep = self.concat((all_pos_item_rep, all_pos_item_embed))
|
||||||
|
|
||||||
|
pos_item_rep, pos_item_neigh_rep = self.split(all_pos_item_rep)
|
||||||
|
pos_item_embed, pos_item_neigh_embed = self.split(all_pos_item_embed)
|
||||||
|
|
||||||
|
pos_item_item_distance = self.reduce_sum(self.pow(pos_item_rep - pos_item_neigh_rep, 2)) \
|
||||||
|
+ self.reduce_sum(self.pow(pos_item_embed - pos_item_neigh_embed, 2))
|
||||||
|
pos_item_item_distance = self.cast(pos_item_item_distance, mstype.float32)
|
||||||
|
|
||||||
|
neg_item_embed = self.cast(neg_item_embed, mstype.float16)
|
||||||
|
neg_item_rep = self.concat2((neg_item_rep, neg_item_embed))
|
||||||
|
|
||||||
|
item_rep = self.concat((self.expand_dims(pos_item_rep, 1), neg_item_rep))
|
||||||
|
|
||||||
|
pos_rating = self.reduce_sum(self.multiply(self.squeeze(user_rep), pos_item_rep), 1)
|
||||||
|
pos_rating = self.expand_dims(pos_rating, 1)
|
||||||
|
pos_rating = self.tile(pos_rating, (1, self.neg_item_num))
|
||||||
|
pos_rating = self.reshape(pos_rating, (self.shape(pos_rating)[0] * self.neg_item_num, 1))
|
||||||
|
pos_rating = self.cast(pos_rating, mstype.float32)
|
||||||
|
|
||||||
|
batch_neg_item_embedding = self.transpose(neg_item_rep, (0, 2, 1))
|
||||||
|
neg_rating = self.matmul(user_rep, batch_neg_item_embedding)
|
||||||
|
neg_rating = self.squeeze(neg_rating)
|
||||||
|
neg_rating = self.reshape(neg_rating, (self.shape(neg_rating)[0] * self.neg_item_num, 1))
|
||||||
|
neg_rating = self.cast(neg_rating, mstype.float32)
|
||||||
|
|
||||||
|
bpr_loss = pos_rating - neg_rating
|
||||||
|
bpr_loss = self.sigmoid(bpr_loss)
|
||||||
|
bpr_loss = - self.log(bpr_loss)
|
||||||
|
bpr_loss = self.reduce_sum(bpr_loss)
|
||||||
|
|
||||||
|
reg_loss = self.l2_embed * (self.l2_loss(user_rep) + self.l2_loss(item_rep))
|
||||||
|
|
||||||
|
loss = bpr_loss + reg_loss + self.dist_reg * (user_user_distance + pos_item_item_distance)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class LossWrapper(nn.Cell):
|
||||||
|
"""
|
||||||
|
Wraps the BGCF model with loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): BGCF network.
|
||||||
|
neg_item_num (Number): The num of negative instances for a positive instance.
|
||||||
|
l2_embed (Number): The coefficient of l2 loss.
|
||||||
|
dist_reg (Number): The coefficient of distance loss.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, network, neg_item_num, l2_embed, dist_reg=0.002):
|
||||||
|
super(LossWrapper, self).__init__(auto_prefix=False)
|
||||||
|
self.network = network
|
||||||
|
self.loss_func = BGCFLoss(neg_item_num, l2_embed, dist_reg)
|
||||||
|
|
||||||
|
def construct(self,
|
||||||
|
u_id,
|
||||||
|
pos_item_id,
|
||||||
|
neg_item_id,
|
||||||
|
pos_users,
|
||||||
|
pos_items,
|
||||||
|
u_group_nodes,
|
||||||
|
u_neighs,
|
||||||
|
u_gnew_neighs,
|
||||||
|
i_group_nodes,
|
||||||
|
i_neighs,
|
||||||
|
i_gnew_neighs,
|
||||||
|
neg_group_nodes,
|
||||||
|
neg_neighs,
|
||||||
|
neg_gnew_neighs):
|
||||||
|
"""Return loss"""
|
||||||
|
all_user_embed, all_user_rep, all_pos_item_embed, \
|
||||||
|
all_pos_item_rep, neg_item_embed, neg_item_rep = self.network(u_id,
|
||||||
|
pos_item_id,
|
||||||
|
neg_item_id,
|
||||||
|
pos_users,
|
||||||
|
pos_items,
|
||||||
|
u_group_nodes,
|
||||||
|
u_neighs,
|
||||||
|
u_gnew_neighs,
|
||||||
|
i_group_nodes,
|
||||||
|
i_neighs,
|
||||||
|
i_gnew_neighs,
|
||||||
|
neg_group_nodes,
|
||||||
|
neg_neighs,
|
||||||
|
neg_gnew_neighs,
|
||||||
|
10)
|
||||||
|
loss = self.loss_func(all_user_embed, all_user_rep, all_pos_item_embed,
|
||||||
|
all_pos_item_rep, neg_item_embed, neg_item_rep)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class TrainOneStepCell(nn.Cell):
|
||||||
|
r"""
|
||||||
|
Network training package class.
|
||||||
|
|
||||||
|
Wraps the network with an optimizer. The resulting Cell be trained with sample inputs.
|
||||||
|
Backward graph will be created in the construct function to do parameter updating. Different
|
||||||
|
parallel models are available to run the training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): The training network.
|
||||||
|
optimizer (Cell): Optimizer for updating the weights.
|
||||||
|
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, a scalar Tensor with shape :math:`()`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = Net()
|
||||||
|
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
|
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||||
|
>>> loss_net = nn.WithLossCell(net, loss_fn)
|
||||||
|
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, network, optimizer, sens=1.0):
|
||||||
|
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
||||||
|
|
||||||
|
self.network = network
|
||||||
|
self.network.add_flags(defer_inline=True)
|
||||||
|
self.weights = ParameterTuple(network.trainable_params())
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||||
|
self.sens = sens
|
||||||
|
|
||||||
|
def construct(self,
|
||||||
|
u_id,
|
||||||
|
pos_item_id,
|
||||||
|
neg_item_id,
|
||||||
|
pos_users,
|
||||||
|
pos_items,
|
||||||
|
u_group_nodes,
|
||||||
|
u_neighs,
|
||||||
|
u_gnew_neighs,
|
||||||
|
i_group_nodes,
|
||||||
|
i_neighs,
|
||||||
|
i_gnew_neighs,
|
||||||
|
neg_group_nodes,
|
||||||
|
neg_neighs,
|
||||||
|
neg_gnew_neighs):
|
||||||
|
"""Grad process"""
|
||||||
|
weights = self.weights
|
||||||
|
loss = self.network(u_id,
|
||||||
|
pos_item_id,
|
||||||
|
neg_item_id,
|
||||||
|
pos_users,
|
||||||
|
pos_items,
|
||||||
|
u_group_nodes,
|
||||||
|
u_neighs,
|
||||||
|
u_gnew_neighs,
|
||||||
|
i_group_nodes,
|
||||||
|
i_neighs,
|
||||||
|
i_gnew_neighs,
|
||||||
|
neg_group_nodes,
|
||||||
|
neg_neighs,
|
||||||
|
neg_gnew_neighs)
|
||||||
|
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||||
|
grads = self.grad(self.network, weights)(u_id,
|
||||||
|
pos_item_id,
|
||||||
|
neg_item_id,
|
||||||
|
pos_users,
|
||||||
|
pos_items,
|
||||||
|
u_group_nodes,
|
||||||
|
u_neighs,
|
||||||
|
u_gnew_neighs,
|
||||||
|
i_group_nodes,
|
||||||
|
i_neighs,
|
||||||
|
i_gnew_neighs,
|
||||||
|
neg_group_nodes,
|
||||||
|
neg_neighs,
|
||||||
|
neg_gnew_neighs,
|
||||||
|
sens)
|
||||||
|
return F.depend(loss, self.optimizer(grads))
|
||||||
|
|
||||||
|
|
||||||
|
class TrainBGCF(nn.Cell):
|
||||||
|
"""
|
||||||
|
Wraps the BGCF model with optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): BGCF network.
|
||||||
|
neg_item_num (Number): The num of negative instances for a positive instance.
|
||||||
|
l2_embed (Number): The coefficient of l2 loss.
|
||||||
|
learning_rate (Number): The learning rate.
|
||||||
|
epsilon (Number):The term added to the denominator to improve numerical stability.
|
||||||
|
dist_reg (Number): The coefficient of distance loss.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
network,
|
||||||
|
neg_item_num,
|
||||||
|
l2_embed,
|
||||||
|
learning_rate,
|
||||||
|
epsilon,
|
||||||
|
dist_reg=0.002):
|
||||||
|
super(TrainBGCF, self).__init__(auto_prefix=False)
|
||||||
|
|
||||||
|
self.network = network
|
||||||
|
loss_net = LossWrapper(network,
|
||||||
|
neg_item_num,
|
||||||
|
l2_embed,
|
||||||
|
dist_reg)
|
||||||
|
optimizer = nn.Adam(loss_net.trainable_params(),
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
eps=epsilon)
|
||||||
|
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
|
||||||
|
|
||||||
|
def construct(self,
|
||||||
|
u_id,
|
||||||
|
pos_item_id,
|
||||||
|
neg_item_id,
|
||||||
|
pos_users,
|
||||||
|
pos_items,
|
||||||
|
u_group_nodes,
|
||||||
|
u_neighs,
|
||||||
|
u_gnew_neighs,
|
||||||
|
i_group_nodes,
|
||||||
|
i_neighs,
|
||||||
|
i_gnew_neighs,
|
||||||
|
neg_group_nodes,
|
||||||
|
neg_neighs,
|
||||||
|
neg_gnew_neighs):
|
||||||
|
"""Return loss"""
|
||||||
|
loss = self.loss_train_net(u_id,
|
||||||
|
pos_item_id,
|
||||||
|
neg_item_id,
|
||||||
|
pos_users,
|
||||||
|
pos_items,
|
||||||
|
u_group_nodes,
|
||||||
|
u_neighs,
|
||||||
|
u_gnew_neighs,
|
||||||
|
i_group_nodes,
|
||||||
|
i_neighs,
|
||||||
|
i_gnew_neighs,
|
||||||
|
neg_group_nodes,
|
||||||
|
neg_neighs,
|
||||||
|
neg_gnew_neighs)
|
||||||
|
return loss
|
|
@ -0,0 +1,57 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
network config setting, will be used in train.py
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def parser_args():
|
||||||
|
"""Config for BGCF"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-d", "--dataset", type=str, default="Beauty")
|
||||||
|
parser.add_argument("-dpath", "--datapath", type=str, default="./scripts/data_mr")
|
||||||
|
parser.add_argument("-de", "--device", type=str, default='0')
|
||||||
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
|
parser.add_argument('--Ks', type=list, default=[5, 10, 20, 100])
|
||||||
|
parser.add_argument('--test_ratio', type=float, default=0.2)
|
||||||
|
parser.add_argument('--val_ratio', type=float, default=None)
|
||||||
|
parser.add_argument('-w', '--workers', type=int, default=10)
|
||||||
|
|
||||||
|
parser.add_argument("-eps", "--epsilon", type=float, default=1e-8)
|
||||||
|
parser.add_argument("-lr", "--learning_rate", type=float, default=1e-3)
|
||||||
|
parser.add_argument("-l2", "--l2", type=float, default=0.03)
|
||||||
|
parser.add_argument("-wd", "--weight_decay", type=float, default=0.01)
|
||||||
|
parser.add_argument("-act", "--activation", type=str, default='tanh', choices=['relu', 'tanh'])
|
||||||
|
parser.add_argument("-ndrop", "--neighbor_dropout", type=list, default=[0.0, 0.2, 0.3])
|
||||||
|
parser.add_argument("-log", "--log_name", type=str, default='test')
|
||||||
|
|
||||||
|
parser.add_argument("-e", "--num_epoch", type=int, default=600)
|
||||||
|
parser.add_argument('-input', '--input_dim', type=int, default=64, choices=[64, 128])
|
||||||
|
parser.add_argument("-b", "--batch_pairs", type=int, default=5000)
|
||||||
|
parser.add_argument('--eval_interval', type=int, default=20)
|
||||||
|
|
||||||
|
parser.add_argument("-neg", "--num_neg", type=int, default=10)
|
||||||
|
parser.add_argument('-max', '--max_degree', type=str, default='[128,128]')
|
||||||
|
parser.add_argument("-g1", "--raw_neighs", type=int, default=40)
|
||||||
|
parser.add_argument("-g2", "--gnew_neighs", type=int, default=20)
|
||||||
|
parser.add_argument("-emb", "--embedded_dimension", type=int, default=64)
|
||||||
|
parser.add_argument('-dist', '--distance', type=str, default='iou')
|
||||||
|
parser.add_argument('--dist_reg', type=float, default=0.003)
|
||||||
|
|
||||||
|
parser.add_argument('-ng', '--num_graphs', type=int, default=5)
|
||||||
|
parser.add_argument('-geps', '--graph_epsilon', type=float, default=0.01)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
|
@ -0,0 +1,191 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
preprocess raw data; generate batched data and sample neighbors on graph for training and test;
|
||||||
|
Amazon Beauty datasets are supported by our example, the original versions of these datasets are as follows:
|
||||||
|
@article{Amazon Beauty,
|
||||||
|
title = {Ups and Downs: Modeling the Visual Evolution of Fashion Trends with One-Class Collaborative Filtering},
|
||||||
|
author = {R. He, J. McAuley},
|
||||||
|
journal = {WWW},
|
||||||
|
year = {2016},
|
||||||
|
url = {http://jmcauley.ucsd.edu/data/amazon}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
|
||||||
|
|
||||||
|
class RandomBatchedSampler(ds.Sampler):
|
||||||
|
"""RandomBatchedSampler generate random sequence without replacement in a batched manner"""
|
||||||
|
|
||||||
|
sampled_graph_index = 0
|
||||||
|
|
||||||
|
def __init__(self, index_range, num_edges_per_sample):
|
||||||
|
super().__init__()
|
||||||
|
self.index_range = index_range
|
||||||
|
self.num_edges_per_sample = num_edges_per_sample
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self.sampled_graph_index += 1
|
||||||
|
indices = [i for i in range(self.index_range)]
|
||||||
|
np.random.shuffle(indices)
|
||||||
|
for i in range(0, self.index_range, self.num_edges_per_sample):
|
||||||
|
if i + self.num_edges_per_sample <= self.index_range:
|
||||||
|
result = indices[i: i + self.num_edges_per_sample]
|
||||||
|
result.append(self.sampled_graph_index)
|
||||||
|
yield result
|
||||||
|
|
||||||
|
|
||||||
|
class TrainGraphDataset():
|
||||||
|
"""Sample node neighbors on graphs for training"""
|
||||||
|
|
||||||
|
def __init__(self, train_graph, sampled_graphs, batch_num, num_samples, num_bgcn_neigh, num_neg):
|
||||||
|
self.g = train_graph
|
||||||
|
self.batch_num = batch_num
|
||||||
|
self.sampled_graphs = sampled_graphs
|
||||||
|
self.sampled_graph_num = len(sampled_graphs)
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.num_bgcn_neigh = num_bgcn_neigh
|
||||||
|
self.num_neg = num_neg
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.g.graph_info()['edge_num'][0] // self.batch_num
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
"""
|
||||||
|
Sample negative items with their neighbors, user neighbors, pos item neighbors
|
||||||
|
based on the user-item pairs
|
||||||
|
"""
|
||||||
|
sampled_graph_index = index[-1] % self.sampled_graph_num
|
||||||
|
index = index[0:-1]
|
||||||
|
train_graph = self.g
|
||||||
|
sampled_graph = self.sampled_graphs[sampled_graph_index]
|
||||||
|
|
||||||
|
rating = train_graph.get_nodes_from_edges(index.astype(np.int32))
|
||||||
|
users = rating[:, 0]
|
||||||
|
|
||||||
|
u_group_nodes = train_graph.get_sampled_neighbors(
|
||||||
|
node_list=users, neighbor_nums=[1], neighbor_types=[0])
|
||||||
|
pos_users = u_group_nodes[:, 1]
|
||||||
|
u_group_nodes = np.concatenate((users, pos_users), axis=0)
|
||||||
|
u_group_nodes = u_group_nodes.reshape(-1,).tolist()
|
||||||
|
u_neighs = train_graph.get_sampled_neighbors(
|
||||||
|
node_list=u_group_nodes, neighbor_nums=[self.num_samples], neighbor_types=[1])
|
||||||
|
u_neighs = u_neighs[:, 1:]
|
||||||
|
u_gnew_neighs = sampled_graph.get_sampled_neighbors(
|
||||||
|
node_list=u_group_nodes, neighbor_nums=[self.num_bgcn_neigh], neighbor_types=[1])
|
||||||
|
u_gnew_neighs = u_gnew_neighs[:, 1:]
|
||||||
|
|
||||||
|
items = rating[:, 1]
|
||||||
|
i_group_nodes = train_graph.get_sampled_neighbors(
|
||||||
|
node_list=items, neighbor_nums=[1], neighbor_types=[1])
|
||||||
|
pos_items = i_group_nodes[:, 1]
|
||||||
|
i_group_nodes = np.concatenate((items, pos_items), axis=0)
|
||||||
|
i_group_nodes = i_group_nodes.reshape(-1,).tolist()
|
||||||
|
i_neighs = train_graph.get_sampled_neighbors(
|
||||||
|
node_list=i_group_nodes, neighbor_nums=[self.num_samples], neighbor_types=[0])
|
||||||
|
i_neighs = i_neighs[:, 1:]
|
||||||
|
i_gnew_neighs = sampled_graph.get_sampled_neighbors(
|
||||||
|
node_list=i_group_nodes, neighbor_nums=[self.num_bgcn_neigh], neighbor_types=[0])
|
||||||
|
i_gnew_neighs = i_gnew_neighs[:, 1:]
|
||||||
|
|
||||||
|
neg_item_id = train_graph.get_neg_sampled_neighbors(
|
||||||
|
node_list=users, neg_neighbor_num=self.num_neg, neg_neighbor_type=1)
|
||||||
|
neg_item_id = neg_item_id[:, 1:]
|
||||||
|
neg_group_nodes = neg_item_id.reshape(-1,)
|
||||||
|
neg_neighs = train_graph.get_sampled_neighbors(
|
||||||
|
node_list=neg_group_nodes, neighbor_nums=[self.num_samples], neighbor_types=[0])
|
||||||
|
neg_neighs = neg_neighs[:, 1:]
|
||||||
|
neg_gnew_neighs = sampled_graph.get_sampled_neighbors(
|
||||||
|
node_list=neg_group_nodes, neighbor_nums=[self.num_bgcn_neigh], neighbor_types=[0])
|
||||||
|
neg_gnew_neighs = neg_gnew_neighs[:, 1:]
|
||||||
|
|
||||||
|
return users, items, neg_item_id, pos_users, pos_items, u_group_nodes, u_neighs, u_gnew_neighs, \
|
||||||
|
i_group_nodes, i_neighs, i_gnew_neighs, neg_group_nodes, neg_neighs, neg_gnew_neighs
|
||||||
|
|
||||||
|
|
||||||
|
class TestGraphDataset():
|
||||||
|
"""Sample node neighbors on graphs for test"""
|
||||||
|
|
||||||
|
def __init__(self, g, sampled_graphs, num_samples, num_bgcn_neigh, num_neg):
|
||||||
|
self.g = g
|
||||||
|
self.sampled_graphs = sampled_graphs
|
||||||
|
self.sampled_graph_index = 0
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.num_bgcn_neigh = num_bgcn_neigh
|
||||||
|
self.num_neg = num_neg
|
||||||
|
self.num_user = self.g.graph_info()["node_num"][0]
|
||||||
|
self.num_item = self.g.graph_info()["node_num"][1]
|
||||||
|
|
||||||
|
def random_select_sampled_graph(self):
|
||||||
|
self.sampled_graph_index = np.random.randint(len(self.sampled_graphs))
|
||||||
|
|
||||||
|
def get_user_sapmled_neighbor(self):
|
||||||
|
"""Sample all users neighbors for test"""
|
||||||
|
users = np.arange(self.num_user, dtype=np.int32)
|
||||||
|
u_neighs = self.g.get_sampled_neighbors(
|
||||||
|
node_list=users, neighbor_nums=[self.num_samples], neighbor_types=[1])
|
||||||
|
u_neighs = u_neighs[:, 1:]
|
||||||
|
sampled_graph = self.sampled_graphs[self.sampled_graph_index]
|
||||||
|
u_gnew_neighs = sampled_graph.get_sampled_neighbors(
|
||||||
|
node_list=users, neighbor_nums=[self.num_bgcn_neigh], neighbor_types=[1])
|
||||||
|
u_gnew_neighs = u_gnew_neighs[:, 1:]
|
||||||
|
return u_neighs, u_gnew_neighs
|
||||||
|
|
||||||
|
def get_item_sampled_neighbor(self):
|
||||||
|
"""Sample all items neighbors for test"""
|
||||||
|
items = np.arange(self.num_user, self.num_user + self.num_item, dtype=np.int32)
|
||||||
|
i_neighs = self.g.get_sampled_neighbors(
|
||||||
|
node_list=items, neighbor_nums=[self.num_samples], neighbor_types=[0])
|
||||||
|
i_neighs = i_neighs[:, 1:]
|
||||||
|
|
||||||
|
sampled_graph = self.sampled_graphs[self.sampled_graph_index]
|
||||||
|
i_gnew_neighs = sampled_graph.get_sampled_neighbors(
|
||||||
|
node_list=items, neighbor_nums=[self.num_bgcn_neigh], neighbor_types=[0])
|
||||||
|
i_gnew_neighs = i_gnew_neighs[:, 1:]
|
||||||
|
return i_neighs, i_gnew_neighs
|
||||||
|
|
||||||
|
|
||||||
|
def load_graph(data_path):
|
||||||
|
"""Load train graph, test graph and sampled graph"""
|
||||||
|
train_graph = ds.GraphData(
|
||||||
|
data_path + "/train_mr", num_parallel_workers=8)
|
||||||
|
|
||||||
|
test_graph = ds.GraphData(
|
||||||
|
data_path + "/test_mr", num_parallel_workers=8)
|
||||||
|
|
||||||
|
sampled_graph_list = []
|
||||||
|
for i in range(0, 5):
|
||||||
|
sampled_graph = ds.GraphData(
|
||||||
|
data_path + "/sampled" + str(i) + "_mr", num_parallel_workers=8)
|
||||||
|
sampled_graph_list.append(sampled_graph)
|
||||||
|
|
||||||
|
return train_graph, test_graph, sampled_graph_list
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(train_graph, sampled_graph_list, batch_size=32, repeat_size=1, num_samples=40, num_bgcn_neigh=20,
|
||||||
|
num_neg=10):
|
||||||
|
"""Data generator for training"""
|
||||||
|
edge_num = train_graph.graph_info()['edge_num'][0]
|
||||||
|
out_column_names = ["users", "items", "neg_item_id", "pos_users", "pos_items", "u_group_nodes", "u_neighs",
|
||||||
|
"u_gnew_neighs", "i_group_nodes", "i_neighs", "i_gnew_neighs", "neg_group_nodes",
|
||||||
|
"neg_neighs", "neg_gnew_neighs"]
|
||||||
|
train_graph_dataset = TrainGraphDataset(
|
||||||
|
train_graph, sampled_graph_list, batch_size, num_samples, num_bgcn_neigh, num_neg)
|
||||||
|
dataset = ds.GeneratorDataset(source=train_graph_dataset, column_names=out_column_names,
|
||||||
|
sampler=RandomBatchedSampler(edge_num, batch_size), num_parallel_workers=8)
|
||||||
|
dataset = dataset.repeat(repeat_size)
|
||||||
|
|
||||||
|
return dataset
|
|
@ -0,0 +1,184 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Recommendation metrics
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
import heapq
|
||||||
|
from multiprocessing import Pool
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from src.utils import convert_item_id
|
||||||
|
|
||||||
|
|
||||||
|
def ndcg_k(actual, predicted, topk):
|
||||||
|
"""Calculates the normalized discounted cumulative gain at k"""
|
||||||
|
idcg = idcg_k(actual, topk)
|
||||||
|
res = 0
|
||||||
|
|
||||||
|
dcg_k = sum([int(predicted[j] in set(actual)) / math.log(j + 2, 2) for j in range(topk)])
|
||||||
|
res += dcg_k / idcg
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def idcg_k(actual, k):
|
||||||
|
"""Calculates the ideal discounted cumulative gain at k"""
|
||||||
|
res = sum([1.0 / math.log(i + 2, 2) for i in range(min(k, len(actual)))])
|
||||||
|
return 1.0 if not res else res
|
||||||
|
|
||||||
|
|
||||||
|
def recall_at_k_2(r, k, all_pos_num):
|
||||||
|
"""Calculates the recall at k"""
|
||||||
|
r = np.asfarray(r)[:k]
|
||||||
|
return np.sum(r) / all_pos_num
|
||||||
|
|
||||||
|
|
||||||
|
def novelty_at_k(topk_items, item_degree_dict, num_user, k):
|
||||||
|
"""Calculate the novelty at k"""
|
||||||
|
avg_nov = []
|
||||||
|
for item in topk_items[:k]:
|
||||||
|
avg_nov.append(-np.log2(item_degree_dict[item] / num_user))
|
||||||
|
return np.mean(avg_nov)
|
||||||
|
|
||||||
|
|
||||||
|
def ranklist_by_heapq(user_pos_test, test_items, rating, Ks):
|
||||||
|
"""Return the n largest score from the item_score by heap algorithm"""
|
||||||
|
item_score = {}
|
||||||
|
for i in test_items:
|
||||||
|
item_score[i] = rating[i]
|
||||||
|
|
||||||
|
K_max = max(Ks)
|
||||||
|
K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get)
|
||||||
|
|
||||||
|
r = []
|
||||||
|
for i in K_max_item_score:
|
||||||
|
if i in user_pos_test:
|
||||||
|
r.append(1)
|
||||||
|
else:
|
||||||
|
r.append(0)
|
||||||
|
return r, K_max_item_score
|
||||||
|
|
||||||
|
|
||||||
|
def get_performance(user_pos_test, r, K_max_item, item_degree_dict, num_user, Ks):
|
||||||
|
"""Wraps the model metrics"""
|
||||||
|
recall, ndcg, novelty = [], [], []
|
||||||
|
for K in Ks:
|
||||||
|
recall.append(recall_at_k_2(r, K, len(user_pos_test)))
|
||||||
|
ndcg.append(ndcg_k(user_pos_test, K_max_item, K))
|
||||||
|
novelty.append(novelty_at_k(K_max_item, item_degree_dict, num_user, K))
|
||||||
|
return {'recall': np.array(recall), 'ndcg': np.array(ndcg), 'nov': np.array(novelty)}
|
||||||
|
|
||||||
|
|
||||||
|
class BGCFEvaluate:
|
||||||
|
"""
|
||||||
|
Evaluate the model recommendation performance
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, parser, train_graph, test_graph, Ks):
|
||||||
|
self.num_user = train_graph.graph_info()["node_num"][0]
|
||||||
|
self.num_item = train_graph.graph_info()["node_num"][1]
|
||||||
|
self.Ks = Ks
|
||||||
|
|
||||||
|
self.test_set = []
|
||||||
|
self.train_set = []
|
||||||
|
for i in range(0, self.num_user):
|
||||||
|
train_item = train_graph.get_all_neighbors(node_list=[i], neighbor_type=1)
|
||||||
|
train_item = train_item[1:]
|
||||||
|
self.train_set.append(train_item)
|
||||||
|
for i in range(0, self.num_user):
|
||||||
|
test_item = test_graph.get_all_neighbors(node_list=[i], neighbor_type=1)
|
||||||
|
test_item = test_item[1:]
|
||||||
|
self.test_set.append(test_item)
|
||||||
|
self.train_set = convert_item_id(self.train_set, self.num_user).tolist()
|
||||||
|
self.test_set = convert_item_id(self.test_set, self.num_user).tolist()
|
||||||
|
|
||||||
|
self.item_deg_dict = {}
|
||||||
|
self.item_full_set = []
|
||||||
|
for i in range(self.num_user, self.num_user + self.num_item):
|
||||||
|
train_users = train_graph.get_all_neighbors(node_list=[i], neighbor_type=0)
|
||||||
|
train_users = train_users.tolist()
|
||||||
|
if isinstance(train_users, int):
|
||||||
|
train_users = []
|
||||||
|
else:
|
||||||
|
train_users = train_users[1:]
|
||||||
|
self.item_deg_dict[i - self.num_user] = len(train_users)
|
||||||
|
test_users = test_graph.get_all_neighbors(node_list=[i], neighbor_type=0)
|
||||||
|
test_users = test_users.tolist()
|
||||||
|
if isinstance(test_users, int):
|
||||||
|
test_users = []
|
||||||
|
else:
|
||||||
|
test_users = test_users[1:]
|
||||||
|
self.item_full_set.append(train_users + test_users)
|
||||||
|
|
||||||
|
def test_one_user(self, x):
|
||||||
|
"""Calculate one user metrics"""
|
||||||
|
rating = x[0]
|
||||||
|
u = x[1]
|
||||||
|
|
||||||
|
training_items = self.train_set[u]
|
||||||
|
|
||||||
|
user_pos_test = self.test_set[u]
|
||||||
|
|
||||||
|
all_items = set(range(self.num_item))
|
||||||
|
|
||||||
|
test_items = list(all_items - set(training_items))
|
||||||
|
|
||||||
|
r, k_max_items = ranklist_by_heapq(user_pos_test, test_items, rating, self.Ks)
|
||||||
|
|
||||||
|
return get_performance(user_pos_test, r, k_max_items, self.item_deg_dict, self.num_user, self.Ks), \
|
||||||
|
[k_max_items[:self.Ks[x]] for x in range(len(self.Ks))]
|
||||||
|
|
||||||
|
def eval_with_rep(self, user_rep, item_rep, parser):
|
||||||
|
"""Evaluation with user and item rep"""
|
||||||
|
result = {'recall': np.zeros(len(self.Ks)), 'ndcg': np.zeros(len(self.Ks)),
|
||||||
|
'nov': np.zeros(len(self.Ks))}
|
||||||
|
pool = Pool(parser.workers)
|
||||||
|
user_indexes = np.arange(self.num_user)
|
||||||
|
|
||||||
|
rating_preds = user_rep @ item_rep.transpose()
|
||||||
|
user_rating_uid = zip(rating_preds, user_indexes)
|
||||||
|
all_result = pool.map(self.test_one_user, user_rating_uid)
|
||||||
|
|
||||||
|
top20 = []
|
||||||
|
|
||||||
|
for re in all_result:
|
||||||
|
result['recall'] += re[0]['recall'] / self.num_user
|
||||||
|
result['ndcg'] += re[0]['ndcg'] / self.num_user
|
||||||
|
result['nov'] += re[0]['nov'] / self.num_user
|
||||||
|
top20.append(re[1][2])
|
||||||
|
|
||||||
|
pool.close()
|
||||||
|
|
||||||
|
sedp = [[] for i in range(len(self.Ks) - 1)]
|
||||||
|
|
||||||
|
num_all_links = np.sum([len(x) for x in self.item_full_set])
|
||||||
|
|
||||||
|
for k in range(len(self.Ks) - 1):
|
||||||
|
for u in range(self.num_user):
|
||||||
|
diff = []
|
||||||
|
pred_items_at_k = all_result[u][1][k]
|
||||||
|
for item in pred_items_at_k:
|
||||||
|
if item in self.test_set[u]:
|
||||||
|
avg_prob_all_user = len(self.item_full_set[item]) / num_all_links
|
||||||
|
diff.append(max((self.Ks[k] - pred_items_at_k.index(item) - 1)
|
||||||
|
/ (self.Ks[k] - 1) - avg_prob_all_user, 0))
|
||||||
|
one_user_sedp = sum(diff) / self.Ks[k]
|
||||||
|
sedp[k].append(one_user_sedp)
|
||||||
|
|
||||||
|
sedp = np.array(sedp).mean(1)
|
||||||
|
|
||||||
|
return result['recall'].tolist(), result['ndcg'].tolist(), \
|
||||||
|
[sedp[1], sedp[2]], result['nov'].tolist()
|
|
@ -0,0 +1,67 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Utils for training BGCF"""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import glob
|
||||||
|
import shutil
|
||||||
|
import pickle as pkl
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def load_pickle(path, name):
|
||||||
|
"""Load pickle"""
|
||||||
|
with open(path + name, 'rb') as f:
|
||||||
|
return pkl.load(f, encoding='latin1')
|
||||||
|
|
||||||
|
|
||||||
|
class BGCFLogger:
|
||||||
|
"""log the output metrics"""
|
||||||
|
|
||||||
|
def __init__(self, logname, now, foldername, copy):
|
||||||
|
self.terminal = sys.stdout
|
||||||
|
self.file = None
|
||||||
|
|
||||||
|
path = os.path.join(foldername, logname, now)
|
||||||
|
os.makedirs(path)
|
||||||
|
|
||||||
|
if copy:
|
||||||
|
filenames = glob.glob('*.py')
|
||||||
|
for filename in filenames:
|
||||||
|
shutil.copy(filename, path)
|
||||||
|
|
||||||
|
def open(self, file, mode=None):
|
||||||
|
if mode is None:
|
||||||
|
mode = 'w'
|
||||||
|
self.file = open(file, mode)
|
||||||
|
|
||||||
|
def write(self, message, is_terminal=True, is_file=True):
|
||||||
|
"""Write log"""
|
||||||
|
if '\r' in message:
|
||||||
|
is_file = False
|
||||||
|
|
||||||
|
if is_terminal:
|
||||||
|
self.terminal.write(message)
|
||||||
|
self.terminal.flush()
|
||||||
|
|
||||||
|
if is_file:
|
||||||
|
self.file.write(message)
|
||||||
|
self.file.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def convert_item_id(item_list, num_user):
|
||||||
|
"""Convert the graph node id into item id"""
|
||||||
|
return np.array(item_list) - num_user
|
|
@ -0,0 +1,173 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
BGCF training script.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from mindspore import Tensor
|
||||||
|
import mindspore.context as context
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
||||||
|
|
||||||
|
from src.bgcf import BGCF
|
||||||
|
from src.metrics import BGCFEvaluate
|
||||||
|
from src.config import parser_args
|
||||||
|
from src.utils import BGCFLogger, convert_item_id
|
||||||
|
from src.callback import ForwardBGCF, TrainBGCF, TestBGCF
|
||||||
|
from src.dataset import load_graph, create_dataset, TestGraphDataset
|
||||||
|
|
||||||
|
|
||||||
|
def train_and_eval():
|
||||||
|
"""Train and eval"""
|
||||||
|
num_user = train_graph.graph_info()["node_num"][0]
|
||||||
|
num_item = train_graph.graph_info()["node_num"][1]
|
||||||
|
num_pairs = train_graph.graph_info()['edge_num'][0]
|
||||||
|
|
||||||
|
bgcfnet = BGCF([parser.input_dim, num_user, num_item],
|
||||||
|
parser.embedded_dimension,
|
||||||
|
parser.activation,
|
||||||
|
parser.neighbor_dropout,
|
||||||
|
num_user,
|
||||||
|
num_item,
|
||||||
|
parser.input_dim)
|
||||||
|
|
||||||
|
train_net = TrainBGCF(bgcfnet, parser.num_neg, parser.l2, parser.learning_rate,
|
||||||
|
parser.epsilon, parser.dist_reg)
|
||||||
|
train_net.set_train(True)
|
||||||
|
|
||||||
|
eval_class = BGCFEvaluate(parser, train_graph, test_graph, parser.Ks)
|
||||||
|
|
||||||
|
itr = train_ds.create_dict_iterator(parser.num_epoch)
|
||||||
|
num_iter = int(num_pairs / parser.batch_pairs)
|
||||||
|
|
||||||
|
for _epoch in range(1, parser.num_epoch + 1):
|
||||||
|
|
||||||
|
iter_num = 1
|
||||||
|
|
||||||
|
for data in itr:
|
||||||
|
|
||||||
|
u_id = Tensor(data["users"], mstype.int32)
|
||||||
|
pos_item_id = Tensor(convert_item_id(data["items"], num_user), mstype.int32)
|
||||||
|
neg_item_id = Tensor(convert_item_id(data["neg_item_id"], num_user), mstype.int32)
|
||||||
|
pos_users = Tensor(data["pos_users"], mstype.int32)
|
||||||
|
pos_items = Tensor(convert_item_id(data["pos_items"], num_user), mstype.int32)
|
||||||
|
|
||||||
|
u_group_nodes = Tensor(data["u_group_nodes"], mstype.int32)
|
||||||
|
u_neighs = Tensor(convert_item_id(data["u_neighs"], num_user), mstype.int32)
|
||||||
|
u_gnew_neighs = Tensor(convert_item_id(data["u_gnew_neighs"], num_user), mstype.int32)
|
||||||
|
|
||||||
|
i_group_nodes = Tensor(convert_item_id(data["i_group_nodes"], num_user), mstype.int32)
|
||||||
|
i_neighs = Tensor(data["i_neighs"], mstype.int32)
|
||||||
|
i_gnew_neighs = Tensor(data["i_gnew_neighs"], mstype.int32)
|
||||||
|
|
||||||
|
neg_group_nodes = Tensor(convert_item_id(data["neg_group_nodes"], num_user), mstype.int32)
|
||||||
|
neg_neighs = Tensor(data["neg_neighs"], mstype.int32)
|
||||||
|
neg_gnew_neighs = Tensor(data["neg_gnew_neighs"], mstype.int32)
|
||||||
|
|
||||||
|
train_loss = train_net(u_id,
|
||||||
|
pos_item_id,
|
||||||
|
neg_item_id,
|
||||||
|
pos_users,
|
||||||
|
pos_items,
|
||||||
|
u_group_nodes,
|
||||||
|
u_neighs,
|
||||||
|
u_gnew_neighs,
|
||||||
|
i_group_nodes,
|
||||||
|
i_neighs,
|
||||||
|
i_gnew_neighs,
|
||||||
|
neg_group_nodes,
|
||||||
|
neg_neighs,
|
||||||
|
neg_gnew_neighs)
|
||||||
|
|
||||||
|
if iter_num == num_iter:
|
||||||
|
print('Epoch', '%03d' % _epoch, 'iter', '%02d' % iter_num,
|
||||||
|
'loss',
|
||||||
|
'{}'.format(train_loss))
|
||||||
|
iter_num += 1
|
||||||
|
|
||||||
|
if _epoch % parser.eval_interval == 0:
|
||||||
|
if os.path.exists("ckpts/bgcf.ckpt"):
|
||||||
|
os.remove("ckpts/bgcf.ckpt")
|
||||||
|
save_checkpoint(bgcfnet, "ckpts/bgcf.ckpt")
|
||||||
|
|
||||||
|
bgcfnet_test = BGCF([parser.input_dim, num_user, num_item],
|
||||||
|
parser.embedded_dimension,
|
||||||
|
parser.activation,
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
num_user,
|
||||||
|
num_item,
|
||||||
|
parser.input_dim)
|
||||||
|
|
||||||
|
load_checkpoint("ckpts/bgcf.ckpt", net=bgcfnet_test)
|
||||||
|
|
||||||
|
forward_net = ForwardBGCF(bgcfnet_test)
|
||||||
|
user_reps, item_reps = TestBGCF(forward_net, num_user, num_item, parser.input_dim, test_graph_dataset)
|
||||||
|
|
||||||
|
test_recall_bgcf, test_ndcg_bgcf, \
|
||||||
|
test_sedp, test_nov = eval_class.eval_with_rep(user_reps, item_reps, parser)
|
||||||
|
|
||||||
|
if parser.log_name:
|
||||||
|
log.write(
|
||||||
|
'epoch:%03d, recall_@10:%.5f, recall_@20:%.5f, ndcg_@10:%.5f, ndcg_@20:%.5f, '
|
||||||
|
'sedp_@10:%.5f, sedp_@20:%.5f, nov_@10:%.5f, nov_@20:%.5f\n' % (_epoch,
|
||||||
|
test_recall_bgcf[1],
|
||||||
|
test_recall_bgcf[2],
|
||||||
|
test_ndcg_bgcf[1],
|
||||||
|
test_ndcg_bgcf[2],
|
||||||
|
test_sedp[0],
|
||||||
|
test_sedp[1],
|
||||||
|
test_nov[1],
|
||||||
|
test_nov[2]))
|
||||||
|
else:
|
||||||
|
print('epoch:%03d, recall_@10:%.5f, recall_@20:%.5f, ndcg_@10:%.5f, ndcg_@20:%.5f, '
|
||||||
|
'sedp_@10:%.5f, sedp_@20:%.5f, nov_@10:%.5f, nov_@20:%.5f\n' % (_epoch,
|
||||||
|
test_recall_bgcf[1],
|
||||||
|
test_recall_bgcf[2],
|
||||||
|
test_ndcg_bgcf[1],
|
||||||
|
test_ndcg_bgcf[2],
|
||||||
|
test_sedp[0],
|
||||||
|
test_sedp[1],
|
||||||
|
test_nov[1],
|
||||||
|
test_nov[2]))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
|
device_target="Ascend",
|
||||||
|
save_graphs=False)
|
||||||
|
|
||||||
|
parser = parser_args()
|
||||||
|
|
||||||
|
train_graph, test_graph, sampled_graph_list = load_graph(parser.datapath)
|
||||||
|
train_ds = create_dataset(train_graph, sampled_graph_list, batch_size=parser.batch_pairs)
|
||||||
|
test_graph_dataset = TestGraphDataset(train_graph, sampled_graph_list, num_samples=parser.raw_neighs,
|
||||||
|
num_bgcn_neigh=parser.gnew_neighs,
|
||||||
|
num_neg=parser.num_neg)
|
||||||
|
|
||||||
|
if parser.log_name:
|
||||||
|
now = datetime.datetime.now().strftime("%b_%d_%H_%M_%S")
|
||||||
|
name = "bgcf" + '-' + parser.log_name + '-' + parser.dataset
|
||||||
|
log_save_path = './log-files/' + name + '/' + now
|
||||||
|
log = BGCFLogger(logname=name, now=now, foldername='log-files', copy=False)
|
||||||
|
log.open(log_save_path + '/log.train.txt', mode='a')
|
||||||
|
for arg in vars(parser):
|
||||||
|
log.write(arg + '=' + str(getattr(parser, arg)) + '\n')
|
||||||
|
else:
|
||||||
|
for arg in vars(parser):
|
||||||
|
print(arg + '=' + str(getattr(parser, arg)))
|
||||||
|
|
||||||
|
train_and_eval()
|
|
@ -0,0 +1,365 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
Preprocess data.
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import argparse
|
||||||
|
import pickle as pkl
|
||||||
|
from copy import deepcopy
|
||||||
|
from functools import partial
|
||||||
|
from multiprocessing import Pool
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from scipy.sparse import csr_matrix
|
||||||
|
from sklearn.neighbors import kneighbors_graph
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
|
|
||||||
|
def load_pickle(path, name):
|
||||||
|
"""Load pickle"""
|
||||||
|
with open(path + name, 'rb') as f:
|
||||||
|
return pkl.load(f, encoding='latin1')
|
||||||
|
|
||||||
|
|
||||||
|
def generate_inverse_mapping(data_list):
|
||||||
|
"""Generate inverse id map"""
|
||||||
|
ds_matrix_mapping = dict()
|
||||||
|
for inner_id, true_id in enumerate(data_list):
|
||||||
|
ds_matrix_mapping[true_id] = inner_id
|
||||||
|
return ds_matrix_mapping
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_inner_index(user_records, user_mapping, item_mapping):
|
||||||
|
"""Convert real id to inner id"""
|
||||||
|
inner_user_records = []
|
||||||
|
user_inverse_mapping = generate_inverse_mapping(user_mapping)
|
||||||
|
item_inverse_mapping = generate_inverse_mapping(item_mapping)
|
||||||
|
for user_id, _ in enumerate(user_mapping):
|
||||||
|
real_user_id = user_mapping[user_id]
|
||||||
|
item_list = list(user_records[real_user_id])
|
||||||
|
for index, real_item_id in enumerate(item_list):
|
||||||
|
item_list[index] = item_inverse_mapping[real_item_id]
|
||||||
|
inner_user_records.append(item_list)
|
||||||
|
return inner_user_records, user_inverse_mapping, item_inverse_mapping
|
||||||
|
|
||||||
|
|
||||||
|
def split_data_randomly(user_records, test_ratio, seed=0):
|
||||||
|
"""Split data"""
|
||||||
|
print('seed %d ' % seed)
|
||||||
|
train_set = []
|
||||||
|
test_set = []
|
||||||
|
for _, item_list in enumerate(user_records):
|
||||||
|
tmp_train_sample, tmp_test_sample = train_test_split(
|
||||||
|
item_list, test_size=test_ratio, random_state=seed)
|
||||||
|
|
||||||
|
train_sample = []
|
||||||
|
for place in item_list:
|
||||||
|
if place not in tmp_test_sample:
|
||||||
|
train_sample.append(place)
|
||||||
|
|
||||||
|
test_sample = []
|
||||||
|
for place in tmp_test_sample:
|
||||||
|
if place not in tmp_train_sample:
|
||||||
|
test_sample.append(place)
|
||||||
|
|
||||||
|
train_set.append(train_sample)
|
||||||
|
test_set.append(test_sample)
|
||||||
|
|
||||||
|
return train_set, test_set
|
||||||
|
|
||||||
|
|
||||||
|
def create_adj_matrix(train_matrix):
|
||||||
|
"""Create adj matrix"""
|
||||||
|
user2item, item2user = {}, {}
|
||||||
|
user_item_ratings = train_matrix.toarray()
|
||||||
|
for i, _ in enumerate(user_item_ratings):
|
||||||
|
neigh_items = np.where(user_item_ratings[i] != 0)[0].tolist()
|
||||||
|
user2item[i] = set(neigh_items)
|
||||||
|
item_user_ratings = user_item_ratings.transpose()
|
||||||
|
for j, _ in enumerate(item_user_ratings):
|
||||||
|
neigh_users = np.where(item_user_ratings[j] != 0)[0].tolist()
|
||||||
|
item2user[j] = set(neigh_users)
|
||||||
|
return user2item, item2user
|
||||||
|
|
||||||
|
|
||||||
|
def generate_rating_matrix(train_set, num_users, num_items, user_shift=0, item_shift=0):
|
||||||
|
"""Generate rating matrix"""
|
||||||
|
row = []
|
||||||
|
col = []
|
||||||
|
data = []
|
||||||
|
for user_id, article_list in enumerate(train_set):
|
||||||
|
for article in article_list:
|
||||||
|
row.append(user_id + user_shift)
|
||||||
|
col.append(article + item_shift)
|
||||||
|
data.append(1)
|
||||||
|
row = np.array(row)
|
||||||
|
col = np.array(col)
|
||||||
|
data = np.array(data)
|
||||||
|
rating_matrix = csr_matrix(
|
||||||
|
(data, (row, col)), shape=(num_users, num_items))
|
||||||
|
return rating_matrix
|
||||||
|
|
||||||
|
|
||||||
|
def flatten(distance, adj, thre=10):
|
||||||
|
"""Flatten the distance matrix for the smoother sampling"""
|
||||||
|
print('start flattening the dataset with threshold = {}'.format(thre))
|
||||||
|
top_ids = np.argsort(distance, 1)[:, -thre:]
|
||||||
|
|
||||||
|
flat_distance = np.zeros_like(distance)
|
||||||
|
values = 1 / thre
|
||||||
|
for i, _ in enumerate(flat_distance):
|
||||||
|
adj_len = len(adj[i])
|
||||||
|
if adj_len == 0 or adj_len > thre:
|
||||||
|
flat_distance[i][top_ids[i]] = values
|
||||||
|
else:
|
||||||
|
flat_distance[i][top_ids[i][thre - adj_len:]] = 1 / adj_len
|
||||||
|
return flat_distance
|
||||||
|
|
||||||
|
|
||||||
|
def sample_graph_copying(node_neighbors_dict, distances, epsilon=0.01, seed=0, set_seed=False):
|
||||||
|
"""node copying node by node"""
|
||||||
|
if set_seed:
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
N = len(distances)
|
||||||
|
|
||||||
|
sampled_graph = dict()
|
||||||
|
nodes = np.arange(0, N).astype(np.int)
|
||||||
|
|
||||||
|
for i in range(N):
|
||||||
|
if random.uniform(0, 1) < 1 - epsilon:
|
||||||
|
sampled_node = np.random.choice(nodes, 1, p=distances[i])
|
||||||
|
else:
|
||||||
|
sampled_node = [i]
|
||||||
|
|
||||||
|
sampled_graph[i] = node_neighbors_dict[sampled_node[0]]
|
||||||
|
|
||||||
|
return sampled_graph
|
||||||
|
|
||||||
|
|
||||||
|
def remove_infrequent_users(data, min_counts=10):
|
||||||
|
"""Remove infrequent users"""
|
||||||
|
df = deepcopy(data)
|
||||||
|
counts = df['user_id'].value_counts()
|
||||||
|
df = df[df["user_id"].isin(counts[counts >= min_counts].index)]
|
||||||
|
|
||||||
|
print("users with < {} interactoins are removed".format(min_counts))
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def remove_infrequent_items(data, min_counts=5):
|
||||||
|
"""Remove infrequent items"""
|
||||||
|
df = deepcopy(data)
|
||||||
|
counts = df['item_id'].value_counts()
|
||||||
|
df = df[df["item_id"].isin(counts[counts >= min_counts].index)]
|
||||||
|
|
||||||
|
print("items with < {} interactoins are removed".format(min_counts))
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def save_obj(obj, data_path, name):
|
||||||
|
"""Save object"""
|
||||||
|
with open(data_path + "/" + name + '.pkl', 'wb') as f:
|
||||||
|
pkl.dump(obj, f)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_data(data_path, data_name):
|
||||||
|
"""Preprocess data"""
|
||||||
|
rating_file = 'ratings_{}.csv'.format(data_name)
|
||||||
|
col_names = ['user_id', 'item_id', 'rating', 'timestamp']
|
||||||
|
data_records = pd.read_csv(data_path + "/" + rating_file, sep=',', names=col_names, engine='python')
|
||||||
|
|
||||||
|
data_records.loc[data_records.rating != 0, 'rating'] = 1
|
||||||
|
data_records = data_records[data_records.rating > 0]
|
||||||
|
filtered_data = remove_infrequent_users(data_records, 10)
|
||||||
|
filtered_data = remove_infrequent_items(filtered_data, 10)
|
||||||
|
|
||||||
|
data = filtered_data.groupby('user_id')['item_id'].apply(list)
|
||||||
|
unique_data = filtered_data.groupby('user_id')['item_id'].nunique()
|
||||||
|
data = data[unique_data[unique_data >= 5].index]
|
||||||
|
|
||||||
|
user_item_dict = data.to_dict()
|
||||||
|
user_mapping = []
|
||||||
|
item_set = set()
|
||||||
|
for user_id, item_list in data.iteritems():
|
||||||
|
user_mapping.append(user_id)
|
||||||
|
for item_id in item_list:
|
||||||
|
item_set.add(item_id)
|
||||||
|
item_mapping = list(item_set)
|
||||||
|
return user_item_dict, user_mapping, item_mapping
|
||||||
|
|
||||||
|
|
||||||
|
def iou_set(set1, set2):
|
||||||
|
"""Calculate iou_set """
|
||||||
|
union = set1.union(set2)
|
||||||
|
return len(set1.intersection(set2)) / len(union) if union else 0
|
||||||
|
|
||||||
|
|
||||||
|
def build_func(train_set, data):
|
||||||
|
"""Build function"""
|
||||||
|
res = []
|
||||||
|
res.append([iou_set(set(train_set), x) for x in data.values()])
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def build_distance_mp_map(train_set, u_adj_list, v_adj_list, num_workers=5, tag='user', norm=True):
|
||||||
|
"""Build distance matrix"""
|
||||||
|
start = time.time()
|
||||||
|
pool = Pool(processes=num_workers)
|
||||||
|
|
||||||
|
if tag == 'user':
|
||||||
|
results = pool.map_async(partial(build_func, data=u_adj_list), train_set)
|
||||||
|
|
||||||
|
if tag == 'item':
|
||||||
|
results = pool.map_async(partial(build_func, data=v_adj_list), train_set)
|
||||||
|
|
||||||
|
results.wait()
|
||||||
|
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
|
|
||||||
|
distances = np.array(results.get()).squeeze(1)
|
||||||
|
np.fill_diagonal(distances, 0)
|
||||||
|
print('=== info: elapsed time with mp for building ' + tag + ' distance matrix: ', time.time() - start)
|
||||||
|
|
||||||
|
for i, _ in enumerate(distances):
|
||||||
|
if sum(distances[i]) == 0:
|
||||||
|
distances[i] = 1.
|
||||||
|
|
||||||
|
if norm:
|
||||||
|
distances = distances / np.sum(distances, axis=1).reshape(-1, 1)
|
||||||
|
distances.astype(np.float16)
|
||||||
|
return distances
|
||||||
|
|
||||||
|
|
||||||
|
def trans(src_path, data_name, out_path):
|
||||||
|
"""Convert into MindSpore data"""
|
||||||
|
print('=== loading datasets')
|
||||||
|
user_records, user_mapping, item_mapping = preprocess_data(src_path, data_name)
|
||||||
|
inner_data_records, user_inverse_mapping, \
|
||||||
|
item_inverse_mapping = convert_to_inner_index(
|
||||||
|
user_records, user_mapping, item_mapping)
|
||||||
|
|
||||||
|
test_ratio = 0.2
|
||||||
|
train_set, test_set = split_data_randomly(
|
||||||
|
inner_data_records, test_ratio=test_ratio, seed=0)
|
||||||
|
train_matrix = generate_rating_matrix(
|
||||||
|
train_set, len(user_mapping), len(item_mapping))
|
||||||
|
u_adj_list, v_adj_list = create_adj_matrix(train_matrix)
|
||||||
|
num_user, num_item = train_matrix.shape
|
||||||
|
|
||||||
|
print('=== building user-user grpah and item-item graph')
|
||||||
|
num_self_neigh = 10
|
||||||
|
user_user_graph = kneighbors_graph(train_matrix, num_self_neigh,
|
||||||
|
mode='connectivity', include_self=False)
|
||||||
|
user_self_neighs = user_user_graph.tocoo().col
|
||||||
|
user_self_neighs = np.array(np.array_split(user_self_neighs, num_user)).tolist()
|
||||||
|
|
||||||
|
item_item_graph = kneighbors_graph(train_matrix.transpose(), num_self_neigh,
|
||||||
|
mode='connectivity', include_self=False)
|
||||||
|
item_self_neighs = item_item_graph.tocoo().col
|
||||||
|
item_self_neighs = np.array(np.array_split(item_self_neighs, num_item)).tolist()
|
||||||
|
|
||||||
|
assert len(train_set) == len(user_self_neighs)
|
||||||
|
|
||||||
|
user_distances = build_distance_mp_map(train_set, u_adj_list, v_adj_list, num_workers=10, tag='user', norm=True)
|
||||||
|
user_distances = flatten(user_distances, u_adj_list,
|
||||||
|
thre=10)
|
||||||
|
|
||||||
|
item_start_id = num_user
|
||||||
|
user_file = out_path + "/user.csv"
|
||||||
|
item_file = out_path + "/item.csv"
|
||||||
|
train_file = out_path + "/rating_train.csv"
|
||||||
|
test_file = out_path + "/rating_test.csv"
|
||||||
|
with open(user_file, 'a+') as user_f:
|
||||||
|
for k in user_inverse_mapping:
|
||||||
|
print(k + ',' + str(user_inverse_mapping[k]), file=user_f)
|
||||||
|
with open(item_file, 'a+') as item_f:
|
||||||
|
for k in item_inverse_mapping:
|
||||||
|
print(k + ',' + str(item_inverse_mapping[k] + item_start_id), file=item_f)
|
||||||
|
with open(train_file, 'a+') as train_f:
|
||||||
|
print("src_id,dst_id,type", file=train_f)
|
||||||
|
for user in u_adj_list:
|
||||||
|
for item in sorted(list(u_adj_list[user])):
|
||||||
|
print(str(user) + ',' + str(item + item_start_id) + ',0', file=train_f)
|
||||||
|
for item in v_adj_list:
|
||||||
|
for user in v_adj_list[item]:
|
||||||
|
print(str(item + item_start_id) + ',' + str(user) + ',1', file=train_f)
|
||||||
|
src_user = 0
|
||||||
|
for users in user_self_neighs:
|
||||||
|
for dst_user in users:
|
||||||
|
print(str(src_user) + ',' + str(dst_user) + ',2', file=train_f)
|
||||||
|
src_user += 1
|
||||||
|
src_item = 0
|
||||||
|
for items in item_self_neighs:
|
||||||
|
for dst_item in items:
|
||||||
|
print(str(src_item + item_start_id) + ',' + str(dst_item + item_start_id) + ',3', file=train_f)
|
||||||
|
src_item += 1
|
||||||
|
with open(test_file, 'a+') as test_f:
|
||||||
|
print("src_id,dst_id,type", file=test_f)
|
||||||
|
user = 0
|
||||||
|
for items in test_set:
|
||||||
|
for item in items:
|
||||||
|
print(str(user) + ',' + str(item + item_start_id) + ',0', file=test_f)
|
||||||
|
user += 1
|
||||||
|
user = 0
|
||||||
|
for items in test_set:
|
||||||
|
for item in items:
|
||||||
|
print(str(item + item_start_id) + ',' + str(user) + ',1', file=test_f)
|
||||||
|
user += 1
|
||||||
|
|
||||||
|
print('start generating sampled graphs...')
|
||||||
|
num_graphs = 5
|
||||||
|
for i in range(num_graphs):
|
||||||
|
print('=== info: sampling graph {} / {}'.format(i + 1, num_graphs))
|
||||||
|
sampled_user_graph = sample_graph_copying(node_neighbors_dict=u_adj_list,
|
||||||
|
distances=user_distances)
|
||||||
|
|
||||||
|
print('avg. sampled user-item graph degree: ',
|
||||||
|
np.mean([len(x) for x in [*sampled_user_graph.values()]]))
|
||||||
|
|
||||||
|
sampled_item_graph = {x: set() for x in range(num_item)}
|
||||||
|
|
||||||
|
for k, items in sampled_user_graph.items():
|
||||||
|
for x in items:
|
||||||
|
sampled_item_graph[x].add(k)
|
||||||
|
|
||||||
|
print('avg. sampled item-user graph degree: ',
|
||||||
|
np.mean([len(x) for x in [*sampled_item_graph.values()]]))
|
||||||
|
|
||||||
|
sampled_file = out_path + "/rating_sampled" + str(i) + ".csv"
|
||||||
|
with open(sampled_file, 'a+') as sampled_f:
|
||||||
|
print("src_id,dst_id,type", file=sampled_f)
|
||||||
|
for user in sampled_user_graph:
|
||||||
|
for item in sampled_user_graph[user]:
|
||||||
|
print(str(user) + ',' + str(item + item_start_id) + ',0', file=sampled_f)
|
||||||
|
for item in sampled_item_graph:
|
||||||
|
for user in sampled_item_graph[item]:
|
||||||
|
print(str(item + item_start_id) + ',' + str(user) + ',1', file=sampled_f)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description='Converting Data')
|
||||||
|
parser.add_argument('--src_path', type=str, default="/tmp/",
|
||||||
|
help='source data directory')
|
||||||
|
parser.add_argument('--out_path', type=str, default="/tmp/",
|
||||||
|
help='output directory')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
trans(args.src_path + "/", "Beauty", args.out_path + "/")
|
|
@ -0,0 +1,75 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
User-defined API for MindRecord GNN writer.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import csv
|
||||||
|
|
||||||
|
args = os.environ['graph_api_args'].split(':')
|
||||||
|
USER_FILE = args[0]
|
||||||
|
ITEM_FILE = args[1]
|
||||||
|
RATING_FILE = args[2]
|
||||||
|
|
||||||
|
node_profile = (0, [], [])
|
||||||
|
edge_profile = (0, [], [])
|
||||||
|
|
||||||
|
|
||||||
|
def yield_nodes(task_id=0):
|
||||||
|
"""
|
||||||
|
Generate node data
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
data (dict): data row which is dict.
|
||||||
|
"""
|
||||||
|
print("Node task is {}".format(task_id))
|
||||||
|
with open(USER_FILE) as user_file:
|
||||||
|
user_reader = csv.reader(user_file, delimiter=',')
|
||||||
|
line_count = 0
|
||||||
|
for row in user_reader:
|
||||||
|
node = {'id': int(row[1]), 'type': 0}
|
||||||
|
yield node
|
||||||
|
line_count += 1
|
||||||
|
print('Processed {} lines for users.'.format(line_count))
|
||||||
|
|
||||||
|
with open(ITEM_FILE) as item_file:
|
||||||
|
item_reader = csv.reader(item_file, delimiter=',')
|
||||||
|
line_count = 0
|
||||||
|
for row in item_reader:
|
||||||
|
node = {'id': int(row[1]), 'type': 1,}
|
||||||
|
yield node
|
||||||
|
line_count += 1
|
||||||
|
print('Processed {} lines for items.'.format(line_count))
|
||||||
|
|
||||||
|
|
||||||
|
def yield_edges(task_id=0):
|
||||||
|
"""
|
||||||
|
Generate edge data
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
data (dict): data row which is dict.
|
||||||
|
"""
|
||||||
|
print("Edge task is {}".format(task_id))
|
||||||
|
with open(RATING_FILE) as rating_file:
|
||||||
|
rating_reader = csv.reader(rating_file, delimiter=',')
|
||||||
|
line_count = 0
|
||||||
|
for row in rating_reader:
|
||||||
|
if line_count == 0:
|
||||||
|
line_count += 1
|
||||||
|
continue
|
||||||
|
edge = {'id': line_count - 1, 'src_id': int(row[0]), 'dst_id': int(row[1]), 'type': int(row[2])}
|
||||||
|
yield edge
|
||||||
|
line_count += 1
|
||||||
|
print('Processed {} lines for edges.'.format(line_count))
|
Loading…
Reference in New Issue