!16635 Support GCN on GPU

Merge pull request !16635 from yuruilee/master
This commit is contained in:
i-robot 2021-06-15 22:05:33 +08:00 committed by Gitee
commit 3a9d3f2580
14 changed files with 428 additions and 360 deletions

View File

@ -75,16 +75,16 @@ Note that you can run the scripts based on the dataset mentioned in original pap
```buildoutcfg
cd ./scripts
# SRC_PATH is the dataset file path you downloaded, DATASET_NAME is cora or citeseer
sh run_process_data.sh [SRC_PATH] [DATASET_NAME]
bash run_process_data.sh [SRC_PATH] [DATASET_NAME]
```
### Launch
```bash
#Generate dataset in mindrecord format for cora
sh run_process_data.sh ./data cora
bash run_process_data.sh ./data cora
#Generate dataset in mindrecord format for citeseer
sh run_process_data.sh ./data citeseer
bash run_process_data.sh ./data citeseer
```
- Running on local with Ascend
@ -149,12 +149,15 @@ sh run_train.sh [DATASET_NAME]
├─scripts
| ├─run_infer_310.sh # shell script for infer on Ascend 310
| ├─run_process_data.sh # Generate dataset in mindrecord format
| └─run_train.sh # Launch training, now only Ascend backend is supported.
| ├─run_train_gpu.sh # Launch GPU training.
| ├─run_eval_gpu.sh # Launch GPU inference.
| └─run_train.sh # Launch Ascend training.
|
├─src
| ├─config.py # Parameter configuration
| ├─dataset.py # Data preprocessin
| ├─gcn.py # GCN backbone
| ├─eval_callback.py # Callback function
| └─metrics.py # Loss and accuracy
|
├─default_config.py # Configurations
@ -162,7 +165,8 @@ sh run_train.sh [DATASET_NAME]
├─mindspore_hub_conf.py # mindspore_hub_conf scripts
├─postprocess.py # postprocess script
├─preprocess.py # preprocess scripts
└─train.py # Train net, evaluation is performed after every training epoch. After the verification result converges, the training stops, then testing is performed.
|─eval.py # Evaluation net, testing is performed.
└─train.py # Train net, evaluation is performed after every training epoch. After the verification result converges, the training stops.
```
## [Script Parameters](#contents)
@ -176,6 +180,14 @@ Parameters for training can be set in config.py.
"dropout": 0.5, # Dropout ratio for the first graph convolution layer
"weight_decay": 5e-4, # Weight decay for the parameter of the first graph convolution layer
"early_stopping": 10, # Tolerance for early stopping
"save_ckpt_steps": 549 # Step to save checkpoint
"keep_ckpt_max": 10 # Maximum step to save checkpoint
"ckpt_dir": './ckpt' # The folder to save checkpoint
"best_ckpt_dir": './best_ckpt' # The folder to save the best checkpoint
"best_ckpt_name": 'best.ckpt' # The file name of the best checkpoint
"eval_start_epoch": 100 # Start step for eval
"save_best_ckpt": True # Save the best checkpoint or not
"eval_interval": 1 # The interval of eval
```
### [Training, Evaluation, Test Process](#contents)
@ -183,14 +195,22 @@ Parameters for training can be set in config.py.
#### Usage
```bash
# run train with cora or citeseer dataset, DATASET_NAME is cora or citeseer
sh run_train.sh [DATASET_NAME]
# run train with cora or citeseer dataset on Ascend, DATASET_NAME is cora or citeseer
bash run_train.sh [DATASET_NAME]
# run train with cora or citeseer dataset on GPU, DATASET_NAME is cora or citeseer
bash run_train_gpu.sh [DATASET_NAME]
# run inference with cora or citeseer dataset on GPU, DATASET_NAME is cora or citeseer
bash run_eval_gpu.sh [DATASET_NAME] [CKPT]
```
#### Launch
```bash
sh run_train.sh cora
bash run_train.sh cora
bash run_train_gpu.sh cora
bash run_eval_gpu.sh cora ckpt/gcn.ckpt
```
#### Result
@ -250,18 +270,18 @@ Test set results: accuracy= 0.81300
### [Performance](#contents)
| Parameters | GCN |
| -------------------------- | -------------------------------------------------------------- |
| Resource | Ascend 910; OS Euler2.8 |
| uploaded Date | 06/09/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | Cora/Citeseer |
| Training Parameters | epoch=200 |
| Optimizer | Adam |
| Loss Function | Softmax Cross Entropy |
| Accuracy | 81.5/70.3 |
| Parameters (B) | 92160/59344 |
| Scripts | [GCN Script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn) |
| Parameters | GCN | GCN |
| ------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ |
| Resource | Ascend 910; OS Euler2.8 | NV SMX3 V100-32G |
| uploaded Date | 06/09/2020 (month/day/year) | 05/06/2021 (month/day/year) |
| MindSpore Version | 1.0.0 | 1.1.0 |
| Dataset | Cora/Citeseer | Cora/Citeseer |
| Training Parameters | epoch=200 | epoch=200 |
| Optimizer | Adam | Adam |
| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy |
| Accuracy | 81.5/70.3 | 87.5/76.9 |
| Parameters (B) | 92160/59344 | 92160/59344 |
| Scripts | [GCN Script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn) | [GCN Script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn) |
## [Description of Random Situation](#contents)

