forked from mindspore-Ecosystem/mindspore
!16635 Support GCN on GPU
Merge pull request !16635 from yuruilee/master
This commit is contained in:
commit
3a9d3f2580
|
@ -75,16 +75,16 @@ Note that you can run the scripts based on the dataset mentioned in original pap
|
||||||
```buildoutcfg
|
```buildoutcfg
|
||||||
cd ./scripts
|
cd ./scripts
|
||||||
# SRC_PATH is the dataset file path you downloaded, DATASET_NAME is cora or citeseer
|
# SRC_PATH is the dataset file path you downloaded, DATASET_NAME is cora or citeseer
|
||||||
sh run_process_data.sh [SRC_PATH] [DATASET_NAME]
|
bash run_process_data.sh [SRC_PATH] [DATASET_NAME]
|
||||||
```
|
```
|
||||||
|
|
||||||
### Launch
|
### Launch
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
#Generate dataset in mindrecord format for cora
|
#Generate dataset in mindrecord format for cora
|
||||||
sh run_process_data.sh ./data cora
|
bash run_process_data.sh ./data cora
|
||||||
#Generate dataset in mindrecord format for citeseer
|
#Generate dataset in mindrecord format for citeseer
|
||||||
sh run_process_data.sh ./data citeseer
|
bash run_process_data.sh ./data citeseer
|
||||||
```
|
```
|
||||||
|
|
||||||
- Running on local with Ascend
|
- Running on local with Ascend
|
||||||
|
@ -149,12 +149,15 @@ sh run_train.sh [DATASET_NAME]
|
||||||
├─scripts
|
├─scripts
|
||||||
| ├─run_infer_310.sh # shell script for infer on Ascend 310
|
| ├─run_infer_310.sh # shell script for infer on Ascend 310
|
||||||
| ├─run_process_data.sh # Generate dataset in mindrecord format
|
| ├─run_process_data.sh # Generate dataset in mindrecord format
|
||||||
| └─run_train.sh # Launch training, now only Ascend backend is supported.
|
| ├─run_train_gpu.sh # Launch GPU training.
|
||||||
|
| ├─run_eval_gpu.sh # Launch GPU inference.
|
||||||
|
| └─run_train.sh # Launch Ascend training.
|
||||||
|
|
|
|
||||||
├─src
|
├─src
|
||||||
| ├─config.py # Parameter configuration
|
| ├─config.py # Parameter configuration
|
||||||
| ├─dataset.py # Data preprocessin
|
| ├─dataset.py # Data preprocessin
|
||||||
| ├─gcn.py # GCN backbone
|
| ├─gcn.py # GCN backbone
|
||||||
|
| ├─eval_callback.py # Callback function
|
||||||
| └─metrics.py # Loss and accuracy
|
| └─metrics.py # Loss and accuracy
|
||||||
|
|
|
|
||||||
├─default_config.py # Configurations
|
├─default_config.py # Configurations
|
||||||
|
@ -162,7 +165,8 @@ sh run_train.sh [DATASET_NAME]
|
||||||
├─mindspore_hub_conf.py # mindspore_hub_conf scripts
|
├─mindspore_hub_conf.py # mindspore_hub_conf scripts
|
||||||
├─postprocess.py # postprocess script
|
├─postprocess.py # postprocess script
|
||||||
├─preprocess.py # preprocess scripts
|
├─preprocess.py # preprocess scripts
|
||||||
└─train.py # Train net, evaluation is performed after every training epoch. After the verification result converges, the training stops, then testing is performed.
|
|─eval.py # Evaluation net, testing is performed.
|
||||||
|
└─train.py # Train net, evaluation is performed after every training epoch. After the verification result converges, the training stops.
|
||||||
```
|
```
|
||||||
|
|
||||||
## [Script Parameters](#contents)
|
## [Script Parameters](#contents)
|
||||||
|
@ -176,6 +180,14 @@ Parameters for training can be set in config.py.
|
||||||
"dropout": 0.5, # Dropout ratio for the first graph convolution layer
|
"dropout": 0.5, # Dropout ratio for the first graph convolution layer
|
||||||
"weight_decay": 5e-4, # Weight decay for the parameter of the first graph convolution layer
|
"weight_decay": 5e-4, # Weight decay for the parameter of the first graph convolution layer
|
||||||
"early_stopping": 10, # Tolerance for early stopping
|
"early_stopping": 10, # Tolerance for early stopping
|
||||||
|
"save_ckpt_steps": 549 # Step to save checkpoint
|
||||||
|
"keep_ckpt_max": 10 # Maximum step to save checkpoint
|
||||||
|
"ckpt_dir": './ckpt' # The folder to save checkpoint
|
||||||
|
"best_ckpt_dir": './best_ckpt' # The folder to save the best checkpoint
|
||||||
|
"best_ckpt_name": 'best.ckpt' # The file name of the best checkpoint
|
||||||
|
"eval_start_epoch": 100 # Start step for eval
|
||||||
|
"save_best_ckpt": True # Save the best checkpoint or not
|
||||||
|
"eval_interval": 1 # The interval of eval
|
||||||
```
|
```
|
||||||
|
|
||||||
### [Training, Evaluation, Test Process](#contents)
|
### [Training, Evaluation, Test Process](#contents)
|
||||||
|
@ -183,14 +195,22 @@ Parameters for training can be set in config.py.
|
||||||
#### Usage
|
#### Usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# run train with cora or citeseer dataset, DATASET_NAME is cora or citeseer
|
# run train with cora or citeseer dataset on Ascend, DATASET_NAME is cora or citeseer
|
||||||
sh run_train.sh [DATASET_NAME]
|
bash run_train.sh [DATASET_NAME]
|
||||||
|
|
||||||
|
# run train with cora or citeseer dataset on GPU, DATASET_NAME is cora or citeseer
|
||||||
|
bash run_train_gpu.sh [DATASET_NAME]
|
||||||
|
|
||||||
|
# run inference with cora or citeseer dataset on GPU, DATASET_NAME is cora or citeseer
|
||||||
|
bash run_eval_gpu.sh [DATASET_NAME] [CKPT]
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Launch
|
#### Launch
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sh run_train.sh cora
|
bash run_train.sh cora
|
||||||
|
bash run_train_gpu.sh cora
|
||||||
|
bash run_eval_gpu.sh cora ckpt/gcn.ckpt
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Result
|
#### Result
|
||||||
|
@ -250,18 +270,18 @@ Test set results: accuracy= 0.81300
|
||||||
|
|
||||||
### [Performance](#contents)
|
### [Performance](#contents)
|
||||||
|
|
||||||
| Parameters | GCN |
|
| Parameters | GCN | GCN |
|
||||||
| -------------------------- | -------------------------------------------------------------- |
|
| ------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||||
| Resource | Ascend 910; OS Euler2.8 |
|
| Resource | Ascend 910; OS Euler2.8 | NV SMX3 V100-32G |
|
||||||
| uploaded Date | 06/09/2020 (month/day/year) |
|
| uploaded Date | 06/09/2020 (month/day/year) | 05/06/2021 (month/day/year) |
|
||||||
| MindSpore Version | 1.0.0 |
|
| MindSpore Version | 1.0.0 | 1.1.0 |
|
||||||
| Dataset | Cora/Citeseer |
|
| Dataset | Cora/Citeseer | Cora/Citeseer |
|
||||||
| Training Parameters | epoch=200 |
|
| Training Parameters | epoch=200 | epoch=200 |
|
||||||
| Optimizer | Adam |
|
| Optimizer | Adam | Adam |
|
||||||
| Loss Function | Softmax Cross Entropy |
|
| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy |
|
||||||
| Accuracy | 81.5/70.3 |
|
| Accuracy | 81.5/70.3 | 87.5/76.9 |
|
||||||
| Parameters (B) | 92160/59344 |
|
| Parameters (B) | 92160/59344 | 92160/59344 |
|
||||||
| Scripts | [GCN Script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn) |
|
| Scripts | [GCN Script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn) | [GCN Script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn) |
|
||||||
|
|
||||||
## [Description of Random Situation](#contents)
|
## [Description of Random Situation](#contents)
|
||||||
|
|
||||||
|
|
|
@ -157,12 +157,15 @@ sh run_train.sh [DATASET_NAME]
|
||||||
├─scripts
|
├─scripts
|
||||||
| ├─run_infer_310.sh # Ascend310 推理shell脚本
|
| ├─run_infer_310.sh # Ascend310 推理shell脚本
|
||||||
| ├─run_process_data.sh # 生成MindRecord格式的数据集
|
| ├─run_process_data.sh # 生成MindRecord格式的数据集
|
||||||
|
| ├─run_train_gpu.sh # 启动GPU后端的训练
|
||||||
|
| ├─run_eval_gpu.sh # 启动GPU后端的推理
|
||||||
| └─run_train.sh # 启动训练,目前只支持Ascend后端
|
| └─run_train.sh # 启动训练,目前只支持Ascend后端
|
||||||
|
|
|
|
||||||
├─src
|
├─src
|
||||||
| ├─config.py # 参数配置
|
| ├─config.py # 参数配置
|
||||||
| ├─dataset.py # 数据预处理
|
| ├─dataset.py # 数据预处理
|
||||||
| ├─gcn.py # GCN骨干
|
| ├─gcn.py # GCN骨干
|
||||||
|
| ├─eval_callback.py # 回调函数
|
||||||
| └─metrics.py # 损失和准确率
|
| └─metrics.py # 损失和准确率
|
||||||
|
|
|
|
||||||
├─default_config.py # 配置文件
|
├─default_config.py # 配置文件
|
||||||
|
@ -170,6 +173,7 @@ sh run_train.sh [DATASET_NAME]
|
||||||
├─mindspore_hub_conf.py # mindspore hub 脚本
|
├─mindspore_hub_conf.py # mindspore hub 脚本
|
||||||
├─postprocess.py # 后处理脚本
|
├─postprocess.py # 后处理脚本
|
||||||
├─preprocess.py # 预处理脚本
|
├─preprocess.py # 预处理脚本
|
||||||
|
|─eval.py # 推理网络,进行测试。
|
||||||
└─train.py # 训练网络,每个训练轮次后评估验证结果收敛后,训练停止,然后进行测试。
|
└─train.py # 训练网络,每个训练轮次后评估验证结果收敛后,训练停止,然后进行测试。
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -184,6 +188,14 @@ sh run_train.sh [DATASET_NAME]
|
||||||
"dropout": 0.5, # 第一图卷积层dropout率
|
"dropout": 0.5, # 第一图卷积层dropout率
|
||||||
"weight_decay": 5e-4, # 第一图卷积层参数的权重衰减
|
"weight_decay": 5e-4, # 第一图卷积层参数的权重衰减
|
||||||
"early_stopping": 10, # 早停容限
|
"early_stopping": 10, # 早停容限
|
||||||
|
"save_ckpt_steps": 549 # 保存ckpt的步数
|
||||||
|
"keep_ckpt_max": 10 # 保存ckpt的最大步数
|
||||||
|
"ckpt_dir": './ckpt' # 保存ckpt的文件夹
|
||||||
|
"best_ckpt_dir": './best_ckpt’ # 最好ckpt的文件夹
|
||||||
|
"best_ckpt_name": 'best.ckpt' # 最好ckpt的文件名
|
||||||
|
"eval_start_epoch": 100 # 从哪一步开始eval
|
||||||
|
"save_best_ckpt": True # 是否存储最好的ckpt
|
||||||
|
"eval_interval": 1 # eval间隔
|
||||||
```
|
```
|
||||||
|
|
||||||
### 培训、评估、测试过程
|
### 培训、评估、测试过程
|
||||||
|
@ -191,14 +203,22 @@ sh run_train.sh [DATASET_NAME]
|
||||||
#### 用法
|
#### 用法
|
||||||
|
|
||||||
```text
|
```text
|
||||||
# 使用Cora或Citeseer数据集进行训练,DATASET_NAME为Cora或Citeseer
|
# 在Ascend上使用Cora或Citeseer数据集进行训练,DATASET_NAME为Cora或Citeseer
|
||||||
sh run_train.sh [DATASET_NAME]
|
bash run_train.sh [DATASET_NAME]
|
||||||
|
|
||||||
|
# 在GPU上使用Cora或Citeseer数据集进行训练,DATASET_NAME为Cora或Citeseer
|
||||||
|
bash run_train_gpu.sh [DATASET_NAME]
|
||||||
|
|
||||||
|
# 在GPU上对Cora或Citeseer数据集进行测试
|
||||||
|
bash run_eval_gpu.sh [DATASET_NAME] [CKPT]
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 启动
|
#### 启动
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sh run_train.sh cora
|
bash run_train.sh cora
|
||||||
|
bash run_train_gpu.sh cora
|
||||||
|
bash run_eval_gpu.sh cora ckpt/gcn.ckpt
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 结果
|
#### 结果
|
||||||
|
@ -258,18 +278,18 @@ Test set results: accuracy= 0.81300
|
||||||
|
|
||||||
### 性能
|
### 性能
|
||||||
|
|
||||||
| 参数 | GCN |
|
| 参数 | GCN | GCN |
|
||||||
| -------------------------- | -------------------------------------------------------------- |
|
| -------------------------- | -------------------------------------------------------------- | -------------------------- |
|
||||||
| 资源 | Ascend 910;系统 Euler2.8 |
|
| 资源 | Ascend 910;系统 Euler2.8 | NV SMX3 V100-32G |
|
||||||
| 上传日期 | 2020-06-09 |
|
| 上传日期 | 2020-06-09 | 2021-05-06 |
|
||||||
| MindSpore版本 | 0.5.0-beta |
|
| MindSpore版本 | 0.5.0-beta | 1.1.0 |
|
||||||
| 数据集 | Cora/Citeseer |
|
| 数据集 | Cora/Citeseer | Cora/Citeseer |
|
||||||
| 训练参数 | epoch=200 |
|
| 训练参数 | epoch=200 | epoch=200 |
|
||||||
| 优化器 | Adam |
|
| 优化器 | Adam | Adam |
|
||||||
| 损失函数 | Softmax交叉熵 |
|
| 损失函数 | Softmax交叉熵 | Softmax交叉熵 |
|
||||||
| 准确率 | 81.5/70.3 |
|
| 准确率 | 81.5/70.3 | 87.5/76.9 |
|
||||||
| 参数(B) | 92160/59344 |
|
| 参数(B) | 92160/59344 | 92160/59344 |
|
||||||
| 脚本 | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn> |
|
| 脚本 | [GCN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn) | [GCN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn) |
|
||||||
|
|
||||||
## 随机情况说明
|
## 随机情况说明
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ eval_nodes_num: 500
|
||||||
test_nodes_num: 1000
|
test_nodes_num: 1000
|
||||||
save_TSNE: False
|
save_TSNE: False
|
||||||
save_ckptpath: "ckpts/"
|
save_ckptpath: "ckpts/"
|
||||||
|
train_with_eval: False
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
@ -29,4 +30,5 @@ data_dir: "Dataset directory"
|
||||||
train_nodes_num: "Nodes numbers for training"
|
train_nodes_num: "Nodes numbers for training"
|
||||||
eval_nodes_num: "Nodes numbers for evaluation"
|
eval_nodes_num: "Nodes numbers for evaluation"
|
||||||
test_nodes_num: "Nodes numbers for test"
|
test_nodes_num: "Nodes numbers for test"
|
||||||
save_TSNE: "Whether to save t-SNE graph"
|
save_TSNE: "Whether to save t-SNE graph"
|
||||||
|
train_with_eval: "Whether to train with evaluation"
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""
|
||||||
|
GCN eval script.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore.train.serialization import load_checkpoint
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import Model, context
|
||||||
|
|
||||||
|
from src.config import ConfigGCN
|
||||||
|
from src.dataset import get_adj_features_labels, get_mask
|
||||||
|
from src.metrics import Loss
|
||||||
|
from src.gcn import GCN
|
||||||
|
|
||||||
|
def run_gcn_infer():
|
||||||
|
"""
|
||||||
|
Run gcn infer
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(description='GCN')
|
||||||
|
parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Dataset directory')
|
||||||
|
parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test')
|
||||||
|
parser.add_argument("--model_ckpt", type=str, required=True,
|
||||||
|
help="existed checkpoint address.")
|
||||||
|
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
|
||||||
|
help='device where the code will be implemented (default: Ascend)')
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
|
device_target=args_opt.device_target, save_graphs=False)
|
||||||
|
config = ConfigGCN()
|
||||||
|
adj, feature, label_onehot, _ = get_adj_features_labels(args_opt.data_dir)
|
||||||
|
feature_d = np.expand_dims(feature, axis=0)
|
||||||
|
label_onehot_d = np.expand_dims(label_onehot, axis=0)
|
||||||
|
data = {"feature": feature_d, "label": label_onehot_d}
|
||||||
|
dataset = ds.NumpySlicesDataset(data=data)
|
||||||
|
adj = Tensor(adj, dtype=mstype.float32)
|
||||||
|
feature = Tensor(feature)
|
||||||
|
nodes_num = label_onehot.shape[0]
|
||||||
|
test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num)
|
||||||
|
class_num = label_onehot.shape[1]
|
||||||
|
input_dim = feature.shape[1]
|
||||||
|
gcn_net_test = GCN(config, input_dim, class_num, adj)
|
||||||
|
load_checkpoint(args_opt.model_ckpt, net=gcn_net_test)
|
||||||
|
eval_metrics = {'Acc': nn.Accuracy()}
|
||||||
|
criterion = Loss(test_mask, config.weight_decay, gcn_net_test.trainable_params()[0])
|
||||||
|
model = Model(gcn_net_test, loss_fn=criterion, metrics=eval_metrics)
|
||||||
|
res = model.eval(dataset, dataset_sink_mode=True)
|
||||||
|
print(res)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run_gcn_infer()
|
|
@ -32,7 +32,7 @@ parser.add_argument("--device_target", type=str, default="Ascend",
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
if args.device_target == "Ascend":
|
if args.device_target == "Ascend" or args.device_target == "GPU":
|
||||||
context.set_context(device_id=args.device_id)
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
DATASET_NAME=$1
|
||||||
|
echo $DATASET_NAME
|
||||||
|
MODEL_CKPT=$(get_real_path $2)
|
||||||
|
echo $MODEL_CKPT
|
||||||
|
|
||||||
|
|
||||||
|
if [ -d "eval" ];
|
||||||
|
then
|
||||||
|
rm -rf ./eval
|
||||||
|
fi
|
||||||
|
mkdir ./eval
|
||||||
|
cp ../*.py ./eval
|
||||||
|
cp ../*.yaml ./eval
|
||||||
|
cp *.sh ./eval
|
||||||
|
cp -r ../src ./eval
|
||||||
|
cp -r ../model_utils ./eval
|
||||||
|
cd ./eval || exit
|
||||||
|
echo "start eval on standalone GPU"
|
||||||
|
|
||||||
|
if [ $DATASET_NAME == cora ]
|
||||||
|
then
|
||||||
|
python eval.py --data_dir=../data_mr/$DATASET_NAME --device_target="GPU" --model_ckpt $MODEL_CKPT &> log &
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $DATASET_NAME == citeseer ]
|
||||||
|
then
|
||||||
|
python eval.py --data_dir=../data_mr/$DATASET_NAME --device_target="GPU" --model_ckpt $MODEL_CKPT &> log &
|
||||||
|
fi
|
||||||
|
cd ..
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
if [ $# != 2 ]
|
if [ $# != 2 ]
|
||||||
then
|
then
|
||||||
echo "Usage: sh run_train.sh [SRC_PATH] [DATASET_NAME]"
|
echo "Usage: sh run_process_data.sh [SRC_PATH] [DATASET_NAME]"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ MINDRECORD_PATH=`pwd`/data_mr
|
||||||
rm -f $MINDRECORD_PATH/$DATASET_NAME
|
rm -f $MINDRECORD_PATH/$DATASET_NAME
|
||||||
rm -f $MINDRECORD_PATH/$DATASET_NAME.db
|
rm -f $MINDRECORD_PATH/$DATASET_NAME.db
|
||||||
|
|
||||||
cd ../../utils/graph_to_mindrecord || exit
|
cd ../../../../utils/graph_to_mindrecord || exit
|
||||||
|
|
||||||
python writer.py --mindrecord_script $DATASET_NAME \
|
python writer.py --mindrecord_script $DATASET_NAME \
|
||||||
--mindrecord_file "$MINDRECORD_PATH/$DATASET_NAME" \
|
--mindrecord_file "$MINDRECORD_PATH/$DATASET_NAME" \
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 1 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_train.sh [DATASET_NAME]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
DATASET_NAME=$1
|
||||||
|
echo $DATASET_NAME
|
||||||
|
|
||||||
|
if [ -d "train" ];
|
||||||
|
then
|
||||||
|
rm -rf ./train
|
||||||
|
fi
|
||||||
|
mkdir ./train
|
||||||
|
cp ../*.py ./train
|
||||||
|
cp ../*.yaml ./train
|
||||||
|
cp *.sh ./train
|
||||||
|
cp -r ../src ./train
|
||||||
|
cp -r ../model_utils ./train
|
||||||
|
cd ./train || exit
|
||||||
|
env > env.log
|
||||||
|
echo "start training for standalone GPU"
|
||||||
|
|
||||||
|
|
||||||
|
if [ $DATASET_NAME == cora ]
|
||||||
|
then
|
||||||
|
python train.py --data_dir=../data_mr/$DATASET_NAME --train_nodes_num=140 --device_target="GPU" &> log &
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $DATASET_NAME == citeseer ]
|
||||||
|
then
|
||||||
|
python train.py --data_dir=../data_mr/$DATASET_NAME --train_nodes_num=120 --device_target="GPU" &> log &
|
||||||
|
fi
|
||||||
|
cd ..
|
||||||
|
|
|
@ -18,9 +18,20 @@ network config setting, will be used in train.py
|
||||||
|
|
||||||
|
|
||||||
class ConfigGCN():
|
class ConfigGCN():
|
||||||
|
"""
|
||||||
|
Configuration of GCN
|
||||||
|
"""
|
||||||
learning_rate = 0.01
|
learning_rate = 0.01
|
||||||
epochs = 200
|
epochs = 200
|
||||||
hidden1 = 16
|
hidden1 = 16
|
||||||
dropout = 0.5
|
dropout = 0.5
|
||||||
weight_decay = 5e-4
|
weight_decay = 5e-4
|
||||||
early_stopping = 50
|
early_stopping = 50
|
||||||
|
save_ckpt_steps = 549
|
||||||
|
keep_ckpt_max = 10
|
||||||
|
ckpt_dir = './ckpt'
|
||||||
|
best_ckpt_dir = './best_ckpt'
|
||||||
|
best_ckpt_name = 'best.ckpt'
|
||||||
|
eval_start_epoch = 100
|
||||||
|
save_best_ckpt = True
|
||||||
|
eval_interval = 1
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Evaluation callback when training"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import stat
|
||||||
|
from mindspore import save_checkpoint
|
||||||
|
from mindspore import log as logger
|
||||||
|
from mindspore.train.callback import Callback
|
||||||
|
|
||||||
|
class EvalCallBack(Callback):
|
||||||
|
"""
|
||||||
|
Evaluation callback when training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eval_function (function): evaluation function.
|
||||||
|
eval_param_dict (dict): evaluation parameters' configure dict.
|
||||||
|
interval (int): run evaluation interval, default is 1.
|
||||||
|
eval_start_epoch (int): evaluation start epoch, default is 1.
|
||||||
|
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
|
||||||
|
besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
|
||||||
|
metrics_name (str): evaluation metrics name, default is `acc`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> EvalCallBack(eval_function, eval_param_dict)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
|
||||||
|
ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
|
||||||
|
super(EvalCallBack, self).__init__()
|
||||||
|
self.eval_param_dict = eval_param_dict
|
||||||
|
self.eval_function = eval_function
|
||||||
|
self.eval_start_epoch = eval_start_epoch
|
||||||
|
if interval < 1:
|
||||||
|
raise ValueError("interval should >= 1.")
|
||||||
|
self.interval = interval
|
||||||
|
self.save_best_ckpt = save_best_ckpt
|
||||||
|
self.best_res = 0
|
||||||
|
self.best_epoch = 0
|
||||||
|
if not os.path.isdir(ckpt_directory):
|
||||||
|
os.makedirs(ckpt_directory)
|
||||||
|
self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
|
||||||
|
self.metrics_name = metrics_name
|
||||||
|
self.save_TSNE = save_TSNE
|
||||||
|
|
||||||
|
def remove_ckpoint_file(self, file_name):
|
||||||
|
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
|
||||||
|
try:
|
||||||
|
os.chmod(file_name, stat.S_IWRITE)
|
||||||
|
os.remove(file_name)
|
||||||
|
except OSError:
|
||||||
|
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
|
||||||
|
|
||||||
|
def epoch_end(self, run_context):
|
||||||
|
"""Callback when epoch end."""
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
cur_epoch = cb_params.cur_epoch_num
|
||||||
|
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
|
||||||
|
res = self.eval_function(self.eval_param_dict)
|
||||||
|
print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
|
||||||
|
if res >= self.best_res:
|
||||||
|
self.best_res = res
|
||||||
|
self.best_epoch = cur_epoch
|
||||||
|
print("update best result: {}".format(res), flush=True)
|
||||||
|
if self.save_best_ckpt:
|
||||||
|
if os.path.exists(self.bast_ckpt_path):
|
||||||
|
self.remove_ckpoint_file(self.bast_ckpt_path)
|
||||||
|
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
|
||||||
|
print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
|
||||||
|
|
||||||
|
def end(self, run_context):
|
||||||
|
print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
|
||||||
|
self.best_res,
|
||||||
|
self.best_epoch), flush=True)
|
||||||
|
|
|
@ -92,12 +92,12 @@ class GCN(nn.Cell):
|
||||||
output_dim (int): The number of output channels, equal to classes num.
|
output_dim (int): The number of output channels, equal to classes num.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, input_dim, output_dim):
|
def __init__(self, config, input_dim, output_dim, adj):
|
||||||
super(GCN, self).__init__()
|
super(GCN, self).__init__()
|
||||||
self.layer0 = GraphConvolution(input_dim, config.hidden1, activation="relu", dropout_ratio=config.dropout)
|
self.layer0 = GraphConvolution(input_dim, config.hidden1, activation="relu", dropout_ratio=config.dropout)
|
||||||
self.layer1 = GraphConvolution(config.hidden1, output_dim, dropout_ratio=None)
|
self.layer1 = GraphConvolution(config.hidden1, output_dim, dropout_ratio=None)
|
||||||
|
self.adj = adj
|
||||||
def construct(self, adj, feature):
|
def construct(self, feature):
|
||||||
output0 = self.layer0(adj, feature)
|
output0 = self.layer0(self.adj, feature)
|
||||||
output1 = self.layer1(adj, output0)
|
output1 = self.layer1(self.adj, output0)
|
||||||
return output1
|
return output1
|
||||||
|
|
|
@ -17,16 +17,13 @@ from mindspore import nn
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.common.parameter import ParameterTuple
|
from mindspore.nn.metrics import Metric
|
||||||
from mindspore.ops import composite as C
|
|
||||||
from mindspore.ops import functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class Loss(nn.Cell):
|
class Loss(nn.Cell):
|
||||||
"""Softmax cross-entropy loss with masking."""
|
"""Softmax cross-entropy loss with masking."""
|
||||||
def __init__(self, label, mask, weight_decay, param):
|
def __init__(self, mask, weight_decay, param):
|
||||||
super(Loss, self).__init__(auto_prefix=False)
|
super(Loss, self).__init__(auto_prefix=False)
|
||||||
self.label = Tensor(label)
|
|
||||||
self.mask = Tensor(mask)
|
self.mask = Tensor(mask)
|
||||||
self.loss = P.SoftmaxCrossEntropyWithLogits()
|
self.loss = P.SoftmaxCrossEntropyWithLogits()
|
||||||
self.one = Tensor(1.0, mstype.float32)
|
self.one = Tensor(1.0, mstype.float32)
|
||||||
|
@ -38,12 +35,12 @@ class Loss(nn.Cell):
|
||||||
self.weight_decay = weight_decay
|
self.weight_decay = weight_decay
|
||||||
self.param = param
|
self.param = param
|
||||||
|
|
||||||
def construct(self, preds):
|
def construct(self, preds, label):
|
||||||
"""Calculate loss"""
|
"""Calculate loss"""
|
||||||
param = self.l2_loss(self.param)
|
param = self.l2_loss(self.param)
|
||||||
loss = self.weight_decay * param
|
loss = self.weight_decay * param
|
||||||
preds = self.cast(preds, mstype.float32)
|
preds = self.cast(preds, mstype.float32)
|
||||||
loss = loss + self.loss(preds, self.label)[0]
|
loss = loss + self.loss(preds, label)[0]
|
||||||
mask = self.cast(self.mask, mstype.float32)
|
mask = self.cast(self.mask, mstype.float32)
|
||||||
mask_reduce = self.mean(mask)
|
mask_reduce = self.mean(mask)
|
||||||
mask = mask / mask_reduce
|
mask = mask / mask_reduce
|
||||||
|
@ -51,138 +48,38 @@ class Loss(nn.Cell):
|
||||||
loss = self.mean(loss)
|
loss = self.mean(loss)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
class GCNAccuracy(Metric):
|
||||||
class Accuracy(nn.Cell):
|
"""
|
||||||
"""Accuracy with masking."""
|
Accuracy for GCN
|
||||||
def __init__(self, label, mask):
|
"""
|
||||||
super(Accuracy, self).__init__(auto_prefix=False)
|
def __init__(self, mask):
|
||||||
self.label = Tensor(label)
|
super(GCNAccuracy, self).__init__()
|
||||||
self.mask = Tensor(mask)
|
self.mask = Tensor(mask)
|
||||||
self.equal = P.Equal()
|
self.equal = P.Equal()
|
||||||
self.argmax = P.Argmax()
|
self.argmax = P.Argmax()
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.mean = P.ReduceMean()
|
self.mean = P.ReduceMean()
|
||||||
|
self.accuracy_all = 0
|
||||||
|
|
||||||
def construct(self, preds):
|
def clear(self):
|
||||||
preds = self.cast(preds, mstype.float32)
|
self.accuracy_all = 0
|
||||||
correct_prediction = self.equal(self.argmax(preds), self.argmax(self.label))
|
|
||||||
accuracy_all = self.cast(correct_prediction, mstype.float32)
|
def update(self, *inputs):
|
||||||
|
preds = self.cast(inputs[1], mstype.float32)
|
||||||
|
correct_prediction = self.equal(self.argmax(preds), self.argmax(inputs[0]))
|
||||||
|
self.accuracy_all = self.cast(correct_prediction, mstype.float32)
|
||||||
mask = self.cast(self.mask, mstype.float32)
|
mask = self.cast(self.mask, mstype.float32)
|
||||||
mask_reduce = self.mean(mask)
|
mask_reduce = self.mean(mask)
|
||||||
mask = mask / mask_reduce
|
mask = mask / mask_reduce
|
||||||
accuracy_all *= mask
|
self.accuracy_all *= mask
|
||||||
return self.mean(accuracy_all)
|
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
return float(self.mean(self.accuracy_all).asnumpy())
|
||||||
|
|
||||||
class LossAccuracyWrapper(nn.Cell):
|
def apply_eval(eval_param_dict):
|
||||||
"""
|
"""run Evaluation"""
|
||||||
Wraps the GCN model with loss and accuracy cell.
|
model = eval_param_dict["model"]
|
||||||
|
dataset = eval_param_dict["dataset"]
|
||||||
Args:
|
metrics_name = eval_param_dict["metrics_name"]
|
||||||
network (Cell): GCN network.
|
eval_score = model.eval(dataset, dataset_sink_mode=False)[metrics_name]
|
||||||
label (numpy.ndarray): Dataset labels.
|
return eval_score
|
||||||
mask (numpy.ndarray): Mask for training, evaluation or test.
|
|
||||||
weight_decay (float): Weight decay parameter for weight of the first convolution layer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, network, label, mask, weight_decay):
|
|
||||||
super(LossAccuracyWrapper, self).__init__(auto_prefix=False)
|
|
||||||
self.network = network
|
|
||||||
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
|
||||||
self.accuracy = Accuracy(label, mask)
|
|
||||||
|
|
||||||
def construct(self, adj, feature):
|
|
||||||
preds = self.network(adj, feature)
|
|
||||||
loss = self.loss(preds)
|
|
||||||
accuracy = self.accuracy(preds)
|
|
||||||
return loss, accuracy
|
|
||||||
|
|
||||||
|
|
||||||
class LossWrapper(nn.Cell):
|
|
||||||
"""
|
|
||||||
Wraps the GCN model with loss.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
network (Cell): GCN network.
|
|
||||||
label (numpy.ndarray): Dataset labels.
|
|
||||||
mask (numpy.ndarray): Mask for training.
|
|
||||||
weight_decay (float): Weight decay parameter for weight of the first convolution layer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, network, label, mask, weight_decay):
|
|
||||||
super(LossWrapper, self).__init__(auto_prefix=False)
|
|
||||||
self.network = network
|
|
||||||
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
|
||||||
|
|
||||||
def construct(self, adj, feature):
|
|
||||||
preds = self.network(adj, feature)
|
|
||||||
loss = self.loss(preds)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
class TrainOneStepCell(nn.Cell):
|
|
||||||
r"""
|
|
||||||
Network training package class.
|
|
||||||
|
|
||||||
Wraps the network with an optimizer. The resulting Cell be trained without inputs.
|
|
||||||
Backward graph will be created in the construct function to do parameter updating. Different
|
|
||||||
parallel modes are available to run the training.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
network (Cell): The training network.
|
|
||||||
optimizer (Cell): Optimizer for updating the weights.
|
|
||||||
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
Tensor, a scalar Tensor with shape :math:`()`.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> net = Net()
|
|
||||||
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
|
||||||
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
||||||
>>> loss_net = nn.WithLossCell(net, loss_fn)
|
|
||||||
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, network, optimizer, sens=1.0):
|
|
||||||
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
|
||||||
self.network = network
|
|
||||||
self.network.set_grad()
|
|
||||||
self.network.add_flags(defer_inline=True)
|
|
||||||
self.weights = ParameterTuple(network.trainable_params())
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
|
||||||
self.sens = sens
|
|
||||||
|
|
||||||
def construct(self, adj, feature):
|
|
||||||
weights = self.weights
|
|
||||||
loss = self.network(adj, feature)
|
|
||||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
|
||||||
grads = self.grad(self.network, weights)(adj, feature, sens)
|
|
||||||
return F.depend(loss, self.optimizer(grads))
|
|
||||||
|
|
||||||
|
|
||||||
class TrainNetWrapper(nn.Cell):
|
|
||||||
"""
|
|
||||||
Wraps the GCN model with optimizer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
network (Cell): GCN network.
|
|
||||||
label (numpy.ndarray): Dataset labels.
|
|
||||||
mask (numpy.ndarray): Mask for training, evaluation or test.
|
|
||||||
config (ConfigGCN): Configuration for GCN.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, network, label, mask, config):
|
|
||||||
super(TrainNetWrapper, self).__init__(auto_prefix=False)
|
|
||||||
self.network = network
|
|
||||||
loss_net = LossWrapper(network, label, mask, config.weight_decay)
|
|
||||||
optimizer = nn.Adam(loss_net.trainable_params(),
|
|
||||||
learning_rate=config.learning_rate)
|
|
||||||
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
|
|
||||||
self.accuracy = Accuracy(label, mask)
|
|
||||||
|
|
||||||
def construct(self, adj, feature):
|
|
||||||
loss = self.loss_train_net(adj, feature)
|
|
||||||
accuracy = self.accuracy(self.network(adj, feature))
|
|
||||||
return loss, accuracy
|
|
||||||
|
|
|
@ -18,36 +18,25 @@ GCN training script.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import mindspore.nn as nn
|
||||||
import numpy as np
|
import mindspore.dataset as ds
|
||||||
from matplotlib import pyplot as plt
|
import mindspore.common.dtype as mstype
|
||||||
from matplotlib import animation
|
from mindspore import Model, context
|
||||||
from sklearn import manifold
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
|
||||||
from mindspore import context
|
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
import numpy as np
|
||||||
|
|
||||||
from src.gcn import GCN
|
from src.gcn import GCN
|
||||||
from src.metrics import LossAccuracyWrapper, TrainNetWrapper
|
from src.metrics import Loss, GCNAccuracy, apply_eval
|
||||||
from src.config import ConfigGCN
|
from src.config import ConfigGCN
|
||||||
from src.dataset import get_adj_features_labels, get_mask
|
from src.dataset import get_adj_features_labels, get_mask
|
||||||
|
from src.eval_callback import EvalCallBack
|
||||||
|
|
||||||
from model_utils.config import config as default_args
|
from model_utils.config import config as default_args
|
||||||
from model_utils.moxing_adapter import moxing_wrapper
|
from model_utils.moxing_adapter import moxing_wrapper
|
||||||
from model_utils.device_adapter import get_device_id, get_device_num
|
from model_utils.device_adapter import get_device_id, get_device_num
|
||||||
|
|
||||||
|
|
||||||
def t_SNE(out_feature, dim):
|
|
||||||
t_sne = manifold.TSNE(n_components=dim, init='pca', random_state=0)
|
|
||||||
return t_sne.fit_transform(out_feature)
|
|
||||||
|
|
||||||
|
|
||||||
def update_graph(i, data, scat, plot):
|
|
||||||
scat.set_offsets(data[i])
|
|
||||||
plt.title('t-SNE visualization of Epoch:{0}'.format(i))
|
|
||||||
return scat, plot
|
|
||||||
|
|
||||||
|
|
||||||
def modelarts_pre_process():
|
def modelarts_pre_process():
|
||||||
'''modelarts pre process function.'''
|
'''modelarts pre process function.'''
|
||||||
def unzip(zip_file, save_dir):
|
def unzip(zip_file, save_dir):
|
||||||
|
@ -112,85 +101,41 @@ def modelarts_pre_process():
|
||||||
def run_train():
|
def run_train():
|
||||||
"""Train model."""
|
"""Train model."""
|
||||||
context.set_context(mode=context.GRAPH_MODE,
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
device_target="Ascend", save_graphs=False)
|
device_target=default_args.device_target, save_graphs=False)
|
||||||
config = ConfigGCN()
|
config = ConfigGCN()
|
||||||
adj, feature, label_onehot, label = get_adj_features_labels(default_args.data_dir)
|
if not os.path.exists(config.ckpt_dir):
|
||||||
|
os.mkdir(config.ckpt_dir)
|
||||||
|
adj, feature, label_onehot, _ = get_adj_features_labels(default_args.data_dir)
|
||||||
|
feature_d = np.expand_dims(feature, axis=0)
|
||||||
|
label_onehot_d = np.expand_dims(label_onehot, axis=0)
|
||||||
|
data = {"feature": feature_d, "label": label_onehot_d}
|
||||||
|
dataset = ds.NumpySlicesDataset(data=data)
|
||||||
nodes_num = label_onehot.shape[0]
|
nodes_num = label_onehot.shape[0]
|
||||||
train_mask = get_mask(nodes_num, 0, default_args.train_nodes_num)
|
|
||||||
eval_mask = get_mask(nodes_num, default_args.train_nodes_num,
|
eval_mask = get_mask(nodes_num, default_args.train_nodes_num,
|
||||||
default_args.train_nodes_num + default_args.eval_nodes_num)
|
default_args.train_nodes_num + default_args.eval_nodes_num)
|
||||||
test_mask = get_mask(nodes_num, nodes_num - default_args.test_nodes_num, nodes_num)
|
|
||||||
|
|
||||||
class_num = label_onehot.shape[1]
|
class_num = label_onehot.shape[1]
|
||||||
input_dim = feature.shape[1]
|
input_dim = feature.shape[1]
|
||||||
gcn_net = GCN(config, input_dim, class_num)
|
adj = Tensor(adj, dtype=mstype.float32)
|
||||||
gcn_net.add_flags_recursive(fp16=True)
|
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.save_ckpt_steps,
|
||||||
|
keep_checkpoint_max=config.keep_ckpt_max)
|
||||||
adj = Tensor(adj)
|
ckpoint_cb = ModelCheckpoint(prefix='ckpt_gcn',
|
||||||
feature = Tensor(feature)
|
directory=config.ckpt_dir,
|
||||||
|
config=ckpt_config)
|
||||||
eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
|
gcn_net = GCN(config, input_dim, class_num, adj)
|
||||||
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
|
cb = [TimeMonitor(), LossMonitor(), ckpoint_cb]
|
||||||
|
opt = nn.Adam(gcn_net.trainable_params(), learning_rate=config.learning_rate)
|
||||||
loss_list = []
|
criterion = Loss(eval_mask, config.weight_decay, gcn_net.trainable_params()[0])
|
||||||
|
model = Model(gcn_net, loss_fn=criterion, optimizer=opt, amp_level="O3")
|
||||||
if default_args.save_TSNE:
|
if default_args.train_with_eval:
|
||||||
out_feature = gcn_net()
|
GCN_metric = GCNAccuracy(eval_mask)
|
||||||
tsne_result = t_SNE(out_feature.asnumpy(), 2)
|
eval_model = Model(gcn_net, loss_fn=criterion, metrics={'GCNAccuracy': GCN_metric})
|
||||||
graph_data = []
|
eval_param_dict = {"model": eval_model, "dataset": dataset, "metrics_name": "GCNAccuracy"}
|
||||||
graph_data.append(tsne_result)
|
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=config.eval_interval,
|
||||||
fig = plt.figure()
|
eval_start_epoch=default_args.eval_start_epoch, save_best_ckpt=config.save_best_ckpt,
|
||||||
scat = plt.scatter(tsne_result[:, 0], tsne_result[:, 1], s=2, c=label, cmap='rainbow')
|
ckpt_directory=config.best_ckpt_dir, besk_ckpt_name=config.best_ckpt_name,
|
||||||
plt.title('t-SNE visualization of Epoch:0', fontsize='large', fontweight='bold', verticalalignment='center')
|
metrics_name="GCNAccuracy")
|
||||||
|
cb.append(eval_cb)
|
||||||
for epoch in range(config.epochs):
|
model.train(config.epochs, dataset, callbacks=cb, dataset_sink_mode=True)
|
||||||
t = time.time()
|
|
||||||
|
|
||||||
train_net.set_train()
|
|
||||||
train_result = train_net(adj, feature)
|
|
||||||
train_loss = train_result[0].asnumpy()
|
|
||||||
train_accuracy = train_result[1].asnumpy()
|
|
||||||
|
|
||||||
eval_net.set_train(False)
|
|
||||||
eval_result = eval_net(adj, feature)
|
|
||||||
eval_loss = eval_result[0].asnumpy()
|
|
||||||
eval_accuracy = eval_result[1].asnumpy()
|
|
||||||
|
|
||||||
loss_list.append(eval_loss)
|
|
||||||
print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(train_loss),
|
|
||||||
"train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss),
|
|
||||||
"val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t))
|
|
||||||
|
|
||||||
if default_args.save_TSNE:
|
|
||||||
out_feature = gcn_net()
|
|
||||||
tsne_result = t_SNE(out_feature.asnumpy(), 2)
|
|
||||||
graph_data.append(tsne_result)
|
|
||||||
|
|
||||||
if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]):
|
|
||||||
print("Early stopping...")
|
|
||||||
break
|
|
||||||
if not os.path.isdir(default_args.save_ckptpath):
|
|
||||||
os.makedirs(default_args.save_ckptpath)
|
|
||||||
ckpt_path = os.path.join(default_args.save_ckptpath, "gcn.ckpt")
|
|
||||||
save_checkpoint(gcn_net, ckpt_path)
|
|
||||||
gcn_net_test = GCN(config, input_dim, class_num)
|
|
||||||
load_checkpoint(ckpt_path, net=gcn_net_test)
|
|
||||||
gcn_net_test.add_flags_recursive(fp16=True)
|
|
||||||
|
|
||||||
test_net = LossAccuracyWrapper(gcn_net_test, label_onehot, test_mask, config.weight_decay)
|
|
||||||
t_test = time.time()
|
|
||||||
test_net.set_train(False)
|
|
||||||
test_result = test_net(adj, feature)
|
|
||||||
test_loss = test_result[0].asnumpy()
|
|
||||||
test_accuracy = test_result[1].asnumpy()
|
|
||||||
print("Test set results:", "loss=", "{:.5f}".format(test_loss),
|
|
||||||
"accuracy=", "{:.5f}".format(test_accuracy), "time=", "{:.5f}".format(time.time() - t_test))
|
|
||||||
|
|
||||||
if default_args.save_TSNE:
|
|
||||||
ani = animation.FuncAnimation(fig, update_graph, frames=range(config.epochs + 1), fargs=(graph_data, scat, plt))
|
|
||||||
ani.save('t-SNE_visualization.gif', writer='imagemagick')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_train()
|
run_train()
|
||||||
|
|
|
@ -1,93 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
import time
|
|
||||||
import pytest
|
|
||||||
import numpy as np
|
|
||||||
from mindspore import context
|
|
||||||
from mindspore import Tensor
|
|
||||||
from model_zoo.official.gnn.gcn.src.gcn import GCN
|
|
||||||
from model_zoo.official.gnn.gcn.src.metrics import LossAccuracyWrapper, TrainNetWrapper
|
|
||||||
from model_zoo.official.gnn.gcn.src.config import ConfigGCN
|
|
||||||
from model_zoo.official.gnn.gcn.src.dataset import get_adj_features_labels, get_mask
|
|
||||||
|
|
||||||
|
|
||||||
DATA_DIR = '/home/workspace/mindspore_dataset/cora/cora_mr/cora_mr'
|
|
||||||
TRAIN_NODE_NUM = 140
|
|
||||||
EVAL_NODE_NUM = 500
|
|
||||||
TEST_NODE_NUM = 1000
|
|
||||||
SEED = 20
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
|
||||||
@pytest.mark.platform_arm_ascend_training
|
|
||||||
@pytest.mark.platform_x86_ascend_training
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
def test_gcn():
|
|
||||||
print("test_gcn begin")
|
|
||||||
np.random.seed(SEED)
|
|
||||||
context.set_context(mode=context.GRAPH_MODE,
|
|
||||||
device_target="Ascend", save_graphs=False)
|
|
||||||
config = ConfigGCN()
|
|
||||||
config.dropout = 0.0
|
|
||||||
adj, feature, label_onehot, _ = get_adj_features_labels(DATA_DIR)
|
|
||||||
|
|
||||||
nodes_num = label_onehot.shape[0]
|
|
||||||
train_mask = get_mask(nodes_num, 0, TRAIN_NODE_NUM)
|
|
||||||
eval_mask = get_mask(nodes_num, TRAIN_NODE_NUM, TRAIN_NODE_NUM + EVAL_NODE_NUM)
|
|
||||||
test_mask = get_mask(nodes_num, nodes_num - TEST_NODE_NUM, nodes_num)
|
|
||||||
|
|
||||||
class_num = label_onehot.shape[1]
|
|
||||||
input_dim = feature.shape[1]
|
|
||||||
gcn_net = GCN(config, input_dim, class_num)
|
|
||||||
gcn_net.add_flags_recursive(fp16=True)
|
|
||||||
|
|
||||||
adj = Tensor(adj)
|
|
||||||
feature = Tensor(feature)
|
|
||||||
|
|
||||||
eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
|
|
||||||
test_net = LossAccuracyWrapper(gcn_net, label_onehot, test_mask, config.weight_decay)
|
|
||||||
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
|
|
||||||
|
|
||||||
loss_list = []
|
|
||||||
for epoch in range(config.epochs):
|
|
||||||
t = time.time()
|
|
||||||
|
|
||||||
train_net.set_train()
|
|
||||||
train_result = train_net(adj, feature)
|
|
||||||
train_loss = train_result[0].asnumpy()
|
|
||||||
train_accuracy = train_result[1].asnumpy()
|
|
||||||
|
|
||||||
eval_net.set_train(False)
|
|
||||||
eval_result = eval_net(adj, feature)
|
|
||||||
eval_loss = eval_result[0].asnumpy()
|
|
||||||
eval_accuracy = eval_result[1].asnumpy()
|
|
||||||
|
|
||||||
loss_list.append(eval_loss)
|
|
||||||
print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(train_loss),
|
|
||||||
"train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss),
|
|
||||||
"val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t))
|
|
||||||
|
|
||||||
if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]):
|
|
||||||
print("Early stopping...")
|
|
||||||
break
|
|
||||||
|
|
||||||
test_net.set_train(False)
|
|
||||||
test_result = test_net(adj, feature)
|
|
||||||
test_loss = test_result[0].asnumpy()
|
|
||||||
test_accuracy = test_result[1].asnumpy()
|
|
||||||
print("Test set results:", "loss=", "{:.5f}".format(test_loss),
|
|
||||||
"accuracy=", "{:.5f}".format(test_accuracy))
|
|
||||||
assert test_accuracy > 0.812
|
|
Loading…
Reference in New Issue