forked from mindspore-Ecosystem/mindspore
add tbnet for open source in model_zoo/official/recommend
This commit is contained in:
parent
7971b57dcf
commit
9b9dc5ee94
|
@ -0,0 +1,250 @@
|
||||||
|
# Contents
|
||||||
|
|
||||||
|
- [Contents](#contents)
|
||||||
|
- [TBNet Description](#tbnet-description)
|
||||||
|
- [Model Architecture](#model-architecture)
|
||||||
|
- [Dataset](#dataset)
|
||||||
|
- [Environment Requirements](#environment-requirements)
|
||||||
|
- [Quick Start](#quick-start)
|
||||||
|
- [Script Description](#script-description)
|
||||||
|
- [Script and Sample Code](#script-and-sample-code)
|
||||||
|
- [Script Parameters](#script-parameters)
|
||||||
|
- [Model Description](#model-description)
|
||||||
|
- [Performance](#performance)
|
||||||
|
- [Training Performance](#training-performance)
|
||||||
|
- [Evaluation Performance](#evaluation-performance)
|
||||||
|
- [Inference and Explanation Performance](#inference-explanation-performance)
|
||||||
|
- [Description of Random Situation](#description-of-random-situation)
|
||||||
|
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||||
|
|
||||||
|
# [TBNet Description](#contents)
|
||||||
|
|
||||||
|
TB-Net is a knowledge graph based explainable recommender system.
|
||||||
|
|
||||||
|
Paper: Shendi Wang, Haoyang Li, Xiao-Hui Li, Caleb Chen Cao, Lei Chen. Tower Bridge Net (TB-Net): Bidirectional Knowledge Graph Aware Embedding Propagation for Explainable Recommender Systems
|
||||||
|
|
||||||
|
# [Model Architecture](#contents)
|
||||||
|
|
||||||
|
TB-Net constructs subgraphs in knowledge graph based on the interaction between users and items as well as the feature of items, and then calculates paths in the graphs using bidirectional conduction algorithm. Finally we can obtain explainable recommendation results.
|
||||||
|
|
||||||
|
# [Dataset](#contents)
|
||||||
|
|
||||||
|
[Interaction of users and games](https://www.kaggle.com/tamber/steam-video-games), and the [games' feature data](https://www.kaggle.com/nikdavis/steam-store-games?select=steam.csv) on the game platform Steam are public on Kaggle.
|
||||||
|
|
||||||
|
Dataset directory: `./data/{DATASET}/`, e.g. `./data/steam/`.
|
||||||
|
|
||||||
|
- train: train.csv, evaluation: test.csv
|
||||||
|
|
||||||
|
Each line indicates a \<user\>, an \<item\>, the user-item \<rating\> (1 or 0), and PER_ITEM_NUM_PATHS paths between the item and the user's \<hist_item\> (\<hist_item\> is the item whose the user-item \<rating\> in historical data is 1).
|
||||||
|
|
||||||
|
```text
|
||||||
|
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
|
||||||
|
```
|
||||||
|
|
||||||
|
- infer and explain: infer.csv
|
||||||
|
|
||||||
|
Each line indicates the \<user\> and \<item\> to be inferred, \<rating\>, and PER_ITEM_NUM_PATHS paths between the item and the user's \<hist_item\> (\<hist_item\> is the item whose the user-item \<rating\> in historical data is 1).
|
||||||
|
Note that the \<item\> needs to traverse candidate items (all items by default) in the dataset. \<rating\> can be randomly assigned (all values are assigned to 0 by default) and is not used in the inference and explanation phases.
|
||||||
|
|
||||||
|
```text
|
||||||
|
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
|
||||||
|
```
|
||||||
|
|
||||||
|
# [Environment Requirements](#contents)
|
||||||
|
|
||||||
|
- Hardware(GPU)
|
||||||
|
- Prepare hardware environment with GPU processor.
|
||||||
|
- Framework
|
||||||
|
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||||
|
- For more information, please check the resources below:
|
||||||
|
- [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html)
|
||||||
|
- [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html)
|
||||||
|
|
||||||
|
# [Quick Start](#contents)
|
||||||
|
|
||||||
|
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||||
|
|
||||||
|
- Data preprocessing
|
||||||
|
|
||||||
|
Process the data to the format in chapter [Dataset](#Dataset) (e.g. 'steam' dataset), and then run code as follows.
|
||||||
|
|
||||||
|
- Training
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python train.py \
|
||||||
|
--dataset [DATASET] \
|
||||||
|
--epochs [EPOCHS]
|
||||||
|
```
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python train.py \
|
||||||
|
--dataset steam \
|
||||||
|
--epochs 20
|
||||||
|
```
|
||||||
|
|
||||||
|
- Evaluation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python eval.py \
|
||||||
|
--dataset [DATASET] \
|
||||||
|
--checkpoint_id [CHECKPOINT_ID]
|
||||||
|
```
|
||||||
|
|
||||||
|
Argument `--checkpoint_id` is required.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python eval.py \
|
||||||
|
--dataset steam \
|
||||||
|
--checkpoint_id 8
|
||||||
|
```
|
||||||
|
|
||||||
|
- Inference and Explanation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python infer.py \
|
||||||
|
--dataset [DATASET] \
|
||||||
|
--checkpoint_id [CHECKPOINT_ID] \
|
||||||
|
--user [USER] \
|
||||||
|
--items [ITEMS] \
|
||||||
|
--explanations [EXPLANATIONS]
|
||||||
|
```
|
||||||
|
|
||||||
|
Arguments `--checkpoint_id` and `--user` are required.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python infer.py \
|
||||||
|
--dataset steam \
|
||||||
|
--checkpoint_id 8 \
|
||||||
|
--user 1 \
|
||||||
|
--items 1 \
|
||||||
|
--explanations 3
|
||||||
|
```
|
||||||
|
|
||||||
|
# [Script Description](#contents)
|
||||||
|
|
||||||
|
## [Script and Sample Code](#contents)
|
||||||
|
|
||||||
|
```text
|
||||||
|
.
|
||||||
|
└─tbnet
|
||||||
|
├─README.md
|
||||||
|
├─data
|
||||||
|
├─steam
|
||||||
|
├─config.json # data and training parameter configuration
|
||||||
|
├─infer.csv # inference and explanation dataset
|
||||||
|
├─test.csv # evaluation dataset
|
||||||
|
├─train.csv # training dataset
|
||||||
|
└─trainslate.json # explanation configuration
|
||||||
|
├─src
|
||||||
|
├─aggregator.py # inference result aggregation
|
||||||
|
├─config.py # parsing parameter configuration
|
||||||
|
├─dataset.py # generate dataset
|
||||||
|
├─embedding.py # 3-dim embedding matrix initialization
|
||||||
|
├─metrics.py # model metrics
|
||||||
|
├─steam.py # 'steam' dataset text explainer
|
||||||
|
└─tbnet.py # TB-Net model
|
||||||
|
├─eval.py # evaluation
|
||||||
|
├─infer.py # inference and explanation
|
||||||
|
└─train.py # training
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Script Parameters](#contents)
|
||||||
|
|
||||||
|
- train.py parameters
|
||||||
|
|
||||||
|
```text
|
||||||
|
--dataset 'steam' dataset is supported currently
|
||||||
|
--train_csv the train csv datafile inside the dataset folder
|
||||||
|
--test_csv the test csv datafile inside the dataset folder
|
||||||
|
--device_id device id
|
||||||
|
--epochs number of training epochs
|
||||||
|
--device_target run code on GPU
|
||||||
|
--run_mode run code by GRAPH mode or PYNATIVE mode
|
||||||
|
```
|
||||||
|
|
||||||
|
- eval.py parameters
|
||||||
|
|
||||||
|
```text
|
||||||
|
--dataset 'steam' dataset is supported currently
|
||||||
|
--csv the csv datafile inside the dataset folder (e.g. test.csv)
|
||||||
|
--checkpoint_id use which checkpoint(.ckpt) file to eval
|
||||||
|
--device_id device id
|
||||||
|
--device_target run code on GPU
|
||||||
|
--run_mode run code by GRAPH mode or PYNATIVE mode
|
||||||
|
```
|
||||||
|
|
||||||
|
- infer.py parameters
|
||||||
|
|
||||||
|
```text
|
||||||
|
--dataset 'steam' dataset is supported currently
|
||||||
|
--csv the csv datafile inside the dataset folder (e.g. infer.csv)
|
||||||
|
--checkpoint_id use which checkpoint(.ckpt) file to infer
|
||||||
|
--user id of the user to be recommended to
|
||||||
|
--items no. of items to be recommended
|
||||||
|
--reasons no. of recommendation reasons to be shown
|
||||||
|
--device_id device id
|
||||||
|
--device_target run code on GPU
|
||||||
|
--run_mode run code by GRAPH mode or PYNATIVE mode
|
||||||
|
```
|
||||||
|
|
||||||
|
# [Model Description](#contents)
|
||||||
|
|
||||||
|
## [Performance](#contents)
|
||||||
|
|
||||||
|
### Training Performance
|
||||||
|
|
||||||
|
| Parameters | GPU |
|
||||||
|
| -------------------------- | ----------------------------------------------------------- |
|
||||||
|
| Model Version | TB-Net |
|
||||||
|
| Resource |Tesla V100-SXM2-32GB |
|
||||||
|
| Uploaded Date | 2021-08-01 |
|
||||||
|
| MindSpore Version | 1.3.0 |
|
||||||
|
| Dataset | steam |
|
||||||
|
| Training Parameter | epoch=20, batch_size=1024, lr=0.001 |
|
||||||
|
| Optimizer | Adam |
|
||||||
|
| Loss Function | Sigmoid Cross Entropy |
|
||||||
|
| Outputs | AUC=0.8596,Accuracy=0.7761 |
|
||||||
|
| Loss | 0.57 |
|
||||||
|
| Speed | 1pc: 90ms/step |
|
||||||
|
| Total Time | 1pc: 297s |
|
||||||
|
| Checkpoint for Fine Tuning | 104.66M (.ckpt file) |
|
||||||
|
| Scripts | [TB-Net scripts](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/tbnet) |
|
||||||
|
|
||||||
|
### Evaluation Performance
|
||||||
|
|
||||||
|
| Parameters | GPU |
|
||||||
|
| ------------------------- | ----------------------------- |
|
||||||
|
| Model Version | TB-Net |
|
||||||
|
| Resource | Tesla V100-SXM2-32GB |
|
||||||
|
| Uploaded Date | 2021-08-01 |
|
||||||
|
| MindSpore Version | 1.3.0 |
|
||||||
|
| Dataset | steam |
|
||||||
|
| Batch Size | 1024 |
|
||||||
|
| Outputs | AUC=0.8252,Accuracy=0.7503 |
|
||||||
|
| Total Time | 1pc: 5.7s |
|
||||||
|
|
||||||
|
### Inference and Explanation Performance
|
||||||
|
|
||||||
|
| Parameters | GPU |
|
||||||
|
| --------------------------| ------------------------------------- |
|
||||||
|
| Model Version | TB-Net |
|
||||||
|
| Resource | Tesla V100-SXM2-32GB |
|
||||||
|
| Uploaded Date | 2021-08-01 |
|
||||||
|
| MindSpore Version | 1.3.0 |
|
||||||
|
| Dataset | steam |
|
||||||
|
| Outputs | Recommendation Result and Explanation |
|
||||||
|
| Total Time | 1pc: 3.66s |
|
||||||
|
|
||||||
|
# [Description of Random Situation](#contents)
|
||||||
|
|
||||||
|
- Initialization of embedding matrix in `tbnet.py` and `embedding.py`.
|
||||||
|
|
||||||
|
# [ModelZoo Homepage](#contents)
|
||||||
|
|
||||||
|
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,252 @@
|
||||||
|
# 目录
|
||||||
|
|
||||||
|
<!-- TOC -->
|
||||||
|
|
||||||
|
- [目录](#目录)
|
||||||
|
- [TBNet概述](#tbnet概述)
|
||||||
|
- [模型架构](#模型架构)
|
||||||
|
- [数据集](#数据集)
|
||||||
|
- [环境要求](#环境要求)
|
||||||
|
- [快速入门](#快速入门)
|
||||||
|
- [脚本说明](#脚本说明)
|
||||||
|
- [脚本和样例代码](#脚本和样例代码)
|
||||||
|
- [脚本参数](#脚本参数)
|
||||||
|
- [模型描述](#模型描述)
|
||||||
|
- [性能](#性能)
|
||||||
|
- [训练性能](#训练性能)
|
||||||
|
- [评估性能](#评估性能)
|
||||||
|
- [推理和解释性能](#推理和解释性能)
|
||||||
|
- [随机情况说明](#随机情况说明)
|
||||||
|
- [ModelZoo主页](#modelzoo主页)
|
||||||
|
|
||||||
|
# [TBNet概述](#目录)
|
||||||
|
|
||||||
|
TB-Net是一个基于知识图谱的可解释推荐系统。
|
||||||
|
|
||||||
|
论文:Shendi Wang, Haoyang Li, Xiao-Hui Li, Caleb Chen Cao, Lei Chen. Tower Bridge Net (TB-Net): Bidirectional Knowledge Graph Aware Embedding Propagation for Explainable Recommender Systems
|
||||||
|
|
||||||
|
# [模型架构](#目录)
|
||||||
|
|
||||||
|
TB-Net将用户和物品的交互信息以及物品的属性信息在知识图谱中构建子图,并利用双向传导的计算方法对图谱中的路径进行计算,最后得到可解释的推荐结果。
|
||||||
|
|
||||||
|
# [数据集](#目录)
|
||||||
|
|
||||||
|
本示例提供Kaggle上的Steam游戏平台公开数据集,包含[用户与游戏的交互记录](https://www.kaggle.com/tamber/steam-video-games)和[游戏的属性信息](https://www.kaggle.com/nikdavis/steam-store-games?select=steam.csv)。
|
||||||
|
|
||||||
|
数据集路径:`./data/{DATASET}/`,如:`./data/steam/`。
|
||||||
|
|
||||||
|
- 训练:train.csv,评估:test.csv
|
||||||
|
|
||||||
|
每一行记录代表某\<user\>对某\<item\>的\<rating\>(1或0),以及该\<item\>与\<hist_item\>(即该\<user\>历史\<rating\>为1的\<item\>)的PER_ITEM_NUM_PATHS条路径。
|
||||||
|
|
||||||
|
```text
|
||||||
|
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
|
||||||
|
```
|
||||||
|
|
||||||
|
- 推理和解释:infer.csv
|
||||||
|
|
||||||
|
每一行记录代表**待推理**的\<user\>和\<item\>,\<rating\>,以及该\<item\>与\<hist_item\>(即该\<user\>历史\<rating\>为1的\<item\>)的PER_ITEM_NUM_PATHS条路径。
|
||||||
|
其中\<item\>需要遍历数据集中**所有**待推荐物品(默认所有物品);\<rating\>可随机赋值(默认全部赋值为0),在推理和解释阶段不会使用。
|
||||||
|
|
||||||
|
```text
|
||||||
|
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
|
||||||
|
```
|
||||||
|
|
||||||
|
# [环境要求](#目录)
|
||||||
|
|
||||||
|
- 硬件(GPU)
|
||||||
|
- 使用GPU处理器准备硬件环境。
|
||||||
|
- 框架
|
||||||
|
- [MindSpore](https://www.mindspore.cn/install)
|
||||||
|
- 如需查看详情,请参见如下资源:
|
||||||
|
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
|
||||||
|
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
|
||||||
|
|
||||||
|
# [快速入门](#目录)
|
||||||
|
|
||||||
|
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练、评估、推理和解释:
|
||||||
|
|
||||||
|
- 数据准备
|
||||||
|
|
||||||
|
将数据处理成上一节[数据集](#数据集)中的格式(以'steam'数据集为例),然后按照以下步骤运行代码。
|
||||||
|
|
||||||
|
- 训练
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python train.py \
|
||||||
|
--dataset [DATASET] \
|
||||||
|
--epochs [EPOCHS]
|
||||||
|
```
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python train.py \
|
||||||
|
--dataset steam \
|
||||||
|
--epochs 20
|
||||||
|
```
|
||||||
|
|
||||||
|
- 评估
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python eval.py \
|
||||||
|
--dataset [DATASET] \
|
||||||
|
--checkpoint_id [CHECKPOINT_ID]
|
||||||
|
```
|
||||||
|
|
||||||
|
参数`--checkpoint_id`是必填项。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python eval.py \
|
||||||
|
--dataset steam \
|
||||||
|
--checkpoint_id 8
|
||||||
|
```
|
||||||
|
|
||||||
|
- 推理和解释
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python infer.py \
|
||||||
|
--dataset [DATASET] \
|
||||||
|
--checkpoint_id [CHECKPOINT_ID] \
|
||||||
|
--user [USER] \
|
||||||
|
--items [ITEMS] \
|
||||||
|
--explanations [EXPLANATIONS]
|
||||||
|
```
|
||||||
|
|
||||||
|
参数`--checkpoint_id`和`--user`是必填项。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python infer.py \
|
||||||
|
--dataset steam \
|
||||||
|
--checkpoint_id 8 \
|
||||||
|
--user 1 \
|
||||||
|
--items 1 \
|
||||||
|
--explanations 3
|
||||||
|
```
|
||||||
|
|
||||||
|
# [脚本说明](#目录)
|
||||||
|
|
||||||
|
## [脚本和样例代码](#目录)
|
||||||
|
|
||||||
|
```text
|
||||||
|
.
|
||||||
|
└─tbnet
|
||||||
|
├─README.md
|
||||||
|
├─data
|
||||||
|
├─steam
|
||||||
|
├─config.json # 数据和训练参数配置
|
||||||
|
├─infer.csv # 推理和解释数据集
|
||||||
|
├─test.csv # 测试数据集
|
||||||
|
├─train.csv # 训练数据集
|
||||||
|
└─trainslate.json # 输出解释相关配置
|
||||||
|
├─src
|
||||||
|
├─aggregator.py # 推理结果聚合
|
||||||
|
├─config.py # 参数配置解析
|
||||||
|
├─dataset.py # 创建数据集
|
||||||
|
├─embedding.py # 三维embedding矩阵初始化
|
||||||
|
├─metrics.py # 模型度量
|
||||||
|
├─steam.py # 'steam'数据集文本解析
|
||||||
|
└─tbnet.py # TB-Net网络
|
||||||
|
├─eval.py # 评估网络
|
||||||
|
├─infer.py # 推理和解释
|
||||||
|
└─train.py # 训练网络
|
||||||
|
```
|
||||||
|
|
||||||
|
## [脚本参数](#目录)
|
||||||
|
|
||||||
|
- train.py参数
|
||||||
|
|
||||||
|
```text
|
||||||
|
--dataset 'steam' dataset is supported currently
|
||||||
|
--train_csv the train csv datafile inside the dataset folder
|
||||||
|
--test_csv the test csv datafile inside the dataset folder
|
||||||
|
--device_id device id
|
||||||
|
--epochs number of training epochs
|
||||||
|
--device_target run code on GPU
|
||||||
|
--run_mode run code by GRAPH mode or PYNATIVE mode
|
||||||
|
```
|
||||||
|
|
||||||
|
- eval.py参数
|
||||||
|
|
||||||
|
```text
|
||||||
|
--dataset 'steam' dataset is supported currently
|
||||||
|
--csv the csv datafile inside the dataset folder (e.g. test.csv)
|
||||||
|
--checkpoint_id use which checkpoint(.ckpt) file to eval
|
||||||
|
--device_id device id
|
||||||
|
--device_target run code on GPU
|
||||||
|
--run_mode run code by GRAPH mode or PYNATIVE mode
|
||||||
|
```
|
||||||
|
|
||||||
|
- infer.py参数
|
||||||
|
|
||||||
|
```text
|
||||||
|
--dataset 'steam' dataset is supported currently
|
||||||
|
--csv the csv datafile inside the dataset folder (e.g. infer.csv)
|
||||||
|
--checkpoint_id use which checkpoint(.ckpt) file to infer
|
||||||
|
--user id of the user to be recommended to
|
||||||
|
--items no. of items to be recommended
|
||||||
|
--reasons no. of recommendation reasons to be shown
|
||||||
|
--device_id device id
|
||||||
|
--device_target run code on GPU
|
||||||
|
--run_mode run code by GRAPH mode or PYNATIVE mode
|
||||||
|
```
|
||||||
|
|
||||||
|
# [模型描述](#目录)
|
||||||
|
|
||||||
|
## [性能](#目录)
|
||||||
|
|
||||||
|
### [训练性能](#目录)
|
||||||
|
|
||||||
|
| 参数 | GPU |
|
||||||
|
| ------------------- | --------------------------------------------------- |
|
||||||
|
| 模型版本 | TB-Net |
|
||||||
|
| 资源 |Tesla V100-SXM2-32GB |
|
||||||
|
| 上传日期 | 2021-08-01 |
|
||||||
|
| MindSpore版本 | 1.3.0 |
|
||||||
|
| 数据集 | steam |
|
||||||
|
| 训练参数 | epoch=20, batch_size=1024, lr=0.001 |
|
||||||
|
| 优化器 | Adam |
|
||||||
|
| 损失函数 | Sigmoid交叉熵 |
|
||||||
|
| 输出 | AUC=0.8596,准确率=0.7761 |
|
||||||
|
| 损失 | 0.57 |
|
||||||
|
| 速度 | 单卡:90毫秒/步 |
|
||||||
|
| 总时长 | 单卡:297秒 |
|
||||||
|
| 微调检查点 | 104.66M (.ckpt 文件) |
|
||||||
|
| 脚本 | [TB-Net脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/tbnet) |
|
||||||
|
|
||||||
|
### [评估性能](#目录)
|
||||||
|
|
||||||
|
| 参数 | GPU |
|
||||||
|
| -------------------------- | ----------------------------- |
|
||||||
|
| 模型版本 | TB-Net |
|
||||||
|
| 资源 | Tesla V100-SXM2-32GB |
|
||||||
|
| 上传日期 | 2021-08-01 |
|
||||||
|
| MindSpore版本 | 1.3.0 |
|
||||||
|
| 数据集 | steam |
|
||||||
|
| 批次大小 | 1024 |
|
||||||
|
| 输出 | AUC=0.8252,准确率=0.7503 |
|
||||||
|
| 总时长 | 单卡:5.7秒 |
|
||||||
|
|
||||||
|
### [推理和解释性能](#目录)
|
||||||
|
|
||||||
|
| 参数 | GPU |
|
||||||
|
| -------------------------- | ----------------------------- |
|
||||||
|
| 模型版本 | TB-Net |
|
||||||
|
| 资源 | Tesla V100-SXM2-32GB |
|
||||||
|
| 上传日期 | 2021-08-01 |
|
||||||
|
| MindSpore版本 | 1.3.0 |
|
||||||
|
| 数据集 | steam |
|
||||||
|
| 输出 | 推荐结果和解释结果 |
|
||||||
|
| 总时长 | 单卡:3.66秒 |
|
||||||
|
|
||||||
|
# [随机情况说明](#目录)
|
||||||
|
|
||||||
|
- `tbnet.py`和`embedding.py`中Embedding矩阵的随机初始化。
|
||||||
|
|
||||||
|
# [ModelZoo主页](#目录)
|
||||||
|
|
||||||
|
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,12 @@
|
||||||
|
{
|
||||||
|
"num_item": 3005,
|
||||||
|
"num_relation": 5,
|
||||||
|
"num_entity": 5138,
|
||||||
|
"per_item_num_paths": 39,
|
||||||
|
"embedding_dim": 26,
|
||||||
|
"batch_size": 1024,
|
||||||
|
"lr": 0.001,
|
||||||
|
"kge_weight": 0.05,
|
||||||
|
"node_weight": 0.002,
|
||||||
|
"l2_weight": 1e-6
|
||||||
|
}
|
|
@ -0,0 +1 @@
|
||||||
|
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
|
|
|
@ -0,0 +1 @@
|
||||||
|
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
|
|
|
@ -0,0 +1 @@
|
||||||
|
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
{
|
||||||
|
"item": {
|
||||||
|
"0": "Star Wars",
|
||||||
|
"1": "Battlefield 1"
|
||||||
|
},
|
||||||
|
"relation": {
|
||||||
|
"0": "Developer",
|
||||||
|
"1": "Genre"
|
||||||
|
},
|
||||||
|
"entity": {
|
||||||
|
"425": "EA Games",
|
||||||
|
"426": "Shooting"
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,117 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""TB-Net evaluation."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from mindspore import context, Model, load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
|
from src import tbnet, config, metrics, dataset
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
"""Parse commandline arguments."""
|
||||||
|
parser = argparse.ArgumentParser(description='Train TBNet.')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--dataset',
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default='steam',
|
||||||
|
help="'steam' dataset is supported currently"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--csv',
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default='test.csv',
|
||||||
|
help="the csv datafile inside the dataset folder (e.g. test.csv)"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--checkpoint_id',
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="use which checkpoint(.ckpt) file to eval"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--device_id',
|
||||||
|
type=int,
|
||||||
|
required=False,
|
||||||
|
default=0,
|
||||||
|
help="device id"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--device_target',
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default='GPU',
|
||||||
|
choices=['GPU'],
|
||||||
|
help="run code on GPU"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--run_mode',
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default='graph',
|
||||||
|
choices=['graph', 'pynative'],
|
||||||
|
help="run code by GRAPH mode or PYNATIVE mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def eval_tbnet():
|
||||||
|
"""Evaluation process."""
|
||||||
|
args = get_args()
|
||||||
|
|
||||||
|
home = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
config_path = os.path.join(home, 'data', args.dataset, 'config.json')
|
||||||
|
test_csv_path = os.path.join(home, 'data', args.dataset, args.csv)
|
||||||
|
ckpt_path = os.path.join(home, 'checkpoints')
|
||||||
|
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
if args.run_mode == 'graph':
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
else:
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
|
||||||
|
|
||||||
|
print(f"creating dataset from {test_csv_path}...")
|
||||||
|
net_config = config.TBNetConfig(config_path)
|
||||||
|
eval_ds = dataset.create(test_csv_path, net_config.per_item_num_paths, train=True).batch(net_config.batch_size)
|
||||||
|
|
||||||
|
print(f"creating TBNet from checkpoint {args.checkpoint_id} for evaluation...")
|
||||||
|
network = tbnet.TBNet(net_config)
|
||||||
|
param_dict = load_checkpoint(os.path.join(ckpt_path, f'tbnet_epoch{args.checkpoint_id}.ckpt'))
|
||||||
|
load_param_into_net(network, param_dict)
|
||||||
|
|
||||||
|
loss_net = tbnet.NetWithLossClass(network, net_config)
|
||||||
|
train_net = tbnet.TrainStepWrap(loss_net, net_config.lr)
|
||||||
|
train_net.set_train()
|
||||||
|
eval_net = tbnet.PredictWithSigmoid(network)
|
||||||
|
model = Model(network=train_net, eval_network=eval_net, metrics={'auc': metrics.AUC(), 'acc': metrics.ACC()})
|
||||||
|
|
||||||
|
print("evaluating...")
|
||||||
|
e_out = model.eval(eval_ds, dataset_sink_mode=False)
|
||||||
|
print(f'Test AUC:{e_out ["auc"]} ACC:{e_out ["acc"]}')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
eval_tbnet()
|
|
@ -0,0 +1,162 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""TBNet inference."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from mindspore import load_checkpoint, load_param_into_net, context
|
||||||
|
from src.config import TBNetConfig
|
||||||
|
from src.tbnet import TBNet
|
||||||
|
from src.aggregator import InferenceAggregator
|
||||||
|
from src import dataset
|
||||||
|
from src import steam
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
"""Parse commandline arguments."""
|
||||||
|
parser = argparse.ArgumentParser(description='Infer TBNet.')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--dataset',
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default='steam',
|
||||||
|
help="'steam' dataset is supported currently"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--csv',
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default='infer.csv',
|
||||||
|
help="the csv datafile inside the dataset folder (e.g. infer.csv)"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--checkpoint_id',
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="use which checkpoint(.ckpt) file to infer"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--user',
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="id of the user to be recommended to"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--items',
|
||||||
|
type=int,
|
||||||
|
required=False,
|
||||||
|
default=1,
|
||||||
|
help="no. of items to be recommended"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--explanations',
|
||||||
|
type=int,
|
||||||
|
required=False,
|
||||||
|
default=3,
|
||||||
|
help="no. of recommendation explanations to be shown"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--device_id',
|
||||||
|
type=int,
|
||||||
|
required=False,
|
||||||
|
default=0,
|
||||||
|
help="device id"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--device_target',
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default='GPU',
|
||||||
|
choices=['GPU'],
|
||||||
|
help="run code on GPU"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--run_mode',
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default='graph',
|
||||||
|
choices=['graph', 'pynative'],
|
||||||
|
help="run code by GRAPH mode or PYNATIVE mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def infer_tbnet():
|
||||||
|
"""Inference process."""
|
||||||
|
args = get_args()
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
if args.run_mode == 'graph':
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
else:
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
|
||||||
|
|
||||||
|
home = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
config_path = os.path.join(home, 'data', args.dataset, 'config.json')
|
||||||
|
translate_path = os.path.join(home, 'data', args.dataset, 'translate.json')
|
||||||
|
data_path = os.path.join(home, 'data', args.dataset, args.csv)
|
||||||
|
ckpt_path = os.path.join(home, 'checkpoints')
|
||||||
|
|
||||||
|
print(f"creating TBNet from checkpoint {args.checkpoint_id}...")
|
||||||
|
config = TBNetConfig(config_path)
|
||||||
|
network = TBNet(config)
|
||||||
|
param_dict = load_checkpoint(os.path.join(ckpt_path, f'tbnet_epoch{args.checkpoint_id}.ckpt'))
|
||||||
|
load_param_into_net(network, param_dict)
|
||||||
|
|
||||||
|
print(f"creating dataset from {data_path}...")
|
||||||
|
infer_ds = dataset.create(data_path, config.per_item_num_paths, train=False, users=args.user)
|
||||||
|
infer_ds = infer_ds.batch(config.batch_size)
|
||||||
|
|
||||||
|
print("inferring...")
|
||||||
|
# infer and aggregate results
|
||||||
|
aggregator = InferenceAggregator(top_k=args.items)
|
||||||
|
for user, item, relation1, entity, relation2, hist_item, rating in infer_ds:
|
||||||
|
del rating
|
||||||
|
result = network(item, relation1, entity, relation2, hist_item)
|
||||||
|
item_score = result[0]
|
||||||
|
path_importance = result[1]
|
||||||
|
aggregator.aggregate(user, item, relation1, entity, relation2, hist_item, item_score, path_importance)
|
||||||
|
|
||||||
|
# show recommendations with explanations
|
||||||
|
explainer = steam.TextExplainer(translate_path)
|
||||||
|
recomms = aggregator.recommend()
|
||||||
|
for user, recomm in recomms.items():
|
||||||
|
for item_rec in recomm.item_records:
|
||||||
|
|
||||||
|
item_name = explainer.translate_item(item_rec.item)
|
||||||
|
print(f"Recommend <{item_name}> to user:{user}, because:")
|
||||||
|
|
||||||
|
# show explanations
|
||||||
|
explanation = 0
|
||||||
|
for path in item_rec.paths:
|
||||||
|
print(" - " + explainer.explain(path))
|
||||||
|
explanation += 1
|
||||||
|
if explanation >= args.explanations:
|
||||||
|
break
|
||||||
|
print("")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
infer_tbnet()
|
|
@ -0,0 +1 @@
|
||||||
|
sklearn
|
|
@ -0,0 +1,151 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Inference result aggregator."""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
|
class Recommendation:
|
||||||
|
"""Recommendation."""
|
||||||
|
|
||||||
|
class Path:
|
||||||
|
"""Item path."""
|
||||||
|
def __init__(self, relation1, entity, relation2, hist_item, importance):
|
||||||
|
self.relation1 = relation1
|
||||||
|
self.entity = entity
|
||||||
|
self.relation2 = relation2
|
||||||
|
self.hist_item = hist_item
|
||||||
|
self.importance = importance
|
||||||
|
|
||||||
|
class ItemRecord:
|
||||||
|
"""Recommended item info."""
|
||||||
|
def __init__(self, item, score):
|
||||||
|
self.item = item
|
||||||
|
self.score = score
|
||||||
|
# paths must be sorted with importance in descending order
|
||||||
|
self.paths = []
|
||||||
|
|
||||||
|
def __init__(self, user):
|
||||||
|
self.user = user
|
||||||
|
# item_records must be sorted with score in descending order
|
||||||
|
self.item_records = []
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceAggregator:
|
||||||
|
"""
|
||||||
|
Inference result aggregator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
top_k (int): The number of items to be recommended for each distinct user.
|
||||||
|
"""
|
||||||
|
def __init__(self, top_k=1):
|
||||||
|
if top_k < 1:
|
||||||
|
raise ValueError('top_k is less than 1.')
|
||||||
|
self._top_k = top_k
|
||||||
|
self._user_recomms = dict()
|
||||||
|
self._paths_sorted = False
|
||||||
|
|
||||||
|
def aggregate(self, user, item, relation1, entity, relation2, hist_item, item_score, path_importance):
|
||||||
|
"""
|
||||||
|
Aggregate inference results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user (Tensor): User IDs, int Tensor in shape of [N, ].
|
||||||
|
item (Tensor): Candidate item IDs, int Tensor in shape of [N, ].
|
||||||
|
relation1 (Tensor): IDs of item-entity relations, int Tensor in shape of [N, <no. of per-item path>].
|
||||||
|
entity (Tensor): Entity IDs, int Tensor in shape of [N, <no. of per-item path>].
|
||||||
|
relation2 (Tensor): IDs of entity-hist_item relations, int Tensor in shape of [N, <no. of per-item path>].
|
||||||
|
hist_item (Tensor): Historical item IDs, int Tensor in shape of [N, <no. of per-item path>].
|
||||||
|
item_score (Tensor): TBNet output, recommendation scores of candidate items, float Tensor in shape of [N, ].
|
||||||
|
path_importance (Tensor): TBNet output, the importance of each item to hist_item path for the
|
||||||
|
recommendations, float Tensor in shape of [N, <no. of per-item path>].
|
||||||
|
"""
|
||||||
|
user = user.asnumpy()
|
||||||
|
item = item.asnumpy()
|
||||||
|
relation1 = relation1.asnumpy()
|
||||||
|
entity = entity.asnumpy()
|
||||||
|
relation2 = relation2.asnumpy()
|
||||||
|
hist_item = hist_item.asnumpy()
|
||||||
|
item_score = item_score.asnumpy()
|
||||||
|
path_importance = path_importance.asnumpy()
|
||||||
|
|
||||||
|
batch_size = user.shape[0]
|
||||||
|
|
||||||
|
added_users = set()
|
||||||
|
for i in range(batch_size):
|
||||||
|
if self._add(user[i], item[i], relation1[i], entity[i], relation2[i],
|
||||||
|
hist_item[i], item_score[i], path_importance[i]):
|
||||||
|
added_users.add(user[i])
|
||||||
|
self._paths_sorted = False
|
||||||
|
|
||||||
|
for added_user in added_users:
|
||||||
|
recomm = self._user_recomms[added_user]
|
||||||
|
if len(recomm.item_records) > self._top_k:
|
||||||
|
recomm.item_records = recomm.item_records[0:self._top_k]
|
||||||
|
|
||||||
|
def recommend(self):
|
||||||
|
"""
|
||||||
|
Generate recommendations for all distinct users.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[int, Recommendation], a dictionary with user id as keys and Recommendation objects as values.
|
||||||
|
"""
|
||||||
|
if not self._paths_sorted:
|
||||||
|
self._sort_paths()
|
||||||
|
return copy.deepcopy(self._user_recomms)
|
||||||
|
|
||||||
|
def _add(self, user, item, relation1, entity, relation2, hist_item, item_score, path_importance):
|
||||||
|
"""Add a single infer record."""
|
||||||
|
|
||||||
|
recomm = self._user_recomms.get(user, None)
|
||||||
|
if recomm is None:
|
||||||
|
recomm = Recommendation(user)
|
||||||
|
self._user_recomms[user] = recomm
|
||||||
|
|
||||||
|
# insert at the appropriate position
|
||||||
|
for i, old_item_rec in enumerate(recomm.item_records):
|
||||||
|
if i >= self._top_k:
|
||||||
|
return False
|
||||||
|
if item_score > old_item_rec.score:
|
||||||
|
rec = self._infer_2_item_rec(item, relation1, entity, relation2,
|
||||||
|
hist_item, item_score, path_importance)
|
||||||
|
recomm.item_records.insert(i, rec)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# append if has rooms
|
||||||
|
if len(recomm.item_records) < self._top_k:
|
||||||
|
rec = self._infer_2_item_rec(item, relation1, entity, relation2,
|
||||||
|
hist_item, item_score, path_importance)
|
||||||
|
recomm.item_records.append(rec)
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _infer_2_item_rec(item, relation1, entity, relation2, hist_item, item_score, path_importance):
|
||||||
|
"""Converts a single infer result to a item record."""
|
||||||
|
item_rec = Recommendation.ItemRecord(item, item_score)
|
||||||
|
num_paths = path_importance.shape[0]
|
||||||
|
for i in range(num_paths):
|
||||||
|
path = Recommendation.Path(relation1[i], entity[i], relation2[i], hist_item[i], path_importance[i])
|
||||||
|
item_rec.paths.append(path)
|
||||||
|
return item_rec
|
||||||
|
|
||||||
|
def _sort_paths(self):
|
||||||
|
"""Sort all item paths."""
|
||||||
|
for recomm in self._user_recomms.values():
|
||||||
|
for item_rec in recomm.item_records:
|
||||||
|
item_rec.paths.sort(key=lambda x: x.importance, reverse=True)
|
||||||
|
self._paths_sorted = True
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""TBNet configurations."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class TBNetConfig:
|
||||||
|
"""
|
||||||
|
TBNet config file parser and holder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path (str): json config file path.
|
||||||
|
"""
|
||||||
|
def __init__(self, config_path):
|
||||||
|
with open(config_path) as f:
|
||||||
|
json_dict = json.load(f)
|
||||||
|
self.num_item = int(json_dict['num_item'])
|
||||||
|
self.num_relation = int(json_dict['num_relation'])
|
||||||
|
self.num_entity = int(json_dict['num_entity'])
|
||||||
|
self.per_item_num_paths = int(json_dict['per_item_num_paths'])
|
||||||
|
self.embedding_dim = int(json_dict['embedding_dim'])
|
||||||
|
self.batch_size = int(json_dict['batch_size'])
|
||||||
|
self.lr = float(json_dict['lr'])
|
||||||
|
self.kge_weight = float(json_dict['kge_weight'])
|
||||||
|
self.node_weight = float(json_dict['node_weight'])
|
||||||
|
self.l2_weight = float(json_dict['l2_weight'])
|
|
@ -0,0 +1,88 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Dataset loader."""
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from mindspore.dataset import GeneratorDataset
|
||||||
|
|
||||||
|
|
||||||
|
def create(data_path, per_item_num_paths, train, users=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Create a dataset for TBNet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_path (str): The csv datafile path.
|
||||||
|
per_item_num_paths (int): The number of paths per item.
|
||||||
|
train (bool): True to create for training with columns:
|
||||||
|
'item', 'relation1', 'entity', 'relation2', 'hist_item', 'rating'
|
||||||
|
otherwise:
|
||||||
|
'user', 'item', 'relation1', 'entity', 'relation2', 'hist_item', 'rating'
|
||||||
|
users (Union[list[int], int], optional): Users data to be loaded, if None is provided, all data will be loaded.
|
||||||
|
**kwargs (any): Other arguments for GeneratorDataset(), except 'source' and 'column_names'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GeneratorDataset, the generator dataset that reads from the csv datafile.
|
||||||
|
"""
|
||||||
|
if isinstance(users, int):
|
||||||
|
users = (users,)
|
||||||
|
kwargs['source'] = partial(csv_generator, data_path, per_item_num_paths, users, train)
|
||||||
|
|
||||||
|
if train:
|
||||||
|
kwargs['column_names'] = ['item', 'relation1', 'entity', 'relation2', 'hist_item', 'rating']
|
||||||
|
else:
|
||||||
|
kwargs['column_names'] = ['user', 'item', 'relation1', 'entity', 'relation2', 'hist_item', 'rating']
|
||||||
|
return GeneratorDataset(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def csv_generator(csv_path, per_item_num_paths, users, train):
|
||||||
|
"""Generator for csv datafile."""
|
||||||
|
expected_columns = 3 + (per_item_num_paths * 4)
|
||||||
|
file = open(csv_path)
|
||||||
|
for line in file:
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line[0] == '#':
|
||||||
|
continue
|
||||||
|
id_list = line.split(',')
|
||||||
|
if len(id_list) < expected_columns:
|
||||||
|
raise ValueError(f'Expecting {expected_columns} values but got {len(id_list)} only!')
|
||||||
|
id_list = list(map(int, id_list))
|
||||||
|
user = id_list[0]
|
||||||
|
if users and user not in users:
|
||||||
|
continue
|
||||||
|
item = id_list[1]
|
||||||
|
rating = id_list[2]
|
||||||
|
|
||||||
|
relation1 = np.empty(shape=(per_item_num_paths,), dtype=np.int)
|
||||||
|
entity = np.empty_like(relation1)
|
||||||
|
relation2 = np.empty_like(relation1)
|
||||||
|
hist_item = np.empty_like(relation1)
|
||||||
|
|
||||||
|
for p in range(per_item_num_paths):
|
||||||
|
offset = 3 + (p * 4)
|
||||||
|
relation1[p] = id_list[offset]
|
||||||
|
entity[p] = id_list[offset + 1]
|
||||||
|
relation2[p] = id_list[offset + 2]
|
||||||
|
hist_item[p] = id_list[offset + 3]
|
||||||
|
|
||||||
|
if train:
|
||||||
|
# item, relation1, entity, relation2, hist_item, rating
|
||||||
|
yield np.array(item, dtype=np.int), relation1, entity, relation2, hist_item, \
|
||||||
|
np.array(rating, dtype=np.float32)
|
||||||
|
else:
|
||||||
|
# user, item, relation1, entity, relation2, hist_item, rating
|
||||||
|
yield np.array(user, dtype=np.int), np.array(item, dtype=np.int),\
|
||||||
|
relation1, entity, relation2, hist_item, np.array(rating, dtype=np.float32)
|
|
@ -0,0 +1,75 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Three-dimension embedding vector initialization."""
|
||||||
|
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore.common.tensor import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common.parameter import Parameter
|
||||||
|
from mindspore.common.initializer import initializer
|
||||||
|
from mindspore._checkparam import Rel
|
||||||
|
from mindspore._checkparam import Validator as validator
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingMatrix(Cell):
|
||||||
|
"""
|
||||||
|
Support three-dimension embedding vector initialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vocab_size, embedding_size, embedding_table='normal',
|
||||||
|
dtype=mstype.float32, padding_idx=None):
|
||||||
|
super(EmbeddingMatrix, self).__init__()
|
||||||
|
self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
|
||||||
|
self.embedding_size = validator.check_value_type('embedding_size', embedding_size,
|
||||||
|
[int, tuple, list], self.cls_name)
|
||||||
|
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
|
||||||
|
self.dtype = dtype
|
||||||
|
if isinstance(self.embedding_size, int):
|
||||||
|
self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size])
|
||||||
|
self.embedding_out = (self.embedding_size,)
|
||||||
|
else:
|
||||||
|
if len(self.embedding_size) != 2:
|
||||||
|
raise ValueError("embedding_size should be a int or a tuple of two ints")
|
||||||
|
self.init_tensor = initializer(embedding_table, [vocab_size, self.embedding_size[0],
|
||||||
|
self.embedding_size[1]])
|
||||||
|
self.embedding_out = (self.embedding_size[0], self.embedding_size[1],)
|
||||||
|
self.padding_idx = padding_idx
|
||||||
|
if padding_idx is not None:
|
||||||
|
self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH,
|
||||||
|
"padding_idx", self.cls_name)
|
||||||
|
if isinstance(self.init_tensor, Tensor) and self.init_tensor.init is not None:
|
||||||
|
self.init_tensor = self.init_tensor.init_data()
|
||||||
|
self.init_tensor = self.init_tensor.asnumpy()
|
||||||
|
self.init_tensor[self.padding_idx] = 0
|
||||||
|
self.init_tensor = Tensor(self.init_tensor)
|
||||||
|
self.embedding_table = Parameter(self.init_tensor, name='embedding_table')
|
||||||
|
self.expand = P.ExpandDims()
|
||||||
|
self.reshape_flat = P.Reshape()
|
||||||
|
self.shp_flat = (-1,)
|
||||||
|
self.gather = P.Gather()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.get_shp = P.Shape()
|
||||||
|
|
||||||
|
def construct(self, ids):
|
||||||
|
"""
|
||||||
|
Return the initialized three-dimension embedding vector
|
||||||
|
"""
|
||||||
|
extended_ids = self.expand(ids, -1)
|
||||||
|
out_shape = self.get_shp(ids) + self.embedding_out
|
||||||
|
flat_ids = self.reshape_flat(extended_ids, self.shp_flat)
|
||||||
|
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
|
||||||
|
output = self.reshape(output_for_reshape, out_shape)
|
||||||
|
return output
|
|
@ -0,0 +1,80 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""TB-Net metrics."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.metrics import roc_auc_score
|
||||||
|
from mindspore.nn.metrics import Metric
|
||||||
|
|
||||||
|
|
||||||
|
class AUC(Metric):
|
||||||
|
"""TB-Net metrics method. Compute model metrics AUC."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(AUC, self).__init__()
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""Clear the internal evaluation result."""
|
||||||
|
self.true_labels = []
|
||||||
|
self.pred_probs = []
|
||||||
|
|
||||||
|
def update(self, *inputs):
|
||||||
|
"""Update list of predictions and labels."""
|
||||||
|
all_predict = inputs[1].asnumpy().flatten().tolist()
|
||||||
|
all_label = inputs[2].asnumpy().flatten().tolist()
|
||||||
|
self.pred_probs.extend(all_predict)
|
||||||
|
self.true_labels.extend(all_label)
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
"""Return AUC score"""
|
||||||
|
if len(self.true_labels) != len(self.pred_probs):
|
||||||
|
raise RuntimeError(
|
||||||
|
'true_labels.size is not equal to pred_probs.size()')
|
||||||
|
|
||||||
|
auc = roc_auc_score(self.true_labels, self.pred_probs)
|
||||||
|
|
||||||
|
return auc
|
||||||
|
|
||||||
|
|
||||||
|
class ACC(Metric):
|
||||||
|
"""TB-Net metrics method. Compute model metrics ACC."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(ACC, self).__init__()
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""Clear the internal evaluation result."""
|
||||||
|
self.true_labels = []
|
||||||
|
self.pred_probs = []
|
||||||
|
|
||||||
|
def update(self, *inputs):
|
||||||
|
"""Update list of predictions and labels."""
|
||||||
|
all_predict = inputs[1].asnumpy().flatten().tolist()
|
||||||
|
all_label = inputs[2].asnumpy().flatten().tolist()
|
||||||
|
self.pred_probs.extend(all_predict)
|
||||||
|
self.true_labels.extend(all_label)
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
"""Return accuracy score"""
|
||||||
|
if len(self.true_labels) != len(self.pred_probs):
|
||||||
|
raise RuntimeError(
|
||||||
|
'true_labels.size is not equal to pred_probs.size()')
|
||||||
|
|
||||||
|
predictions = [1 if i >= 0.5 else 0 for i in self.pred_probs]
|
||||||
|
acc = np.mean(np.equal(predictions, self.true_labels))
|
||||||
|
|
||||||
|
return acc
|
|
@ -0,0 +1,61 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""STEAM dataset explainer."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from src.aggregator import Recommendation
|
||||||
|
|
||||||
|
|
||||||
|
class TextExplainer:
|
||||||
|
"""Text explainer for STEAM game recommendations."""
|
||||||
|
SAME_RELATION_TPL = 'User played the game <%s> before, which has the same %s '\
|
||||||
|
'<%s> as the recommend game.'
|
||||||
|
DIFF_RELATION_TPL = 'User played the game <%s> before, which has the %s <%s> '\
|
||||||
|
'while <%s> is the %s of the recommended game.'
|
||||||
|
|
||||||
|
def __init__(self, translate_path: str):
|
||||||
|
"""Construct from the translate json file."""
|
||||||
|
with open(translate_path) as file:
|
||||||
|
self._translator = json.load(file)
|
||||||
|
|
||||||
|
def explain(self, path: Recommendation.Path) -> str:
|
||||||
|
"""Explain the path."""
|
||||||
|
rel1_str = self.translate_relation(path.relation1)
|
||||||
|
entity_str = self.translate_entity(path.entity)
|
||||||
|
hist_item_str = self.translate_item(path.hist_item)
|
||||||
|
if path.relation1 == path.relation2:
|
||||||
|
return self.SAME_RELATION_TPL % (hist_item_str, rel1_str, entity_str)
|
||||||
|
rel2_str = self.translate_relation(path.relation2)
|
||||||
|
return self.DIFF_RELATION_TPL % (hist_item_str, rel2_str, entity_str, entity_str, rel1_str)
|
||||||
|
|
||||||
|
def translate_item(self, item: int) -> str:
|
||||||
|
"""Translate an item."""
|
||||||
|
return self._translate('item', item)
|
||||||
|
|
||||||
|
def translate_entity(self, entity: int) -> str:
|
||||||
|
"""Translate an entity."""
|
||||||
|
return self._translate('entity', entity)
|
||||||
|
|
||||||
|
def translate_relation(self, relation: int) -> str:
|
||||||
|
"""Translate a relation."""
|
||||||
|
return self._translate('relation', relation)
|
||||||
|
|
||||||
|
def _translate(self, obj_type, obj_id):
|
||||||
|
"""Translate an object."""
|
||||||
|
try:
|
||||||
|
return self._translator[obj_type][str(obj_id)]
|
||||||
|
except KeyError:
|
||||||
|
return f'[{obj_type}:{obj_id}]'
|
|
@ -0,0 +1,366 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""TB-Net Model."""
|
||||||
|
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore import ParameterTuple
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
|
||||||
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||||
|
|
||||||
|
from src.embedding import EmbeddingMatrix
|
||||||
|
|
||||||
|
|
||||||
|
class TBNet(nn.Cell):
|
||||||
|
"""
|
||||||
|
TB-Net model architecture.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_entity (int): number of entities, depends on dataset
|
||||||
|
num_relation (int): number of relations, depends on dataset
|
||||||
|
dim (int): dimension of entity and relation embedding vectors
|
||||||
|
kge_weight (float): weight of the KG Embedding loss term
|
||||||
|
node_weight (float): weight of the node loss term (default=0.002)
|
||||||
|
l2_weight (float): weight of the L2 regularization term (default=1e-7)
|
||||||
|
lr (float): learning rate of model training (default=1e-4)
|
||||||
|
batch_size (int): batch size (default=1024)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super(TBNet, self).__init__()
|
||||||
|
|
||||||
|
self._parse_config(config)
|
||||||
|
self.matmul = C.matmul
|
||||||
|
self.sigmoid = P.Sigmoid()
|
||||||
|
self.embedding_initializer = "normal"
|
||||||
|
|
||||||
|
self.entity_emb_matrix = EmbeddingMatrix(int(self.num_entity),
|
||||||
|
self.dim,
|
||||||
|
embedding_table=self.embedding_initializer)
|
||||||
|
self.relation_emb_matrix = EmbeddingMatrix(int(self.num_relation),
|
||||||
|
embedding_size=(self.dim, self.dim),
|
||||||
|
embedding_table=self.embedding_initializer)
|
||||||
|
|
||||||
|
self.expand_dims = P.ExpandDims()
|
||||||
|
self.squeeze = P.Squeeze(3)
|
||||||
|
self.abs = P.Abs()
|
||||||
|
self.reduce_sum = P.ReduceSum()
|
||||||
|
self.reduce_mean = P.ReduceMean()
|
||||||
|
|
||||||
|
self.transpose = P.Transpose()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.softmax = nn.Softmax()
|
||||||
|
|
||||||
|
def _parse_config(self, config):
|
||||||
|
"""Argument parsing."""
|
||||||
|
|
||||||
|
self.num_entity = config.num_entity
|
||||||
|
self.num_relation = config.num_relation
|
||||||
|
self.dim = config.embedding_dim
|
||||||
|
self.kge_weight = config.kge_weight
|
||||||
|
self.node_weight = config.node_weight
|
||||||
|
self.l2_weight = config.l2_weight
|
||||||
|
self.lr = config.lr
|
||||||
|
self.batch_size = config.batch_size
|
||||||
|
|
||||||
|
def construct(self, items, relation1, mid_entity, relation2, hist_item):
|
||||||
|
"""
|
||||||
|
TB-Net main computation process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items (Tensor): rated item IDs, int Tensor in shape of [batch size, ].
|
||||||
|
relation1 (Tensor): relation1 IDs, int Tensor in shape of [batch size, per_item_num_paths]
|
||||||
|
mid_entity (Tensor): middle entity IDs, int Tensor in shape of [batch size, per_item_num_paths]
|
||||||
|
relation2 (Tensor): relation2 IDs, int Tensor in shape of [batch size, per_item_num_paths]
|
||||||
|
hist_item (Tensor): historical item IDs, int Tensor in shape of [batch size, per_item_num_paths]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
scores (Tensor): model prediction score, float Tensor in shape of [batch size, ]
|
||||||
|
probs_exp (Tensor): path probability/importance, float Tensor in shape of [batch size, per_item_num_paths]
|
||||||
|
item_embeddings (Tensor): rated item embeddings, float Tensor in shape of [batch size, dim]
|
||||||
|
relation1_emb (Tensor): relation1 embeddings,
|
||||||
|
float Tensor in shape of [batch size, per_item_num_paths, dim, dim]
|
||||||
|
mid_entity_emb (Tensor): middle entity embeddings,
|
||||||
|
float Tensor in shape of [batch size, per_item_num_paths, dim]
|
||||||
|
relation2_emb (Tensor): relation2 embeddings,
|
||||||
|
float Tensor in shape of [batch size, per_item_num_paths, dim, dim]
|
||||||
|
hist_item_emb (Tensor): historical item embeddings,
|
||||||
|
float Tensor in shape of [batch size, per_item_num_paths, dim]
|
||||||
|
"""
|
||||||
|
item_embeddings = self.entity_emb_matrix(items)
|
||||||
|
|
||||||
|
relation1_emb = self.relation_emb_matrix(relation1)
|
||||||
|
mid_entity_emb = self.entity_emb_matrix(mid_entity)
|
||||||
|
relation2_emb = self.relation_emb_matrix(relation2)
|
||||||
|
hist_item_emb = self.entity_emb_matrix(hist_item)
|
||||||
|
|
||||||
|
response, probs_exp = self._key_pathing(item_embeddings,
|
||||||
|
relation1_emb,
|
||||||
|
mid_entity_emb,
|
||||||
|
relation2_emb,
|
||||||
|
hist_item_emb)
|
||||||
|
|
||||||
|
scores = P.Squeeze()(self._predict(item_embeddings, response))
|
||||||
|
|
||||||
|
return scores, probs_exp, item_embeddings, relation1_emb, mid_entity_emb, relation2_emb, hist_item_emb
|
||||||
|
|
||||||
|
def _key_pathing(self, item_embeddings, relation1_emb, mid_entity_emb, relation2_emb, hist_item_emb):
|
||||||
|
"""
|
||||||
|
Compute the response and path probability using item and entity embedding.
|
||||||
|
Path structure: (rated item, relation1, entity, relation2, historical item).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item_embeddings (Tensor): rated item embeddings, float Tensor in shape of [batch size, dim]
|
||||||
|
relation1_emb (Tensor): relation1 embeddings,
|
||||||
|
float Tensor in shape of [batch size, per_item_num_paths, dim, dim]
|
||||||
|
mid_entity_emb (Tensor): middle entity embeddings,
|
||||||
|
float Tensor in shape of [batch size, per_item_num_paths, dim]
|
||||||
|
relation2_emb (Tensor): relation2 embeddings,
|
||||||
|
float Tensor in shape of [batch size, per_item_num_paths, dim, dim]
|
||||||
|
hist_item_emb (Tensor): historical item embeddings,
|
||||||
|
float Tensor in shape of [batch size, per_item_num_paths, dim]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
response (Tensor): user's response towards middle entity, float Tensor in shape of [batch size, dim]
|
||||||
|
probs_exp (Tensor): path probability/importance, float Tensor in shape of [batch size, per_item_num_paths]
|
||||||
|
"""
|
||||||
|
|
||||||
|
hist_item_e_4d = self.expand_dims(hist_item_emb, 3)
|
||||||
|
mul_r2_hist = self.squeeze(self.matmul(relation2_emb, hist_item_e_4d))
|
||||||
|
# path_right shape: [batch size, per_item_num_paths, dim]
|
||||||
|
path_right = self.abs(mul_r2_hist + self.reduce_sum(relation2_emb, 2))
|
||||||
|
|
||||||
|
item_emb_3d = self.expand_dims(item_embeddings, 2)
|
||||||
|
mul_r1_item = self.squeeze(self.matmul(relation1_emb, self.expand_dims(item_emb_3d, 1)))
|
||||||
|
path_left = self.abs(mul_r1_item + self.reduce_sum(relation1_emb, 2))
|
||||||
|
# path_left shape: [batch size, dim, per_item_num_paths]
|
||||||
|
path_left = self.transpose(path_left, (0, 2, 1))
|
||||||
|
|
||||||
|
probs = self.reduce_sum(self.matmul(path_right, path_left), 2)
|
||||||
|
# probs_exp shape: [batch size, per_item_num_paths]
|
||||||
|
probs_exp = self.softmax(probs)
|
||||||
|
|
||||||
|
probs_3d = self.expand_dims(probs_exp, 2)
|
||||||
|
# response shape: [batch size, dim]
|
||||||
|
response = self.reduce_sum(mid_entity_emb * probs_3d, 1)
|
||||||
|
|
||||||
|
return response, probs_exp
|
||||||
|
|
||||||
|
def _predict(self, item_embeddings, response):
|
||||||
|
scores = self.reduce_sum(item_embeddings * response, 1)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class NetWithLossClass(nn.Cell):
|
||||||
|
"""NetWithLossClass definition."""
|
||||||
|
|
||||||
|
def __init__(self, network, config):
|
||||||
|
super(NetWithLossClass, self).__init__(auto_prefix=False)
|
||||||
|
self.network = network
|
||||||
|
self.loss = P.SigmoidCrossEntropyWithLogits()
|
||||||
|
self.matmul = C.matmul
|
||||||
|
self.expand_dims = P.ExpandDims()
|
||||||
|
self.squeeze = P.Squeeze(3)
|
||||||
|
self.abs = P.Abs()
|
||||||
|
self.maximum = P.Maximum()
|
||||||
|
self.reduce_sum = P.ReduceSum()
|
||||||
|
self.reduce_mean = P.ReduceMean()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.sigmoid = P.Sigmoid()
|
||||||
|
|
||||||
|
self.kge_weight = config.kge_weight
|
||||||
|
self.node_weight = config.node_weight
|
||||||
|
self.l2_weight = config.l2_weight
|
||||||
|
self.batch_size = config.batch_size
|
||||||
|
self.dim = config.embedding_dim
|
||||||
|
|
||||||
|
self.embedding_initializer = "normal"
|
||||||
|
|
||||||
|
def construct(self, items, relation1, mid_entity, relation2, hist_item, labels):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
items (Tensor): rated item IDs, int Tensor in shape of [batch size, ].
|
||||||
|
relation1 (Tensor): relation1 IDs, int Tensor in shape of [batch size, per_item_num_paths]
|
||||||
|
mid_entity (Tensor): middle entity IDs, int Tensor in shape of [batch size, per_item_num_paths]
|
||||||
|
relation2 (Tensor): relation2 IDs, int Tensor in shape of [batch size, per_item_num_paths]
|
||||||
|
hist_item (Tensor): historical item IDs, int Tensor in shape of [batch size, per_item_num_paths]
|
||||||
|
labels (Tensor): label of rated item record, int Tensor in shape of [batch size, ]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
loss (float): loss value
|
||||||
|
"""
|
||||||
|
scores, _, item_embeddings, relation1_emb, mid_entity_emb, relation2_emb, hist_item_emb = \
|
||||||
|
self.network(items, relation1, mid_entity, relation2, hist_item)
|
||||||
|
loss = self._loss_fun(item_embeddings, relation1_emb, mid_entity_emb,
|
||||||
|
relation2_emb, hist_item_emb, scores, labels)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def _loss_fun(self, item_embeddings, relation1_emb, mid_entity_emb, relation2_emb, hist_item_emb, scores, labels):
|
||||||
|
"""
|
||||||
|
Loss function definition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item_embeddings (Tensor): rated item embeddings, float Tensor in shape of [batch size, dim]
|
||||||
|
relation1_emb (Tensor): relation1 embeddings,
|
||||||
|
float Tensor in shape of [batch size, per_item_num_paths, dim, dim]
|
||||||
|
mid_entity_emb (Tensor): middle entity embeddings,
|
||||||
|
float Tensor in shape of [batch size, per_item_num_paths, dim]
|
||||||
|
relation2_emb (Tensor): relation2 embeddings,
|
||||||
|
float Tensor in shape of [batch size, per_item_num_paths, dim, dim]
|
||||||
|
hist_item_emb (Tensor): historical item embeddings,
|
||||||
|
float Tensor in shape of [batch size, per_item_num_paths, dim]
|
||||||
|
scores (Tensor): model prediction score, float Tensor in shape of [batch size, ]
|
||||||
|
labels (Tensor): label of rated item record, int Tensor in shape of [batch size, ]
|
||||||
|
Returns:
|
||||||
|
loss: includes four part:
|
||||||
|
pred_loss: cross entropy of the model prediction score and labels
|
||||||
|
transr_loss: TransR KG Embedding loss
|
||||||
|
node_loss: node matching loss
|
||||||
|
l2_loss: L2 regularization loss
|
||||||
|
"""
|
||||||
|
pred_loss = self.reduce_mean(self.loss(scores, labels))
|
||||||
|
|
||||||
|
item_emb_3d = self.expand_dims(item_embeddings, 2)
|
||||||
|
item_emb_4d = self.expand_dims(item_emb_3d, 1)
|
||||||
|
|
||||||
|
mul_r1_item = self.squeeze(self.matmul(relation1_emb, item_emb_4d))
|
||||||
|
|
||||||
|
hist_item_e_4d = self.expand_dims(hist_item_emb, 3)
|
||||||
|
mul_r2_hist = self.squeeze(self.matmul(relation2_emb, hist_item_e_4d))
|
||||||
|
|
||||||
|
relation1_3d = self.reduce_sum(relation1_emb, 2)
|
||||||
|
relation2_3d = self.reduce_sum(relation2_emb, 2)
|
||||||
|
|
||||||
|
path_left = self.reduce_sum(self.abs(mul_r1_item + relation1_3d), 2)
|
||||||
|
path_right = self.reduce_sum(self.abs(mul_r2_hist + relation2_3d), 2)
|
||||||
|
|
||||||
|
transr_loss = self.reduce_sum(self.maximum(self.abs(path_left - path_right), 0))
|
||||||
|
transr_loss = self.reduce_mean(self.sigmoid(transr_loss))
|
||||||
|
|
||||||
|
mid_entity_emb_4d = self.expand_dims(mid_entity_emb, 3)
|
||||||
|
mul_r2_mid = self.squeeze(self.matmul(relation2_emb, mid_entity_emb_4d))
|
||||||
|
path_r2_mid = self.abs(mul_r2_mid + relation2_3d)
|
||||||
|
|
||||||
|
node_loss = self.reduce_sum(self.maximum(mul_r2_hist - path_r2_mid, 0))
|
||||||
|
node_loss = self.reduce_mean(self.sigmoid(node_loss))
|
||||||
|
|
||||||
|
l2_loss = self.reduce_mean(self.reduce_sum(relation1_emb * relation1_emb))
|
||||||
|
l2_loss += self.reduce_mean(self.reduce_sum(mid_entity_emb * mid_entity_emb))
|
||||||
|
l2_loss += self.reduce_mean(self.reduce_sum(relation2_emb * relation2_emb))
|
||||||
|
l2_loss += self.reduce_mean(self.reduce_sum(hist_item_emb * hist_item_emb))
|
||||||
|
|
||||||
|
transr_loss = self.kge_weight * transr_loss
|
||||||
|
node_loss = self.node_weight * node_loss
|
||||||
|
|
||||||
|
l2_loss = self.l2_weight * l2_loss
|
||||||
|
|
||||||
|
loss = pred_loss + transr_loss + node_loss + l2_loss
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class TrainStepWrap(nn.Cell):
|
||||||
|
"""TrainStepWrap definition."""
|
||||||
|
|
||||||
|
def __init__(self, network, lr, sens=1):
|
||||||
|
super(TrainStepWrap, self).__init__(auto_prefix=False)
|
||||||
|
self.network = network
|
||||||
|
self.network.set_train()
|
||||||
|
self.network.add_flags(defer_inline=True)
|
||||||
|
self.weights = ParameterTuple(network.trainable_params())
|
||||||
|
self.lr = lr
|
||||||
|
self.optimizer = nn.Adam(self.weights,
|
||||||
|
learning_rate=self.lr,
|
||||||
|
beta1=0.9,
|
||||||
|
beta2=0.999,
|
||||||
|
eps=1e-8,
|
||||||
|
loss_scale=sens)
|
||||||
|
|
||||||
|
self.hyper_map = C.HyperMap()
|
||||||
|
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||||
|
self.sens = sens
|
||||||
|
|
||||||
|
self.reducer_flag = False
|
||||||
|
self.grad_reducer = None
|
||||||
|
parallel_mode = _get_parallel_mode()
|
||||||
|
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||||
|
self.reducer_flag = True
|
||||||
|
if self.reducer_flag:
|
||||||
|
mean = _get_gradients_mean()
|
||||||
|
degree = _get_device_num()
|
||||||
|
self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, mean, degree)
|
||||||
|
|
||||||
|
def construct(self, items, relation1, mid_entity, relation2, hist_item, labels):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
items (Tensor): rated item IDs, int Tensor in shape of [batch size, ].
|
||||||
|
relation1 (Tensor): relation1 IDs, int Tensor in shape of [batch size, per_item_num_paths]
|
||||||
|
mid_entity (Tensor): middle entity IDs, int Tensor in shape of [batch size, per_item_num_paths]
|
||||||
|
relation2 (Tensor): relation2 IDs, int Tensor in shape of [batch size, per_item_num_paths]
|
||||||
|
hist_item (Tensor): historical item IDs, int Tensor in shape of [batch size, per_item_num_paths]
|
||||||
|
labels (Tensor): label of rated item record, int Tensor in shape of [batch size, ]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
loss and gradient
|
||||||
|
"""
|
||||||
|
weights = self.weights
|
||||||
|
loss = self.network(items, relation1, mid_entity, relation2, hist_item, labels)
|
||||||
|
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||||
|
grads = self.grad(self.network, weights)(items, relation1, mid_entity, relation2, hist_item, labels, sens)
|
||||||
|
|
||||||
|
if self.reducer_flag:
|
||||||
|
# apply grad reducer on grads
|
||||||
|
grads = self.grad_reducer(grads)
|
||||||
|
|
||||||
|
return F.depend(loss, self.optimizer(grads))
|
||||||
|
|
||||||
|
|
||||||
|
class PredictWithSigmoid(nn.Cell):
|
||||||
|
"""Predict method."""
|
||||||
|
|
||||||
|
def __init__(self, network):
|
||||||
|
super(PredictWithSigmoid, self).__init__(auto_prefix=False)
|
||||||
|
self.network = network
|
||||||
|
self.sigmoid = P.Sigmoid()
|
||||||
|
|
||||||
|
def construct(self, items, relation1, mid_entity, relation2, hist_item, labels):
|
||||||
|
"""
|
||||||
|
Predict with sigmoid definition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items (Tensor): rated item IDs, int Tensor in shape of [batch size, ].
|
||||||
|
relation1 (Tensor): relation1 IDs, int Tensor in shape of [batch size, per_item_num_paths]
|
||||||
|
mid_entity (Tensor): middle entity IDs, int Tensor in shape of [batch size, per_item_num_paths]
|
||||||
|
relation2 (Tensor): relation2 IDs, int Tensor in shape of [batch size, per_item_num_paths]
|
||||||
|
hist_item (Tensor): historical item IDs, int Tensor in shape of [batch size, per_item_num_paths]
|
||||||
|
labels (Tensor): label of rated item record, int Tensor in shape of [batch size, ]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
scores (Tensor): model prediction score, float Tensor in shape of [batch size, ]
|
||||||
|
pred_probs (Tensor): prediction probability, float Tensor in shape of [batch size, ]
|
||||||
|
labels (Tensor): label of rated item record, int Tensor in shape of [batch size, ]
|
||||||
|
probs_exp (Tensor): path probability/importance, float Tensor in shape of [batch size, per_item_num_paths]
|
||||||
|
"""
|
||||||
|
|
||||||
|
scores, probs_exp, _, _, _, _, _ = self.network(items, relation1, mid_entity, relation2, hist_item)
|
||||||
|
pred_probs = self.sigmoid(scores)
|
||||||
|
|
||||||
|
return scores, pred_probs, labels, probs_exp
|
|
@ -0,0 +1,155 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""TB-Net training."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from mindspore import context, Model, Tensor
|
||||||
|
from mindspore.train.serialization import save_checkpoint
|
||||||
|
from mindspore.train.callback import Callback, TimeMonitor
|
||||||
|
|
||||||
|
from src import tbnet, config, metrics, dataset
|
||||||
|
|
||||||
|
|
||||||
|
class MyLossMonitor(Callback):
|
||||||
|
"""My loss monitor definition."""
|
||||||
|
|
||||||
|
def epoch_end(self, run_context):
|
||||||
|
"""Print loss at each epoch end."""
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
loss = cb_params.net_outputs
|
||||||
|
|
||||||
|
if isinstance(loss, (tuple, list)):
|
||||||
|
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
|
||||||
|
loss = loss[0]
|
||||||
|
|
||||||
|
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
|
||||||
|
loss = np.mean(loss.asnumpy())
|
||||||
|
print('loss:' + str(loss))
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
"""Parse commandline arguments."""
|
||||||
|
parser = argparse.ArgumentParser(description='Train TBNet.')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--dataset',
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default='steam',
|
||||||
|
help="'steam' dataset is supported currently"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--train_csv',
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default='train.csv',
|
||||||
|
help="the train csv datafile inside the dataset folder"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--test_csv',
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default='test.csv',
|
||||||
|
help="the test csv datafile inside the dataset folder"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--device_id',
|
||||||
|
type=int,
|
||||||
|
required=False,
|
||||||
|
default=0,
|
||||||
|
help="device id"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--epochs',
|
||||||
|
type=int,
|
||||||
|
required=False,
|
||||||
|
default=20,
|
||||||
|
help="number of training epochs"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--device_target',
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default='GPU',
|
||||||
|
choices=['GPU'],
|
||||||
|
help="run code on GPU"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--run_mode',
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default='graph',
|
||||||
|
choices=['graph', 'pynative'],
|
||||||
|
help="run code by GRAPH mode or PYNATIVE mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def train_tbnet():
|
||||||
|
"""Training process."""
|
||||||
|
args = get_args()
|
||||||
|
|
||||||
|
home = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
config_path = os.path.join(home, 'data', args.dataset, 'config.json')
|
||||||
|
train_csv_path = os.path.join(home, 'data', args.dataset, args.train_csv)
|
||||||
|
test_csv_path = os.path.join(home, 'data', args.dataset, args.test_csv)
|
||||||
|
ckpt_path = os.path.join(home, 'checkpoints')
|
||||||
|
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
if args.run_mode == 'graph':
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
else:
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
|
||||||
|
|
||||||
|
if not os.path.exists(ckpt_path):
|
||||||
|
os.makedirs(ckpt_path)
|
||||||
|
|
||||||
|
print(f"creating dataset from {train_csv_path}...")
|
||||||
|
net_config = config.TBNetConfig(config_path)
|
||||||
|
train_ds = dataset.create(train_csv_path, net_config.per_item_num_paths, train=True).batch(net_config.batch_size)
|
||||||
|
test_ds = dataset.create(test_csv_path, net_config.per_item_num_paths, train=True).batch(net_config.batch_size)
|
||||||
|
|
||||||
|
print("creating TBNet for training...")
|
||||||
|
network = tbnet.TBNet(net_config)
|
||||||
|
loss_net = tbnet.NetWithLossClass(network, net_config)
|
||||||
|
train_net = tbnet.TrainStepWrap(loss_net, net_config.lr)
|
||||||
|
train_net.set_train()
|
||||||
|
eval_net = tbnet.PredictWithSigmoid(network)
|
||||||
|
time_callback = TimeMonitor(data_size=train_ds.get_dataset_size())
|
||||||
|
loss_callback = MyLossMonitor()
|
||||||
|
model = Model(network=train_net, eval_network=eval_net, metrics={'auc': metrics.AUC(), 'acc': metrics.ACC()})
|
||||||
|
print("training...")
|
||||||
|
for i in range(args.epochs):
|
||||||
|
print(f'===================== Epoch {i} =====================')
|
||||||
|
model.train(epoch=1, train_dataset=train_ds, callbacks=[time_callback, loss_callback], dataset_sink_mode=False)
|
||||||
|
train_out = model.eval(train_ds, dataset_sink_mode=False)
|
||||||
|
test_out = model.eval(test_ds, dataset_sink_mode=False)
|
||||||
|
print(f'Train AUC:{train_out["auc"]} ACC:{train_out["acc"]} Test AUC:{test_out["auc"]} ACC:{test_out["acc"]}')
|
||||||
|
|
||||||
|
save_checkpoint(network, os.path.join(ckpt_path, f'tbnet_epoch{i}.ckpt'))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
train_tbnet()
|
Loading…
Reference in New Issue