View File

@ -157,12 +157,15 @@ sh run_train.sh [DATASET_NAME]
├─scripts
| ├─run_infer_310.sh # Ascend310 推理shell脚本
| ├─run_process_data.sh # 生成MindRecord格式的数据集
| ├─run_train_gpu.sh # 启动GPU后端的训练
| ├─run_eval_gpu.sh # 启动GPU后端的推理
| └─run_train.sh # 启动训练目前只支持Ascend后端
|
├─src
| ├─config.py # 参数配置
| ├─dataset.py # 数据预处理
| ├─gcn.py # GCN骨干
| ├─eval_callback.py # 回调函数
| └─metrics.py # 损失和准确率
|
├─default_config.py # 配置文件
@ -170,6 +173,7 @@ sh run_train.sh [DATASET_NAME]
├─mindspore_hub_conf.py # mindspore hub 脚本
├─postprocess.py # 后处理脚本
├─preprocess.py # 预处理脚本
|─eval.py # 推理网络,进行测试。
└─train.py # 训练网络,每个训练轮次后评估验证结果收敛后,训练停止,然后进行测试。
```
@ -184,6 +188,14 @@ sh run_train.sh [DATASET_NAME]
"dropout": 0.5, # 第一图卷积层dropout率
"weight_decay": 5e-4, # 第一图卷积层参数的权重衰减
"early_stopping": 10, # 早停容限
"save_ckpt_steps": 549 # 保存ckpt的步数
"keep_ckpt_max": 10 # 保存ckpt的最大步数
"ckpt_dir": './ckpt' # 保存ckpt的文件夹
"best_ckpt_dir": './best_ckpt # 最好ckpt的文件夹
"best_ckpt_name": 'best.ckpt' # 最好ckpt的文件名
"eval_start_epoch": 100 # 从哪一步开始eval
"save_best_ckpt": True # 是否存储最好的ckpt
"eval_interval": 1 # eval间隔
```
### 培训、评估、测试过程
@ -191,14 +203,22 @@ sh run_train.sh [DATASET_NAME]
#### 用法
```text
# 使用Cora或Citeseer数据集进行训练DATASET_NAME为Cora或Citeseer
sh run_train.sh [DATASET_NAME]
# 在Ascend上使用Cora或Citeseer数据集进行训练DATASET_NAME为Cora或Citeseer
bash run_train.sh [DATASET_NAME]
# 在GPU上使用Cora或Citeseer数据集进行训练DATASET_NAME为Cora或Citeseer
bash run_train_gpu.sh [DATASET_NAME]
# 在GPU上对Cora或Citeseer数据集进行测试
bash run_eval_gpu.sh [DATASET_NAME] [CKPT]
```
#### 启动
```bash
sh run_train.sh cora
bash run_train.sh cora
bash run_train_gpu.sh cora
bash run_eval_gpu.sh cora ckpt/gcn.ckpt
```
#### 结果
@ -258,18 +278,18 @@ Test set results: accuracy= 0.81300
### 性能
| 参数 | GCN |
| -------------------------- | -------------------------------------------------------------- |
| 资源 | Ascend 910系统 Euler2.8 |
| 上传日期 | 2020-06-09 |
| MindSpore版本 | 0.5.0-beta |
| 数据集 | Cora/Citeseer |
| 训练参数 | epoch=200 |
| 优化器 | Adam |
| 损失函数 | Softmax交叉熵 |
| 准确率 | 81.5/70.3 |
| 参数(B) | 92160/59344 |
| 脚本 | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn> |
| 参数 | GCN | GCN |
| -------------------------- | -------------------------------------------------------------- | -------------------------- |
| 资源 | Ascend 910系统 Euler2.8 | NV SMX3 V100-32G |
| 上传日期 | 2020-06-09 | 2021-05-06 |
| MindSpore版本 | 0.5.0-beta | 1.1.0 |
| 数据集 | Cora/Citeseer | Cora/Citeseer |
| 训练参数 | epoch=200 | epoch=200 |
| 优化器 | Adam | Adam |
| 损失函数 | Softmax交叉熵 | Softmax交叉熵 |
| 准确率 | 81.5/70.3 | 87.5/76.9 |
| 参数(B) | 92160/59344 | 92160/59344 |
| 脚本 | [GCN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn) | [GCN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn) |
## 随机情况说明

