forked from mindspore-Ecosystem/mindspore
!2119 Add gat to model zoo
Merge pull request !2119 from zhangdengcheng/master
This commit is contained in:
commit
4485e0b55c
|
@ -0,0 +1,166 @@
|
|||
<!--TOC -->
|
||||
|
||||
- [Graph Attention Networks Description](#graph-attention-networks-description)
|
||||
- [Model architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Data Preparation](#data-preparation)
|
||||
- [Features](#features)
|
||||
- [Mixed Precision](#mixed-precision)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Structure](#structure)
|
||||
- [Parameter configuration](#parameter-configuration)
|
||||
- [Running the example](#running-the-example)
|
||||
- [Usage](#usage)
|
||||
- [Result](#result)
|
||||
- [Description of random situation](#description-of-random-situation)
|
||||
- [Others](#others)
|
||||
<!--TOC -->
|
||||
# Graph Attention Networks Description
|
||||
|
||||
Graph Attention Networks(GAT) was proposed in 2017 by Petar Veličković et al. By leveraging masked self-attentional layers to address shortcomings of prior graph based method, GAT achieved or matched state of the art performance on both transductive datasets like Cora and inductive dataset like PPI. This is an example of training GAT with Cora dataset in MindSpore.
|
||||
|
||||
[Paper](https://arxiv.org/abs/1710.10903): Veličković, P., Cucurull, G., Casanova, A., Romero, A., Lio, P., & Bengio, Y. (2017). Graph attention networks. arXiv preprint arXiv:1710.10903.
|
||||
|
||||
# Model architecture
|
||||
|
||||
An illustration of multi- head attention (with K = 3 heads) by node 1 on its neighborhood can be found below:
|
||||
|
||||
![](https://camo.githubusercontent.com/4fe1a90e67d17a2330d7cfcddc930d5f7501750c/68747470733a2f2f7777772e64726f70626f782e636f6d2f732f71327a703170366b37396a6a6431352f6761745f6c617965722e706e673f7261773d31)
|
||||
|
||||
Note that according to whether this attention layer is the output layer of the network or not, the node update function can be concatenate or average.
|
||||
|
||||
# Dataset
|
||||
Statistics of dataset used are summerized as below:
|
||||
|
||||
| | Cora | Citeseer |
|
||||
| ------------------ | -------------: | -------------: |
|
||||
| Task | Transductive | Transductive |
|
||||
| # Nodes | 2708 (1 graph) | 3327 (1 graph) |
|
||||
| # Edges | 5429 | 4732 |
|
||||
| # Features/Node | 1433 | 3703 |
|
||||
| # Classes | 7 | 6 |
|
||||
| # Training Nodes | 140 | 120 |
|
||||
| # Validation Nodes | 500 | 500 |
|
||||
| # Test Nodes | 1000 | 1000 |
|
||||
|
||||
## Data Preparation
|
||||
Download the dataset Cora or Citeseer provided by /kimiyoung/planetoid from github.
|
||||
|
||||
> Place the dataset to any path you want, the folder should include files as follows(we use Cora dataset as an example):
|
||||
|
||||
```
|
||||
.
|
||||
└─data
|
||||
├─ind.cora.allx
|
||||
├─ind.cora.ally
|
||||
├─ind.cora.graph
|
||||
├─ind.cora.test.index
|
||||
├─ind.cora.tx
|
||||
├─ind.cora.ty
|
||||
├─ind.cora.x
|
||||
└─ind.cora.y
|
||||
```
|
||||
|
||||
> Generate dataset in mindrecord format for cora or citeseer.
|
||||
>> Usage
|
||||
```buildoutcfg
|
||||
cd ./scripts
|
||||
# SRC_PATH is the dataset file path you downloaded, DATASET_NAME is cora or citeseer
|
||||
sh run_process_data.sh [SRC_PATH] [DATASET_NAME]
|
||||
```
|
||||
|
||||
>> Launch
|
||||
```
|
||||
#Generate dataset in mindrecord format for cora
|
||||
sh run_process_data.sh cora
|
||||
#Generate dataset in mindrecord format for citeseer
|
||||
sh run_process_data.sh citeseer
|
||||
```
|
||||
|
||||
# Features
|
||||
|
||||
## Mixed Precision
|
||||
|
||||
To ultilize the strong computation power of Ascend chip, and accelerate the training process, the mixed training method is used. MindSpore is able to cope with FP32 inputs and FP16 operators. In GAT example, the model is set to FP16 mode except for the loss calculation part.
|
||||
|
||||
# Environment Requirements
|
||||
|
||||
- Hardward (Ascend)
|
||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||
|
||||
# Structure
|
||||
|
||||
```shell
|
||||
.
|
||||
└─gat
|
||||
├─README.md
|
||||
├─scripts
|
||||
| ├─run_process_data.sh # Generate dataset in mindrecord format
|
||||
| └─run_train.sh # Launch training
|
||||
|
|
||||
├─src
|
||||
| ├─config.py # Training configurations
|
||||
| ├─dataset.py # Data preprocessing
|
||||
| ├─gat.py # GAT model
|
||||
| └─utils.py # Utils for training gat
|
||||
|
|
||||
└─train.py # Train net
|
||||
```
|
||||
|
||||
## Parameter configuration
|
||||
|
||||
Parameters for training can be set in config.py.
|
||||
|
||||
```
|
||||
"learning_rate": 0.005, # Learning rate
|
||||
"num_epochs": 200, # Epoch sizes for training
|
||||
"hid_units": [8], # Hidden units for attention head at each layer
|
||||
"n_heads": [8, 1], # Num heads for each layer
|
||||
"early_stopping": 100, # Early stop patience
|
||||
"l2_coeff": 0.0005 # l2 coefficient
|
||||
"attn_dropout": 0.6 # Attention dropout ratio
|
||||
"feature_dropout":0.6 # Feature dropout ratio
|
||||
```
|
||||
|
||||
# Running the example
|
||||
## Usage
|
||||
After Dataset is correctly generated.
|
||||
```
|
||||
# run train with cora dataset, DATASET_NAME is cora
|
||||
sh run_train.sh [DATASET_NAME]
|
||||
```
|
||||
|
||||
## Result
|
||||
|
||||
Training result will be stored in the scripts path, whose folder name begins with "train". You can find the result like the followings in log.
|
||||
|
||||
|
||||
```
|
||||
Epoch:0, train loss=1.98498 train acc=0.17143 | val loss=1.97946 val acc=0.27200
|
||||
Epoch:1, train loss=1.98345 train acc=0.15000 | val loss=1.97233 val acc=0.32600
|
||||
Epoch:2, train loss=1.96968 train acc=0.21429 | val loss=1.96747 val acc=0.37400
|
||||
Epoch:3, train loss=1.97061 train acc=0.20714 | val loss=1.96410 val acc=0.47600
|
||||
Epoch:4, train loss=1.96864 train acc=0.13571 | val loss=1.96066 val acc=0.59600
|
||||
...
|
||||
Epoch:195, train loss=1.45111 train_acc=0.56429 | val_loss=1.44325 val_acc=0.81200
|
||||
Epoch:196, train loss=1.52476 train_acc=0.52143 | val_loss=1.43871 val_acc=0.81200
|
||||
Epoch:197, train loss=1.35807 train_acc=0.62857 | val_loss=1.43364 val_acc=0.81400
|
||||
Epoch:198, train loss=1.47566 train_acc=0.51429 | val_loss=1.42948 val_acc=0.81000
|
||||
Epoch:199, train loss=1.56411 train_acc=0.55000 | val_loss=1.42632 val_acc=0.80600
|
||||
Test loss=1.5366285, test acc=0.84199995
|
||||
...
|
||||
```
|
||||
|
||||
Results on Cora dataset is shown by table below:
|
||||
|
||||
| | MindSpore + Ascend910 | Tensorflow + V100 |
|
||||
| ------------------------------------ | --------------------: | ----------------: |
|
||||
| Accuracy | 0.830933271 | 0.828649968 |
|
||||
| Training Cost(200 epochs) | 27.62298311s | 36.711862s |
|
||||
| End to End Training Cost(200 epochs) | 39.074s | 50.894s |
|
||||
|
||||
# Description of random situation
|
||||
GAT model contains lots of dropout operations, if you want to disable dropout, set the attn_dropout and feature_dropout to 0 in src/config.py. Note that this operation will cause the accuracy drop to approximately 80%.
|
||||
|
||||
# Others
|
||||
GAT model is verified on Ascend environment, not on CPU or GPU.
|
|
@ -0,0 +1,54 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_train.sh [SRC_PATH] [DATASET_NAME]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
SRC_PATH=$(get_real_path $1)
|
||||
echo $SRC_PATH
|
||||
|
||||
DATASET_NAME=$2
|
||||
echo $DATASET_NAME
|
||||
|
||||
if [ ! -d data_mr ]; then
|
||||
mkdir data_mr
|
||||
else
|
||||
echo data_mr exist
|
||||
fi
|
||||
MINDRECORD_PATH=`pwd`/data_mr
|
||||
|
||||
rm -f $MINDRECORD_PATH/*
|
||||
|
||||
cd ../../../example/graph_to_mindrecord || exit
|
||||
|
||||
python writer.py --mindrecord_script $DATASET_NAME \
|
||||
--mindrecord_file "$MINDRECORD_PATH/$DATASET_NAME" \
|
||||
--mindrecord_partitions 1 \
|
||||
--mindrecord_header_size_by_bit 18 \
|
||||
--mindrecord_page_size_by_bit 20 \
|
||||
--graph_api_args "$SRC_PATH"
|
||||
|
||||
cd - || exit
|
|
@ -0,0 +1,54 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]
|
||||
then
|
||||
echo "Usage: sh run_train.sh [DATASET_NAME]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DATASET_NAME=$1
|
||||
echo $DATASET_NAME
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
env > env.log
|
||||
echo "start training for device $DEVICE_ID"
|
||||
|
||||
|
||||
if [ $DATASET_NAME == cora ]
|
||||
then
|
||||
python train.py --data_dir=../data_mr/$DATASET_NAME &> log &
|
||||
fi
|
||||
|
||||
if [ $DATASET_NAME == citeseer ]
|
||||
then
|
||||
python train.py --data_dir=../data_mr/$DATASET_NAME --train_nodes_num=120 &> log &
|
||||
fi
|
||||
cd ..
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Train configs for training gat"""
|
||||
|
||||
|
||||
class GatConfig():
|
||||
lr = 0.005
|
||||
num_epochs = 200
|
||||
hid_units = [8]
|
||||
n_heads = [8, 1]
|
||||
early_stopping = 100
|
||||
l2_coeff = 0.0005
|
||||
attn_dropout = 0.6
|
||||
feature_dropout = 0.6
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Preprocess data obtained for training"""
|
||||
import numpy as np
|
||||
import mindspore.dataset as ds
|
||||
|
||||
|
||||
def adj_to_bias(adj):
|
||||
"""Add self loop to adj and make sure only one hop neighbors are engaged in computing"""
|
||||
num_graphs = adj.shape[0]
|
||||
adj_temp = np.empty(adj.shape)
|
||||
for i in range(num_graphs):
|
||||
adj_temp[i] = adj[i] + np.eye(adj.shape[1])
|
||||
return -1e9 * (1.0 - adj_temp)
|
||||
|
||||
|
||||
def get_biases_features_labels(data_dir):
|
||||
"""Get biases, features, labels from Dataset"""
|
||||
g = ds.GraphData(data_dir)
|
||||
nodes = g.get_all_nodes(0)
|
||||
nodes_list = nodes.tolist()
|
||||
row_tensor = g.get_node_feature(nodes_list, [1, 2])
|
||||
features = row_tensor[0]
|
||||
features = features[np.newaxis]
|
||||
|
||||
labels = row_tensor[1]
|
||||
|
||||
nodes_num = labels.shape[0]
|
||||
class_num = labels.max() + 1
|
||||
labels_onehot = np.eye(nodes_num, class_num)[labels].astype(np.float32)
|
||||
|
||||
neighbor = g.get_all_neighbors(nodes_list, 0)
|
||||
node_map = {node_id: index for index, node_id in enumerate(nodes_list)}
|
||||
adj = np.zeros([nodes_num, nodes_num], dtype=np.float32)
|
||||
for index, value in np.ndenumerate(neighbor):
|
||||
if value >= 0 and index[1] > 0:
|
||||
adj[node_map[neighbor[index[0], 0]], node_map[value]] = 1
|
||||
adj = adj[np.newaxis]
|
||||
biases = adj_to_bias(adj)
|
||||
|
||||
return biases, features, labels_onehot
|
||||
|
||||
|
||||
def get_mask(total, begin, end):
|
||||
"""Generate mask according to begin and end position"""
|
||||
mask = np.zeros([total]).astype(np.float32)
|
||||
mask[begin:end] = 1
|
||||
return np.array(mask, dtype=np.bool)
|
||||
|
||||
|
||||
def load_and_process(data_dir, train_node_num, eval_node_num, test_node_num):
|
||||
"""Load cora dataset and preprocessing"""
|
||||
biases, feature, label = get_biases_features_labels(data_dir)
|
||||
# split training, validation and testing set
|
||||
nodes_num = label.shape[0]
|
||||
train_mask = get_mask(nodes_num, 0, train_node_num)
|
||||
eval_mask = get_mask(nodes_num, train_node_num, train_node_num + eval_node_num)
|
||||
test_mask = get_mask(nodes_num, nodes_num - test_node_num, nodes_num)
|
||||
|
||||
y_train = np.zeros(label.shape)
|
||||
y_val = np.zeros(label.shape)
|
||||
y_test = np.zeros(label.shape)
|
||||
|
||||
y_train[train_mask, :] = label[train_mask, :]
|
||||
y_val[eval_mask, :] = label[eval_mask, :]
|
||||
y_test[test_mask, :] = label[test_mask, :]
|
||||
|
||||
y_train = y_train[np.newaxis]
|
||||
y_val = y_val[np.newaxis]
|
||||
y_test = y_test[np.newaxis]
|
||||
train_mask = train_mask[np.newaxis]
|
||||
eval_mask = eval_mask[np.newaxis]
|
||||
test_mask = test_mask[np.newaxis]
|
||||
|
||||
return feature, biases, y_train, train_mask, y_val, eval_mask, y_test, test_mask
|
|
@ -0,0 +1,496 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Aggregator."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore._extends import cell_attr_register
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore._checkparam import check_int_positive, check_bool
|
||||
from mindspore.nn.layer.activation import get_activation
|
||||
|
||||
|
||||
class GNNFeatureTransform(nn.Cell):
|
||||
r"""
|
||||
The GNN featuren transform layer for input.
|
||||
|
||||
Applies linear transformation for the input feature. This layer implements the operation as:
|
||||
|
||||
.. math::
|
||||
\text{outputs} = \text{inputs} * \text{kernel} + \text{bias},
|
||||
|
||||
where :math:`\text{activation}` is the activation function passed as the activation
|
||||
argument (if passed in),:math:`\text{activation}` is a weight matrix with the same
|
||||
data type as the inputs created by the layer, and :math:`\text{bias}` is a bias vector
|
||||
with the same data type as the inputs created by the layer (only if has_bias is True).
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels in the input space.
|
||||
out_channels (int): The number of channels in the output space.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
||||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||
|
||||
Raises:
|
||||
ValueError: If weight_init or bias_init shape is incorrect.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(*B, N, C)`,
|
||||
where :math:`*B` represents the batch size which can be multidimensional, :math:`N` and :math:`C` are the
|
||||
size of the last two dimensions. If `transpose_a` is True, its shape should be :math:`(*B, C, N)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape of the output tensor is :math:`(*B, N, M)`.
|
||||
|
||||
Examples:
|
||||
>>> net = nn.Dense(3, 4)
|
||||
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
||||
>>> net(input)
|
||||
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
|
||||
[ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
|
||||
"""
|
||||
@cell_attr_register
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
has_bias=True):
|
||||
super(GNNFeatureTransform, self).__init__()
|
||||
self.in_channels = check_int_positive(in_channels)
|
||||
self.out_channels = check_int_positive(out_channels)
|
||||
self.has_bias = check_bool(has_bias)
|
||||
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
|
||||
weight_init.shape()[1] != in_channels:
|
||||
raise ValueError("weight_init shape error")
|
||||
|
||||
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
|
||||
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
||||
|
||||
self.matmul = P.MatMul(transpose_b=True)
|
||||
self.bias_add = P.BiasAdd()
|
||||
|
||||
def construct(self, x):
|
||||
tensor_shape = F.shape(x)
|
||||
input_feature = F.reshape(x, (tensor_shape[0] * tensor_shape[1], tensor_shape[2]))
|
||||
output = self.matmul(input_feature, self.weight)
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
output = F.reshape(output, (tensor_shape[0], tensor_shape[1], self.out_channels))
|
||||
return output
|
||||
|
||||
def extend_repr(self):
|
||||
str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \
|
||||
.format(self.in_channels, self.out_channels, self.weight, self.has_bias)
|
||||
if self.has_bias:
|
||||
str_info = str_info + ', bias={}'.format(self.bias)
|
||||
|
||||
return str_info
|
||||
|
||||
|
||||
class _BaseAggregator(nn.Cell):
|
||||
"""
|
||||
Base Aggregator of GNN
|
||||
|
||||
Args:
|
||||
feature_in_dim (int): Node or edge input feature dim.
|
||||
feature_out_dim (int): Node or edge outpout feature dim.
|
||||
use_fc (bool): Specifies whether a linear transformation before message is aggregated. Default: True
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
||||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||
dropout_ratio (float): The keep rate of dropout layer, greater than 0 and less equal than 1. Default: None.
|
||||
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||
|
||||
Examples:
|
||||
>>> class MyAggregator(_BaseAggregator):
|
||||
>>> def __init__(self):
|
||||
>>> super(MyAggregator, self).__init__(self, feature_in_dim, feature_out_dim)
|
||||
>>> self.reduce_mean = P.ReduceSum()
|
||||
>>>
|
||||
>>> def construct(self, x):
|
||||
>>> return self.reduce_mean(x, 1)
|
||||
"""
|
||||
def __init__(self,
|
||||
feature_in_dim,
|
||||
feature_out_dim,
|
||||
use_fc=True,
|
||||
weight_init="normal",
|
||||
bias_init="zeros",
|
||||
has_bias=True,
|
||||
dropout_ratio=None,
|
||||
activation=None):
|
||||
super(_BaseAggregator, self).__init__()
|
||||
self.in_dim = feature_in_dim
|
||||
self.out_dim = feature_out_dim
|
||||
self.use_fc = use_fc
|
||||
if self.use_fc:
|
||||
self.weight_init = weight_init
|
||||
self.bias_init = bias_init
|
||||
self.has_bias = has_bias
|
||||
self.fc = GNNFeatureTransform(self.in_dim,
|
||||
self.out_dim,
|
||||
weight_init=self.weight_init,
|
||||
bias_init=self.bias_init,
|
||||
has_bias=self.has_bias)
|
||||
self.dropout_ratio = dropout_ratio
|
||||
if self.dropout_ratio is not None:
|
||||
self.dropout = nn.Dropout(keep_prob=self.dropout_ratio)
|
||||
self.dropout_flag = self.dropout_ratio is not None
|
||||
self.activation = get_activation(activation)
|
||||
self.activation_flag = self.activation is not None
|
||||
|
||||
def construct(self, **kward):
|
||||
"""Must be overridden by all subclasses."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MeanAggregator(_BaseAggregator):
|
||||
"""
|
||||
Mean Aggregator of GNN
|
||||
|
||||
Args:
|
||||
feature_in_dim (int): Node or edge input feature dim.
|
||||
feature_out_dim (int): Node or edge outpout feature dim.
|
||||
use_fc (bool): Specifies whether a linear transformation before message is aggregated. Default: True
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
||||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||
dropout_ratio (float): The keep rate of dropout layer, greater than 0 and less equal than 1. Default: None.
|
||||
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||
|
||||
Examples:
|
||||
>>> net = MeanAggregator(32, 64, activation="relu", dropout=0.5)
|
||||
>>> input_data = Tensor(np.array(np.random.rand(32, 3, 32), dtypy=np.float32))
|
||||
>>> output = net(input_data)
|
||||
"""
|
||||
def __init__(self,
|
||||
feature_in_dim,
|
||||
feature_out_dim,
|
||||
use_fc=True,
|
||||
weight_init="normal",
|
||||
bias_init="zeros",
|
||||
has_bias=True,
|
||||
dropout_ratio=None,
|
||||
activation=None):
|
||||
super(MeanAggregator, self).__init__(
|
||||
feature_in_dim,
|
||||
feature_out_dim,
|
||||
use_fc,
|
||||
weight_init,
|
||||
bias_init,
|
||||
has_bias,
|
||||
dropout_ratio,
|
||||
activation)
|
||||
self.reduce_mean = P.ReduceMean(keep_dims=False)
|
||||
|
||||
def construct(self, input_feature):
|
||||
if self.use_fc:
|
||||
input_feature = self.fc(input_feature)
|
||||
if self.dropout_flag:
|
||||
input_feature = self.dropout(input_feature)
|
||||
if self.activation_flag:
|
||||
input_feature = self.activation(input_feature)
|
||||
output_feature = self.reduce_mean(input_feature, 1)
|
||||
return output_feature
|
||||
|
||||
|
||||
class AttentionHead(nn.Cell):
|
||||
"""
|
||||
Attention Head for Graph Attention Networks.
|
||||
|
||||
Args:
|
||||
in_channel (int): The number of input channel, input feature dim.
|
||||
out_channel (int): The number of output channel, output feature dim.
|
||||
in_drop_ratio (float): Input feature dropout ratio, default 0.0.
|
||||
coef_drop_ratio (float): Coefficient dropout ratio, default 0.0.
|
||||
residual (bool): Whether to use residual connection, default False.
|
||||
coef_activation (Cell): The attention coefficient activation function,
|
||||
default nn.LeakyReLU().
|
||||
activation (Cell): The output activation function, default nn.ELU().
|
||||
|
||||
Inputs:
|
||||
- **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim).
|
||||
- **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes).
|
||||
|
||||
Examples:
|
||||
>>> head = AttentionHead(1433,
|
||||
8,
|
||||
in_drop_ratio=0.6,
|
||||
coef_drop_ratio=0.6,
|
||||
residual=False)
|
||||
>>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtypy=np.float32))
|
||||
>>> output = net(input_data)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
in_drop_ratio=0.0,
|
||||
coef_drop_ratio=0.0,
|
||||
residual=False,
|
||||
coef_activation=nn.LeakyReLU(),
|
||||
activation=nn.ELU()):
|
||||
super(AttentionHead, self).__init__()
|
||||
self.in_channel = check_int_positive(in_channel)
|
||||
self.out_channel = check_int_positive(out_channel)
|
||||
self.in_drop_ratio = in_drop_ratio
|
||||
self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio)
|
||||
self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio)
|
||||
self.feature_transform = GNNFeatureTransform(
|
||||
in_channels=self.in_channel,
|
||||
out_channels=self.out_channel,
|
||||
has_bias=False,
|
||||
weight_init='XavierUniform')
|
||||
|
||||
self.f_1_transform = GNNFeatureTransform(
|
||||
in_channels=self.out_channel,
|
||||
out_channels=1,
|
||||
weight_init='XavierUniform')
|
||||
self.f_2_transform = GNNFeatureTransform(
|
||||
in_channels=self.out_channel,
|
||||
out_channels=1,
|
||||
weight_init='XavierUniform')
|
||||
self.softmax = nn.Softmax()
|
||||
|
||||
self.coef_drop = nn.Dropout(keep_prob=1 - coef_drop_ratio)
|
||||
self.matmul = P.MatMul()
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.bias = Parameter(initializer('zeros', self.out_channel), name='bias')
|
||||
self.residual = check_bool(residual)
|
||||
if self.residual:
|
||||
if in_channel != out_channel:
|
||||
self.residual_transform_flag = True
|
||||
self.residual_transform = GNNFeatureTransform(
|
||||
in_channels=self.in_channel,
|
||||
out_channels=self.out_channel)
|
||||
else:
|
||||
self.residual_transform = None
|
||||
self.coef_activation = coef_activation
|
||||
self.activation = activation
|
||||
|
||||
def construct(self, input_feature, bias_mat, training=True):
|
||||
if training is True:
|
||||
input_feature = self.in_drop(input_feature)
|
||||
|
||||
feature = self.feature_transform(input_feature)
|
||||
# self attention
|
||||
f_1 = self.f_1_transform(feature)
|
||||
f_2 = self.f_2_transform(feature)
|
||||
logits = f_1 + P.Transpose()(f_2, (0, 2, 1))
|
||||
logits = self.coef_activation(logits) + bias_mat
|
||||
coefs = self.softmax(logits)
|
||||
if training is True:
|
||||
coefs = self.coef_drop(coefs)
|
||||
feature = self.in_drop_2(feature)
|
||||
|
||||
coefs = P.Squeeze(0)(coefs)
|
||||
feature = P.Squeeze(0)(feature)
|
||||
|
||||
ret = self.matmul(coefs, feature)
|
||||
ret = self.bias_add(ret, self.bias)
|
||||
ret = P.ExpandDims()(ret, 0)
|
||||
# residual connection
|
||||
if self.residual:
|
||||
if self.residual_transform_flag:
|
||||
res = self.residual_transform(input_feature)
|
||||
ret = ret + res
|
||||
else:
|
||||
ret = ret + input_feature
|
||||
# activation
|
||||
if self.activation is not None:
|
||||
ret = self.activation(ret)
|
||||
return ret
|
||||
|
||||
|
||||
class AttentionAggregator(nn.Cell):
|
||||
"""
|
||||
Attention Head for Graph Attention Networks,can be regarded as one
|
||||
GAT layer.
|
||||
|
||||
Args:
|
||||
in_channel (int): Input channel.
|
||||
out_channel (int): Output channel.
|
||||
num_heads (int): Number of attention heads for this layer, default 1.
|
||||
in_drop_ratio (float): Input feature dropout ratio, default 0.0.
|
||||
coef_drop_ratio (float): Coefficient dropout ratio, default 0.0.
|
||||
activation (Cell): The output activation function, default nn.ELU().
|
||||
residual (bool): Whether to use residual connection, default False.
|
||||
output_transform (str['concat', 'sum']): output transform for a layer,
|
||||
default 'concat'
|
||||
|
||||
Inputs:
|
||||
- **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim).
|
||||
- **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes).
|
||||
|
||||
Examples:
|
||||
>>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtype=np.float32))
|
||||
>>> biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32))
|
||||
>>> net = AttentionAggregator(1433,
|
||||
8,
|
||||
8)
|
||||
>>> net(input_data, biases)
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_heads=1,
|
||||
in_drop=0.0,
|
||||
coef_drop=0.0,
|
||||
activation=nn.ELU(),
|
||||
residual=False,
|
||||
output_transform='concat'):
|
||||
super(AttentionAggregator, self).__init__()
|
||||
self.num_heads = num_heads
|
||||
self.attns = []
|
||||
for _ in range(num_heads):
|
||||
self.attns.append(AttentionHead(in_channels,
|
||||
out_channels,
|
||||
in_drop_ratio=in_drop,
|
||||
coef_drop_ratio=coef_drop,
|
||||
activation=activation,
|
||||
residual=residual))
|
||||
self.attns = nn.layer.CellList(self.attns)
|
||||
if output_transform == 'concat':
|
||||
self.out_trans = P.Concat(-1)
|
||||
elif output_transform == 'sum':
|
||||
self.out_trans = P.AddN()
|
||||
else:
|
||||
raise ValueError("output_transform must be either 'concat' or 'sum'")
|
||||
|
||||
def construct(self, input_data, bias_mat, training=True):
|
||||
res = ()
|
||||
for i in range(self.num_heads):
|
||||
res += (self.attns[i](input_data, bias_mat, training),)
|
||||
return self.out_trans(res)
|
||||
|
||||
|
||||
class GAT(nn.Cell):
|
||||
"""
|
||||
Graph Attention Network
|
||||
|
||||
Args:
|
||||
ftr_dims (int): Initial feature dimensions.
|
||||
num_class (int): Num of class to identify.
|
||||
num_nodes (int): Num of nodes in this graph.
|
||||
hidden_units (list[int]): Num of hidden units at each layer.
|
||||
num_heads (list[int]): Num of heads at each layer.
|
||||
attn_drop (float): Drop out ratio of attention coefficient,
|
||||
default 0.0.
|
||||
ftr_drop (float): Drop out ratio of feature, default 0.0.
|
||||
activation (Cell): Activation Function for output layer, default
|
||||
nn.Elu().
|
||||
residual (bool): Whether to use residual connection between
|
||||
intermediate layers, default False.
|
||||
|
||||
Examples:
|
||||
>>> ft_sizes = 1433
|
||||
>>> num_class = 7
|
||||
>>> num_nodes = 2708
|
||||
>>> hid_units = [8]
|
||||
>>> n_heads = [8, 1]
|
||||
>>> activation = nn.ELU()
|
||||
>>> residual = False
|
||||
>>> input_data = np.array(np.random.rand(1, 2708, 1433))
|
||||
>>> biases = np.array(np.random.rand(1, 2708, 2708))
|
||||
>>> net = GAT(ft_sizes,
|
||||
num_class,
|
||||
num_nodes,
|
||||
hidden_units=hid_units,
|
||||
num_heads=n_heads,
|
||||
attn_drop=0.6,
|
||||
ftr_drop=0.6,
|
||||
activation=activation,
|
||||
residual=residual)
|
||||
>>> output = net(input_data, biases)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
features,
|
||||
biases,
|
||||
ftr_dims,
|
||||
num_class,
|
||||
num_nodes,
|
||||
hidden_units,
|
||||
num_heads,
|
||||
attn_drop=0.0,
|
||||
ftr_drop=0.0,
|
||||
activation=nn.ELU(),
|
||||
residual=False):
|
||||
super(GAT, self).__init__()
|
||||
self.features = Tensor(features)
|
||||
self.biases = Tensor(biases)
|
||||
self.ftr_dims = check_int_positive(ftr_dims)
|
||||
self.num_class = check_int_positive(num_class)
|
||||
self.num_nodes = check_int_positive(num_nodes)
|
||||
self.hidden_units = hidden_units
|
||||
self.num_heads = num_heads
|
||||
self.attn_drop = attn_drop
|
||||
self.ftr_drop = ftr_drop
|
||||
self.activation = activation
|
||||
self.residual = check_bool(residual)
|
||||
self.layers = []
|
||||
# first layer
|
||||
self.layers.append(AttentionAggregator(
|
||||
self.ftr_dims,
|
||||
self.hidden_units[0],
|
||||
self.num_heads[0],
|
||||
self.ftr_drop,
|
||||
self.attn_drop,
|
||||
self.activation,
|
||||
residual=False))
|
||||
# intermediate layer
|
||||
for i in range(1, len(self.hidden_units)):
|
||||
self.layers.append(AttentionAggregator(
|
||||
self.hidden_units[i-1]*self.num_heads[i-1],
|
||||
self.hidden_units[i],
|
||||
self.num_heads[i],
|
||||
self.ftr_drop,
|
||||
self.attn_drop,
|
||||
self.activation,
|
||||
residual=self.residual))
|
||||
# output layer
|
||||
self.layers.append(AttentionAggregator(
|
||||
self.hidden_units[-1]*self.num_heads[-2],
|
||||
self.num_class,
|
||||
self.num_heads[-1],
|
||||
self.ftr_drop,
|
||||
self.attn_drop,
|
||||
activation=None,
|
||||
residual=False,
|
||||
output_transform='sum'))
|
||||
self.layers = nn.layer.CellList(self.layers)
|
||||
|
||||
def construct(self, training=True):
|
||||
input_data = self.features
|
||||
bias_mat = self.biases
|
||||
for cell in self.layers:
|
||||
input_data = cell(input_data, bias_mat, training)
|
||||
return input_data/self.num_heads[-1]
|
|
@ -0,0 +1,178 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Utils for training gat"""
|
||||
from mindspore import nn
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class MaskedSoftMaxLoss(nn.Cell):
|
||||
"""Calculate masked softmax loss with l2 loss"""
|
||||
def __init__(self, num_class, label, mask, l2_coeff, params):
|
||||
super(MaskedSoftMaxLoss, self).__init__()
|
||||
self.num_class = num_class
|
||||
self.label = label
|
||||
self.mask = mask
|
||||
self.softmax = P.SoftmaxCrossEntropyWithLogits()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.cast = P.Cast()
|
||||
self.l2_coeff = l2_coeff
|
||||
self.params = ParameterTuple(list(param for param in params if param.name[-4:] != 'bias'))
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.num_params = len(self.params)
|
||||
|
||||
def construct(self, logits):
|
||||
# calc l2 loss
|
||||
l2_loss = 0
|
||||
for i in range(self.num_params):
|
||||
l2_loss = l2_loss + self.l2_coeff * P.L2Loss()(self.params[i])
|
||||
|
||||
logits = P.Reshape()(logits, (-1, self.num_class))
|
||||
label = P.Reshape()(self.label, (-1, self.num_class))
|
||||
mask = P.Reshape()(self.mask, (-1,))
|
||||
|
||||
logits = self.cast(logits, mstype.float32)
|
||||
loss = self.softmax(logits, label)[0]
|
||||
mask /= self.reduce_mean(mask)
|
||||
loss *= mask
|
||||
loss = self.reduce_mean(loss)
|
||||
l2_loss = P.Cast()(l2_loss, mstype.float32)
|
||||
return loss+l2_loss
|
||||
|
||||
|
||||
class MaskedAccuracy(nn.Cell):
|
||||
"""Calculate accuracy with mask"""
|
||||
def __init__(self, num_class, label, mask):
|
||||
super(MaskedAccuracy, self).__init__()
|
||||
self.argmax = P.Argmax(axis=1)
|
||||
self.cast = P.Cast()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.equal = P.Equal()
|
||||
self.num_class = num_class
|
||||
self.label = Tensor(label, dtype=mstype.float32)
|
||||
self.mask = Tensor(mask, dtype=mstype.float32)
|
||||
|
||||
def construct(self, logits):
|
||||
logits = P.Reshape()(logits, (-1, self.num_class))
|
||||
labels = P.Reshape()(self.label, (-1, self.num_class))
|
||||
mask = P.Reshape()(self.mask, (-1,))
|
||||
|
||||
labels = self.cast(labels, mstype.float32)
|
||||
|
||||
correct_prediction = self.equal(self.argmax(logits), self.argmax(labels))
|
||||
accuracy_all = self.cast(correct_prediction, mstype.float32)
|
||||
mask = self.cast(mask, mstype.float32)
|
||||
mask /= self.reduce_mean(mask)
|
||||
accuracy_all *= mask
|
||||
return self.reduce_mean(accuracy_all)
|
||||
|
||||
|
||||
class LossAccuracyWrapper(nn.Cell):
|
||||
"""
|
||||
Warp GAT model with loss calculation and accuracy calculation, loss is calculated with l2 loss.
|
||||
|
||||
Args:
|
||||
network (Cell): GAT network with logits calculation as output.
|
||||
num_class (int): num of class for classification.
|
||||
label (numpy.ndarray): Train Dataset label.
|
||||
mask (numpy.ndarray): Train Dataset mask.
|
||||
l2_coeff (float): l2 loss discount rate.
|
||||
"""
|
||||
def __init__(self, network, num_class, label, mask, l2_coeff):
|
||||
super(LossAccuracyWrapper, self).__init__()
|
||||
self.network = network
|
||||
label = Tensor(label, dtype=mstype.float32)
|
||||
mask = Tensor(mask, dtype=mstype.float32)
|
||||
self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, self.network.trainable_params())
|
||||
self.acc_func = MaskedAccuracy(num_class, label, mask)
|
||||
|
||||
def construct(self):
|
||||
logits = self.network(training=False)
|
||||
loss = self.loss_func(logits)
|
||||
accuracy = self.acc_func(logits)
|
||||
return loss, accuracy
|
||||
|
||||
|
||||
class LossNetWrapper(nn.Cell):
|
||||
"""Wrap GAT model with loss calculation"""
|
||||
def __init__(self, network, num_class, label, mask, l2_coeff):
|
||||
super(LossNetWrapper, self).__init__()
|
||||
self.network = network
|
||||
label = Tensor(label, dtype=mstype.float32)
|
||||
mask = Tensor(mask, dtype=mstype.float32)
|
||||
params = list(param for param in self.network.trainable_params() if param.name[-4:] != 'bias')
|
||||
self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, params)
|
||||
|
||||
def construct(self):
|
||||
logits = self.network()
|
||||
loss = self.loss_func(logits)
|
||||
return loss
|
||||
|
||||
|
||||
class TrainOneStepCell(nn.Cell):
|
||||
"""
|
||||
For network training. Warp the loss net with optimizer.
|
||||
|
||||
Args:
|
||||
network (Cell): GAT network with loss calculation as the output.
|
||||
optimizer (Cell): Optimizer for minimize the loss.
|
||||
sens (Float): Backpropagation input number, default 1.0.
|
||||
"""
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(TrainOneStepCell, self).__init__(auto_prefix=True)
|
||||
self.network = network
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
|
||||
def construct(self):
|
||||
weights = self.weights
|
||||
loss = self.network()
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(sens)
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
|
||||
class TrainGAT(nn.Cell):
|
||||
"""
|
||||
Warp GAT model with everything needed for training, include loss, optimizer ,etc.
|
||||
|
||||
Args:
|
||||
network (Cell): GAT network.
|
||||
num_class (int): num of class for classification.
|
||||
label (numpy.ndarray): Train Dataset label.
|
||||
mask (numpy.ndarray): Train Dataset mask.
|
||||
learning_rate (float): Learning rate.
|
||||
l2_coeff (float): l2 loss discount rate.
|
||||
"""
|
||||
def __init__(self, network, num_class, label, mask, learning_rate, l2_coeff):
|
||||
super(TrainGAT, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
loss_net = LossNetWrapper(network, num_class, label, mask, l2_coeff)
|
||||
optimizer = nn.Adam(loss_net.trainable_params(),
|
||||
learning_rate=learning_rate)
|
||||
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
|
||||
self.accuracy_func = MaskedAccuracy(num_class, label, mask)
|
||||
|
||||
def construct(self):
|
||||
loss = self.loss_train_net()
|
||||
accuracy = self.accuracy_func(self.network())
|
||||
return loss, accuracy
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Test train gat"""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
from mindspore.train.serialization import _exec_save_checkpoint, load_checkpoint
|
||||
|
||||
from src.config import GatConfig
|
||||
from src.dataset import load_and_process
|
||||
from src.gat import GAT
|
||||
from src.utils import LossAccuracyWrapper, TrainGAT
|
||||
|
||||
|
||||
def train():
|
||||
"""Train GAT model."""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Data dir')
|
||||
parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training')
|
||||
parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation')
|
||||
parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test')
|
||||
args = parser.parse_args()
|
||||
if not os.path.exists("ckpts"):
|
||||
os.mkdir("ckpts")
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend",
|
||||
save_graphs=False)
|
||||
# train parameters
|
||||
hid_units = GatConfig.hid_units
|
||||
n_heads = GatConfig.n_heads
|
||||
early_stopping = GatConfig.early_stopping
|
||||
lr = GatConfig.lr
|
||||
l2_coeff = GatConfig.l2_coeff
|
||||
num_epochs = GatConfig.num_epochs
|
||||
feature, biases, y_train, train_mask, y_val, eval_mask, y_test, test_mask = load_and_process(args.data_dir,
|
||||
args.train_nodes_num,
|
||||
args.eval_nodes_num,
|
||||
args.test_nodes_num)
|
||||
feature_size = feature.shape[2]
|
||||
num_nodes = feature.shape[1]
|
||||
num_class = y_train.shape[2]
|
||||
|
||||
gat_net = GAT(feature,
|
||||
biases,
|
||||
feature_size,
|
||||
num_class,
|
||||
num_nodes,
|
||||
hid_units,
|
||||
n_heads,
|
||||
attn_drop=GatConfig.attn_dropout,
|
||||
ftr_drop=GatConfig.feature_dropout)
|
||||
gat_net.add_flags_recursive(fp16=True)
|
||||
|
||||
eval_net = LossAccuracyWrapper(gat_net,
|
||||
num_class,
|
||||
y_val,
|
||||
eval_mask,
|
||||
l2_coeff)
|
||||
|
||||
train_net = TrainGAT(gat_net,
|
||||
num_class,
|
||||
y_train,
|
||||
train_mask,
|
||||
lr,
|
||||
l2_coeff)
|
||||
|
||||
train_net.set_train(True)
|
||||
val_acc_max = 0.0
|
||||
val_loss_min = np.inf
|
||||
for _epoch in range(num_epochs):
|
||||
train_result = train_net()
|
||||
train_loss = train_result[0].asnumpy()
|
||||
train_acc = train_result[1].asnumpy()
|
||||
|
||||
eval_result = eval_net()
|
||||
eval_loss = eval_result[0].asnumpy()
|
||||
eval_acc = eval_result[1].asnumpy()
|
||||
|
||||
print("Epoch:{}, train loss={:.5f}, train acc={:.5f} | val loss={:.5f}, val acc={:.5f}".format(
|
||||
_epoch, train_loss, train_acc, eval_loss, eval_acc))
|
||||
if eval_acc >= val_acc_max or eval_loss < val_loss_min:
|
||||
if eval_acc >= val_acc_max and eval_loss < val_loss_min:
|
||||
val_acc_model = eval_acc
|
||||
val_loss_model = eval_loss
|
||||
_exec_save_checkpoint(train_net.network, "ckpts/gat.ckpt")
|
||||
val_acc_max = np.max((val_acc_max, eval_acc))
|
||||
val_loss_min = np.min((val_loss_min, eval_loss))
|
||||
curr_step = 0
|
||||
else:
|
||||
curr_step += 1
|
||||
if curr_step == early_stopping:
|
||||
print("Early Stop Triggered!, Min loss: {}, Max accuracy: {}".format(val_loss_min, val_acc_max))
|
||||
print("Early stop model validation loss: {}, accuracy{}".format(val_loss_model, val_acc_model))
|
||||
break
|
||||
gat_net_test = GAT(feature,
|
||||
biases,
|
||||
feature_size,
|
||||
num_class,
|
||||
num_nodes,
|
||||
hid_units,
|
||||
n_heads,
|
||||
attn_drop=0.0,
|
||||
ftr_drop=0.0)
|
||||
load_checkpoint("ckpts/gat.ckpt", net=gat_net_test)
|
||||
gat_net_test.add_flags_recursive(fp16=True)
|
||||
|
||||
test_net = LossAccuracyWrapper(gat_net_test,
|
||||
num_class,
|
||||
y_test,
|
||||
test_mask,
|
||||
l2_coeff)
|
||||
test_result = test_net()
|
||||
print("Test loss={}, test acc={}".format(test_result[0], test_result[1]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
|
@ -110,4 +110,4 @@ Epoch: 0200 train_loss= 0.57948 train_acc= 0.96429 val_loss= 1.04753 val_acc= 0.
|
|||
Optimization Finished!
|
||||
Test set results: cost= 1.00983 accuracy= 0.81300 time= 0.39083
|
||||
...
|
||||
```
|
||||
```
|
Loading…
Reference in New Issue