forked from mindspore-Ecosystem/mindspore
!16635 Support GCN on GPU
Merge pull request !16635 from yuruilee/master
This commit is contained in:
commit
3a9d3f2580
|
@ -75,16 +75,16 @@ Note that you can run the scripts based on the dataset mentioned in original pap
|
|||
```buildoutcfg
|
||||
cd ./scripts
|
||||
# SRC_PATH is the dataset file path you downloaded, DATASET_NAME is cora or citeseer
|
||||
sh run_process_data.sh [SRC_PATH] [DATASET_NAME]
|
||||
bash run_process_data.sh [SRC_PATH] [DATASET_NAME]
|
||||
```
|
||||
|
||||
### Launch
|
||||
|
||||
```bash
|
||||
#Generate dataset in mindrecord format for cora
|
||||
sh run_process_data.sh ./data cora
|
||||
bash run_process_data.sh ./data cora
|
||||
#Generate dataset in mindrecord format for citeseer
|
||||
sh run_process_data.sh ./data citeseer
|
||||
bash run_process_data.sh ./data citeseer
|
||||
```
|
||||
|
||||
- Running on local with Ascend
|
||||
|
@ -149,12 +149,15 @@ sh run_train.sh [DATASET_NAME]
|
|||
├─scripts
|
||||
| ├─run_infer_310.sh # shell script for infer on Ascend 310
|
||||
| ├─run_process_data.sh # Generate dataset in mindrecord format
|
||||
| └─run_train.sh # Launch training, now only Ascend backend is supported.
|
||||
| ├─run_train_gpu.sh # Launch GPU training.
|
||||
| ├─run_eval_gpu.sh # Launch GPU inference.
|
||||
| └─run_train.sh # Launch Ascend training.
|
||||
|
|
||||
├─src
|
||||
| ├─config.py # Parameter configuration
|
||||
| ├─dataset.py # Data preprocessin
|
||||
| ├─gcn.py # GCN backbone
|
||||
| ├─eval_callback.py # Callback function
|
||||
| └─metrics.py # Loss and accuracy
|
||||
|
|
||||
├─default_config.py # Configurations
|
||||
|
@ -162,7 +165,8 @@ sh run_train.sh [DATASET_NAME]
|
|||
├─mindspore_hub_conf.py # mindspore_hub_conf scripts
|
||||
├─postprocess.py # postprocess script
|
||||
├─preprocess.py # preprocess scripts
|
||||
└─train.py # Train net, evaluation is performed after every training epoch. After the verification result converges, the training stops, then testing is performed.
|
||||
|─eval.py # Evaluation net, testing is performed.
|
||||
└─train.py # Train net, evaluation is performed after every training epoch. After the verification result converges, the training stops.
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
@ -176,6 +180,14 @@ Parameters for training can be set in config.py.
|
|||
"dropout": 0.5, # Dropout ratio for the first graph convolution layer
|
||||
"weight_decay": 5e-4, # Weight decay for the parameter of the first graph convolution layer
|
||||
"early_stopping": 10, # Tolerance for early stopping
|
||||
"save_ckpt_steps": 549 # Step to save checkpoint
|
||||
"keep_ckpt_max": 10 # Maximum step to save checkpoint
|
||||
"ckpt_dir": './ckpt' # The folder to save checkpoint
|
||||
"best_ckpt_dir": './best_ckpt' # The folder to save the best checkpoint
|
||||
"best_ckpt_name": 'best.ckpt' # The file name of the best checkpoint
|
||||
"eval_start_epoch": 100 # Start step for eval
|
||||
"save_best_ckpt": True # Save the best checkpoint or not
|
||||
"eval_interval": 1 # The interval of eval
|
||||
```
|
||||
|
||||
### [Training, Evaluation, Test Process](#contents)
|
||||
|
@ -183,14 +195,22 @@ Parameters for training can be set in config.py.
|
|||
#### Usage
|
||||
|
||||
```bash
|
||||
# run train with cora or citeseer dataset, DATASET_NAME is cora or citeseer
|
||||
sh run_train.sh [DATASET_NAME]
|
||||
# run train with cora or citeseer dataset on Ascend, DATASET_NAME is cora or citeseer
|
||||
bash run_train.sh [DATASET_NAME]
|
||||
|
||||
# run train with cora or citeseer dataset on GPU, DATASET_NAME is cora or citeseer
|
||||
bash run_train_gpu.sh [DATASET_NAME]
|
||||
|
||||
# run inference with cora or citeseer dataset on GPU, DATASET_NAME is cora or citeseer
|
||||
bash run_eval_gpu.sh [DATASET_NAME] [CKPT]
|
||||
```
|
||||
|
||||
#### Launch
|
||||
|
||||
```bash
|
||||
sh run_train.sh cora
|
||||
bash run_train.sh cora
|
||||
bash run_train_gpu.sh cora
|
||||
bash run_eval_gpu.sh cora ckpt/gcn.ckpt
|
||||
```
|
||||
|
||||
#### Result
|
||||
|
@ -250,18 +270,18 @@ Test set results: accuracy= 0.81300
|
|||
|
||||
### [Performance](#contents)
|
||||
|
||||
| Parameters | GCN |
|
||||
| -------------------------- | -------------------------------------------------------------- |
|
||||
| Resource | Ascend 910; OS Euler2.8 |
|
||||
| uploaded Date | 06/09/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | Cora/Citeseer |
|
||||
| Training Parameters | epoch=200 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | Softmax Cross Entropy |
|
||||
| Accuracy | 81.5/70.3 |
|
||||
| Parameters (B) | 92160/59344 |
|
||||
| Scripts | [GCN Script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn) |
|
||||
| Parameters | GCN | GCN |
|
||||
| ------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||
| Resource | Ascend 910; OS Euler2.8 | NV SMX3 V100-32G |
|
||||
| uploaded Date | 06/09/2020 (month/day/year) | 05/06/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 | 1.1.0 |
|
||||
| Dataset | Cora/Citeseer | Cora/Citeseer |
|
||||
| Training Parameters | epoch=200 | epoch=200 |
|
||||
| Optimizer | Adam | Adam |
|
||||
| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy |
|
||||
| Accuracy | 81.5/70.3 | 87.5/76.9 |
|
||||
| Parameters (B) | 92160/59344 | 92160/59344 |
|
||||
| Scripts | [GCN Script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn) | [GCN Script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn) |
|
||||
|
||||
## [Description of Random Situation](#contents)
|
||||
|
||||
|
|
|
@ -157,12 +157,15 @@ sh run_train.sh [DATASET_NAME]
|
|||
├─scripts
|
||||
| ├─run_infer_310.sh # Ascend310 推理shell脚本
|
||||
| ├─run_process_data.sh # 生成MindRecord格式的数据集
|
||||
| ├─run_train_gpu.sh # 启动GPU后端的训练
|
||||
| ├─run_eval_gpu.sh # 启动GPU后端的推理
|
||||
| └─run_train.sh # 启动训练,目前只支持Ascend后端
|
||||
|
|
||||
├─src
|
||||
| ├─config.py # 参数配置
|
||||
| ├─dataset.py # 数据预处理
|
||||
| ├─gcn.py # GCN骨干
|
||||
| ├─eval_callback.py # 回调函数
|
||||
| └─metrics.py # 损失和准确率
|
||||
|
|
||||
├─default_config.py # 配置文件
|
||||
|
@ -170,6 +173,7 @@ sh run_train.sh [DATASET_NAME]
|
|||
├─mindspore_hub_conf.py # mindspore hub 脚本
|
||||
├─postprocess.py # 后处理脚本
|
||||
├─preprocess.py # 预处理脚本
|
||||
|─eval.py # 推理网络,进行测试。
|
||||
└─train.py # 训练网络,每个训练轮次后评估验证结果收敛后,训练停止,然后进行测试。
|
||||
```
|
||||
|
||||
|
@ -184,6 +188,14 @@ sh run_train.sh [DATASET_NAME]
|
|||
"dropout": 0.5, # 第一图卷积层dropout率
|
||||
"weight_decay": 5e-4, # 第一图卷积层参数的权重衰减
|
||||
"early_stopping": 10, # 早停容限
|
||||
"save_ckpt_steps": 549 # 保存ckpt的步数
|
||||
"keep_ckpt_max": 10 # 保存ckpt的最大步数
|
||||
"ckpt_dir": './ckpt' # 保存ckpt的文件夹
|
||||
"best_ckpt_dir": './best_ckpt’ # 最好ckpt的文件夹
|
||||
"best_ckpt_name": 'best.ckpt' # 最好ckpt的文件名
|
||||
"eval_start_epoch": 100 # 从哪一步开始eval
|
||||
"save_best_ckpt": True # 是否存储最好的ckpt
|
||||
"eval_interval": 1 # eval间隔
|
||||
```
|
||||
|
||||
### 培训、评估、测试过程
|
||||
|
@ -191,14 +203,22 @@ sh run_train.sh [DATASET_NAME]
|
|||
#### 用法
|
||||
|
||||
```text
|
||||
# 使用Cora或Citeseer数据集进行训练,DATASET_NAME为Cora或Citeseer
|
||||
sh run_train.sh [DATASET_NAME]
|
||||
# 在Ascend上使用Cora或Citeseer数据集进行训练,DATASET_NAME为Cora或Citeseer
|
||||
bash run_train.sh [DATASET_NAME]
|
||||
|
||||
# 在GPU上使用Cora或Citeseer数据集进行训练,DATASET_NAME为Cora或Citeseer
|
||||
bash run_train_gpu.sh [DATASET_NAME]
|
||||
|
||||
# 在GPU上对Cora或Citeseer数据集进行测试
|
||||
bash run_eval_gpu.sh [DATASET_NAME] [CKPT]
|
||||
```
|
||||
|
||||
#### 启动
|
||||
|
||||
```bash
|
||||
sh run_train.sh cora
|
||||
bash run_train.sh cora
|
||||
bash run_train_gpu.sh cora
|
||||
bash run_eval_gpu.sh cora ckpt/gcn.ckpt
|
||||
```
|
||||
|
||||
#### 结果
|
||||
|
@ -258,18 +278,18 @@ Test set results: accuracy= 0.81300
|
|||
|
||||
### 性能
|
||||
|
||||
| 参数 | GCN |
|
||||
| -------------------------- | -------------------------------------------------------------- |
|
||||
| 资源 | Ascend 910;系统 Euler2.8 |
|
||||
| 上传日期 | 2020-06-09 |
|
||||
| MindSpore版本 | 0.5.0-beta |
|
||||
| 数据集 | Cora/Citeseer |
|
||||
| 训练参数 | epoch=200 |
|
||||
| 优化器 | Adam |
|
||||
| 损失函数 | Softmax交叉熵 |
|
||||
| 准确率 | 81.5/70.3 |
|
||||
| 参数(B) | 92160/59344 |
|
||||
| 脚本 | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn> |
|
||||
| 参数 | GCN | GCN |
|
||||
| -------------------------- | -------------------------------------------------------------- | -------------------------- |
|
||||
| 资源 | Ascend 910;系统 Euler2.8 | NV SMX3 V100-32G |
|
||||
| 上传日期 | 2020-06-09 | 2021-05-06 |
|
||||
| MindSpore版本 | 0.5.0-beta | 1.1.0 |
|
||||
| 数据集 | Cora/Citeseer | Cora/Citeseer |
|
||||
| 训练参数 | epoch=200 | epoch=200 |
|
||||
| 优化器 | Adam | Adam |
|
||||
| 损失函数 | Softmax交叉熵 | Softmax交叉熵 |
|
||||
| 准确率 | 81.5/70.3 | 87.5/76.9 |
|
||||
| 参数(B) | 92160/59344 | 92160/59344 |
|
||||
| 脚本 | [GCN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn) | [GCN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn) |
|
||||
|
||||
## 随机情况说明
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ eval_nodes_num: 500
|
|||
test_nodes_num: 1000
|
||||
save_TSNE: False
|
||||
save_ckptpath: "ckpts/"
|
||||
train_with_eval: False
|
||||
|
||||
|
||||
---
|
||||
|
@ -30,3 +31,4 @@ train_nodes_num: "Nodes numbers for training"
|
|||
eval_nodes_num: "Nodes numbers for evaluation"
|
||||
test_nodes_num: "Nodes numbers for test"
|
||||
save_TSNE: "Whether to save t-SNE graph"
|
||||
train_with_eval: "Whether to train with evaluation"
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
GCN eval script.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
from mindspore import Tensor
|
||||
from mindspore import Model, context
|
||||
|
||||
from src.config import ConfigGCN
|
||||
from src.dataset import get_adj_features_labels, get_mask
|
||||
from src.metrics import Loss
|
||||
from src.gcn import GCN
|
||||
|
||||
def run_gcn_infer():
|
||||
"""
|
||||
Run gcn infer
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='GCN')
|
||||
parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Dataset directory')
|
||||
parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test')
|
||||
parser.add_argument("--model_ckpt", type=str, required=True,
|
||||
help="existed checkpoint address.")
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=args_opt.device_target, save_graphs=False)
|
||||
config = ConfigGCN()
|
||||
adj, feature, label_onehot, _ = get_adj_features_labels(args_opt.data_dir)
|
||||
feature_d = np.expand_dims(feature, axis=0)
|
||||
label_onehot_d = np.expand_dims(label_onehot, axis=0)
|
||||
data = {"feature": feature_d, "label": label_onehot_d}
|
||||
dataset = ds.NumpySlicesDataset(data=data)
|
||||
adj = Tensor(adj, dtype=mstype.float32)
|
||||
feature = Tensor(feature)
|
||||
nodes_num = label_onehot.shape[0]
|
||||
test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num)
|
||||
class_num = label_onehot.shape[1]
|
||||
input_dim = feature.shape[1]
|
||||
gcn_net_test = GCN(config, input_dim, class_num, adj)
|
||||
load_checkpoint(args_opt.model_ckpt, net=gcn_net_test)
|
||||
eval_metrics = {'Acc': nn.Accuracy()}
|
||||
criterion = Loss(test_mask, config.weight_decay, gcn_net_test.trainable_params()[0])
|
||||
model = Model(gcn_net_test, loss_fn=criterion, metrics=eval_metrics)
|
||||
res = model.eval(dataset, dataset_sink_mode=True)
|
||||
print(res)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_gcn_infer()
|
|
@ -32,7 +32,7 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
|||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
if args.device_target == "Ascend" or args.device_target == "GPU":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET_NAME=$1
|
||||
echo $DATASET_NAME
|
||||
MODEL_CKPT=$(get_real_path $2)
|
||||
echo $MODEL_CKPT
|
||||
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp ../*.yaml ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cp -r ../model_utils ./eval
|
||||
cd ./eval || exit
|
||||
echo "start eval on standalone GPU"
|
||||
|
||||
if [ $DATASET_NAME == cora ]
|
||||
then
|
||||
python eval.py --data_dir=../data_mr/$DATASET_NAME --device_target="GPU" --model_ckpt $MODEL_CKPT &> log &
|
||||
fi
|
||||
|
||||
if [ $DATASET_NAME == citeseer ]
|
||||
then
|
||||
python eval.py --data_dir=../data_mr/$DATASET_NAME --device_target="GPU" --model_ckpt $MODEL_CKPT &> log &
|
||||
fi
|
||||
cd ..
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_train.sh [SRC_PATH] [DATASET_NAME]"
|
||||
echo "Usage: sh run_process_data.sh [SRC_PATH] [DATASET_NAME]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -43,7 +43,7 @@ MINDRECORD_PATH=`pwd`/data_mr
|
|||
rm -f $MINDRECORD_PATH/$DATASET_NAME
|
||||
rm -f $MINDRECORD_PATH/$DATASET_NAME.db
|
||||
|
||||
cd ../../utils/graph_to_mindrecord || exit
|
||||
cd ../../../../utils/graph_to_mindrecord || exit
|
||||
|
||||
python writer.py --mindrecord_script $DATASET_NAME \
|
||||
--mindrecord_file "$MINDRECORD_PATH/$DATASET_NAME" \
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]
|
||||
then
|
||||
echo "Usage: sh run_train.sh [DATASET_NAME]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DATASET_NAME=$1
|
||||
echo $DATASET_NAME
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp ../*.yaml ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cp -r ../model_utils ./train
|
||||
cd ./train || exit
|
||||
env > env.log
|
||||
echo "start training for standalone GPU"
|
||||
|
||||
|
||||
if [ $DATASET_NAME == cora ]
|
||||
then
|
||||
python train.py --data_dir=../data_mr/$DATASET_NAME --train_nodes_num=140 --device_target="GPU" &> log &
|
||||
fi
|
||||
|
||||
if [ $DATASET_NAME == citeseer ]
|
||||
then
|
||||
python train.py --data_dir=../data_mr/$DATASET_NAME --train_nodes_num=120 --device_target="GPU" &> log &
|
||||
fi
|
||||
cd ..
|
||||
|
|
@ -18,9 +18,20 @@ network config setting, will be used in train.py
|
|||
|
||||
|
||||
class ConfigGCN():
|
||||
"""
|
||||
Configuration of GCN
|
||||
"""
|
||||
learning_rate = 0.01
|
||||
epochs = 200
|
||||
hidden1 = 16
|
||||
dropout = 0.5
|
||||
weight_decay = 5e-4
|
||||
early_stopping = 50
|
||||
save_ckpt_steps = 549
|
||||
keep_ckpt_max = 10
|
||||
ckpt_dir = './ckpt'
|
||||
best_ckpt_dir = './best_ckpt'
|
||||
best_ckpt_name = 'best.ckpt'
|
||||
eval_start_epoch = 100
|
||||
save_best_ckpt = True
|
||||
eval_interval = 1
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Evaluation callback when training"""
|
||||
|
||||
import os
|
||||
import stat
|
||||
from mindspore import save_checkpoint
|
||||
from mindspore import log as logger
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""
|
||||
Evaluation callback when training.
|
||||
|
||||
Args:
|
||||
eval_function (function): evaluation function.
|
||||
eval_param_dict (dict): evaluation parameters' configure dict.
|
||||
interval (int): run evaluation interval, default is 1.
|
||||
eval_start_epoch (int): evaluation start epoch, default is 1.
|
||||
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
|
||||
besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
|
||||
metrics_name (str): evaluation metrics name, default is `acc`.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Examples:
|
||||
>>> EvalCallBack(eval_function, eval_param_dict)
|
||||
"""
|
||||
|
||||
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
|
||||
ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
|
||||
super(EvalCallBack, self).__init__()
|
||||
self.eval_param_dict = eval_param_dict
|
||||
self.eval_function = eval_function
|
||||
self.eval_start_epoch = eval_start_epoch
|
||||
if interval < 1:
|
||||
raise ValueError("interval should >= 1.")
|
||||
self.interval = interval
|
||||
self.save_best_ckpt = save_best_ckpt
|
||||
self.best_res = 0
|
||||
self.best_epoch = 0
|
||||
if not os.path.isdir(ckpt_directory):
|
||||
os.makedirs(ckpt_directory)
|
||||
self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
|
||||
self.metrics_name = metrics_name
|
||||
self.save_TSNE = save_TSNE
|
||||
|
||||
def remove_ckpoint_file(self, file_name):
|
||||
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
|
||||
try:
|
||||
os.chmod(file_name, stat.S_IWRITE)
|
||||
os.remove(file_name)
|
||||
except OSError:
|
||||
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
|
||||
except ValueError:
|
||||
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""Callback when epoch end."""
|
||||
cb_params = run_context.original_args()
|
||||
cur_epoch = cb_params.cur_epoch_num
|
||||
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
|
||||
res = self.eval_function(self.eval_param_dict)
|
||||
print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
|
||||
if res >= self.best_res:
|
||||
self.best_res = res
|
||||
self.best_epoch = cur_epoch
|
||||
print("update best result: {}".format(res), flush=True)
|
||||
if self.save_best_ckpt:
|
||||
if os.path.exists(self.bast_ckpt_path):
|
||||
self.remove_ckpoint_file(self.bast_ckpt_path)
|
||||
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
|
||||
print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
|
||||
|
||||
def end(self, run_context):
|
||||
print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
|
||||
self.best_res,
|
||||
self.best_epoch), flush=True)
|
||||
|
|
@ -92,12 +92,12 @@ class GCN(nn.Cell):
|
|||
output_dim (int): The number of output channels, equal to classes num.
|
||||
"""
|
||||
|
||||
def __init__(self, config, input_dim, output_dim):
|
||||
def __init__(self, config, input_dim, output_dim, adj):
|
||||
super(GCN, self).__init__()
|
||||
self.layer0 = GraphConvolution(input_dim, config.hidden1, activation="relu", dropout_ratio=config.dropout)
|
||||
self.layer1 = GraphConvolution(config.hidden1, output_dim, dropout_ratio=None)
|
||||
|
||||
def construct(self, adj, feature):
|
||||
output0 = self.layer0(adj, feature)
|
||||
output1 = self.layer1(adj, output0)
|
||||
self.adj = adj
|
||||
def construct(self, feature):
|
||||
output0 = self.layer0(self.adj, feature)
|
||||
output1 = self.layer1(self.adj, output0)
|
||||
return output1
|
||||
|
|
|
@ -17,16 +17,13 @@ from mindspore import nn
|
|||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.nn.metrics import Metric
|
||||
|
||||
|
||||
class Loss(nn.Cell):
|
||||
"""Softmax cross-entropy loss with masking."""
|
||||
def __init__(self, label, mask, weight_decay, param):
|
||||
def __init__(self, mask, weight_decay, param):
|
||||
super(Loss, self).__init__(auto_prefix=False)
|
||||
self.label = Tensor(label)
|
||||
self.mask = Tensor(mask)
|
||||
self.loss = P.SoftmaxCrossEntropyWithLogits()
|
||||
self.one = Tensor(1.0, mstype.float32)
|
||||
|
@ -38,12 +35,12 @@ class Loss(nn.Cell):
|
|||
self.weight_decay = weight_decay
|
||||
self.param = param
|
||||
|
||||
def construct(self, preds):
|
||||
def construct(self, preds, label):
|
||||
"""Calculate loss"""
|
||||
param = self.l2_loss(self.param)
|
||||
loss = self.weight_decay * param
|
||||
preds = self.cast(preds, mstype.float32)
|
||||
loss = loss + self.loss(preds, self.label)[0]
|
||||
loss = loss + self.loss(preds, label)[0]
|
||||
mask = self.cast(self.mask, mstype.float32)
|
||||
mask_reduce = self.mean(mask)
|
||||
mask = mask / mask_reduce
|
||||
|
@ -51,138 +48,38 @@ class Loss(nn.Cell):
|
|||
loss = self.mean(loss)
|
||||
return loss
|
||||
|
||||
|
||||
class Accuracy(nn.Cell):
|
||||
"""Accuracy with masking."""
|
||||
def __init__(self, label, mask):
|
||||
super(Accuracy, self).__init__(auto_prefix=False)
|
||||
self.label = Tensor(label)
|
||||
class GCNAccuracy(Metric):
|
||||
"""
|
||||
Accuracy for GCN
|
||||
"""
|
||||
def __init__(self, mask):
|
||||
super(GCNAccuracy, self).__init__()
|
||||
self.mask = Tensor(mask)
|
||||
self.equal = P.Equal()
|
||||
self.argmax = P.Argmax()
|
||||
self.cast = P.Cast()
|
||||
self.mean = P.ReduceMean()
|
||||
self.accuracy_all = 0
|
||||
|
||||
def construct(self, preds):
|
||||
preds = self.cast(preds, mstype.float32)
|
||||
correct_prediction = self.equal(self.argmax(preds), self.argmax(self.label))
|
||||
accuracy_all = self.cast(correct_prediction, mstype.float32)
|
||||
def clear(self):
|
||||
self.accuracy_all = 0
|
||||
|
||||
def update(self, *inputs):
|
||||
preds = self.cast(inputs[1], mstype.float32)
|
||||
correct_prediction = self.equal(self.argmax(preds), self.argmax(inputs[0]))
|
||||
self.accuracy_all = self.cast(correct_prediction, mstype.float32)
|
||||
mask = self.cast(self.mask, mstype.float32)
|
||||
mask_reduce = self.mean(mask)
|
||||
mask = mask / mask_reduce
|
||||
accuracy_all *= mask
|
||||
return self.mean(accuracy_all)
|
||||
self.accuracy_all *= mask
|
||||
|
||||
def eval(self):
|
||||
return float(self.mean(self.accuracy_all).asnumpy())
|
||||
|
||||
class LossAccuracyWrapper(nn.Cell):
|
||||
"""
|
||||
Wraps the GCN model with loss and accuracy cell.
|
||||
|
||||
Args:
|
||||
network (Cell): GCN network.
|
||||
label (numpy.ndarray): Dataset labels.
|
||||
mask (numpy.ndarray): Mask for training, evaluation or test.
|
||||
weight_decay (float): Weight decay parameter for weight of the first convolution layer.
|
||||
"""
|
||||
|
||||
def __init__(self, network, label, mask, weight_decay):
|
||||
super(LossAccuracyWrapper, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
||||
self.accuracy = Accuracy(label, mask)
|
||||
|
||||
def construct(self, adj, feature):
|
||||
preds = self.network(adj, feature)
|
||||
loss = self.loss(preds)
|
||||
accuracy = self.accuracy(preds)
|
||||
return loss, accuracy
|
||||
|
||||
|
||||
class LossWrapper(nn.Cell):
|
||||
"""
|
||||
Wraps the GCN model with loss.
|
||||
|
||||
Args:
|
||||
network (Cell): GCN network.
|
||||
label (numpy.ndarray): Dataset labels.
|
||||
mask (numpy.ndarray): Mask for training.
|
||||
weight_decay (float): Weight decay parameter for weight of the first convolution layer.
|
||||
"""
|
||||
|
||||
def __init__(self, network, label, mask, weight_decay):
|
||||
super(LossWrapper, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
||||
|
||||
def construct(self, adj, feature):
|
||||
preds = self.network(adj, feature)
|
||||
loss = self.loss(preds)
|
||||
return loss
|
||||
|
||||
|
||||
class TrainOneStepCell(nn.Cell):
|
||||
r"""
|
||||
Network training package class.
|
||||
|
||||
Wraps the network with an optimizer. The resulting Cell be trained without inputs.
|
||||
Backward graph will be created in the construct function to do parameter updating. Different
|
||||
parallel modes are available to run the training.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network.
|
||||
optimizer (Cell): Optimizer for updating the weights.
|
||||
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
||||
|
||||
Outputs:
|
||||
Tensor, a scalar Tensor with shape :math:`()`.
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> loss_net = nn.WithLossCell(net, loss_fn)
|
||||
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
|
||||
def construct(self, adj, feature):
|
||||
weights = self.weights
|
||||
loss = self.network(adj, feature)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(adj, feature, sens)
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
|
||||
class TrainNetWrapper(nn.Cell):
|
||||
"""
|
||||
Wraps the GCN model with optimizer.
|
||||
|
||||
Args:
|
||||
network (Cell): GCN network.
|
||||
label (numpy.ndarray): Dataset labels.
|
||||
mask (numpy.ndarray): Mask for training, evaluation or test.
|
||||
config (ConfigGCN): Configuration for GCN.
|
||||
"""
|
||||
|
||||
def __init__(self, network, label, mask, config):
|
||||
super(TrainNetWrapper, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
loss_net = LossWrapper(network, label, mask, config.weight_decay)
|
||||
optimizer = nn.Adam(loss_net.trainable_params(),
|
||||
learning_rate=config.learning_rate)
|
||||
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
|
||||
self.accuracy = Accuracy(label, mask)
|
||||
|
||||
def construct(self, adj, feature):
|
||||
loss = self.loss_train_net(adj, feature)
|
||||
accuracy = self.accuracy(self.network(adj, feature))
|
||||
return loss, accuracy
|
||||
def apply_eval(eval_param_dict):
|
||||
"""run Evaluation"""
|
||||
model = eval_param_dict["model"]
|
||||
dataset = eval_param_dict["dataset"]
|
||||
metrics_name = eval_param_dict["metrics_name"]
|
||||
eval_score = model.eval(dataset, dataset_sink_mode=False)[metrics_name]
|
||||
return eval_score
|
||||
|
|
|
@ -18,36 +18,25 @@ GCN training script.
|
|||
"""
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib import animation
|
||||
from sklearn import manifold
|
||||
from mindspore import context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Model, context
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
||||
import numpy as np
|
||||
|
||||
from src.gcn import GCN
|
||||
from src.metrics import LossAccuracyWrapper, TrainNetWrapper
|
||||
from src.metrics import Loss, GCNAccuracy, apply_eval
|
||||
from src.config import ConfigGCN
|
||||
from src.dataset import get_adj_features_labels, get_mask
|
||||
from src.eval_callback import EvalCallBack
|
||||
|
||||
from model_utils.config import config as default_args
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
|
||||
def t_SNE(out_feature, dim):
|
||||
t_sne = manifold.TSNE(n_components=dim, init='pca', random_state=0)
|
||||
return t_sne.fit_transform(out_feature)
|
||||
|
||||
|
||||
def update_graph(i, data, scat, plot):
|
||||
scat.set_offsets(data[i])
|
||||
plt.title('t-SNE visualization of Epoch:{0}'.format(i))
|
||||
return scat, plot
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
|
@ -112,85 +101,41 @@ def modelarts_pre_process():
|
|||
def run_train():
|
||||
"""Train model."""
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend", save_graphs=False)
|
||||
device_target=default_args.device_target, save_graphs=False)
|
||||
config = ConfigGCN()
|
||||
adj, feature, label_onehot, label = get_adj_features_labels(default_args.data_dir)
|
||||
|
||||
if not os.path.exists(config.ckpt_dir):
|
||||
os.mkdir(config.ckpt_dir)
|
||||
adj, feature, label_onehot, _ = get_adj_features_labels(default_args.data_dir)
|
||||
feature_d = np.expand_dims(feature, axis=0)
|
||||
label_onehot_d = np.expand_dims(label_onehot, axis=0)
|
||||
data = {"feature": feature_d, "label": label_onehot_d}
|
||||
dataset = ds.NumpySlicesDataset(data=data)
|
||||
nodes_num = label_onehot.shape[0]
|
||||
train_mask = get_mask(nodes_num, 0, default_args.train_nodes_num)
|
||||
eval_mask = get_mask(nodes_num, default_args.train_nodes_num,
|
||||
default_args.train_nodes_num + default_args.eval_nodes_num)
|
||||
test_mask = get_mask(nodes_num, nodes_num - default_args.test_nodes_num, nodes_num)
|
||||
|
||||
class_num = label_onehot.shape[1]
|
||||
input_dim = feature.shape[1]
|
||||
gcn_net = GCN(config, input_dim, class_num)
|
||||
gcn_net.add_flags_recursive(fp16=True)
|
||||
|
||||
adj = Tensor(adj)
|
||||
feature = Tensor(feature)
|
||||
|
||||
eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
|
||||
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
|
||||
|
||||
loss_list = []
|
||||
|
||||
if default_args.save_TSNE:
|
||||
out_feature = gcn_net()
|
||||
tsne_result = t_SNE(out_feature.asnumpy(), 2)
|
||||
graph_data = []
|
||||
graph_data.append(tsne_result)
|
||||
fig = plt.figure()
|
||||
scat = plt.scatter(tsne_result[:, 0], tsne_result[:, 1], s=2, c=label, cmap='rainbow')
|
||||
plt.title('t-SNE visualization of Epoch:0', fontsize='large', fontweight='bold', verticalalignment='center')
|
||||
|
||||
for epoch in range(config.epochs):
|
||||
t = time.time()
|
||||
|
||||
train_net.set_train()
|
||||
train_result = train_net(adj, feature)
|
||||
train_loss = train_result[0].asnumpy()
|
||||
train_accuracy = train_result[1].asnumpy()
|
||||
|
||||
eval_net.set_train(False)
|
||||
eval_result = eval_net(adj, feature)
|
||||
eval_loss = eval_result[0].asnumpy()
|
||||
eval_accuracy = eval_result[1].asnumpy()
|
||||
|
||||
loss_list.append(eval_loss)
|
||||
print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(train_loss),
|
||||
"train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss),
|
||||
"val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t))
|
||||
|
||||
if default_args.save_TSNE:
|
||||
out_feature = gcn_net()
|
||||
tsne_result = t_SNE(out_feature.asnumpy(), 2)
|
||||
graph_data.append(tsne_result)
|
||||
|
||||
if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]):
|
||||
print("Early stopping...")
|
||||
break
|
||||
if not os.path.isdir(default_args.save_ckptpath):
|
||||
os.makedirs(default_args.save_ckptpath)
|
||||
ckpt_path = os.path.join(default_args.save_ckptpath, "gcn.ckpt")
|
||||
save_checkpoint(gcn_net, ckpt_path)
|
||||
gcn_net_test = GCN(config, input_dim, class_num)
|
||||
load_checkpoint(ckpt_path, net=gcn_net_test)
|
||||
gcn_net_test.add_flags_recursive(fp16=True)
|
||||
|
||||
test_net = LossAccuracyWrapper(gcn_net_test, label_onehot, test_mask, config.weight_decay)
|
||||
t_test = time.time()
|
||||
test_net.set_train(False)
|
||||
test_result = test_net(adj, feature)
|
||||
test_loss = test_result[0].asnumpy()
|
||||
test_accuracy = test_result[1].asnumpy()
|
||||
print("Test set results:", "loss=", "{:.5f}".format(test_loss),
|
||||
"accuracy=", "{:.5f}".format(test_accuracy), "time=", "{:.5f}".format(time.time() - t_test))
|
||||
|
||||
if default_args.save_TSNE:
|
||||
ani = animation.FuncAnimation(fig, update_graph, frames=range(config.epochs + 1), fargs=(graph_data, scat, plt))
|
||||
ani.save('t-SNE_visualization.gif', writer='imagemagick')
|
||||
|
||||
adj = Tensor(adj, dtype=mstype.float32)
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.save_ckpt_steps,
|
||||
keep_checkpoint_max=config.keep_ckpt_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='ckpt_gcn',
|
||||
directory=config.ckpt_dir,
|
||||
config=ckpt_config)
|
||||
gcn_net = GCN(config, input_dim, class_num, adj)
|
||||
cb = [TimeMonitor(), LossMonitor(), ckpoint_cb]
|
||||
opt = nn.Adam(gcn_net.trainable_params(), learning_rate=config.learning_rate)
|
||||
criterion = Loss(eval_mask, config.weight_decay, gcn_net.trainable_params()[0])
|
||||
model = Model(gcn_net, loss_fn=criterion, optimizer=opt, amp_level="O3")
|
||||
if default_args.train_with_eval:
|
||||
GCN_metric = GCNAccuracy(eval_mask)
|
||||
eval_model = Model(gcn_net, loss_fn=criterion, metrics={'GCNAccuracy': GCN_metric})
|
||||
eval_param_dict = {"model": eval_model, "dataset": dataset, "metrics_name": "GCNAccuracy"}
|
||||
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=config.eval_interval,
|
||||
eval_start_epoch=default_args.eval_start_epoch, save_best_ckpt=config.save_best_ckpt,
|
||||
ckpt_directory=config.best_ckpt_dir, besk_ckpt_name=config.best_ckpt_name,
|
||||
metrics_name="GCNAccuracy")
|
||||
cb.append(eval_cb)
|
||||
model.train(config.epochs, dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_train()
|
||||
|
|
|
@ -1,93 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import time
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from model_zoo.official.gnn.gcn.src.gcn import GCN
|
||||
from model_zoo.official.gnn.gcn.src.metrics import LossAccuracyWrapper, TrainNetWrapper
|
||||
from model_zoo.official.gnn.gcn.src.config import ConfigGCN
|
||||
from model_zoo.official.gnn.gcn.src.dataset import get_adj_features_labels, get_mask
|
||||
|
||||
|
||||
DATA_DIR = '/home/workspace/mindspore_dataset/cora/cora_mr/cora_mr'
|
||||
TRAIN_NODE_NUM = 140
|
||||
EVAL_NODE_NUM = 500
|
||||
TEST_NODE_NUM = 1000
|
||||
SEED = 20
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gcn():
|
||||
print("test_gcn begin")
|
||||
np.random.seed(SEED)
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend", save_graphs=False)
|
||||
config = ConfigGCN()
|
||||
config.dropout = 0.0
|
||||
adj, feature, label_onehot, _ = get_adj_features_labels(DATA_DIR)
|
||||
|
||||
nodes_num = label_onehot.shape[0]
|
||||
train_mask = get_mask(nodes_num, 0, TRAIN_NODE_NUM)
|
||||
eval_mask = get_mask(nodes_num, TRAIN_NODE_NUM, TRAIN_NODE_NUM + EVAL_NODE_NUM)
|
||||
test_mask = get_mask(nodes_num, nodes_num - TEST_NODE_NUM, nodes_num)
|
||||
|
||||
class_num = label_onehot.shape[1]
|
||||
input_dim = feature.shape[1]
|
||||
gcn_net = GCN(config, input_dim, class_num)
|
||||
gcn_net.add_flags_recursive(fp16=True)
|
||||
|
||||
adj = Tensor(adj)
|
||||
feature = Tensor(feature)
|
||||
|
||||
eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
|
||||
test_net = LossAccuracyWrapper(gcn_net, label_onehot, test_mask, config.weight_decay)
|
||||
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
|
||||
|
||||
loss_list = []
|
||||
for epoch in range(config.epochs):
|
||||
t = time.time()
|
||||
|
||||
train_net.set_train()
|
||||
train_result = train_net(adj, feature)
|
||||
train_loss = train_result[0].asnumpy()
|
||||
train_accuracy = train_result[1].asnumpy()
|
||||
|
||||
eval_net.set_train(False)
|
||||
eval_result = eval_net(adj, feature)
|
||||
eval_loss = eval_result[0].asnumpy()
|
||||
eval_accuracy = eval_result[1].asnumpy()
|
||||
|
||||
loss_list.append(eval_loss)
|
||||
print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(train_loss),
|
||||
"train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss),
|
||||
"val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t))
|
||||
|
||||
if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]):
|
||||
print("Early stopping...")
|
||||
break
|
||||
|
||||
test_net.set_train(False)
|
||||
test_result = test_net(adj, feature)
|
||||
test_loss = test_result[0].asnumpy()
|
||||
test_accuracy = test_result[1].asnumpy()
|
||||
print("Test set results:", "loss=", "{:.5f}".format(test_loss),
|
||||
"accuracy=", "{:.5f}".format(test_accuracy))
|
||||
assert test_accuracy > 0.812
|
Loading…
Reference in New Issue