View File

@ -20,6 +20,7 @@ eval_nodes_num: 500
test_nodes_num: 1000
save_TSNE: False
save_ckptpath: "ckpts/"
train_with_eval: False
---
@ -29,4 +30,5 @@ data_dir: "Dataset directory"
train_nodes_num: "Nodes numbers for training"
eval_nodes_num: "Nodes numbers for evaluation"
test_nodes_num: "Nodes numbers for test"
save_TSNE: "Whether to save t-SNE graph"
save_TSNE: "Whether to save t-SNE graph"
train_with_eval: "Whether to train with evaluation"

View File

@ -0,0 +1,70 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
GCN eval script.
"""
import argparse
import numpy as np
import mindspore.nn as nn
import mindspore.dataset as ds
import mindspore.common.dtype as mstype
from mindspore.train.serialization import load_checkpoint
from mindspore import Tensor
from mindspore import Model, context
from src.config import ConfigGCN
from src.dataset import get_adj_features_labels, get_mask
from src.metrics import Loss
from src.gcn import GCN
def run_gcn_infer():
"""
Run gcn infer
"""
parser = argparse.ArgumentParser(description='GCN')
parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Dataset directory')
parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test')
parser.add_argument("--model_ckpt", type=str, required=True,
help="existed checkpoint address.")
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE,
device_target=args_opt.device_target, save_graphs=False)
config = ConfigGCN()
adj, feature, label_onehot, _ = get_adj_features_labels(args_opt.data_dir)
feature_d = np.expand_dims(feature, axis=0)
label_onehot_d = np.expand_dims(label_onehot, axis=0)
data = {"feature": feature_d, "label": label_onehot_d}
dataset = ds.NumpySlicesDataset(data=data)
adj = Tensor(adj, dtype=mstype.float32)
feature = Tensor(feature)
nodes_num = label_onehot.shape[0]
test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num)
class_num = label_onehot.shape[1]
input_dim = feature.shape[1]
gcn_net_test = GCN(config, input_dim, class_num, adj)
load_checkpoint(args_opt.model_ckpt, net=gcn_net_test)
eval_metrics = {'Acc': nn.Accuracy()}
criterion = Loss(test_mask, config.weight_decay, gcn_net_test.trainable_params()[0])
model = Model(gcn_net_test, loss_fn=criterion, metrics=eval_metrics)
res = model.eval(dataset, dataset_sink_mode=True)
print(res)
if __name__ == '__main__':
run_gcn_infer()

View File

@ -32,7 +32,7 @@ parser.add_argument("--device_target", type=str, default="Ascend",
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
if args.device_target == "Ascend" or args.device_target == "GPU":
context.set_context(device_id=args.device_id)
if __name__ == "__main__":

View File

@ -0,0 +1,53 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_NAME=$1
echo $DATASET_NAME
MODEL_CKPT=$(get_real_path $2)
echo $MODEL_CKPT
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp ../*.yaml ./eval
cp *.sh ./eval
cp -r ../src ./eval
cp -r ../model_utils ./eval
cd ./eval || exit
echo "start eval on standalone GPU"
if [ $DATASET_NAME == cora ]
then
python eval.py --data_dir=../data_mr/$DATASET_NAME --device_target="GPU" --model_ckpt $MODEL_CKPT &> log &
fi
if [ $DATASET_NAME == citeseer ]
then
python eval.py --data_dir=../data_mr/$DATASET_NAME --device_target="GPU" --model_ckpt $MODEL_CKPT &> log &
fi
cd ..

View File

@ -16,7 +16,7 @@
if [ $# != 2 ]
then
echo "Usage: sh run_train.sh [SRC_PATH] [DATASET_NAME]"
echo "Usage: sh run_process_data.sh [SRC_PATH] [DATASET_NAME]"
exit 1
fi
@ -43,7 +43,7 @@ MINDRECORD_PATH=`pwd`/data_mr
rm -f $MINDRECORD_PATH/$DATASET_NAME
rm -f $MINDRECORD_PATH/$DATASET_NAME.db
cd ../../utils/graph_to_mindrecord || exit
cd ../../../../utils/graph_to_mindrecord || exit
python writer.py --mindrecord_script $DATASET_NAME \
--mindrecord_file "$MINDRECORD_PATH/$DATASET_NAME" \

View File

@ -0,0 +1,51 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
then
echo "Usage: sh run_train.sh [DATASET_NAME]"
exit 1
fi
DATASET_NAME=$1
echo $DATASET_NAME
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp ../*.yaml ./train
cp *.sh ./train
cp -r ../src ./train
cp -r ../model_utils ./train
cd ./train || exit
env > env.log
echo "start training for standalone GPU"
if [ $DATASET_NAME == cora ]
then
python train.py --data_dir=../data_mr/$DATASET_NAME --train_nodes_num=140 --device_target="GPU" &> log &
fi
if [ $DATASET_NAME == citeseer ]
then
python train.py --data_dir=../data_mr/$DATASET_NAME --train_nodes_num=120 --device_target="GPU" &> log &
fi
cd ..

View File

@ -18,9 +18,20 @@ network config setting, will be used in train.py
class ConfigGCN():
"""
Configuration of GCN
"""
learning_rate = 0.01
epochs = 200
hidden1 = 16
dropout = 0.5
weight_decay = 5e-4
early_stopping = 50
save_ckpt_steps = 549
keep_ckpt_max = 10
ckpt_dir = './ckpt'
best_ckpt_dir = './best_ckpt'
best_ckpt_name = 'best.ckpt'
eval_start_epoch = 100
save_best_ckpt = True
eval_interval = 1

View File

@ -0,0 +1,92 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation callback when training"""
import os
import stat
from mindspore import save_checkpoint
from mindspore import log as logger
from mindspore.train.callback import Callback
class EvalCallBack(Callback):
"""
Evaluation callback when training.
Args:
eval_function (function): evaluation function.
eval_param_dict (dict): evaluation parameters' configure dict.
interval (int): run evaluation interval, default is 1.
eval_start_epoch (int): evaluation start epoch, default is 1.
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
metrics_name (str): evaluation metrics name, default is `acc`.
Returns:
None
Examples:
>>> EvalCallBack(eval_function, eval_param_dict)
"""
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
super(EvalCallBack, self).__init__()
self.eval_param_dict = eval_param_dict
self.eval_function = eval_function
self.eval_start_epoch = eval_start_epoch
if interval < 1:
raise ValueError("interval should >= 1.")
self.interval = interval
self.save_best_ckpt = save_best_ckpt
self.best_res = 0
self.best_epoch = 0
if not os.path.isdir(ckpt_directory):
os.makedirs(ckpt_directory)
self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
self.metrics_name = metrics_name
self.save_TSNE = save_TSNE
def remove_ckpoint_file(self, file_name):
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
try:
os.chmod(file_name, stat.S_IWRITE)
os.remove(file_name)
except OSError:
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
except ValueError:
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
def epoch_end(self, run_context):
"""Callback when epoch end."""
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
res = self.eval_function(self.eval_param_dict)
print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
if res >= self.best_res:
self.best_res = res
self.best_epoch = cur_epoch
print("update best result: {}".format(res), flush=True)
if self.save_best_ckpt:
if os.path.exists(self.bast_ckpt_path):
self.remove_ckpoint_file(self.bast_ckpt_path)
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
def end(self, run_context):
print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
self.best_res,
self.best_epoch), flush=True)

