forked from mindspore-Ecosystem/mindspore
add eval script
This commit is contained in:
parent
daf6739b22
commit
36f9f59fd0
|
@ -12,6 +12,8 @@
|
|||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Description of random situation](#description-of-random-situation)
|
||||
|
@ -88,6 +90,9 @@ After installing MindSpore via the official website and Dataset is correctly gen
|
|||
```
|
||||
# run training example with Amazon-Beauty dataset
|
||||
sh run_train_ascend.sh
|
||||
|
||||
# run evaluation example with Amazon-Beauty dataset
|
||||
sh run_eval_ascend.sh
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
@ -99,6 +104,7 @@ After installing MindSpore via the official website and Dataset is correctly gen
|
|||
└─bgcf
|
||||
├─README.md
|
||||
├─scripts
|
||||
| ├─run_eval_ascend.sh # Launch evaluation
|
||||
| ├─run_process_data_ascend.sh # Generate dataset in mindrecord format
|
||||
| └─run_train_ascend.sh # Launch training
|
||||
|
|
||||
|
@ -110,6 +116,7 @@ After installing MindSpore via the official website and Dataset is correctly gen
|
|||
| ├─metrics.py # Recommendation metrics
|
||||
| └─utils.py # Utils for training bgcf
|
||||
|
|
||||
├─eval.py # Evaluation net
|
||||
└─train.py # Train net
|
||||
```
|
||||
|
||||
|
@ -118,7 +125,7 @@ After installing MindSpore via the official website and Dataset is correctly gen
|
|||
Parameters for both training and evaluation can be set in config.py.
|
||||
|
||||
- config for BGCF dataset
|
||||
|
||||
|
||||
```python
|
||||
"learning_rate": 0.001, # Learning rate
|
||||
"num_epochs": 600, # Epoch sizes for training
|
||||
|
@ -130,6 +137,7 @@ Parameters for both training and evaluation can be set in config.py.
|
|||
"neighbor_dropout": [0.0, 0.2, 0.3]# Dropout ratio for different aggregation layer
|
||||
"num_graphs":5 # Num of sample graph
|
||||
```
|
||||
config.py for more configuration.
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
|
@ -154,27 +162,54 @@ Parameters for both training and evaluation can be set in config.py.
|
|||
Epoch 598 iter 12 loss 3640.7612
|
||||
Epoch 599 iter 12 loss 3654.9087
|
||||
Epoch 600 iter 12 loss 3632.4585
|
||||
epoch:600, recall_@10:0.10393, recall_@20:0.15669, ndcg_@10:0.07564, ndcg_@20:0.09343,
|
||||
sedp_@10:0.01936, sedp_@20:0.01544, nov_@10:7.58599, nov_@20:7.79782
|
||||
...
|
||||
```
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Evaluation
|
||||
|
||||
- Evaluation on Ascend
|
||||
```python
|
||||
sh run_eval_ascend.sh
|
||||
```
|
||||
|
||||
Evaluation result will be stored in the scripts path, whose folder name begins with "eval". You can find the result like the
|
||||
followings in log.
|
||||
|
||||
```python
|
||||
epoch:020, recall_@10:0.07345, recall_@20:0.11193, ndcg_@10:0.05293, ndcg_@20:0.06613,
|
||||
sedp_@10:0.01393, sedp_@20:0.01126, nov_@10:6.95106, nov_@20:7.22280
|
||||
epoch:040, recall_@10:0.07410, recall_@20:0.11537, ndcg_@10:0.05387, ndcg_@20:0.06801,
|
||||
sedp_@10:0.01445, sedp_@20:0.01168, nov_@10:7.34799, nov_@20:7.58883
|
||||
epoch:060, recall_@10:0.07654, recall_@20:0.11987, ndcg_@10:0.05530, ndcg_@20:0.07015,
|
||||
sedp_@10:0.01474, sedp_@20:0.01206, nov_@10:7.46553, nov_@20:7.69436
|
||||
|
||||
...
|
||||
epoch:560, recall_@10:0.09825, recall_@20:0.14877, ndcg_@10:0.07176, ndcg_@20:0.08883,
|
||||
sedp_@10:0.01882, sedp_@20:0.01501, nov_@10:7.58045, nov_@20:7.79586
|
||||
epoch:580, recall_@10:0.09917, recall_@20:0.14970, ndcg_@10:0.07337, ndcg_@20:0.09037,
|
||||
sedp_@10:0.01896, sedp_@20:0.01504, nov_@10:7.57995, nov_@20:7.79439
|
||||
epoch:600, recall_@10:0.09926, recall_@20:0.15080, ndcg_@10:0.07283, ndcg_@20:0.09016,
|
||||
sedp_@10:0.01890, sedp_@20:0.01517, nov_@10:7.58277, nov_@20:7.80038
|
||||
...
|
||||
```
|
||||
# [Model Description](#contents)
|
||||
## [Performance](#contents)
|
||||
|
||||
| Parameter | BGCF |
|
||||
| ------------------------------------ | ----------------------------------------- |
|
||||
| Resource | Ascend 910 |
|
||||
| uploaded Date | 09/04/2020(month/day/year) |
|
||||
| MindSpore Version | 1.0 |
|
||||
| uploaded Date | |
|
||||
| MindSpore Version | |
|
||||
| Dataset | Amazon-Beauty |
|
||||
| Training Parameter | epoch=600 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | BPR loss |
|
||||
| Recall@20 | 0.1534 |
|
||||
| NDCG@20 | 0.0912 |
|
||||
| Total time | 30min |
|
||||
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf |
|
||||
| Training Cost | 25min |
|
||||
| Scripts | |
|
||||
|
||||
# [Description of random situation](#contents)
|
||||
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
BGCF evaluation script.
|
||||
"""
|
||||
import os
|
||||
import datetime
|
||||
|
||||
import mindspore.context as context
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
|
||||
from src.bgcf import BGCF
|
||||
from src.utils import BGCFLogger
|
||||
from src.config import parser_args
|
||||
from src.metrics import BGCFEvaluate
|
||||
from src.callback import ForwardBGCF, TestBGCF
|
||||
from src.dataset import TestGraphDataset, load_graph
|
||||
|
||||
|
||||
def evaluation():
|
||||
"""evaluation"""
|
||||
num_user = train_graph.graph_info()["node_num"][0]
|
||||
num_item = train_graph.graph_info()["node_num"][1]
|
||||
|
||||
eval_class = BGCFEvaluate(parser, train_graph, test_graph, parser.Ks)
|
||||
for _epoch in range(parser.eval_interval, parser.num_epoch+1, parser.eval_interval):
|
||||
bgcfnet_test = BGCF([parser.input_dim, num_user, num_item],
|
||||
parser.embedded_dimension,
|
||||
parser.activation,
|
||||
[0.0, 0.0, 0.0],
|
||||
num_user,
|
||||
num_item,
|
||||
parser.input_dim)
|
||||
|
||||
load_checkpoint(parser.ckptpath + "/bgcf_epoch{}.ckpt".format(_epoch), net=bgcfnet_test)
|
||||
|
||||
forward_net = ForwardBGCF(bgcfnet_test)
|
||||
user_reps, item_reps = TestBGCF(forward_net, num_user, num_item, parser.input_dim, test_graph_dataset)
|
||||
|
||||
test_recall_bgcf, test_ndcg_bgcf, \
|
||||
test_sedp, test_nov = eval_class.eval_with_rep(user_reps, item_reps, parser)
|
||||
|
||||
if parser.log_name:
|
||||
log.write(
|
||||
'epoch:%03d, recall_@10:%.5f, recall_@20:%.5f, ndcg_@10:%.5f, ndcg_@20:%.5f, '
|
||||
'sedp_@10:%.5f, sedp_@20:%.5f, nov_@10:%.5f, nov_@20:%.5f\n' % (_epoch,
|
||||
test_recall_bgcf[1],
|
||||
test_recall_bgcf[2],
|
||||
test_ndcg_bgcf[1],
|
||||
test_ndcg_bgcf[2],
|
||||
test_sedp[0],
|
||||
test_sedp[1],
|
||||
test_nov[1],
|
||||
test_nov[2]))
|
||||
else:
|
||||
print('epoch:%03d, recall_@10:%.5f, recall_@20:%.5f, ndcg_@10:%.5f, ndcg_@20:%.5f, '
|
||||
'sedp_@10:%.5f, sedp_@20:%.5f, nov_@10:%.5f, nov_@20:%.5f\n' % (_epoch,
|
||||
test_recall_bgcf[1],
|
||||
test_recall_bgcf[2],
|
||||
test_ndcg_bgcf[1],
|
||||
test_ndcg_bgcf[2],
|
||||
test_sedp[0],
|
||||
test_sedp[1],
|
||||
test_nov[1],
|
||||
test_nov[2]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend",
|
||||
save_graphs=False)
|
||||
|
||||
parser = parser_args()
|
||||
os.environ['DEVICE_ID'] = parser.device
|
||||
|
||||
train_graph, test_graph, sampled_graph_list = load_graph(parser.datapath)
|
||||
test_graph_dataset = TestGraphDataset(train_graph, sampled_graph_list, num_samples=parser.raw_neighs,
|
||||
num_bgcn_neigh=parser.gnew_neighs,
|
||||
num_neg=parser.num_neg)
|
||||
|
||||
if parser.log_name:
|
||||
now = datetime.datetime.now().strftime("%b_%d_%H_%M_%S")
|
||||
name = "bgcf" + '-' + parser.log_name + '-' + parser.dataset
|
||||
log_save_path = './log-files/' + name + '/' + now
|
||||
log = BGCFLogger(logname=name, now=now, foldername='log-files', copy=False)
|
||||
log.open(log_save_path + '/log.train.txt', mode='a')
|
||||
for arg in vars(parser):
|
||||
log.write(arg + '=' + str(getattr(parser, arg)) + '\n')
|
||||
else:
|
||||
for arg in vars(parser):
|
||||
print(arg + '=' + str(getattr(parser, arg)))
|
||||
|
||||
evaluation()
|
|
@ -0,0 +1,38 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env > env.log
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
|
||||
python eval.py --datapath=../data_mr --ckptpath=../ckpts &> log &
|
||||
|
||||
cd ..
|
|
@ -25,14 +25,20 @@ then
|
|||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
|
||||
if [ -d "ckpts" ];
|
||||
then
|
||||
rm -rf ./ckpts
|
||||
fi
|
||||
mkdir ./ckpts
|
||||
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
mkdir ./ckpts
|
||||
env > env.log
|
||||
echo "start training for device $DEVICE_ID"
|
||||
|
||||
python train.py --datapath=../data_mr &> log &
|
||||
python train.py --datapath=../data_mr --ckptpath=../ckpts &> log &
|
||||
|
||||
cd ..
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py
|
||||
network config setting
|
||||
"""
|
||||
import argparse
|
||||
|
||||
|
@ -21,37 +21,38 @@ import argparse
|
|||
def parser_args():
|
||||
"""Config for BGCF"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-d", "--dataset", type=str, default="Beauty")
|
||||
parser.add_argument("-dpath", "--datapath", type=str, default="./scripts/data_mr")
|
||||
parser.add_argument("-de", "--device", type=str, default='0')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--Ks', type=list, default=[5, 10, 20, 100])
|
||||
parser.add_argument('--test_ratio', type=float, default=0.2)
|
||||
parser.add_argument('--val_ratio', type=float, default=None)
|
||||
parser.add_argument('-w', '--workers', type=int, default=10)
|
||||
parser.add_argument("-d", "--dataset", type=str, default="Beauty", help="choose which dataset")
|
||||
parser.add_argument("-dpath", "--datapath", type=str, default="./scripts/data_mr", help="minddata path")
|
||||
parser.add_argument("-de", "--device", type=str, default='0', help="device id")
|
||||
parser.add_argument('--seed', type=int, default=0, help="random seed")
|
||||
parser.add_argument('--Ks', type=list, default=[5, 10, 20, 100], help="top K")
|
||||
parser.add_argument('--test_ratio', type=float, default=0.2, help="test ratio")
|
||||
parser.add_argument('-w', '--workers', type=int, default=8, help="number of process")
|
||||
parser.add_argument("-ckpt", "--ckptpath", type=str, default="./ckpts", help="checkpoint path")
|
||||
|
||||
parser.add_argument("-eps", "--epsilon", type=float, default=1e-8)
|
||||
parser.add_argument("-lr", "--learning_rate", type=float, default=1e-3)
|
||||
parser.add_argument("-l2", "--l2", type=float, default=0.03)
|
||||
parser.add_argument("-wd", "--weight_decay", type=float, default=0.01)
|
||||
parser.add_argument("-act", "--activation", type=str, default='tanh', choices=['relu', 'tanh'])
|
||||
parser.add_argument("-ndrop", "--neighbor_dropout", type=list, default=[0.0, 0.2, 0.3])
|
||||
parser.add_argument("-log", "--log_name", type=str, default='test')
|
||||
parser.add_argument("-eps", "--epsilon", type=float, default=1e-8, help="optimizer parameter")
|
||||
parser.add_argument("-lr", "--learning_rate", type=float, default=1e-3, help="learning rate")
|
||||
parser.add_argument("-l2", "--l2", type=float, default=0.03, help="l2 coefficient")
|
||||
parser.add_argument("-wd", "--weight_decay", type=float, default=0.01, help="weight decay")
|
||||
parser.add_argument("-act", "--activation", type=str, default='tanh', choices=['relu', 'tanh'],
|
||||
help="activation function")
|
||||
parser.add_argument("-ndrop", "--neighbor_dropout", type=list, default=[0.0, 0.2, 0.3],
|
||||
help="dropout ratio for different aggregation layer")
|
||||
parser.add_argument("-log", "--log_name", type=str, default='test', help="log name")
|
||||
|
||||
parser.add_argument("-e", "--num_epoch", type=int, default=600)
|
||||
parser.add_argument('-input', '--input_dim', type=int, default=64, choices=[64, 128])
|
||||
parser.add_argument("-b", "--batch_pairs", type=int, default=5000)
|
||||
parser.add_argument('--eval_interval', type=int, default=20)
|
||||
parser.add_argument("-e", "--num_epoch", type=int, default=600, help="epoch sizes for training")
|
||||
parser.add_argument('-input', '--input_dim', type=int, default=64, choices=[64, 128],
|
||||
help="user and item embedding dimension")
|
||||
parser.add_argument("-b", "--batch_pairs", type=int, default=5000, help="batch size")
|
||||
parser.add_argument('--eval_interval', type=int, default=20, help="evaluation interval")
|
||||
|
||||
parser.add_argument("-neg", "--num_neg", type=int, default=10)
|
||||
parser.add_argument('-max', '--max_degree', type=str, default='[128,128]')
|
||||
parser.add_argument("-g1", "--raw_neighs", type=int, default=40)
|
||||
parser.add_argument("-g2", "--gnew_neighs", type=int, default=20)
|
||||
parser.add_argument("-emb", "--embedded_dimension", type=int, default=64)
|
||||
parser.add_argument('-dist', '--distance', type=str, default='iou')
|
||||
parser.add_argument('--dist_reg', type=float, default=0.003)
|
||||
parser.add_argument("-neg", "--num_neg", type=int, default=10, help="negative sampling rate ")
|
||||
parser.add_argument("-g1", "--raw_neighs", type=int, default=40, help="num of sampling neighbors in raw graph")
|
||||
parser.add_argument("-g2", "--gnew_neighs", type=int, default=20, help="num of sampling neighbors in sample graph")
|
||||
parser.add_argument("-emb", "--embedded_dimension", type=int, default=64, help="output embedding dim")
|
||||
parser.add_argument('--dist_reg', type=float, default=0.003, help="distance loss coefficient")
|
||||
|
||||
parser.add_argument('-ng', '--num_graphs', type=int, default=5)
|
||||
parser.add_argument('-geps', '--graph_epsilon', type=float, default=0.01)
|
||||
parser.add_argument('-ng', '--num_graphs', type=int, default=5, help="num of sample graph")
|
||||
parser.add_argument('-geps', '--graph_epsilon', type=float, default=0.01, help="node copy parameter")
|
||||
|
||||
return parser.parse_args()
|
||||
|
|
|
@ -175,8 +175,8 @@ def load_graph(data_path):
|
|||
return train_graph, test_graph, sampled_graph_list
|
||||
|
||||
|
||||
def create_dataset(train_graph, sampled_graph_list, batch_size=32, repeat_size=1, num_samples=40, num_bgcn_neigh=20,
|
||||
num_neg=10):
|
||||
def create_dataset(train_graph, sampled_graph_list, num_workers, batch_size=32, repeat_size=1,
|
||||
num_samples=40, num_bgcn_neigh=20, num_neg=10):
|
||||
"""Data generator for training"""
|
||||
edge_num = train_graph.graph_info()['edge_num'][0]
|
||||
out_column_names = ["users", "items", "neg_item_id", "pos_users", "pos_items", "u_group_nodes", "u_neighs",
|
||||
|
@ -185,7 +185,7 @@ def create_dataset(train_graph, sampled_graph_list, batch_size=32, repeat_size=1
|
|||
train_graph_dataset = TrainGraphDataset(
|
||||
train_graph, sampled_graph_list, batch_size, num_samples, num_bgcn_neigh, num_neg)
|
||||
dataset = ds.GeneratorDataset(source=train_graph_dataset, column_names=out_column_names,
|
||||
sampler=RandomBatchedSampler(edge_num, batch_size), num_parallel_workers=8)
|
||||
sampler=RandomBatchedSampler(edge_num, batch_size), num_parallel_workers=num_workers)
|
||||
dataset = dataset.repeat(repeat_size)
|
||||
|
||||
return dataset
|
||||
|
|
|
@ -17,23 +17,21 @@ BGCF training script.
|
|||
"""
|
||||
import os
|
||||
import time
|
||||
import datetime
|
||||
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
|
||||
from src.bgcf import BGCF
|
||||
from src.metrics import BGCFEvaluate
|
||||
from src.config import parser_args
|
||||
from src.utils import BGCFLogger, convert_item_id
|
||||
from src.callback import ForwardBGCF, TrainBGCF, TestBGCF
|
||||
from src.dataset import load_graph, create_dataset, TestGraphDataset
|
||||
from src.utils import convert_item_id
|
||||
from src.callback import TrainBGCF
|
||||
from src.dataset import load_graph, create_dataset
|
||||
|
||||
|
||||
def train_and_eval():
|
||||
"""Train and eval"""
|
||||
def train():
|
||||
"""Train"""
|
||||
num_user = train_graph.graph_info()["node_num"][0]
|
||||
num_item = train_graph.graph_info()["node_num"][1]
|
||||
num_pairs = train_graph.graph_info()['edge_num'][0]
|
||||
|
@ -50,8 +48,6 @@ def train_and_eval():
|
|||
parser.epsilon, parser.dist_reg)
|
||||
train_net.set_train(True)
|
||||
|
||||
eval_class = BGCFEvaluate(parser, train_graph, test_graph, parser.Ks)
|
||||
|
||||
itr = train_ds.create_dict_iterator(parser.num_epoch, output_numpy=True)
|
||||
num_iter = int(num_pairs / parser.batch_pairs)
|
||||
|
||||
|
@ -102,49 +98,7 @@ def train_and_eval():
|
|||
iter_num += 1
|
||||
|
||||
if _epoch % parser.eval_interval == 0:
|
||||
if os.path.exists("ckpts/bgcf.ckpt"):
|
||||
os.remove("ckpts/bgcf.ckpt")
|
||||
save_checkpoint(bgcfnet, "ckpts/bgcf.ckpt")
|
||||
|
||||
bgcfnet_test = BGCF([parser.input_dim, num_user, num_item],
|
||||
parser.embedded_dimension,
|
||||
parser.activation,
|
||||
[0.0, 0.0, 0.0],
|
||||
num_user,
|
||||
num_item,
|
||||
parser.input_dim)
|
||||
|
||||
load_checkpoint("ckpts/bgcf.ckpt", net=bgcfnet_test)
|
||||
|
||||
forward_net = ForwardBGCF(bgcfnet_test)
|
||||
user_reps, item_reps = TestBGCF(forward_net, num_user, num_item, parser.input_dim, test_graph_dataset)
|
||||
|
||||
test_recall_bgcf, test_ndcg_bgcf, \
|
||||
test_sedp, test_nov = eval_class.eval_with_rep(user_reps, item_reps, parser)
|
||||
|
||||
if parser.log_name:
|
||||
log.write(
|
||||
'epoch:%03d, recall_@10:%.5f, recall_@20:%.5f, ndcg_@10:%.5f, ndcg_@20:%.5f, '
|
||||
'sedp_@10:%.5f, sedp_@20:%.5f, nov_@10:%.5f, nov_@20:%.5f\n' % (_epoch,
|
||||
test_recall_bgcf[1],
|
||||
test_recall_bgcf[2],
|
||||
test_ndcg_bgcf[1],
|
||||
test_ndcg_bgcf[2],
|
||||
test_sedp[0],
|
||||
test_sedp[1],
|
||||
test_nov[1],
|
||||
test_nov[2]))
|
||||
else:
|
||||
print('epoch:%03d, recall_@10:%.5f, recall_@20:%.5f, ndcg_@10:%.5f, ndcg_@20:%.5f, '
|
||||
'sedp_@10:%.5f, sedp_@20:%.5f, nov_@10:%.5f, nov_@20:%.5f\n' % (_epoch,
|
||||
test_recall_bgcf[1],
|
||||
test_recall_bgcf[2],
|
||||
test_ndcg_bgcf[1],
|
||||
test_ndcg_bgcf[2],
|
||||
test_sedp[0],
|
||||
test_sedp[1],
|
||||
test_nov[1],
|
||||
test_nov[2]))
|
||||
save_checkpoint(bgcfnet, parser.ckptpath + "/bgcf_epoch{}.ckpt".format(_epoch))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -153,23 +107,9 @@ if __name__ == "__main__":
|
|||
save_graphs=False)
|
||||
|
||||
parser = parser_args()
|
||||
os.environ['DEVICE_ID'] = parser.device
|
||||
train_graph, _, sampled_graph_list = load_graph(parser.datapath)
|
||||
train_ds = create_dataset(train_graph, sampled_graph_list, parser.workers, batch_size=parser.batch_pairs,
|
||||
num_samples=parser.raw_neighs, num_bgcn_neigh=parser.gnew_neighs, num_neg=parser.num_neg)
|
||||
|
||||
train_graph, test_graph, sampled_graph_list = load_graph(parser.datapath)
|
||||
train_ds = create_dataset(train_graph, sampled_graph_list, batch_size=parser.batch_pairs)
|
||||
test_graph_dataset = TestGraphDataset(train_graph, sampled_graph_list, num_samples=parser.raw_neighs,
|
||||
num_bgcn_neigh=parser.gnew_neighs,
|
||||
num_neg=parser.num_neg)
|
||||
|
||||
if parser.log_name:
|
||||
now = datetime.datetime.now().strftime("%b_%d_%H_%M_%S")
|
||||
name = "bgcf" + '-' + parser.log_name + '-' + parser.dataset
|
||||
log_save_path = './log-files/' + name + '/' + now
|
||||
log = BGCFLogger(logname=name, now=now, foldername='log-files', copy=False)
|
||||
log.open(log_save_path + '/log.train.txt', mode='a')
|
||||
for arg in vars(parser):
|
||||
log.write(arg + '=' + str(getattr(parser, arg)) + '\n')
|
||||
else:
|
||||
for arg in vars(parser):
|
||||
print(arg + '=' + str(getattr(parser, arg)))
|
||||
|
||||
train_and_eval()
|
||||
train()
|
||||
|
|
|
@ -258,7 +258,7 @@ def trans(src_path, data_name, out_path):
|
|||
|
||||
test_ratio = 0.2
|
||||
train_set, test_set = split_data_randomly(
|
||||
inner_data_records, test_ratio=test_ratio, seed=0)
|
||||
inner_data_records, test_ratio, seed=0)
|
||||
train_matrix = generate_rating_matrix(
|
||||
train_set, len(user_mapping), len(item_mapping))
|
||||
u_adj_list, v_adj_list = create_adj_matrix(train_matrix)
|
||||
|
@ -329,7 +329,7 @@ def trans(src_path, data_name, out_path):
|
|||
for i in range(num_graphs):
|
||||
print('=== info: sampling graph {} / {}'.format(i + 1, num_graphs))
|
||||
sampled_user_graph = sample_graph_copying(node_neighbors_dict=u_adj_list,
|
||||
distances=user_distances)
|
||||
distances=user_distances, epsilon=0.01)
|
||||
|
||||
print('avg. sampled user-item graph degree: ',
|
||||
np.mean([len(x) for x in [*sampled_user_graph.values()]]))
|
||||
|
|
Loading…
Reference in New Issue