add bgcf gpu

This commit is contained in:
牛黄解毒片 2021-01-27 22:18:16 +08:00 committed by panfengfeng
parent 54b8d53780
commit 1bff5f043d
7 changed files with 344 additions and 126 deletions

View File

@ -4,24 +4,26 @@
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Features](#features)
- [Mixed Precision](#mixed-precision)
- [Mixed Precision](#mixed-precision)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Performance](#performance)
- [Description of random situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
<!--TOC -->
# [Bayesian Graph Collaborative Filtering](#contents)
Bayesian Graph Collaborative Filtering(BGCF) was proposed in 2020 by Sun J, Guo W, Zhang D et al. By naturally incorporating the
Bayesian Graph Collaborative Filtering(BGCF) was proposed in 2020 by Sun J, Guo W, Zhang D et al. By naturally incorporating the
uncertainty in the user-item interaction graph shows excellent performance on Amazon recommendation dataset.This is an example of
training of BGCF with Amazon-Beauty dataset in MindSpore. More importantly, this is the first open source version for BGCF.
@ -33,36 +35,40 @@ Specially, BGCF contains two main modules. The first is sampling, which produce
aggregate the neighbors sampling from nodes consisting of mean aggregator and attention aggregator.
# [Dataset](#contents)
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
- Dataset size:
Statistics of dataset used are summarized as below:
| | Amazon-Beauty |
| ------------------ | -----------------------:|
| Task | Recommendation |
| # User | 7068 (1 graph) |
| # Item | 3570 |
| # Interaction | 79506 |
| # Training Data | 60818 |
| # Test Data | 18688 |
| # Density | 0.315% |
| | Amazon-Beauty |
| ------------------ | ----------------------|
| Task | Recommendation |
| # User | 7068 (1 graph) |
| # Item | 3570 |
| # Interaction | 79506 |
| # Training Data | 60818 |
| # Test Data | 18688 |
| # Density | 0.315% |
- Data Preparation
- Place the dataset to any path you want, the folder should include files as follows(we use Amazon-Beauty dataset as an example)"
```
- Place the dataset to any path you want, the folder should include files as follows(we use Amazon-Beauty dataset as an example)"
```python
.
└─data
├─ratings_Beauty.csv
```
- Generate dataset in mindrecord format for Amazon-Beauty.
- Generate dataset in mindrecord format for Amazon-Beauty.
```builddoutcfg
cd ./scripts
# SRC_PATH is the dataset file path you download.
sh run_process_data_ascend.sh [SRC_PATH]
```
# [Features](#contents)
## Mixed Precision
@ -71,12 +77,12 @@ To ultilize the strong computation power of Ascend chip, and accelerate the trai
# [Environment Requirements](#contents)
- Hardward (Ascend)
- Hardware (Ascend/GPU)
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below:
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Quick Start](#contents)
@ -84,26 +90,38 @@ After installing MindSpore via the official website and Dataset is correctly gen
- running on Ascend
```
```python
# run training example with Amazon-Beauty dataset
sh run_train_ascend.sh
# run evaluation example with Amazon-Beauty dataset
sh run_eval_ascend.sh
sh run_eval_ascend.sh
```
- running on GPU
```python
# run training example with Amazon-Beauty dataset
sh run_train_gpu.sh 0 dataset_path
# run evaluation example with Amazon-Beauty dataset
sh run_eval_gpu.sh 0 dataset_path
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```shell
.
└─bgcf
└─bgcf
├─README.md
├─scripts
| ├─run_eval_ascend.sh # Launch evaluation
├─scripts
| ├─run_eval_ascend.sh # Launch evaluation in ascend
| ├─run_eval_gpu.sh # Launch evaluation in gpu
| ├─run_process_data_ascend.sh # Generate dataset in mindrecord format
| └─run_train_ascend.sh # Launch training
| └─run_train_ascend.sh # Launch training in ascend
| └─run_train_gpu.sh # Launch training in gpu
|
├─src
| ├─bgcf.py # BGCF model
@ -122,31 +140,33 @@ After installing MindSpore via the official website and Dataset is correctly gen
Parameters for both training and evaluation can be set in config.py.
- config for BGCF dataset
```python
"learning_rate": 0.001, # Learning rate
"num_epoch": 600, # Epoch sizes for training
"num_neg": 10, # Negative sampling rate
"raw_neighs": 40, # Num of sampling neighbors in raw graph
"gnew_neighs": 20, # Num of sampling neighbors in sample graph
"input_dim": 64, # User and item embedding dimension
"l2": 0.03 # l2 coefficient
"neighbor_dropout": [0.0, 0.2, 0.3]# Dropout ratio for different aggregation layer
"learning_rate": 0.001, # Learning rate
"num_epoch": 600, # Epoch sizes for training
"num_neg": 10, # Negative sampling rate
"raw_neighs": 40, # Num of sampling neighbors in raw graph
"gnew_neighs": 20, # Num of sampling neighbors in sample graph
"input_dim": 64, # User and item embedding dimension
"l2": 0.03 # l2 coefficient
"neighbor_dropout": [0.0, 0.2, 0.3] # Dropout ratio for different aggregation layer
```
config.py for more configuration.
## [Training Process](#contents)
### Training
- running on Ascend
```python
sh run_train_ascend.sh
```python
sh run_train_ascend.sh
```
Training result will be stored in the scripts path, whose folder name begins with "train". You can find the result like the
Training result will be stored in the scripts path, whose folder name begins with "train". You can find the result like the
followings in log.
```python
Epoch 001 iter 12 loss 34696.242
Epoch 002 iter 12 loss 34275.508
@ -160,70 +180,105 @@ Parameters for both training and evaluation can be set in config.py.
Epoch 600 iter 12 loss 3632.4585
...
```
## [Evaluation Process](#contents)
### Evaluation
- running on GPU
- Evaluation on Ascend
```python
sh run_eval_ascend.sh
```python
sh run_train_gpu.sh 0 dataset_path
```
Evaluation result will be stored in the scripts path, whose folder name begins with "eval". You can find the result like the
Training result will be stored in the scripts path, whose folder name begins with "train". You can find the result like the
followings in log.
```python
epoch:020, recall_@10:0.07345, recall_@20:0.11193, ndcg_@10:0.05293, ndcg_@20:0.06613,
Epoch 001 iter 12 loss 34696.242
Epoch 002 iter 12 loss 34275.508
Epoch 003 iter 12 loss 30620.635
Epoch 004 iter 12 loss 21628.908
```
## [Evaluation Process](#contents)
### Evaluation
- Evaluation on Ascend
```python
sh run_eval_ascend.sh
```
Evaluation result will be stored in the scripts path, whose folder name begins with "eval". You can find the result like the
followings in log.
```python
epoch:020, recall_@10:0.07345, recall_@20:0.11193, ndcg_@10:0.05293, ndcg_@20:0.06613,
sedp_@10:0.01393, sedp_@20:0.01126, nov_@10:6.95106, nov_@20:7.22280
epoch:040, recall_@10:0.07410, recall_@20:0.11537, ndcg_@10:0.05387, ndcg_@20:0.06801,
epoch:040, recall_@10:0.07410, recall_@20:0.11537, ndcg_@10:0.05387, ndcg_@20:0.06801,
sedp_@10:0.01445, sedp_@20:0.01168, nov_@10:7.34799, nov_@20:7.58883
epoch:060, recall_@10:0.07654, recall_@20:0.11987, ndcg_@10:0.05530, ndcg_@20:0.07015,
epoch:060, recall_@10:0.07654, recall_@20:0.11987, ndcg_@10:0.05530, ndcg_@20:0.07015,
sedp_@10:0.01474, sedp_@20:0.01206, nov_@10:7.46553, nov_@20:7.69436
...
epoch:560, recall_@10:0.09825, recall_@20:0.14877, ndcg_@10:0.07176, ndcg_@20:0.08883,
epoch:560, recall_@10:0.09825, recall_@20:0.14877, ndcg_@10:0.07176, ndcg_@20:0.08883,
sedp_@10:0.01882, sedp_@20:0.01501, nov_@10:7.58045, nov_@20:7.79586
epoch:580, recall_@10:0.09917, recall_@20:0.14970, ndcg_@10:0.07337, ndcg_@20:0.09037,
epoch:580, recall_@10:0.09917, recall_@20:0.14970, ndcg_@10:0.07337, ndcg_@20:0.09037,
sedp_@10:0.01896, sedp_@20:0.01504, nov_@10:7.57995, nov_@20:7.79439
epoch:600, recall_@10:0.09926, recall_@20:0.15080, ndcg_@10:0.07283, ndcg_@20:0.09016,
epoch:600, recall_@10:0.09926, recall_@20:0.15080, ndcg_@10:0.07283, ndcg_@20:0.09016,
sedp_@10:0.01890, sedp_@20:0.01517, nov_@10:7.58277, nov_@20:7.80038
...
```
- Evaluation on GPU
```python
sh run_eval_gpu.sh 0 dataset_path
```
Evaluation result will be stored in the scripts path, whose folder name begins with "eval". You can find the result like the
followings in log.
```python
epoch:680, recall_@10:0.10383, recall_@20:0.15524, ndcg_@10:0.07503, ndcg_@20:0.09249,
sedp_@10:0.01926, sedp_@20:0.01547, nov_@10:7.60851, nov_@20:7.81969
```
# [Model Description](#contents)
## [Performance](#contents)
### Evaluation Performance
| Parameter | BGCF |
| ------------------------------------ | ----------------------------------------- |
| Model Version | Inception V1 |
| Resource | Ascend 910 |
| uploaded Date | 09/23/2020(month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | Amazon-Beauty |
| Training Parameter | epoch=600,steps=12,batch_size=5000,lr=0.001 |
| Optimizer | Adam |
| Loss Function | BPR loss |
| Training Cost | 25min |
| Scripts | [bgcf script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf) |
### Training Performance
| Parameter | BGCF Ascend | BGCF GPU |
| ------------------------------ | ------------------------------------------ | ------------------------------------------ |
| Model Version | Inception V1 | Inception V1 |
| Resource | Ascend 910 | Tesla V100-PCIE |
| uploaded Date | 09/23/2020(month/day/year) | 01/27/2021(month/day/year) |
| MindSpore Version | 1.0.0 | 1.1.0 |
| Dataset | Amazon-Beauty | Amazon-Beauty |
| Training Parameter | epoch=600,steps=12,batch_size=5000,lr=0.001| epoch=680,steps=12,batch_size=5000,lr=0.001|
| Optimizer | Adam | Adam |
| Loss Function | BPR loss | BPR loss |
| Training Cost | 25min | 60min |
| Scripts | [bgcf script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf) | [bgcf script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf) |
### Inference Performance
| Parameter | BGCF |
| ------------------------------------ | ----------------------------------------- |
| Model Version | Inception V1 |
| Resource | Ascend 910 |
| uploaded Date | 09/23/2020(month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | Amazon-Beauty |
| Batch_size | 5000 |
| Output | probability |
| Recall@20 | 0.1534 |
| NDCG@20 | 0.0912 |
| Parameter | BGCF Ascend | BGCF GPU |
| ------------------------------ | ---------------------------- | ---------------------------- |
| Model Version | Inception V1 | Inception V1 |
| Resource | Ascend 910 | Tesla V100-PCIE |
| uploaded Date | 09/23/2020(month/day/year) | 01/28/2021(month/day/year) |
| MindSpore Version | 1.0.0 | Master(4b3e53b4) |
| Dataset | Amazon-Beauty | Amazon-Beauty |
| Batch_size | 5000 | 5000 |
| Output | probability | probability |
| Recall@20 | 0.1534 | 0.15524 |
| NDCG@20 | 0.0912 | 0.09249 |
# [Description of random situation](#contents)
BGCF model contains lots of dropout operations, if you want to disable dropout, set the neighbor_dropout to [0.0, 0.0, 0.0] in src/config.py.
BGCF model contains lots of dropout operations, if you want to disable dropout, set the neighbor_dropout to [0.0, 0.0, 0.0] in src/config.py.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](http://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -40,39 +40,46 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
所用数据集的统计信息摘要如下:
| | Amazon-Beauty |
| ------------------ | -----------------------:|
| 任务 | 推荐 |
| # 用户 | 7068 (1图) |
| # 物品 | 3570 |
| # 交互 | 79506 |
| # 训练数据 | 60818 |
| # 测试数据 | 18688 |
| # 密度 | 0.315% |
| | Amazon-Beauty |
| ------------------ | ------------------ |
| 任务 | 推荐 |
| # 用户 | 7068 (1图) |
| # 物品 | 3570 |
| # 交互 | 79506 |
| # 训练数据 | 60818 |
| # 测试数据 | 18688 |
| # 密度 | 0.315% |
- 数据准备
- 将数据集放到任意路径文件夹应该包含如下文件以Amazon-Beauty数据集为例
```text
.
└─data
├─ratings_Beauty.csv
```
- 为Amazon-Beauty生成MindRecord格式的数据集
```builddoutcfg
cd ./scripts
# SRC_PATH是您下载的数据集文件路径
sh run_process_data_ascend.sh [SRC_PATH]
```
- 启动
```text
# 为Amazon-Beauty生成MindRecord格式的数据集
sh ./run_process_data_ascend.sh ./data
```
# 特性
## 混合精度
@ -81,7 +88,7 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
# 环境要求
- 硬件Ascend
- 硬件Ascend/GPU
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
@ -104,6 +111,18 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
```
- GPU处理器环境运行
```text
# 使用Amazon-Beauty数据集运行训练示例
sh run_train_gpu.sh 0 dataset_path
# 使用Amazon-Beauty数据集运行评估示例
sh run_eval_gpu.sh 0 dataset_path
```
# 脚本说明
## 脚本及样例代码
@ -113,9 +132,11 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
└─bgcf
├─README.md
├─scripts
| ├─run_eval_ascend.sh # 启动评估
| ├─run_eval_ascend.sh # Ascend启动评估
| ├─run_eval_gpu.sh # GPU启动评估
| ├─run_process_data_ascend.sh # 生成MindRecord格式的数据集
| └─run_train_ascend.sh # 启动训练
| └─run_train_ascend.sh # Ascend启动训练
| └─run_train_gpu.sh # GPU启动训练
|
├─src
| ├─bgcf.py # BGCF模型
@ -178,7 +199,25 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
Epoch 598 iter 12 loss 3640.7612
Epoch 599 iter 12 loss 3654.9087
Epoch 600 iter 12 loss 3632.4585
...
```
- GPU处理器环境运行
```python
sh run_train_gpu.sh 0 dataset_path
```
训练结果将保存在脚本路径下文件夹名称以“train”开头。您可在日志中找到结果如下所示。
```python
Epoch 001 iter 12 loss 34696.242
Epoch 002 iter 12 loss 34275.508
Epoch 003 iter 12 loss 30620.635
Epoch 004 iter 12 loss 21628.908
```
@ -212,7 +251,23 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
sedp_@10:0.01896, sedp_@20:0.01504, nov_@10:7.57995, nov_@20:7.79439
epoch:600, recall_@10:0.09926, recall_@20:0.15080, ndcg_@10:0.07283, ndcg_@20:0.09016,
sedp_@10:0.01890, sedp_@20:0.01517, nov_@10:7.58277, nov_@20:7.80038
...
```
- GPU评估
```python
sh run_eval_gpu.sh 0 dataset_path
```
评估结果将保存在脚本路径下文件夹名称以“eval”开头。您可在日志中找到结果如下所示。
```python
epoch:680, recall_@10:0.10383, recall_@20:0.15524, ndcg_@10:0.07503, ndcg_@20:0.09249,
sedp_@10:0.01926, sedp_@20:0.01547, nov_@10:7.60851, nov_@20:7.81969
```
@ -220,19 +275,19 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
## 性能
| 参数 | BGCF |
| ------------------------------------ | ----------------------------------------- |
| 资源 | Ascend 910 |
| 上传日期 | 09/23/2020(月/日/年) |
| MindSpore版本 | 1.0.0 |
| 数据集 | Amazon-Beauty |
| 训练参数 | epoch=600 |
| 优化器 | Adam |
| 损失函数 | BPR loss |
| Recall@20 | 0.1534 |
| NDCG@20 | 0.0912 |
| 训练成本 | 25min |
| 脚本 | [bgcf脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf) |
| 参数 | BGCF Ascend | BGCF GPU |
| -------------------------- | ------------------------------------------ | ------------------------------------------ |
| 资源 | Ascend 910 | Tesla V100-PCIE |
| 上传日期 | 09/23/2020(月/日/年) | 01/28/2021(月/日/年) |
| MindSpore版本 | 1.0.0 | Master(4b3e53b4) |
| 数据集 | Amazon-Beauty | Amazon-Beauty |
| 训练参数 | epoch=600,steps=12,batch_size=5000,lr=0.001| epoch=680,steps=12,batch_size=5000,lr=0.001|
| 优化器 | Adam | Adam |
| 损失函数 | BPR loss | BPR loss |
| Recall@20 | 0.1534 | 0.15524 |
| NDCG@20 | 0.0912 | 0.09249 |
| 训练成本 | 25min | 60min |
| 脚本 | [bgcf脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf) | [bgcf脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf) |
# 随机情况说明

View File

@ -19,6 +19,7 @@ import datetime
import mindspore.context as context
from mindspore.train.serialization import load_checkpoint
from mindspore.common import set_seed
from src.bgcf import BGCF
from src.utils import BGCFLogger
@ -27,6 +28,7 @@ from src.metrics import BGCFEvaluate
from src.callback import ForwardBGCF, TestBGCF
from src.dataset import TestGraphDataset, load_graph
set_seed(1)
def evaluation():
"""evaluation"""
@ -34,7 +36,8 @@ def evaluation():
num_item = train_graph.graph_info()["node_num"][1]
eval_class = BGCFEvaluate(parser, train_graph, test_graph, parser.Ks)
for _epoch in range(parser.eval_interval, parser.num_epoch+1, parser.eval_interval):
for _epoch in range(parser.eval_interval, parser.num_epoch+1, parser.eval_interval) \
if parser.device_target == "Ascend" else range(parser.num_epoch, parser.num_epoch+1):
bgcfnet_test = BGCF([parser.input_dim, num_user, num_item],
parser.embedded_dimension,
parser.activation,
@ -79,9 +82,10 @@ def evaluation():
if __name__ == "__main__":
parser = parser_args()
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
save_graphs=False,
device_id=int(parser.device))
device_target=parser.device_target,
save_graphs=False)
if parser.device_target == "Ascend":
context.set_context(device_id=int(parser.device))
train_graph, test_graph, sampled_graph_list = load_graph(parser.datapath)
test_graph_dataset = TestGraphDataset(train_graph, sampled_graph_list, num_samples=parser.raw_neighs,

View File

@ -0,0 +1,47 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
ulimit -u unlimited
if [ $# -lt 2 ]
then
echo "Usage: sh run_eval_gpu.sh [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]"
exit 1
fi
export DEVICE_NUM=1
DATASET_PATH=$2
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start evaluation"
export CUDA_VISIBLE_DEVICES="$1"
python eval.py --datapath=$DATASET_PATH --ckptpath=../ckpts \
--device_target='GPU' --num_epoch=680 \
--dist_reg=0 > log 2>&1 &
cd ..

View File

@ -0,0 +1,51 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 2 ]
then
echo "Usage: sh run_train_gpu.sh [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]"
exit 1
fi
export DEVICE_NUM=1
DATASET_PATH=$2
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
if [ -d "ckpts" ];
then
rm -rf ./ckpts
fi
mkdir ./ckpts
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
env > env.log
echo "start training"
export CUDA_VISIBLE_DEVICES="$1"
python train.py --datapath=$DATASET_PATH --ckptpath=../ckpts \
--device_target='GPU' --num_epoch=680 \
--dist_reg=0 > log 2>&1 &
cd ..

View File

@ -49,4 +49,5 @@ def parser_args():
parser.add_argument("-emb", "--embedded_dimension", type=int, default=64, help="output embedding dim")
parser.add_argument('--dist_reg', type=float, default=0.003, help="distance loss coefficient")
parser.add_argument('--device_target', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='device target')
return parser.parse_args()

View File

@ -21,6 +21,7 @@ from mindspore import Tensor
import mindspore.context as context
from mindspore.common import dtype as mstype
from mindspore.train.serialization import save_checkpoint
from mindspore.common import set_seed
from src.bgcf import BGCF
from src.config import parser_args
@ -28,6 +29,7 @@ from src.utils import convert_item_id
from src.callback import TrainBGCF
from src.dataset import load_graph, create_dataset
set_seed(1)
def train():
"""Train"""
@ -102,10 +104,13 @@ def train():
if __name__ == "__main__":
parser = parser_args()
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
save_graphs=False,
device_id=int(parser.device))
device_target=parser.device_target,
save_graphs=False)
if parser.device_target == "Ascend":
context.set_context(device_id=int(parser.device))
train_graph, _, sampled_graph_list = load_graph(parser.datapath)
train_ds = create_dataset(train_graph, sampled_graph_list, parser.workers, batch_size=parser.batch_pairs,