View File

@ -92,12 +92,12 @@ class GCN(nn.Cell):
output_dim (int): The number of output channels, equal to classes num.
"""
def __init__(self, config, input_dim, output_dim):
def __init__(self, config, input_dim, output_dim, adj):
super(GCN, self).__init__()
self.layer0 = GraphConvolution(input_dim, config.hidden1, activation="relu", dropout_ratio=config.dropout)
self.layer1 = GraphConvolution(config.hidden1, output_dim, dropout_ratio=None)
def construct(self, adj, feature):
output0 = self.layer0(adj, feature)
output1 = self.layer1(adj, output0)
self.adj = adj
def construct(self, feature):
output0 = self.layer0(self.adj, feature)
output1 = self.layer1(self.adj, output0)
return output1

View File

@ -17,16 +17,13 @@ from mindspore import nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.common.parameter import ParameterTuple
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.nn.metrics import Metric
class Loss(nn.Cell):
"""Softmax cross-entropy loss with masking."""
def __init__(self, label, mask, weight_decay, param):
def __init__(self, mask, weight_decay, param):
super(Loss, self).__init__(auto_prefix=False)
self.label = Tensor(label)
self.mask = Tensor(mask)
self.loss = P.SoftmaxCrossEntropyWithLogits()
self.one = Tensor(1.0, mstype.float32)
@ -38,12 +35,12 @@ class Loss(nn.Cell):
self.weight_decay = weight_decay
self.param = param
def construct(self, preds):
def construct(self, preds, label):
"""Calculate loss"""
param = self.l2_loss(self.param)
loss = self.weight_decay * param
preds = self.cast(preds, mstype.float32)
loss = loss + self.loss(preds, self.label)[0]
loss = loss + self.loss(preds, label)[0]
mask = self.cast(self.mask, mstype.float32)
mask_reduce = self.mean(mask)
mask = mask / mask_reduce
@ -51,138 +48,38 @@ class Loss(nn.Cell):
loss = self.mean(loss)
return loss
class Accuracy(nn.Cell):
"""Accuracy with masking."""
def __init__(self, label, mask):
super(Accuracy, self).__init__(auto_prefix=False)
self.label = Tensor(label)
class GCNAccuracy(Metric):
"""
Accuracy for GCN
"""
def __init__(self, mask):
super(GCNAccuracy, self).__init__()
self.mask = Tensor(mask)
self.equal = P.Equal()
self.argmax = P.Argmax()
self.cast = P.Cast()
self.mean = P.ReduceMean()
self.accuracy_all = 0
def construct(self, preds):
preds = self.cast(preds, mstype.float32)
correct_prediction = self.equal(self.argmax(preds), self.argmax(self.label))
accuracy_all = self.cast(correct_prediction, mstype.float32)
def clear(self):
self.accuracy_all = 0
def update(self, *inputs):
preds = self.cast(inputs[1], mstype.float32)
correct_prediction = self.equal(self.argmax(preds), self.argmax(inputs[0]))
self.accuracy_all = self.cast(correct_prediction, mstype.float32)
mask = self.cast(self.mask, mstype.float32)
mask_reduce = self.mean(mask)
mask = mask / mask_reduce
accuracy_all *= mask
return self.mean(accuracy_all)
self.accuracy_all *= mask
def eval(self):
return float(self.mean(self.accuracy_all).asnumpy())
class LossAccuracyWrapper(nn.Cell):
"""
Wraps the GCN model with loss and accuracy cell.
Args:
network (Cell): GCN network.
label (numpy.ndarray): Dataset labels.
mask (numpy.ndarray): Mask for training, evaluation or test.
weight_decay (float): Weight decay parameter for weight of the first convolution layer.
"""
def __init__(self, network, label, mask, weight_decay):
super(LossAccuracyWrapper, self).__init__(auto_prefix=False)
self.network = network
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
self.accuracy = Accuracy(label, mask)
def construct(self, adj, feature):
preds = self.network(adj, feature)
loss = self.loss(preds)
accuracy = self.accuracy(preds)
return loss, accuracy
class LossWrapper(nn.Cell):
"""
Wraps the GCN model with loss.
Args:
network (Cell): GCN network.
label (numpy.ndarray): Dataset labels.
mask (numpy.ndarray): Mask for training.
weight_decay (float): Weight decay parameter for weight of the first convolution layer.
"""
def __init__(self, network, label, mask, weight_decay):
super(LossWrapper, self).__init__(auto_prefix=False)
self.network = network
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
def construct(self, adj, feature):
preds = self.network(adj, feature)
loss = self.loss(preds)
return loss
class TrainOneStepCell(nn.Cell):
r"""
Network training package class.
Wraps the network with an optimizer. The resulting Cell be trained without inputs.
Backward graph will be created in the construct function to do parameter updating. Different
parallel modes are available to run the training.
Args:
network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
Outputs:
Tensor, a scalar Tensor with shape :math:`()`.
Examples:
>>> net = Net()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> loss_net = nn.WithLossCell(net, loss_fn)
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
"""
def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
def construct(self, adj, feature):
weights = self.weights
loss = self.network(adj, feature)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(adj, feature, sens)
return F.depend(loss, self.optimizer(grads))
class TrainNetWrapper(nn.Cell):
"""
Wraps the GCN model with optimizer.
Args:
network (Cell): GCN network.
label (numpy.ndarray): Dataset labels.
mask (numpy.ndarray): Mask for training, evaluation or test.
config (ConfigGCN): Configuration for GCN.
"""
def __init__(self, network, label, mask, config):
super(TrainNetWrapper, self).__init__(auto_prefix=False)
self.network = network
loss_net = LossWrapper(network, label, mask, config.weight_decay)
optimizer = nn.Adam(loss_net.trainable_params(),
learning_rate=config.learning_rate)
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
self.accuracy = Accuracy(label, mask)
def construct(self, adj, feature):
loss = self.loss_train_net(adj, feature)
accuracy = self.accuracy(self.network(adj, feature))
return loss, accuracy
def apply_eval(eval_param_dict):
"""run Evaluation"""
model = eval_param_dict["model"]
dataset = eval_param_dict["dataset"]
metrics_name = eval_param_dict["metrics_name"]
eval_score = model.eval(dataset, dataset_sink_mode=False)[metrics_name]
return eval_score

View File

@ -18,36 +18,25 @@ GCN training script.
"""
import os
import time
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import animation
from sklearn import manifold
from mindspore import context
import mindspore.nn as nn
import mindspore.dataset as ds
import mindspore.common.dtype as mstype
from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
from mindspore import Tensor
from mindspore.train.serialization import save_checkpoint, load_checkpoint
import numpy as np
from src.gcn import GCN
from src.metrics import LossAccuracyWrapper, TrainNetWrapper
from src.metrics import Loss, GCNAccuracy, apply_eval
from src.config import ConfigGCN
from src.dataset import get_adj_features_labels, get_mask
from src.eval_callback import EvalCallBack
from model_utils.config import config as default_args
from model_utils.moxing_adapter import moxing_wrapper
from model_utils.device_adapter import get_device_id, get_device_num
def t_SNE(out_feature, dim):
t_sne = manifold.TSNE(n_components=dim, init='pca', random_state=0)
return t_sne.fit_transform(out_feature)
def update_graph(i, data, scat, plot):
scat.set_offsets(data[i])
plt.title('t-SNE visualization of Epoch:{0}'.format(i))
return scat, plot
def modelarts_pre_process():
'''modelarts pre process function.'''
def unzip(zip_file, save_dir):
@ -112,85 +101,41 @@ def modelarts_pre_process():
def run_train():
"""Train model."""
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend", save_graphs=False)
device_target=default_args.device_target, save_graphs=False)
config = ConfigGCN()
adj, feature, label_onehot, label = get_adj_features_labels(default_args.data_dir)
if not os.path.exists(config.ckpt_dir):
os.mkdir(config.ckpt_dir)
adj, feature, label_onehot, _ = get_adj_features_labels(default_args.data_dir)
feature_d = np.expand_dims(feature, axis=0)
label_onehot_d = np.expand_dims(label_onehot, axis=0)
data = {"feature": feature_d, "label": label_onehot_d}
dataset = ds.NumpySlicesDataset(data=data)
nodes_num = label_onehot.shape[0]
train_mask = get_mask(nodes_num, 0, default_args.train_nodes_num)
eval_mask = get_mask(nodes_num, default_args.train_nodes_num,
default_args.train_nodes_num + default_args.eval_nodes_num)
test_mask = get_mask(nodes_num, nodes_num - default_args.test_nodes_num, nodes_num)
class_num = label_onehot.shape[1]
input_dim = feature.shape[1]
gcn_net = GCN(config, input_dim, class_num)
gcn_net.add_flags_recursive(fp16=True)
adj = Tensor(adj)
feature = Tensor(feature)
eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
loss_list = []
if default_args.save_TSNE:
out_feature = gcn_net()
tsne_result = t_SNE(out_feature.asnumpy(), 2)
graph_data = []
graph_data.append(tsne_result)
fig = plt.figure()
scat = plt.scatter(tsne_result[:, 0], tsne_result[:, 1], s=2, c=label, cmap='rainbow')
plt.title('t-SNE visualization of Epoch:0', fontsize='large', fontweight='bold', verticalalignment='center')
for epoch in range(config.epochs):
t = time.time()
train_net.set_train()
train_result = train_net(adj, feature)
train_loss = train_result[0].asnumpy()
train_accuracy = train_result[1].asnumpy()
eval_net.set_train(False)
eval_result = eval_net(adj, feature)
eval_loss = eval_result[0].asnumpy()
eval_accuracy = eval_result[1].asnumpy()
loss_list.append(eval_loss)
print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(train_loss),
"train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss),
"val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t))
if default_args.save_TSNE:
out_feature = gcn_net()
tsne_result = t_SNE(out_feature.asnumpy(), 2)
graph_data.append(tsne_result)
if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]):
print("Early stopping...")
break
if not os.path.isdir(default_args.save_ckptpath):
os.makedirs(default_args.save_ckptpath)
ckpt_path = os.path.join(default_args.save_ckptpath, "gcn.ckpt")
save_checkpoint(gcn_net, ckpt_path)
gcn_net_test = GCN(config, input_dim, class_num)
load_checkpoint(ckpt_path, net=gcn_net_test)
gcn_net_test.add_flags_recursive(fp16=True)
test_net = LossAccuracyWrapper(gcn_net_test, label_onehot, test_mask, config.weight_decay)
t_test = time.time()
test_net.set_train(False)
test_result = test_net(adj, feature)
test_loss = test_result[0].asnumpy()
test_accuracy = test_result[1].asnumpy()
print("Test set results:", "loss=", "{:.5f}".format(test_loss),
"accuracy=", "{:.5f}".format(test_accuracy), "time=", "{:.5f}".format(time.time() - t_test))
if default_args.save_TSNE:
ani = animation.FuncAnimation(fig, update_graph, frames=range(config.epochs + 1), fargs=(graph_data, scat, plt))
ani.save('t-SNE_visualization.gif', writer='imagemagick')
adj = Tensor(adj, dtype=mstype.float32)
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.save_ckpt_steps,
keep_checkpoint_max=config.keep_ckpt_max)
ckpoint_cb = ModelCheckpoint(prefix='ckpt_gcn',
directory=config.ckpt_dir,
config=ckpt_config)
gcn_net = GCN(config, input_dim, class_num, adj)
cb = [TimeMonitor(), LossMonitor(), ckpoint_cb]
opt = nn.Adam(gcn_net.trainable_params(), learning_rate=config.learning_rate)
criterion = Loss(eval_mask, config.weight_decay, gcn_net.trainable_params()[0])
model = Model(gcn_net, loss_fn=criterion, optimizer=opt, amp_level="O3")
if default_args.train_with_eval:
GCN_metric = GCNAccuracy(eval_mask)
eval_model = Model(gcn_net, loss_fn=criterion, metrics={'GCNAccuracy': GCN_metric})
eval_param_dict = {"model": eval_model, "dataset": dataset, "metrics_name": "GCNAccuracy"}
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=config.eval_interval,
eval_start_epoch=default_args.eval_start_epoch, save_best_ckpt=config.save_best_ckpt,
ckpt_directory=config.best_ckpt_dir, besk_ckpt_name=config.best_ckpt_name,
metrics_name="GCNAccuracy")
cb.append(eval_cb)
model.train(config.epochs, dataset, callbacks=cb, dataset_sink_mode=True)
if __name__ == '__main__':
run_train()

