add tbnet for open source in model_zoo/official/recommend

This commit is contained in:
lihaoyang 2021-07-28 11:39:42 +08:00
parent 7971b57dcf
commit 9b9dc5ee94
18 changed files with 1826 additions and 0 deletions

View File

@ -0,0 +1,250 @@
# Contents
- [Contents](#contents)
- [TBNet Description](#tbnet-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#training-performance)
- [Evaluation Performance](#evaluation-performance)
- [Inference and Explanation Performance](#inference-explanation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [TBNet Description](#contents)
TB-Net is a knowledge graph based explainable recommender system.
Paper: Shendi Wang, Haoyang Li, Xiao-Hui Li, Caleb Chen Cao, Lei Chen. Tower Bridge Net (TB-Net): Bidirectional Knowledge Graph Aware Embedding Propagation for Explainable Recommender Systems
# [Model Architecture](#contents)
TB-Net constructs subgraphs in knowledge graph based on the interaction between users and items as well as the feature of items, and then calculates paths in the graphs using bidirectional conduction algorithm. Finally we can obtain explainable recommendation results.
# [Dataset](#contents)
[Interaction of users and games](https://www.kaggle.com/tamber/steam-video-games), and the [games' feature data](https://www.kaggle.com/nikdavis/steam-store-games?select=steam.csv) on the game platform Steam are public on Kaggle.
Dataset directory: `./data/{DATASET}/`, e.g. `./data/steam/`.
- train: train.csv, evaluation: test.csv
Each line indicates a \<user\>, an \<item\>, the user-item \<rating\> (1 or 0), and PER_ITEM_NUM_PATHS paths between the item and the user's \<hist_item\> (\<hist_item\> is the item whose the user-item \<rating\> in historical data is 1).
```text
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
```
- infer and explain: infer.csv
Each line indicates the \<user\> and \<item\> to be inferred, \<rating\>, and PER_ITEM_NUM_PATHS paths between the item and the user's \<hist_item\> (\<hist_item\> is the item whose the user-item \<rating\> in historical data is 1).
Note that the \<item\> needs to traverse candidate items (all items by default) in the dataset. \<rating\> can be randomly assigned (all values are assigned to 0 by default) and is not used in the inference and explanation phases.
```text
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
```
# [Environment Requirements](#contents)
- HardwareGPU
- Prepare hardware environment with GPU processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html)
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
- Data preprocessing
Process the data to the format in chapter [Dataset](#Dataset) (e.g. 'steam' dataset), and then run code as follows.
- Training
```bash
python train.py \
--dataset [DATASET] \
--epochs [EPOCHS]
```
Example:
```bash
python train.py \
--dataset steam \
--epochs 20
```
- Evaluation
```bash
python eval.py \
--dataset [DATASET] \
--checkpoint_id [CHECKPOINT_ID]
```
Argument `--checkpoint_id` is required.
Example:
```bash
python eval.py \
--dataset steam \
--checkpoint_id 8
```
- Inference and Explanation
```bash
python infer.py \
--dataset [DATASET] \
--checkpoint_id [CHECKPOINT_ID] \
--user [USER] \
--items [ITEMS] \
--explanations [EXPLANATIONS]
```
Arguments `--checkpoint_id` and `--user` are required.
Example:
```bash
python infer.py \
--dataset steam \
--checkpoint_id 8 \
--user 1 \
--items 1 \
--explanations 3
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```text
.
└─tbnet
├─README.md
├─data
├─steam
├─config.json # data and training parameter configuration
├─infer.csv # inference and explanation dataset
├─test.csv # evaluation dataset
├─train.csv # training dataset
└─trainslate.json # explanation configuration
├─src
├─aggregator.py # inference result aggregation
├─config.py # parsing parameter configuration
├─dataset.py # generate dataset
├─embedding.py # 3-dim embedding matrix initialization
├─metrics.py # model metrics
├─steam.py # 'steam' dataset text explainer
└─tbnet.py # TB-Net model
├─eval.py # evaluation
├─infer.py # inference and explanation
└─train.py # training
```
## [Script Parameters](#contents)
- train.py parameters
```text
--dataset 'steam' dataset is supported currently
--train_csv the train csv datafile inside the dataset folder
--test_csv the test csv datafile inside the dataset folder
--device_id device id
--epochs number of training epochs
--device_target run code on GPU
--run_mode run code by GRAPH mode or PYNATIVE mode
```
- eval.py parameters
```text
--dataset 'steam' dataset is supported currently
--csv the csv datafile inside the dataset folder (e.g. test.csv)
--checkpoint_id use which checkpoint(.ckpt) file to eval
--device_id device id
--device_target run code on GPU
--run_mode run code by GRAPH mode or PYNATIVE mode
```
- infer.py parameters
```text
--dataset 'steam' dataset is supported currently
--csv the csv datafile inside the dataset folder (e.g. infer.csv)
--checkpoint_id use which checkpoint(.ckpt) file to infer
--user id of the user to be recommended to
--items no. of items to be recommended
--reasons no. of recommendation reasons to be shown
--device_id device id
--device_target run code on GPU
--run_mode run code by GRAPH mode or PYNATIVE mode
```
# [Model Description](#contents)
## [Performance](#contents)
### Training Performance
| Parameters | GPU |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | TB-Net |
| Resource |Tesla V100-SXM2-32GB |
| Uploaded Date | 2021-08-01 |
| MindSpore Version | 1.3.0 |
| Dataset | steam |
| Training Parameter | epoch=20, batch_size=1024, lr=0.001 |
| Optimizer | Adam |
| Loss Function | Sigmoid Cross Entropy |
| Outputs | AUC=0.8596Accuracy=0.7761 |
| Loss | 0.57 |
| Speed | 1pc: 90ms/step |
| Total Time | 1pc: 297s |
| Checkpoint for Fine Tuning | 104.66M (.ckpt file) |
| Scripts | [TB-Net scripts](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/tbnet) |
### Evaluation Performance
| Parameters | GPU |
| ------------------------- | ----------------------------- |
| Model Version | TB-Net |
| Resource | Tesla V100-SXM2-32GB |
| Uploaded Date | 2021-08-01 |
| MindSpore Version | 1.3.0 |
| Dataset | steam |
| Batch Size | 1024 |
| Outputs | AUC=0.8252Accuracy=0.7503 |
| Total Time | 1pc: 5.7s |
### Inference and Explanation Performance
| Parameters | GPU |
| --------------------------| ------------------------------------- |
| Model Version | TB-Net |
| Resource | Tesla V100-SXM2-32GB |
| Uploaded Date | 2021-08-01 |
| MindSpore Version | 1.3.0 |
| Dataset | steam |
| Outputs | Recommendation Result and Explanation |
| Total Time | 1pc: 3.66s |
# [Description of Random Situation](#contents)
- Initialization of embedding matrix in `tbnet.py` and `embedding.py`.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,252 @@
# 目录
<!-- TOC -->
- [目录](#目录)
- [TBNet概述](#tbnet概述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本和样例代码](#脚本和样例代码)
- [脚本参数](#脚本参数)
- [模型描述](#模型描述)
- [性能](#性能)
- [训练性能](#训练性能)
- [评估性能](#评估性能)
- [推理和解释性能](#推理和解释性能)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页)
# [TBNet概述](#目录)
TB-Net是一个基于知识图谱的可解释推荐系统。
论文Shendi Wang, Haoyang Li, Xiao-Hui Li, Caleb Chen Cao, Lei Chen. Tower Bridge Net (TB-Net): Bidirectional Knowledge Graph Aware Embedding Propagation for Explainable Recommender Systems
# [模型架构](#目录)
TB-Net将用户和物品的交互信息以及物品的属性信息在知识图谱中构建子图并利用双向传导的计算方法对图谱中的路径进行计算最后得到可解释的推荐结果。
# [数据集](#目录)
本示例提供Kaggle上的Steam游戏平台公开数据集包含[用户与游戏的交互记录](https://www.kaggle.com/tamber/steam-video-games)和[游戏的属性信息](https://www.kaggle.com/nikdavis/steam-store-games?select=steam.csv)。
数据集路径:`./data/{DATASET}/`,如:`./data/steam/`。
- 训练train.csv评估test.csv
每一行记录代表某\<user\>对某\<item\>的\<rating\>(1或0),以及该\<item\>与\<hist_item\>(即该\<user\>历史\<rating\>为1的\<item\>)的PER_ITEM_NUM_PATHS条路径。
```text
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
```
- 推理和解释infer.csv
每一行记录代表**待推理**的\<user\>和\<item\>\<rating\>,以及该\<item\>与\<hist_item\>(即该\<user\>历史\<rating\>为1的\<item\>)的PER_ITEM_NUM_PATHS条路径。
其中\<item\>需要遍历数据集中**所有**待推荐物品(默认所有物品);\<rating\>可随机赋值默认全部赋值为0在推理和解释阶段不会使用。
```text
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
```
# [环境要求](#目录)
- 硬件GPU
- 使用GPU处理器准备硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
# [快速入门](#目录)
通过官方网站安装MindSpore后您可以按照如下步骤进行训练、评估、推理和解释
- 数据准备
将数据处理成上一节[数据集](#数据集)中的格式(以'steam'数据集为例),然后按照以下步骤运行代码。
- 训练
```bash
python train.py \
--dataset [DATASET] \
--epochs [EPOCHS]
```
示例:
```bash
python train.py \
--dataset steam \
--epochs 20
```
- 评估
```bash
python eval.py \
--dataset [DATASET] \
--checkpoint_id [CHECKPOINT_ID]
```
参数`--checkpoint_id`是必填项。
示例:
```bash
python eval.py \
--dataset steam \
--checkpoint_id 8
```
- 推理和解释
```bash
python infer.py \
--dataset [DATASET] \
--checkpoint_id [CHECKPOINT_ID] \
--user [USER] \
--items [ITEMS] \
--explanations [EXPLANATIONS]
```
参数`--checkpoint_id`和`--user`是必填项。
示例:
```bash
python infer.py \
--dataset steam \
--checkpoint_id 8 \
--user 1 \
--items 1 \
--explanations 3
```
# [脚本说明](#目录)
## [脚本和样例代码](#目录)
```text
.
└─tbnet
├─README.md
├─data
├─steam
├─config.json # 数据和训练参数配置
├─infer.csv # 推理和解释数据集
├─test.csv # 测试数据集
├─train.csv # 训练数据集
└─trainslate.json # 输出解释相关配置
├─src
├─aggregator.py # 推理结果聚合
├─config.py # 参数配置解析
├─dataset.py # 创建数据集
├─embedding.py # 三维embedding矩阵初始化
├─metrics.py # 模型度量
├─steam.py # 'steam'数据集文本解析
└─tbnet.py # TB-Net网络
├─eval.py # 评估网络
├─infer.py # 推理和解释
└─train.py # 训练网络
```
## [脚本参数](#目录)
- train.py参数
```text
--dataset 'steam' dataset is supported currently
--train_csv the train csv datafile inside the dataset folder
--test_csv the test csv datafile inside the dataset folder
--device_id device id
--epochs number of training epochs
--device_target run code on GPU
--run_mode run code by GRAPH mode or PYNATIVE mode
```
- eval.py参数
```text
--dataset 'steam' dataset is supported currently
--csv the csv datafile inside the dataset folder (e.g. test.csv)
--checkpoint_id use which checkpoint(.ckpt) file to eval
--device_id device id
--device_target run code on GPU
--run_mode run code by GRAPH mode or PYNATIVE mode
```
- infer.py参数
```text
--dataset 'steam' dataset is supported currently
--csv the csv datafile inside the dataset folder (e.g. infer.csv)
--checkpoint_id use which checkpoint(.ckpt) file to infer
--user id of the user to be recommended to
--items no. of items to be recommended
--reasons no. of recommendation reasons to be shown
--device_id device id
--device_target run code on GPU
--run_mode run code by GRAPH mode or PYNATIVE mode
```
# [模型描述](#目录)
## [性能](#目录)
### [训练性能](#目录)
| 参数 | GPU |
| ------------------- | --------------------------------------------------- |
| 模型版本 | TB-Net |
| 资源 |Tesla V100-SXM2-32GB |
| 上传日期 | 2021-08-01 |
| MindSpore版本 | 1.3.0 |
| 数据集 | steam |
| 训练参数 | epoch=20, batch_size=1024, lr=0.001 |
| 优化器 | Adam |
| 损失函数 | Sigmoid交叉熵 |
| 输出 | AUC=0.8596,准确率=0.7761 |
| 损失 | 0.57 |
| 速度 | 单卡90毫秒/步 |
| 总时长 | 单卡297秒 |
| 微调检查点 | 104.66M (.ckpt 文件) |
| 脚本 | [TB-Net脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/tbnet) |
### [评估性能](#目录)
| 参数 | GPU |
| -------------------------- | ----------------------------- |
| 模型版本 | TB-Net |
| 资源 | Tesla V100-SXM2-32GB |
| 上传日期 | 2021-08-01 |
| MindSpore版本 | 1.3.0 |
| 数据集 | steam |
| 批次大小 | 1024 |
| 输出 | AUC=0.8252,准确率=0.7503 |
| 总时长 | 单卡5.7秒 |
### [推理和解释性能](#目录)
| 参数 | GPU |
| -------------------------- | ----------------------------- |
| 模型版本 | TB-Net |
| 资源 | Tesla V100-SXM2-32GB |
| 上传日期 | 2021-08-01 |
| MindSpore版本 | 1.3.0 |
| 数据集 | steam |
| 输出 | 推荐结果和解释结果 |
| 总时长 | 单卡3.66秒 |
# [随机情况说明](#目录)
- `tbnet.py`和`embedding.py`中Embedding矩阵的随机初始化。
# [ModelZoo主页](#目录)
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,12 @@
{
"num_item": 3005,
"num_relation": 5,
"num_entity": 5138,
"per_item_num_paths": 39,
"embedding_dim": 26,
"batch_size": 1024,
"lr": 0.001,
"kge_weight": 0.05,
"node_weight": 0.002,
"l2_weight": 1e-6
}

View File

@ -0,0 +1 @@
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
1 #format:user item rating relation1 entity relation2 hist_item relation1 entity relation2 hist_item ... relation1 entity relation2 hist_item # module [relation1 entity relation2 hist_item] repeats PER_ITEM_NUM_PATHS times

View File

@ -0,0 +1 @@
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
1 #format:user item rating relation1 entity relation2 hist_item relation1 entity relation2 hist_item ... relation1 entity relation2 hist_item # module [relation1 entity relation2 hist_item] repeats PER_ITEM_NUM_PATHS times

View File

@ -0,0 +1 @@
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
1 #format:user item rating relation1 entity relation2 hist_item relation1 entity relation2 hist_item ... relation1 entity relation2 hist_item # module [relation1 entity relation2 hist_item] repeats PER_ITEM_NUM_PATHS times

View File

@ -0,0 +1,14 @@
{
"item": {
"0": "Star Wars",
"1": "Battlefield 1"
},
"relation": {
"0": "Developer",
"1": "Genre"
},
"entity": {
"425": "EA Games",
"426": "Shooting"
}
}

View File

@ -0,0 +1,117 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TB-Net evaluation."""
import os
import argparse
from mindspore import context, Model, load_checkpoint, load_param_into_net
from src import tbnet, config, metrics, dataset
def get_args():
"""Parse commandline arguments."""
parser = argparse.ArgumentParser(description='Train TBNet.')
parser.add_argument(
'--dataset',
type=str,
required=False,
default='steam',
help="'steam' dataset is supported currently"
)
parser.add_argument(
'--csv',
type=str,
required=False,
default='test.csv',
help="the csv datafile inside the dataset folder (e.g. test.csv)"
)
parser.add_argument(
'--checkpoint_id',
type=int,
required=True,
help="use which checkpoint(.ckpt) file to eval"
)
parser.add_argument(
'--device_id',
type=int,
required=False,
default=0,
help="device id"
)
parser.add_argument(
'--device_target',
type=str,
required=False,
default='GPU',
choices=['GPU'],
help="run code on GPU"
)
parser.add_argument(
'--run_mode',
type=str,
required=False,
default='graph',
choices=['graph', 'pynative'],
help="run code by GRAPH mode or PYNATIVE mode"
)
return parser.parse_args()
def eval_tbnet():
"""Evaluation process."""
args = get_args()
home = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(home, 'data', args.dataset, 'config.json')
test_csv_path = os.path.join(home, 'data', args.dataset, args.csv)
ckpt_path = os.path.join(home, 'checkpoints')
context.set_context(device_id=args.device_id)
if args.run_mode == 'graph':
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
else:
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
print(f"creating dataset from {test_csv_path}...")
net_config = config.TBNetConfig(config_path)
eval_ds = dataset.create(test_csv_path, net_config.per_item_num_paths, train=True).batch(net_config.batch_size)
print(f"creating TBNet from checkpoint {args.checkpoint_id} for evaluation...")
network = tbnet.TBNet(net_config)
param_dict = load_checkpoint(os.path.join(ckpt_path, f'tbnet_epoch{args.checkpoint_id}.ckpt'))
load_param_into_net(network, param_dict)
loss_net = tbnet.NetWithLossClass(network, net_config)
train_net = tbnet.TrainStepWrap(loss_net, net_config.lr)
train_net.set_train()
eval_net = tbnet.PredictWithSigmoid(network)
model = Model(network=train_net, eval_network=eval_net, metrics={'auc': metrics.AUC(), 'acc': metrics.ACC()})
print("evaluating...")
e_out = model.eval(eval_ds, dataset_sink_mode=False)
print(f'Test AUC:{e_out ["auc"]} ACC:{e_out ["acc"]}')
if __name__ == '__main__':
eval_tbnet()

View File

@ -0,0 +1,162 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TBNet inference."""
import os
import argparse
from mindspore import load_checkpoint, load_param_into_net, context
from src.config import TBNetConfig
from src.tbnet import TBNet
from src.aggregator import InferenceAggregator
from src import dataset
from src import steam
def get_args():
"""Parse commandline arguments."""
parser = argparse.ArgumentParser(description='Infer TBNet.')
parser.add_argument(
'--dataset',
type=str,
required=False,
default='steam',
help="'steam' dataset is supported currently"
)
parser.add_argument(
'--csv',
type=str,
required=False,
default='infer.csv',
help="the csv datafile inside the dataset folder (e.g. infer.csv)"
)
parser.add_argument(
'--checkpoint_id',
type=int,
required=True,
help="use which checkpoint(.ckpt) file to infer"
)
parser.add_argument(
'--user',
type=int,
required=True,
help="id of the user to be recommended to"
)
parser.add_argument(
'--items',
type=int,
required=False,
default=1,
help="no. of items to be recommended"
)
parser.add_argument(
'--explanations',
type=int,
required=False,
default=3,
help="no. of recommendation explanations to be shown"
)
parser.add_argument(
'--device_id',
type=int,
required=False,
default=0,
help="device id"
)
parser.add_argument(
'--device_target',
type=str,
required=False,
default='GPU',
choices=['GPU'],
help="run code on GPU"
)
parser.add_argument(
'--run_mode',
type=str,
required=False,
default='graph',
choices=['graph', 'pynative'],
help="run code by GRAPH mode or PYNATIVE mode"
)
return parser.parse_args()
def infer_tbnet():
"""Inference process."""
args = get_args()
context.set_context(device_id=args.device_id)
if args.run_mode == 'graph':
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
else:
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
home = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(home, 'data', args.dataset, 'config.json')
translate_path = os.path.join(home, 'data', args.dataset, 'translate.json')
data_path = os.path.join(home, 'data', args.dataset, args.csv)
ckpt_path = os.path.join(home, 'checkpoints')
print(f"creating TBNet from checkpoint {args.checkpoint_id}...")
config = TBNetConfig(config_path)
network = TBNet(config)
param_dict = load_checkpoint(os.path.join(ckpt_path, f'tbnet_epoch{args.checkpoint_id}.ckpt'))
load_param_into_net(network, param_dict)
print(f"creating dataset from {data_path}...")
infer_ds = dataset.create(data_path, config.per_item_num_paths, train=False, users=args.user)
infer_ds = infer_ds.batch(config.batch_size)
print("inferring...")
# infer and aggregate results
aggregator = InferenceAggregator(top_k=args.items)
for user, item, relation1, entity, relation2, hist_item, rating in infer_ds:
del rating
result = network(item, relation1, entity, relation2, hist_item)
item_score = result[0]
path_importance = result[1]
aggregator.aggregate(user, item, relation1, entity, relation2, hist_item, item_score, path_importance)
# show recommendations with explanations
explainer = steam.TextExplainer(translate_path)
recomms = aggregator.recommend()
for user, recomm in recomms.items():
for item_rec in recomm.item_records:
item_name = explainer.translate_item(item_rec.item)
print(f"Recommend <{item_name}> to user:{user}, because:")
# show explanations
explanation = 0
for path in item_rec.paths:
print(" - " + explainer.explain(path))
explanation += 1
if explanation >= args.explanations:
break
print("")
if __name__ == '__main__':
infer_tbnet()

View File

@ -0,0 +1 @@
sklearn

View File

@ -0,0 +1,151 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Inference result aggregator."""
import copy
class Recommendation:
"""Recommendation."""
class Path:
"""Item path."""
def __init__(self, relation1, entity, relation2, hist_item, importance):
self.relation1 = relation1
self.entity = entity
self.relation2 = relation2
self.hist_item = hist_item
self.importance = importance
class ItemRecord:
"""Recommended item info."""
def __init__(self, item, score):
self.item = item
self.score = score
# paths must be sorted with importance in descending order
self.paths = []
def __init__(self, user):
self.user = user
# item_records must be sorted with score in descending order
self.item_records = []
class InferenceAggregator:
"""
Inference result aggregator.
Args:
top_k (int): The number of items to be recommended for each distinct user.
"""
def __init__(self, top_k=1):
if top_k < 1:
raise ValueError('top_k is less than 1.')
self._top_k = top_k
self._user_recomms = dict()
self._paths_sorted = False
def aggregate(self, user, item, relation1, entity, relation2, hist_item, item_score, path_importance):
"""
Aggregate inference results.
Args:
user (Tensor): User IDs, int Tensor in shape of [N, ].
item (Tensor): Candidate item IDs, int Tensor in shape of [N, ].
relation1 (Tensor): IDs of item-entity relations, int Tensor in shape of [N, <no. of per-item path>].
entity (Tensor): Entity IDs, int Tensor in shape of [N, <no. of per-item path>].
relation2 (Tensor): IDs of entity-hist_item relations, int Tensor in shape of [N, <no. of per-item path>].
hist_item (Tensor): Historical item IDs, int Tensor in shape of [N, <no. of per-item path>].
item_score (Tensor): TBNet output, recommendation scores of candidate items, float Tensor in shape of [N, ].
path_importance (Tensor): TBNet output, the importance of each item to hist_item path for the
recommendations, float Tensor in shape of [N, <no. of per-item path>].
"""
user = user.asnumpy()
item = item.asnumpy()
relation1 = relation1.asnumpy()
entity = entity.asnumpy()
relation2 = relation2.asnumpy()
hist_item = hist_item.asnumpy()
item_score = item_score.asnumpy()
path_importance = path_importance.asnumpy()
batch_size = user.shape[0]
added_users = set()
for i in range(batch_size):
if self._add(user[i], item[i], relation1[i], entity[i], relation2[i],
hist_item[i], item_score[i], path_importance[i]):
added_users.add(user[i])
self._paths_sorted = False
for added_user in added_users:
recomm = self._user_recomms[added_user]
if len(recomm.item_records) > self._top_k:
recomm.item_records = recomm.item_records[0:self._top_k]
def recommend(self):
"""
Generate recommendations for all distinct users.
Returns:
dict[int, Recommendation], a dictionary with user id as keys and Recommendation objects as values.
"""
if not self._paths_sorted:
self._sort_paths()
return copy.deepcopy(self._user_recomms)
def _add(self, user, item, relation1, entity, relation2, hist_item, item_score, path_importance):
"""Add a single infer record."""
recomm = self._user_recomms.get(user, None)
if recomm is None:
recomm = Recommendation(user)
self._user_recomms[user] = recomm
# insert at the appropriate position
for i, old_item_rec in enumerate(recomm.item_records):
if i >= self._top_k:
return False
if item_score > old_item_rec.score:
rec = self._infer_2_item_rec(item, relation1, entity, relation2,
hist_item, item_score, path_importance)
recomm.item_records.insert(i, rec)
return True
# append if has rooms
if len(recomm.item_records) < self._top_k:
rec = self._infer_2_item_rec(item, relation1, entity, relation2,
hist_item, item_score, path_importance)
recomm.item_records.append(rec)
return True
return False
@staticmethod
def _infer_2_item_rec(item, relation1, entity, relation2, hist_item, item_score, path_importance):
"""Converts a single infer result to a item record."""
item_rec = Recommendation.ItemRecord(item, item_score)
num_paths = path_importance.shape[0]
for i in range(num_paths):
path = Recommendation.Path(relation1[i], entity[i], relation2[i], hist_item[i], path_importance[i])
item_rec.paths.append(path)
return item_rec
def _sort_paths(self):
"""Sort all item paths."""
for recomm in self._user_recomms.values():
for item_rec in recomm.item_records:
item_rec.paths.sort(key=lambda x: x.importance, reverse=True)
self._paths_sorted = True

View File

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TBNet configurations."""
import json
class TBNetConfig:
"""
TBNet config file parser and holder.
Args:
config_path (str): json config file path.
"""
def __init__(self, config_path):
with open(config_path) as f:
json_dict = json.load(f)
self.num_item = int(json_dict['num_item'])
self.num_relation = int(json_dict['num_relation'])
self.num_entity = int(json_dict['num_entity'])
self.per_item_num_paths = int(json_dict['per_item_num_paths'])
self.embedding_dim = int(json_dict['embedding_dim'])
self.batch_size = int(json_dict['batch_size'])
self.lr = float(json_dict['lr'])
self.kge_weight = float(json_dict['kge_weight'])
self.node_weight = float(json_dict['node_weight'])
self.l2_weight = float(json_dict['l2_weight'])

View File

@ -0,0 +1,88 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset loader."""
from functools import partial
import numpy as np
from mindspore.dataset import GeneratorDataset
def create(data_path, per_item_num_paths, train, users=None, **kwargs):
"""
Create a dataset for TBNet.
Args:
data_path (str): The csv datafile path.
per_item_num_paths (int): The number of paths per item.
train (bool): True to create for training with columns:
'item', 'relation1', 'entity', 'relation2', 'hist_item', 'rating'
otherwise:
'user', 'item', 'relation1', 'entity', 'relation2', 'hist_item', 'rating'
users (Union[list[int], int], optional): Users data to be loaded, if None is provided, all data will be loaded.
**kwargs (any): Other arguments for GeneratorDataset(), except 'source' and 'column_names'.
Returns:
GeneratorDataset, the generator dataset that reads from the csv datafile.
"""
if isinstance(users, int):
users = (users,)
kwargs['source'] = partial(csv_generator, data_path, per_item_num_paths, users, train)
if train:
kwargs['column_names'] = ['item', 'relation1', 'entity', 'relation2', 'hist_item', 'rating']
else:
kwargs['column_names'] = ['user', 'item', 'relation1', 'entity', 'relation2', 'hist_item', 'rating']
return GeneratorDataset(**kwargs)
def csv_generator(csv_path, per_item_num_paths, users, train):
"""Generator for csv datafile."""
expected_columns = 3 + (per_item_num_paths * 4)
file = open(csv_path)
for line in file:
line = line.strip()
if not line or line[0] == '#':
continue
id_list = line.split(',')
if len(id_list) < expected_columns:
raise ValueError(f'Expecting {expected_columns} values but got {len(id_list)} only!')
id_list = list(map(int, id_list))
user = id_list[0]
if users and user not in users:
continue
item = id_list[1]
rating = id_list[2]
relation1 = np.empty(shape=(per_item_num_paths,), dtype=np.int)
entity = np.empty_like(relation1)
relation2 = np.empty_like(relation1)
hist_item = np.empty_like(relation1)
for p in range(per_item_num_paths):
offset = 3 + (p * 4)
relation1[p] = id_list[offset]
entity[p] = id_list[offset + 1]
relation2[p] = id_list[offset + 2]
hist_item[p] = id_list[offset + 3]
if train:
# item, relation1, entity, relation2, hist_item, rating
yield np.array(item, dtype=np.int), relation1, entity, relation2, hist_item, \
np.array(rating, dtype=np.float32)
else:
# user, item, relation1, entity, relation2, hist_item, rating
yield np.array(user, dtype=np.int), np.array(item, dtype=np.int),\
relation1, entity, relation2, hist_item, np.array(rating, dtype=np.float32)

View File

@ -0,0 +1,75 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Three-dimension embedding vector initialization."""
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore._checkparam import Rel
from mindspore._checkparam import Validator as validator
from mindspore.nn import Cell
class EmbeddingMatrix(Cell):
"""
Support three-dimension embedding vector initialization.
"""
def __init__(self, vocab_size, embedding_size, embedding_table='normal',
dtype=mstype.float32, padding_idx=None):
super(EmbeddingMatrix, self).__init__()
self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
self.embedding_size = validator.check_value_type('embedding_size', embedding_size,
[int, tuple, list], self.cls_name)
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
self.dtype = dtype
if isinstance(self.embedding_size, int):
self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size])
self.embedding_out = (self.embedding_size,)
else:
if len(self.embedding_size) != 2:
raise ValueError("embedding_size should be a int or a tuple of two ints")
self.init_tensor = initializer(embedding_table, [vocab_size, self.embedding_size[0],
self.embedding_size[1]])
self.embedding_out = (self.embedding_size[0], self.embedding_size[1],)
self.padding_idx = padding_idx
if padding_idx is not None:
self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH,
"padding_idx", self.cls_name)
if isinstance(self.init_tensor, Tensor) and self.init_tensor.init is not None:
self.init_tensor = self.init_tensor.init_data()
self.init_tensor = self.init_tensor.asnumpy()
self.init_tensor[self.padding_idx] = 0
self.init_tensor = Tensor(self.init_tensor)
self.embedding_table = Parameter(self.init_tensor, name='embedding_table')
self.expand = P.ExpandDims()
self.reshape_flat = P.Reshape()
self.shp_flat = (-1,)
self.gather = P.Gather()
self.reshape = P.Reshape()
self.get_shp = P.Shape()
def construct(self, ids):
"""
Return the initialized three-dimension embedding vector
"""
extended_ids = self.expand(ids, -1)
out_shape = self.get_shp(ids) + self.embedding_out
flat_ids = self.reshape_flat(extended_ids, self.shp_flat)
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
output = self.reshape(output_for_reshape, out_shape)
return output

View File

@ -0,0 +1,80 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TB-Net metrics."""
import numpy as np
from sklearn.metrics import roc_auc_score
from mindspore.nn.metrics import Metric
class AUC(Metric):
"""TB-Net metrics method. Compute model metrics AUC."""
def __init__(self):
super(AUC, self).__init__()
self.clear()
def clear(self):
"""Clear the internal evaluation result."""
self.true_labels = []
self.pred_probs = []
def update(self, *inputs):
"""Update list of predictions and labels."""
all_predict = inputs[1].asnumpy().flatten().tolist()
all_label = inputs[2].asnumpy().flatten().tolist()
self.pred_probs.extend(all_predict)
self.true_labels.extend(all_label)
def eval(self):
"""Return AUC score"""
if len(self.true_labels) != len(self.pred_probs):
raise RuntimeError(
'true_labels.size is not equal to pred_probs.size()')
auc = roc_auc_score(self.true_labels, self.pred_probs)
return auc
class ACC(Metric):
"""TB-Net metrics method. Compute model metrics ACC."""
def __init__(self):
super(ACC, self).__init__()
self.clear()
def clear(self):
"""Clear the internal evaluation result."""
self.true_labels = []
self.pred_probs = []
def update(self, *inputs):
"""Update list of predictions and labels."""
all_predict = inputs[1].asnumpy().flatten().tolist()
all_label = inputs[2].asnumpy().flatten().tolist()
self.pred_probs.extend(all_predict)
self.true_labels.extend(all_label)
def eval(self):
"""Return accuracy score"""
if len(self.true_labels) != len(self.pred_probs):
raise RuntimeError(
'true_labels.size is not equal to pred_probs.size()')
predictions = [1 if i >= 0.5 else 0 for i in self.pred_probs]
acc = np.mean(np.equal(predictions, self.true_labels))
return acc

View File

@ -0,0 +1,61 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""STEAM dataset explainer."""
import json
from src.aggregator import Recommendation
class TextExplainer:
"""Text explainer for STEAM game recommendations."""
SAME_RELATION_TPL = 'User played the game <%s> before, which has the same %s '\
'<%s> as the recommend game.'
DIFF_RELATION_TPL = 'User played the game <%s> before, which has the %s <%s> '\
'while <%s> is the %s of the recommended game.'
def __init__(self, translate_path: str):
"""Construct from the translate json file."""
with open(translate_path) as file:
self._translator = json.load(file)
def explain(self, path: Recommendation.Path) -> str:
"""Explain the path."""
rel1_str = self.translate_relation(path.relation1)
entity_str = self.translate_entity(path.entity)
hist_item_str = self.translate_item(path.hist_item)
if path.relation1 == path.relation2:
return self.SAME_RELATION_TPL % (hist_item_str, rel1_str, entity_str)
rel2_str = self.translate_relation(path.relation2)
return self.DIFF_RELATION_TPL % (hist_item_str, rel2_str, entity_str, entity_str, rel1_str)
def translate_item(self, item: int) -> str:
"""Translate an item."""
return self._translate('item', item)
def translate_entity(self, entity: int) -> str:
"""Translate an entity."""
return self._translate('entity', entity)
def translate_relation(self, relation: int) -> str:
"""Translate a relation."""
return self._translate('relation', relation)
def _translate(self, obj_type, obj_id):
"""Translate an object."""
try:
return self._translator[obj_type][str(obj_id)]
except KeyError:
return f'[{obj_type}:{obj_id}]'

View File

@ -0,0 +1,366 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TB-Net Model."""
from mindspore import nn
from mindspore import ParameterTuple
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from src.embedding import EmbeddingMatrix
class TBNet(nn.Cell):
"""
TB-Net model architecture.
Args:
num_entity (int): number of entities, depends on dataset
num_relation (int): number of relations, depends on dataset
dim (int): dimension of entity and relation embedding vectors
kge_weight (float): weight of the KG Embedding loss term
node_weight (float): weight of the node loss term (default=0.002)
l2_weight (float): weight of the L2 regularization term (default=1e-7)
lr (float): learning rate of model training (default=1e-4)
batch_size (int): batch size (default=1024)
"""
def __init__(self, config):
super(TBNet, self).__init__()
self._parse_config(config)
self.matmul = C.matmul
self.sigmoid = P.Sigmoid()
self.embedding_initializer = "normal"
self.entity_emb_matrix = EmbeddingMatrix(int(self.num_entity),
self.dim,
embedding_table=self.embedding_initializer)
self.relation_emb_matrix = EmbeddingMatrix(int(self.num_relation),
embedding_size=(self.dim, self.dim),
embedding_table=self.embedding_initializer)
self.expand_dims = P.ExpandDims()
self.squeeze = P.Squeeze(3)
self.abs = P.Abs()
self.reduce_sum = P.ReduceSum()
self.reduce_mean = P.ReduceMean()
self.transpose = P.Transpose()
self.reshape = P.Reshape()
self.softmax = nn.Softmax()
def _parse_config(self, config):
"""Argument parsing."""
self.num_entity = config.num_entity
self.num_relation = config.num_relation
self.dim = config.embedding_dim
self.kge_weight = config.kge_weight
self.node_weight = config.node_weight
self.l2_weight = config.l2_weight
self.lr = config.lr
self.batch_size = config.batch_size
def construct(self, items, relation1, mid_entity, relation2, hist_item):
"""
TB-Net main computation process.
Args:
items (Tensor): rated item IDs, int Tensor in shape of [batch size, ].
relation1 (Tensor): relation1 IDs, int Tensor in shape of [batch size, per_item_num_paths]
mid_entity (Tensor): middle entity IDs, int Tensor in shape of [batch size, per_item_num_paths]
relation2 (Tensor): relation2 IDs, int Tensor in shape of [batch size, per_item_num_paths]
hist_item (Tensor): historical item IDs, int Tensor in shape of [batch size, per_item_num_paths]
Returns:
scores (Tensor): model prediction score, float Tensor in shape of [batch size, ]
probs_exp (Tensor): path probability/importance, float Tensor in shape of [batch size, per_item_num_paths]
item_embeddings (Tensor): rated item embeddings, float Tensor in shape of [batch size, dim]
relation1_emb (Tensor): relation1 embeddings,
float Tensor in shape of [batch size, per_item_num_paths, dim, dim]
mid_entity_emb (Tensor): middle entity embeddings,
float Tensor in shape of [batch size, per_item_num_paths, dim]
relation2_emb (Tensor): relation2 embeddings,
float Tensor in shape of [batch size, per_item_num_paths, dim, dim]
hist_item_emb (Tensor): historical item embeddings,
float Tensor in shape of [batch size, per_item_num_paths, dim]
"""
item_embeddings = self.entity_emb_matrix(items)
relation1_emb = self.relation_emb_matrix(relation1)
mid_entity_emb = self.entity_emb_matrix(mid_entity)
relation2_emb = self.relation_emb_matrix(relation2)
hist_item_emb = self.entity_emb_matrix(hist_item)
response, probs_exp = self._key_pathing(item_embeddings,
relation1_emb,
mid_entity_emb,
relation2_emb,
hist_item_emb)
scores = P.Squeeze()(self._predict(item_embeddings, response))
return scores, probs_exp, item_embeddings, relation1_emb, mid_entity_emb, relation2_emb, hist_item_emb
def _key_pathing(self, item_embeddings, relation1_emb, mid_entity_emb, relation2_emb, hist_item_emb):
"""
Compute the response and path probability using item and entity embedding.
Path structure: (rated item, relation1, entity, relation2, historical item).
Args:
item_embeddings (Tensor): rated item embeddings, float Tensor in shape of [batch size, dim]
relation1_emb (Tensor): relation1 embeddings,
float Tensor in shape of [batch size, per_item_num_paths, dim, dim]
mid_entity_emb (Tensor): middle entity embeddings,
float Tensor in shape of [batch size, per_item_num_paths, dim]
relation2_emb (Tensor): relation2 embeddings,
float Tensor in shape of [batch size, per_item_num_paths, dim, dim]
hist_item_emb (Tensor): historical item embeddings,
float Tensor in shape of [batch size, per_item_num_paths, dim]
Returns:
response (Tensor): user's response towards middle entity, float Tensor in shape of [batch size, dim]
probs_exp (Tensor): path probability/importance, float Tensor in shape of [batch size, per_item_num_paths]
"""
hist_item_e_4d = self.expand_dims(hist_item_emb, 3)
mul_r2_hist = self.squeeze(self.matmul(relation2_emb, hist_item_e_4d))
# path_right shape: [batch size, per_item_num_paths, dim]
path_right = self.abs(mul_r2_hist + self.reduce_sum(relation2_emb, 2))
item_emb_3d = self.expand_dims(item_embeddings, 2)
mul_r1_item = self.squeeze(self.matmul(relation1_emb, self.expand_dims(item_emb_3d, 1)))
path_left = self.abs(mul_r1_item + self.reduce_sum(relation1_emb, 2))
# path_left shape: [batch size, dim, per_item_num_paths]
path_left = self.transpose(path_left, (0, 2, 1))
probs = self.reduce_sum(self.matmul(path_right, path_left), 2)
# probs_exp shape: [batch size, per_item_num_paths]
probs_exp = self.softmax(probs)
probs_3d = self.expand_dims(probs_exp, 2)
# response shape: [batch size, dim]
response = self.reduce_sum(mid_entity_emb * probs_3d, 1)
return response, probs_exp
def _predict(self, item_embeddings, response):
scores = self.reduce_sum(item_embeddings * response, 1)
return scores
class NetWithLossClass(nn.Cell):
"""NetWithLossClass definition."""
def __init__(self, network, config):
super(NetWithLossClass, self).__init__(auto_prefix=False)
self.network = network
self.loss = P.SigmoidCrossEntropyWithLogits()
self.matmul = C.matmul
self.expand_dims = P.ExpandDims()
self.squeeze = P.Squeeze(3)
self.abs = P.Abs()
self.maximum = P.Maximum()
self.reduce_sum = P.ReduceSum()
self.reduce_mean = P.ReduceMean()
self.cast = P.Cast()
self.reshape = P.Reshape()
self.sigmoid = P.Sigmoid()
self.kge_weight = config.kge_weight
self.node_weight = config.node_weight
self.l2_weight = config.l2_weight
self.batch_size = config.batch_size
self.dim = config.embedding_dim
self.embedding_initializer = "normal"
def construct(self, items, relation1, mid_entity, relation2, hist_item, labels):
"""
Args:
items (Tensor): rated item IDs, int Tensor in shape of [batch size, ].
relation1 (Tensor): relation1 IDs, int Tensor in shape of [batch size, per_item_num_paths]
mid_entity (Tensor): middle entity IDs, int Tensor in shape of [batch size, per_item_num_paths]
relation2 (Tensor): relation2 IDs, int Tensor in shape of [batch size, per_item_num_paths]
hist_item (Tensor): historical item IDs, int Tensor in shape of [batch size, per_item_num_paths]
labels (Tensor): label of rated item record, int Tensor in shape of [batch size, ]
Returns:
loss (float): loss value
"""
scores, _, item_embeddings, relation1_emb, mid_entity_emb, relation2_emb, hist_item_emb = \
self.network(items, relation1, mid_entity, relation2, hist_item)
loss = self._loss_fun(item_embeddings, relation1_emb, mid_entity_emb,
relation2_emb, hist_item_emb, scores, labels)
return loss
def _loss_fun(self, item_embeddings, relation1_emb, mid_entity_emb, relation2_emb, hist_item_emb, scores, labels):
"""
Loss function definition.
Args:
item_embeddings (Tensor): rated item embeddings, float Tensor in shape of [batch size, dim]
relation1_emb (Tensor): relation1 embeddings,
float Tensor in shape of [batch size, per_item_num_paths, dim, dim]
mid_entity_emb (Tensor): middle entity embeddings,
float Tensor in shape of [batch size, per_item_num_paths, dim]
relation2_emb (Tensor): relation2 embeddings,
float Tensor in shape of [batch size, per_item_num_paths, dim, dim]
hist_item_emb (Tensor): historical item embeddings,
float Tensor in shape of [batch size, per_item_num_paths, dim]
scores (Tensor): model prediction score, float Tensor in shape of [batch size, ]
labels (Tensor): label of rated item record, int Tensor in shape of [batch size, ]
Returns:
loss: includes four part:
pred_loss: cross entropy of the model prediction score and labels
transr_loss: TransR KG Embedding loss
node_loss: node matching loss
l2_loss: L2 regularization loss
"""
pred_loss = self.reduce_mean(self.loss(scores, labels))
item_emb_3d = self.expand_dims(item_embeddings, 2)
item_emb_4d = self.expand_dims(item_emb_3d, 1)
mul_r1_item = self.squeeze(self.matmul(relation1_emb, item_emb_4d))
hist_item_e_4d = self.expand_dims(hist_item_emb, 3)
mul_r2_hist = self.squeeze(self.matmul(relation2_emb, hist_item_e_4d))
relation1_3d = self.reduce_sum(relation1_emb, 2)
relation2_3d = self.reduce_sum(relation2_emb, 2)
path_left = self.reduce_sum(self.abs(mul_r1_item + relation1_3d), 2)
path_right = self.reduce_sum(self.abs(mul_r2_hist + relation2_3d), 2)
transr_loss = self.reduce_sum(self.maximum(self.abs(path_left - path_right), 0))
transr_loss = self.reduce_mean(self.sigmoid(transr_loss))
mid_entity_emb_4d = self.expand_dims(mid_entity_emb, 3)
mul_r2_mid = self.squeeze(self.matmul(relation2_emb, mid_entity_emb_4d))
path_r2_mid = self.abs(mul_r2_mid + relation2_3d)
node_loss = self.reduce_sum(self.maximum(mul_r2_hist - path_r2_mid, 0))
node_loss = self.reduce_mean(self.sigmoid(node_loss))
l2_loss = self.reduce_mean(self.reduce_sum(relation1_emb * relation1_emb))
l2_loss += self.reduce_mean(self.reduce_sum(mid_entity_emb * mid_entity_emb))
l2_loss += self.reduce_mean(self.reduce_sum(relation2_emb * relation2_emb))
l2_loss += self.reduce_mean(self.reduce_sum(hist_item_emb * hist_item_emb))
transr_loss = self.kge_weight * transr_loss
node_loss = self.node_weight * node_loss
l2_loss = self.l2_weight * l2_loss
loss = pred_loss + transr_loss + node_loss + l2_loss
return loss
class TrainStepWrap(nn.Cell):
"""TrainStepWrap definition."""
def __init__(self, network, lr, sens=1):
super(TrainStepWrap, self).__init__(auto_prefix=False)
self.network = network
self.network.set_train()
self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params())
self.lr = lr
self.optimizer = nn.Adam(self.weights,
learning_rate=self.lr,
beta1=0.9,
beta2=0.999,
eps=1e-8,
loss_scale=sens)
self.hyper_map = C.HyperMap()
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, mean, degree)
def construct(self, items, relation1, mid_entity, relation2, hist_item, labels):
"""
Args:
items (Tensor): rated item IDs, int Tensor in shape of [batch size, ].
relation1 (Tensor): relation1 IDs, int Tensor in shape of [batch size, per_item_num_paths]
mid_entity (Tensor): middle entity IDs, int Tensor in shape of [batch size, per_item_num_paths]
relation2 (Tensor): relation2 IDs, int Tensor in shape of [batch size, per_item_num_paths]
hist_item (Tensor): historical item IDs, int Tensor in shape of [batch size, per_item_num_paths]
labels (Tensor): label of rated item record, int Tensor in shape of [batch size, ]
Returns:
loss and gradient
"""
weights = self.weights
loss = self.network(items, relation1, mid_entity, relation2, hist_item, labels)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(items, relation1, mid_entity, relation2, hist_item, labels, sens)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))
class PredictWithSigmoid(nn.Cell):
"""Predict method."""
def __init__(self, network):
super(PredictWithSigmoid, self).__init__(auto_prefix=False)
self.network = network
self.sigmoid = P.Sigmoid()
def construct(self, items, relation1, mid_entity, relation2, hist_item, labels):
"""
Predict with sigmoid definition.
Args:
items (Tensor): rated item IDs, int Tensor in shape of [batch size, ].
relation1 (Tensor): relation1 IDs, int Tensor in shape of [batch size, per_item_num_paths]
mid_entity (Tensor): middle entity IDs, int Tensor in shape of [batch size, per_item_num_paths]
relation2 (Tensor): relation2 IDs, int Tensor in shape of [batch size, per_item_num_paths]
hist_item (Tensor): historical item IDs, int Tensor in shape of [batch size, per_item_num_paths]
labels (Tensor): label of rated item record, int Tensor in shape of [batch size, ]
Returns:
scores (Tensor): model prediction score, float Tensor in shape of [batch size, ]
pred_probs (Tensor): prediction probability, float Tensor in shape of [batch size, ]
labels (Tensor): label of rated item record, int Tensor in shape of [batch size, ]
probs_exp (Tensor): path probability/importance, float Tensor in shape of [batch size, per_item_num_paths]
"""
scores, probs_exp, _, _, _, _, _ = self.network(items, relation1, mid_entity, relation2, hist_item)
pred_probs = self.sigmoid(scores)
return scores, pred_probs, labels, probs_exp

View File

@ -0,0 +1,155 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TB-Net training."""
import os
import argparse
import numpy as np
from mindspore import context, Model, Tensor
from mindspore.train.serialization import save_checkpoint
from mindspore.train.callback import Callback, TimeMonitor
from src import tbnet, config, metrics, dataset
class MyLossMonitor(Callback):
"""My loss monitor definition."""
def epoch_end(self, run_context):
"""Print loss at each epoch end."""
cb_params = run_context.original_args()
loss = cb_params.net_outputs
if isinstance(loss, (tuple, list)):
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
loss = loss[0]
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
loss = np.mean(loss.asnumpy())
print('loss:' + str(loss))
def get_args():
"""Parse commandline arguments."""
parser = argparse.ArgumentParser(description='Train TBNet.')
parser.add_argument(
'--dataset',
type=str,
required=False,
default='steam',
help="'steam' dataset is supported currently"
)
parser.add_argument(
'--train_csv',
type=str,
required=False,
default='train.csv',
help="the train csv datafile inside the dataset folder"
)
parser.add_argument(
'--test_csv',
type=str,
required=False,
default='test.csv',
help="the test csv datafile inside the dataset folder"
)
parser.add_argument(
'--device_id',
type=int,
required=False,
default=0,
help="device id"
)
parser.add_argument(
'--epochs',
type=int,
required=False,
default=20,
help="number of training epochs"
)
parser.add_argument(
'--device_target',
type=str,
required=False,
default='GPU',
choices=['GPU'],
help="run code on GPU"
)
parser.add_argument(
'--run_mode',
type=str,
required=False,
default='graph',
choices=['graph', 'pynative'],
help="run code by GRAPH mode or PYNATIVE mode"
)
return parser.parse_args()
def train_tbnet():
"""Training process."""
args = get_args()
home = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(home, 'data', args.dataset, 'config.json')
train_csv_path = os.path.join(home, 'data', args.dataset, args.train_csv)
test_csv_path = os.path.join(home, 'data', args.dataset, args.test_csv)
ckpt_path = os.path.join(home, 'checkpoints')
context.set_context(device_id=args.device_id)
if args.run_mode == 'graph':
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
else:
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
if not os.path.exists(ckpt_path):
os.makedirs(ckpt_path)
print(f"creating dataset from {train_csv_path}...")
net_config = config.TBNetConfig(config_path)
train_ds = dataset.create(train_csv_path, net_config.per_item_num_paths, train=True).batch(net_config.batch_size)
test_ds = dataset.create(test_csv_path, net_config.per_item_num_paths, train=True).batch(net_config.batch_size)
print("creating TBNet for training...")
network = tbnet.TBNet(net_config)
loss_net = tbnet.NetWithLossClass(network, net_config)
train_net = tbnet.TrainStepWrap(loss_net, net_config.lr)
train_net.set_train()
eval_net = tbnet.PredictWithSigmoid(network)
time_callback = TimeMonitor(data_size=train_ds.get_dataset_size())
loss_callback = MyLossMonitor()
model = Model(network=train_net, eval_network=eval_net, metrics={'auc': metrics.AUC(), 'acc': metrics.ACC()})
print("training...")
for i in range(args.epochs):
print(f'===================== Epoch {i} =====================')
model.train(epoch=1, train_dataset=train_ds, callbacks=[time_callback, loss_callback], dataset_sink_mode=False)
train_out = model.eval(train_ds, dataset_sink_mode=False)
test_out = model.eval(test_ds, dataset_sink_mode=False)
print(f'Train AUC:{train_out["auc"]} ACC:{train_out["acc"]} Test AUC:{test_out["auc"]} ACC:{test_out["acc"]}')
save_checkpoint(network, os.path.join(ckpt_path, f'tbnet_epoch{i}.ckpt'))
if __name__ == '__main__':
train_tbnet()