View File

@ -1,93 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import time
import pytest
import numpy as np
from mindspore import context
from mindspore import Tensor
from model_zoo.official.gnn.gcn.src.gcn import GCN
from model_zoo.official.gnn.gcn.src.metrics import LossAccuracyWrapper, TrainNetWrapper
from model_zoo.official.gnn.gcn.src.config import ConfigGCN
from model_zoo.official.gnn.gcn.src.dataset import get_adj_features_labels, get_mask
DATA_DIR = '/home/workspace/mindspore_dataset/cora/cora_mr/cora_mr'
TRAIN_NODE_NUM = 140
EVAL_NODE_NUM = 500
TEST_NODE_NUM = 1000
SEED = 20
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_gcn():
print("test_gcn begin")
np.random.seed(SEED)
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend", save_graphs=False)
config = ConfigGCN()
config.dropout = 0.0
adj, feature, label_onehot, _ = get_adj_features_labels(DATA_DIR)
nodes_num = label_onehot.shape[0]
train_mask = get_mask(nodes_num, 0, TRAIN_NODE_NUM)
eval_mask = get_mask(nodes_num, TRAIN_NODE_NUM, TRAIN_NODE_NUM + EVAL_NODE_NUM)
test_mask = get_mask(nodes_num, nodes_num - TEST_NODE_NUM, nodes_num)
class_num = label_onehot.shape[1]
input_dim = feature.shape[1]
gcn_net = GCN(config, input_dim, class_num)
gcn_net.add_flags_recursive(fp16=True)
adj = Tensor(adj)
feature = Tensor(feature)
eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
test_net = LossAccuracyWrapper(gcn_net, label_onehot, test_mask, config.weight_decay)
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
loss_list = []
for epoch in range(config.epochs):
t = time.time()
train_net.set_train()
train_result = train_net(adj, feature)
train_loss = train_result[0].asnumpy()
train_accuracy = train_result[1].asnumpy()
eval_net.set_train(False)
eval_result = eval_net(adj, feature)
eval_loss = eval_result[0].asnumpy()
eval_accuracy = eval_result[1].asnumpy()
loss_list.append(eval_loss)
print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(train_loss),
"train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss),
"val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t))
if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]):
print("Early stopping...")
break
test_net.set_train(False)
test_result = test_net(adj, feature)
test_loss = test_result[0].asnumpy()
test_accuracy = test_result[1].asnumpy()
print("Test set results:", "loss=", "{:.5f}".format(test_loss),
"accuracy=", "{:.5f}".format(test_accuracy))
assert test_accuracy > 0.812