!2119 Add gat to model zoo

Merge pull request !2119 from zhangdengcheng/master
This commit is contained in:
mindspore-ci-bot 2020-06-16 20:38:57 +08:00 committed by Gitee
commit 4485e0b55c
10 changed files with 1193 additions and 1 deletions

166
model_zoo/gat/README.md Normal file
View File

@ -0,0 +1,166 @@
<!--TOC -->
- [Graph Attention Networks Description](#graph-attention-networks-description)
- [Model architecture](#model-architecture)
- [Dataset](#dataset)
- [Data Preparation](#data-preparation)
- [Features](#features)
- [Mixed Precision](#mixed-precision)
- [Environment Requirements](#environment-requirements)
- [Structure](#structure)
- [Parameter configuration](#parameter-configuration)
- [Running the example](#running-the-example)
- [Usage](#usage)
- [Result](#result)
- [Description of random situation](#description-of-random-situation)
- [Others](#others)
<!--TOC -->
# Graph Attention Networks Description
Graph Attention Networks(GAT) was proposed in 2017 by Petar Veličković et al. By leveraging masked self-attentional layers to address shortcomings of prior graph based method, GAT achieved or matched state of the art performance on both transductive datasets like Cora and inductive dataset like PPI. This is an example of training GAT with Cora dataset in MindSpore.
[Paper](https://arxiv.org/abs/1710.10903): Veličković, P., Cucurull, G., Casanova, A., Romero, A., Lio, P., & Bengio, Y. (2017). Graph attention networks. arXiv preprint arXiv:1710.10903.
# Model architecture
An illustration of multi- head attention (with K = 3 heads) by node 1 on its neighborhood can be found below:
![](https://camo.githubusercontent.com/4fe1a90e67d17a2330d7cfcddc930d5f7501750c/68747470733a2f2f7777772e64726f70626f782e636f6d2f732f71327a703170366b37396a6a6431352f6761745f6c617965722e706e673f7261773d31)
Note that according to whether this attention layer is the output layer of the network or not, the node update function can be concatenate or average.
# Dataset
Statistics of dataset used are summerized as below:
| | Cora | Citeseer |
| ------------------ | -------------: | -------------: |
| Task | Transductive | Transductive |
| # Nodes | 2708 (1 graph) | 3327 (1 graph) |
| # Edges | 5429 | 4732 |
| # Features/Node | 1433 | 3703 |
| # Classes | 7 | 6 |
| # Training Nodes | 140 | 120 |
| # Validation Nodes | 500 | 500 |
| # Test Nodes | 1000 | 1000 |
## Data Preparation
Download the dataset Cora or Citeseer provided by /kimiyoung/planetoid from github.
> Place the dataset to any path you want, the folder should include files as follows(we use Cora dataset as an example):
```
.
└─data
├─ind.cora.allx
├─ind.cora.ally
├─ind.cora.graph
├─ind.cora.test.index
├─ind.cora.tx
├─ind.cora.ty
├─ind.cora.x
└─ind.cora.y
```
> Generate dataset in mindrecord format for cora or citeseer.
>> Usage
```buildoutcfg
cd ./scripts
# SRC_PATH is the dataset file path you downloaded, DATASET_NAME is cora or citeseer
sh run_process_data.sh [SRC_PATH] [DATASET_NAME]
```
>> Launch
```
#Generate dataset in mindrecord format for cora
sh run_process_data.sh cora
#Generate dataset in mindrecord format for citeseer
sh run_process_data.sh citeseer
```
# Features
## Mixed Precision
To ultilize the strong computation power of Ascend chip, and accelerate the training process, the mixed training method is used. MindSpore is able to cope with FP32 inputs and FP16 operators. In GAT example, the model is set to FP16 mode except for the loss calculation part.
# Environment Requirements
- Hardward (Ascend)
- Install [MindSpore](https://www.mindspore.cn/install/en).
# Structure
```shell
.
└─gat
├─README.md
├─scripts
| ├─run_process_data.sh # Generate dataset in mindrecord format
| └─run_train.sh # Launch training
|
├─src
| ├─config.py # Training configurations
| ├─dataset.py # Data preprocessing
| ├─gat.py # GAT model
| └─utils.py # Utils for training gat
|
└─train.py # Train net
```
## Parameter configuration
Parameters for training can be set in config.py.
```
"learning_rate": 0.005, # Learning rate
"num_epochs": 200, # Epoch sizes for training
"hid_units": [8], # Hidden units for attention head at each layer
"n_heads": [8, 1], # Num heads for each layer
"early_stopping": 100, # Early stop patience
"l2_coeff": 0.0005 # l2 coefficient
"attn_dropout": 0.6 # Attention dropout ratio
"feature_dropout":0.6 # Feature dropout ratio
```
# Running the example
## Usage
After Dataset is correctly generated.
```
# run train with cora dataset, DATASET_NAME is cora
sh run_train.sh [DATASET_NAME]
```
## Result
Training result will be stored in the scripts path, whose folder name begins with "train". You can find the result like the followings in log.
```
Epoch:0, train loss=1.98498 train acc=0.17143 | val loss=1.97946 val acc=0.27200
Epoch:1, train loss=1.98345 train acc=0.15000 | val loss=1.97233 val acc=0.32600
Epoch:2, train loss=1.96968 train acc=0.21429 | val loss=1.96747 val acc=0.37400
Epoch:3, train loss=1.97061 train acc=0.20714 | val loss=1.96410 val acc=0.47600
Epoch:4, train loss=1.96864 train acc=0.13571 | val loss=1.96066 val acc=0.59600
...
Epoch:195, train loss=1.45111 train_acc=0.56429 | val_loss=1.44325 val_acc=0.81200
Epoch:196, train loss=1.52476 train_acc=0.52143 | val_loss=1.43871 val_acc=0.81200
Epoch:197, train loss=1.35807 train_acc=0.62857 | val_loss=1.43364 val_acc=0.81400
Epoch:198, train loss=1.47566 train_acc=0.51429 | val_loss=1.42948 val_acc=0.81000
Epoch:199, train loss=1.56411 train_acc=0.55000 | val_loss=1.42632 val_acc=0.80600
Test loss=1.5366285, test acc=0.84199995
...
```
Results on Cora dataset is shown by table below:
| | MindSpore + Ascend910 | Tensorflow + V100 |
| ------------------------------------ | --------------------: | ----------------: |
| Accuracy | 0.830933271 | 0.828649968 |
| Training Cost(200 epochs) | 27.62298311s | 36.711862s |
| End to End Training Cost(200 epochs) | 39.074s | 50.894s |
# Description of random situation
GAT model contains lots of dropout operations, if you want to disable dropout, set the attn_dropout and feature_dropout to 0 in src/config.py. Note that this operation will cause the accuracy drop to approximately 80%.
# Others
GAT model is verified on Ascend environment, not on CPU or GPU.

View File

@ -0,0 +1,54 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_train.sh [SRC_PATH] [DATASET_NAME]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
SRC_PATH=$(get_real_path $1)
echo $SRC_PATH
DATASET_NAME=$2
echo $DATASET_NAME
if [ ! -d data_mr ]; then
mkdir data_mr
else
echo data_mr exist
fi
MINDRECORD_PATH=`pwd`/data_mr
rm -f $MINDRECORD_PATH/*
cd ../../../example/graph_to_mindrecord || exit
python writer.py --mindrecord_script $DATASET_NAME \
--mindrecord_file "$MINDRECORD_PATH/$DATASET_NAME" \
--mindrecord_partitions 1 \
--mindrecord_header_size_by_bit 18 \
--mindrecord_page_size_by_bit 20 \
--graph_api_args "$SRC_PATH"
cd - || exit

View File

@ -0,0 +1,54 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
then
echo "Usage: sh run_train.sh [DATASET_NAME]"
exit 1
fi
DATASET_NAME=$1
echo $DATASET_NAME
ulimit -u unlimited
export DEVICE_NUM=1
export RANK_SIZE=$DEVICE_NUM
export DEVICE_ID=0
export RANK_ID=0
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
env > env.log
echo "start training for device $DEVICE_ID"
if [ $DATASET_NAME == cora ]
then
python train.py --data_dir=../data_mr/$DATASET_NAME &> log &
fi
if [ $DATASET_NAME == citeseer ]
then
python train.py --data_dir=../data_mr/$DATASET_NAME --train_nodes_num=120 &> log &
fi
cd ..

View File

View File

@ -0,0 +1,26 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Train configs for training gat"""
class GatConfig():
lr = 0.005
num_epochs = 200
hid_units = [8]
n_heads = [8, 1]
early_stopping = 100
l2_coeff = 0.0005
attn_dropout = 0.6
feature_dropout = 0.6

View File

@ -0,0 +1,87 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Preprocess data obtained for training"""
import numpy as np
import mindspore.dataset as ds
def adj_to_bias(adj):
"""Add self loop to adj and make sure only one hop neighbors are engaged in computing"""
num_graphs = adj.shape[0]
adj_temp = np.empty(adj.shape)
for i in range(num_graphs):
adj_temp[i] = adj[i] + np.eye(adj.shape[1])
return -1e9 * (1.0 - adj_temp)
def get_biases_features_labels(data_dir):
"""Get biases, features, labels from Dataset"""
g = ds.GraphData(data_dir)
nodes = g.get_all_nodes(0)
nodes_list = nodes.tolist()
row_tensor = g.get_node_feature(nodes_list, [1, 2])
features = row_tensor[0]
features = features[np.newaxis]
labels = row_tensor[1]
nodes_num = labels.shape[0]
class_num = labels.max() + 1
labels_onehot = np.eye(nodes_num, class_num)[labels].astype(np.float32)
neighbor = g.get_all_neighbors(nodes_list, 0)
node_map = {node_id: index for index, node_id in enumerate(nodes_list)}
adj = np.zeros([nodes_num, nodes_num], dtype=np.float32)
for index, value in np.ndenumerate(neighbor):
if value >= 0 and index[1] > 0:
adj[node_map[neighbor[index[0], 0]], node_map[value]] = 1
adj = adj[np.newaxis]
biases = adj_to_bias(adj)
return biases, features, labels_onehot
def get_mask(total, begin, end):
"""Generate mask according to begin and end position"""
mask = np.zeros([total]).astype(np.float32)
mask[begin:end] = 1
return np.array(mask, dtype=np.bool)
def load_and_process(data_dir, train_node_num, eval_node_num, test_node_num):
"""Load cora dataset and preprocessing"""
biases, feature, label = get_biases_features_labels(data_dir)
# split training, validation and testing set
nodes_num = label.shape[0]
train_mask = get_mask(nodes_num, 0, train_node_num)
eval_mask = get_mask(nodes_num, train_node_num, train_node_num + eval_node_num)
test_mask = get_mask(nodes_num, nodes_num - test_node_num, nodes_num)
y_train = np.zeros(label.shape)
y_val = np.zeros(label.shape)
y_test = np.zeros(label.shape)
y_train[train_mask, :] = label[train_mask, :]
y_val[eval_mask, :] = label[eval_mask, :]
y_test[test_mask, :] = label[test_mask, :]
y_train = y_train[np.newaxis]
y_val = y_val[np.newaxis]
y_test = y_test[np.newaxis]
train_mask = train_mask[np.newaxis]
eval_mask = eval_mask[np.newaxis]
test_mask = test_mask[np.newaxis]
return feature, biases, y_train, train_mask, y_val, eval_mask, y_test, test_mask

496
model_zoo/gat/src/gat.py Normal file
View File

@ -0,0 +1,496 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Aggregator."""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore._extends import cell_attr_register
from mindspore import Tensor, Parameter
from mindspore.common.initializer import initializer
from mindspore._checkparam import check_int_positive, check_bool
from mindspore.nn.layer.activation import get_activation
class GNNFeatureTransform(nn.Cell):
r"""
The GNN featuren transform layer for input.
Applies linear transformation for the input feature. This layer implements the operation as:
.. math::
\text{outputs} = \text{inputs} * \text{kernel} + \text{bias},
where :math:`\text{activation}` is the activation function passed as the activation
argument (if passed in),:math:`\text{activation}` is a weight matrix with the same
data type as the inputs created by the layer, and :math:`\text{bias}` is a bias vector
with the same data type as the inputs created by the layer (only if has_bias is True).
Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
Raises:
ValueError: If weight_init or bias_init shape is incorrect.
Inputs:
- **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(*B, N, C)`,
where :math:`*B` represents the batch size which can be multidimensional, :math:`N` and :math:`C` are the
size of the last two dimensions. If `transpose_a` is True, its shape should be :math:`(*B, C, N)`.
Outputs:
Tensor, the shape of the output tensor is :math:`(*B, N, M)`.
Examples:
>>> net = nn.Dense(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> net(input)
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
[ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
"""
@cell_attr_register
def __init__(self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
has_bias=True):
super(GNNFeatureTransform, self).__init__()
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias)
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
weight_init.shape()[1] != in_channels:
raise ValueError("weight_init shape error")
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
self.matmul = P.MatMul(transpose_b=True)
self.bias_add = P.BiasAdd()
def construct(self, x):
tensor_shape = F.shape(x)
input_feature = F.reshape(x, (tensor_shape[0] * tensor_shape[1], tensor_shape[2]))
output = self.matmul(input_feature, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
output = F.reshape(output, (tensor_shape[0], tensor_shape[1], self.out_channels))
return output
def extend_repr(self):
str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \
.format(self.in_channels, self.out_channels, self.weight, self.has_bias)
if self.has_bias:
str_info = str_info + ', bias={}'.format(self.bias)
return str_info
class _BaseAggregator(nn.Cell):
"""
Base Aggregator of GNN
Args:
feature_in_dim (int): Node or edge input feature dim.
feature_out_dim (int): Node or edge outpout feature dim.
use_fc (bool): Specifies whether a linear transformation before message is aggregated. Default: True
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
dropout_ratio (float): The keep rate of dropout layer, greater than 0 and less equal than 1. Default: None.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
Examples:
>>> class MyAggregator(_BaseAggregator):
>>> def __init__(self):
>>> super(MyAggregator, self).__init__(self, feature_in_dim, feature_out_dim)
>>> self.reduce_mean = P.ReduceSum()
>>>
>>> def construct(self, x):
>>> return self.reduce_mean(x, 1)
"""
def __init__(self,
feature_in_dim,
feature_out_dim,
use_fc=True,
weight_init="normal",
bias_init="zeros",
has_bias=True,
dropout_ratio=None,
activation=None):
super(_BaseAggregator, self).__init__()
self.in_dim = feature_in_dim
self.out_dim = feature_out_dim
self.use_fc = use_fc
if self.use_fc:
self.weight_init = weight_init
self.bias_init = bias_init
self.has_bias = has_bias
self.fc = GNNFeatureTransform(self.in_dim,
self.out_dim,
weight_init=self.weight_init,
bias_init=self.bias_init,
has_bias=self.has_bias)
self.dropout_ratio = dropout_ratio
if self.dropout_ratio is not None:
self.dropout = nn.Dropout(keep_prob=self.dropout_ratio)
self.dropout_flag = self.dropout_ratio is not None
self.activation = get_activation(activation)
self.activation_flag = self.activation is not None
def construct(self, **kward):
"""Must be overridden by all subclasses."""
raise NotImplementedError
class MeanAggregator(_BaseAggregator):
"""
Mean Aggregator of GNN
Args:
feature_in_dim (int): Node or edge input feature dim.
feature_out_dim (int): Node or edge outpout feature dim.
use_fc (bool): Specifies whether a linear transformation before message is aggregated. Default: True
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
dropout_ratio (float): The keep rate of dropout layer, greater than 0 and less equal than 1. Default: None.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
Examples:
>>> net = MeanAggregator(32, 64, activation="relu", dropout=0.5)
>>> input_data = Tensor(np.array(np.random.rand(32, 3, 32), dtypy=np.float32))
>>> output = net(input_data)
"""
def __init__(self,
feature_in_dim,
feature_out_dim,
use_fc=True,
weight_init="normal",
bias_init="zeros",
has_bias=True,
dropout_ratio=None,
activation=None):
super(MeanAggregator, self).__init__(
feature_in_dim,
feature_out_dim,
use_fc,
weight_init,
bias_init,
has_bias,
dropout_ratio,
activation)
self.reduce_mean = P.ReduceMean(keep_dims=False)
def construct(self, input_feature):
if self.use_fc:
input_feature = self.fc(input_feature)
if self.dropout_flag:
input_feature = self.dropout(input_feature)
if self.activation_flag:
input_feature = self.activation(input_feature)
output_feature = self.reduce_mean(input_feature, 1)
return output_feature
class AttentionHead(nn.Cell):
"""
Attention Head for Graph Attention Networks.
Args:
in_channel (int): The number of input channel, input feature dim.
out_channel (int): The number of output channel, output feature dim.
in_drop_ratio (float): Input feature dropout ratio, default 0.0.
coef_drop_ratio (float): Coefficient dropout ratio, default 0.0.
residual (bool): Whether to use residual connection, default False.
coef_activation (Cell): The attention coefficient activation function,
default nn.LeakyReLU().
activation (Cell): The output activation function, default nn.ELU().
Inputs:
- **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim).
- **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes).
Examples:
>>> head = AttentionHead(1433,
8,
in_drop_ratio=0.6,
coef_drop_ratio=0.6,
residual=False)
>>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtypy=np.float32))
>>> output = net(input_data)
"""
def __init__(self,
in_channel,
out_channel,
in_drop_ratio=0.0,
coef_drop_ratio=0.0,
residual=False,
coef_activation=nn.LeakyReLU(),
activation=nn.ELU()):
super(AttentionHead, self).__init__()
self.in_channel = check_int_positive(in_channel)
self.out_channel = check_int_positive(out_channel)
self.in_drop_ratio = in_drop_ratio
self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio)
self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio)
self.feature_transform = GNNFeatureTransform(
in_channels=self.in_channel,
out_channels=self.out_channel,
has_bias=False,
weight_init='XavierUniform')
self.f_1_transform = GNNFeatureTransform(
in_channels=self.out_channel,
out_channels=1,
weight_init='XavierUniform')
self.f_2_transform = GNNFeatureTransform(
in_channels=self.out_channel,
out_channels=1,
weight_init='XavierUniform')
self.softmax = nn.Softmax()
self.coef_drop = nn.Dropout(keep_prob=1 - coef_drop_ratio)
self.matmul = P.MatMul()
self.bias_add = P.BiasAdd()
self.bias = Parameter(initializer('zeros', self.out_channel), name='bias')
self.residual = check_bool(residual)
if self.residual:
if in_channel != out_channel:
self.residual_transform_flag = True
self.residual_transform = GNNFeatureTransform(
in_channels=self.in_channel,
out_channels=self.out_channel)
else:
self.residual_transform = None
self.coef_activation = coef_activation
self.activation = activation
def construct(self, input_feature, bias_mat, training=True):
if training is True:
input_feature = self.in_drop(input_feature)
feature = self.feature_transform(input_feature)
# self attention
f_1 = self.f_1_transform(feature)
f_2 = self.f_2_transform(feature)
logits = f_1 + P.Transpose()(f_2, (0, 2, 1))
logits = self.coef_activation(logits) + bias_mat
coefs = self.softmax(logits)
if training is True:
coefs = self.coef_drop(coefs)
feature = self.in_drop_2(feature)
coefs = P.Squeeze(0)(coefs)
feature = P.Squeeze(0)(feature)
ret = self.matmul(coefs, feature)
ret = self.bias_add(ret, self.bias)
ret = P.ExpandDims()(ret, 0)
# residual connection
if self.residual:
if self.residual_transform_flag:
res = self.residual_transform(input_feature)
ret = ret + res
else:
ret = ret + input_feature
# activation
if self.activation is not None:
ret = self.activation(ret)
return ret
class AttentionAggregator(nn.Cell):
"""
Attention Head for Graph Attention Networkscan be regarded as one
GAT layer.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
num_heads (int): Number of attention heads for this layer, default 1.
in_drop_ratio (float): Input feature dropout ratio, default 0.0.
coef_drop_ratio (float): Coefficient dropout ratio, default 0.0.
activation (Cell): The output activation function, default nn.ELU().
residual (bool): Whether to use residual connection, default False.
output_transform (str['concat', 'sum']): output transform for a layer,
default 'concat'
Inputs:
- **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim).
- **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes).
Examples:
>>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtype=np.float32))
>>> biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32))
>>> net = AttentionAggregator(1433,
8,
8)
>>> net(input_data, biases)
"""
def __init__(self,
in_channels,
out_channels,
num_heads=1,
in_drop=0.0,
coef_drop=0.0,
activation=nn.ELU(),
residual=False,
output_transform='concat'):
super(AttentionAggregator, self).__init__()
self.num_heads = num_heads
self.attns = []
for _ in range(num_heads):
self.attns.append(AttentionHead(in_channels,
out_channels,
in_drop_ratio=in_drop,
coef_drop_ratio=coef_drop,
activation=activation,
residual=residual))
self.attns = nn.layer.CellList(self.attns)
if output_transform == 'concat':
self.out_trans = P.Concat(-1)
elif output_transform == 'sum':
self.out_trans = P.AddN()
else:
raise ValueError("output_transform must be either 'concat' or 'sum'")
def construct(self, input_data, bias_mat, training=True):
res = ()
for i in range(self.num_heads):
res += (self.attns[i](input_data, bias_mat, training),)
return self.out_trans(res)
class GAT(nn.Cell):
"""
Graph Attention Network
Args:
ftr_dims (int): Initial feature dimensions.
num_class (int): Num of class to identify.
num_nodes (int): Num of nodes in this graph.
hidden_units (list[int]): Num of hidden units at each layer.
num_heads (list[int]): Num of heads at each layer.
attn_drop (float): Drop out ratio of attention coefficient,
default 0.0.
ftr_drop (float): Drop out ratio of feature, default 0.0.
activation (Cell): Activation Function for output layer, default
nn.Elu().
residual (bool): Whether to use residual connection between
intermediate layers, default False.
Examples:
>>> ft_sizes = 1433
>>> num_class = 7
>>> num_nodes = 2708
>>> hid_units = [8]
>>> n_heads = [8, 1]
>>> activation = nn.ELU()
>>> residual = False
>>> input_data = np.array(np.random.rand(1, 2708, 1433))
>>> biases = np.array(np.random.rand(1, 2708, 2708))
>>> net = GAT(ft_sizes,
num_class,
num_nodes,
hidden_units=hid_units,
num_heads=n_heads,
attn_drop=0.6,
ftr_drop=0.6,
activation=activation,
residual=residual)
>>> output = net(input_data, biases)
"""
def __init__(self,
features,
biases,
ftr_dims,
num_class,
num_nodes,
hidden_units,
num_heads,
attn_drop=0.0,
ftr_drop=0.0,
activation=nn.ELU(),
residual=False):
super(GAT, self).__init__()
self.features = Tensor(features)
self.biases = Tensor(biases)
self.ftr_dims = check_int_positive(ftr_dims)
self.num_class = check_int_positive(num_class)
self.num_nodes = check_int_positive(num_nodes)
self.hidden_units = hidden_units
self.num_heads = num_heads
self.attn_drop = attn_drop
self.ftr_drop = ftr_drop
self.activation = activation
self.residual = check_bool(residual)
self.layers = []
# first layer
self.layers.append(AttentionAggregator(
self.ftr_dims,
self.hidden_units[0],
self.num_heads[0],
self.ftr_drop,
self.attn_drop,
self.activation,
residual=False))
# intermediate layer
for i in range(1, len(self.hidden_units)):
self.layers.append(AttentionAggregator(
self.hidden_units[i-1]*self.num_heads[i-1],
self.hidden_units[i],
self.num_heads[i],
self.ftr_drop,
self.attn_drop,
self.activation,
residual=self.residual))
# output layer
self.layers.append(AttentionAggregator(
self.hidden_units[-1]*self.num_heads[-2],
self.num_class,
self.num_heads[-1],
self.ftr_drop,
self.attn_drop,
activation=None,
residual=False,
output_transform='sum'))
self.layers = nn.layer.CellList(self.layers)
def construct(self, training=True):
input_data = self.features
bias_mat = self.biases
for cell in self.layers:
input_data = cell(input_data, bias_mat, training)
return input_data/self.num_heads[-1]

178
model_zoo/gat/src/utils.py Normal file
View File

@ -0,0 +1,178 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils for training gat"""
from mindspore import nn
from mindspore.common.parameter import ParameterTuple
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class MaskedSoftMaxLoss(nn.Cell):
"""Calculate masked softmax loss with l2 loss"""
def __init__(self, num_class, label, mask, l2_coeff, params):
super(MaskedSoftMaxLoss, self).__init__()
self.num_class = num_class
self.label = label
self.mask = mask
self.softmax = P.SoftmaxCrossEntropyWithLogits()
self.reduce_mean = P.ReduceMean()
self.cast = P.Cast()
self.l2_coeff = l2_coeff
self.params = ParameterTuple(list(param for param in params if param.name[-4:] != 'bias'))
self.reduce_sum = P.ReduceSum()
self.num_params = len(self.params)
def construct(self, logits):
# calc l2 loss
l2_loss = 0
for i in range(self.num_params):
l2_loss = l2_loss + self.l2_coeff * P.L2Loss()(self.params[i])
logits = P.Reshape()(logits, (-1, self.num_class))
label = P.Reshape()(self.label, (-1, self.num_class))
mask = P.Reshape()(self.mask, (-1,))
logits = self.cast(logits, mstype.float32)
loss = self.softmax(logits, label)[0]
mask /= self.reduce_mean(mask)
loss *= mask
loss = self.reduce_mean(loss)
l2_loss = P.Cast()(l2_loss, mstype.float32)
return loss+l2_loss
class MaskedAccuracy(nn.Cell):
"""Calculate accuracy with mask"""
def __init__(self, num_class, label, mask):
super(MaskedAccuracy, self).__init__()
self.argmax = P.Argmax(axis=1)
self.cast = P.Cast()
self.reduce_mean = P.ReduceMean()
self.equal = P.Equal()
self.num_class = num_class
self.label = Tensor(label, dtype=mstype.float32)
self.mask = Tensor(mask, dtype=mstype.float32)
def construct(self, logits):
logits = P.Reshape()(logits, (-1, self.num_class))
labels = P.Reshape()(self.label, (-1, self.num_class))
mask = P.Reshape()(self.mask, (-1,))
labels = self.cast(labels, mstype.float32)
correct_prediction = self.equal(self.argmax(logits), self.argmax(labels))
accuracy_all = self.cast(correct_prediction, mstype.float32)
mask = self.cast(mask, mstype.float32)
mask /= self.reduce_mean(mask)
accuracy_all *= mask
return self.reduce_mean(accuracy_all)
class LossAccuracyWrapper(nn.Cell):
"""
Warp GAT model with loss calculation and accuracy calculation, loss is calculated with l2 loss.
Args:
network (Cell): GAT network with logits calculation as output.
num_class (int): num of class for classification.
label (numpy.ndarray): Train Dataset label.
mask (numpy.ndarray): Train Dataset mask.
l2_coeff (float): l2 loss discount rate.
"""
def __init__(self, network, num_class, label, mask, l2_coeff):
super(LossAccuracyWrapper, self).__init__()
self.network = network
label = Tensor(label, dtype=mstype.float32)
mask = Tensor(mask, dtype=mstype.float32)
self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, self.network.trainable_params())
self.acc_func = MaskedAccuracy(num_class, label, mask)
def construct(self):
logits = self.network(training=False)
loss = self.loss_func(logits)
accuracy = self.acc_func(logits)
return loss, accuracy
class LossNetWrapper(nn.Cell):
"""Wrap GAT model with loss calculation"""
def __init__(self, network, num_class, label, mask, l2_coeff):
super(LossNetWrapper, self).__init__()
self.network = network
label = Tensor(label, dtype=mstype.float32)
mask = Tensor(mask, dtype=mstype.float32)
params = list(param for param in self.network.trainable_params() if param.name[-4:] != 'bias')
self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, params)
def construct(self):
logits = self.network()
loss = self.loss_func(logits)
return loss
class TrainOneStepCell(nn.Cell):
"""
For network training. Warp the loss net with optimizer.
Args:
network (Cell): GAT network with loss calculation as the output.
optimizer (Cell): Optimizer for minimize the loss.
sens (Float): Backpropagation input number, default 1.0.
"""
def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCell, self).__init__(auto_prefix=True)
self.network = network
self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.sens = sens
def construct(self):
weights = self.weights
loss = self.network()
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(sens)
return F.depend(loss, self.optimizer(grads))
class TrainGAT(nn.Cell):
"""
Warp GAT model with everything needed for training, include loss, optimizer ,etc.
Args:
network (Cell): GAT network.
num_class (int): num of class for classification.
label (numpy.ndarray): Train Dataset label.
mask (numpy.ndarray): Train Dataset mask.
learning_rate (float): Learning rate.
l2_coeff (float): l2 loss discount rate.
"""
def __init__(self, network, num_class, label, mask, learning_rate, l2_coeff):
super(TrainGAT, self).__init__(auto_prefix=False)
self.network = network
loss_net = LossNetWrapper(network, num_class, label, mask, l2_coeff)
optimizer = nn.Adam(loss_net.trainable_params(),
learning_rate=learning_rate)
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
self.accuracy_func = MaskedAccuracy(num_class, label, mask)
def construct(self):
loss = self.loss_train_net()
accuracy = self.accuracy_func(self.network())
return loss, accuracy

131
model_zoo/gat/train.py Normal file
View File

@ -0,0 +1,131 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test train gat"""
import argparse
import os
import numpy as np
import mindspore.context as context
from mindspore.train.serialization import _exec_save_checkpoint, load_checkpoint
from src.config import GatConfig
from src.dataset import load_and_process
from src.gat import GAT
from src.utils import LossAccuracyWrapper, TrainGAT
def train():
"""Train GAT model."""
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Data dir')
parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training')
parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation')
parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test')
args = parser.parse_args()
if not os.path.exists("ckpts"):
os.mkdir("ckpts")
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
save_graphs=False)
# train parameters
hid_units = GatConfig.hid_units
n_heads = GatConfig.n_heads
early_stopping = GatConfig.early_stopping
lr = GatConfig.lr
l2_coeff = GatConfig.l2_coeff
num_epochs = GatConfig.num_epochs
feature, biases, y_train, train_mask, y_val, eval_mask, y_test, test_mask = load_and_process(args.data_dir,
args.train_nodes_num,
args.eval_nodes_num,
args.test_nodes_num)
feature_size = feature.shape[2]
num_nodes = feature.shape[1]
num_class = y_train.shape[2]
gat_net = GAT(feature,
biases,
feature_size,
num_class,
num_nodes,
hid_units,
n_heads,
attn_drop=GatConfig.attn_dropout,
ftr_drop=GatConfig.feature_dropout)
gat_net.add_flags_recursive(fp16=True)
eval_net = LossAccuracyWrapper(gat_net,
num_class,
y_val,
eval_mask,
l2_coeff)
train_net = TrainGAT(gat_net,
num_class,
y_train,
train_mask,
lr,
l2_coeff)
train_net.set_train(True)
val_acc_max = 0.0
val_loss_min = np.inf
for _epoch in range(num_epochs):
train_result = train_net()
train_loss = train_result[0].asnumpy()
train_acc = train_result[1].asnumpy()
eval_result = eval_net()
eval_loss = eval_result[0].asnumpy()
eval_acc = eval_result[1].asnumpy()
print("Epoch:{}, train loss={:.5f}, train acc={:.5f} | val loss={:.5f}, val acc={:.5f}".format(
_epoch, train_loss, train_acc, eval_loss, eval_acc))
if eval_acc >= val_acc_max or eval_loss < val_loss_min:
if eval_acc >= val_acc_max and eval_loss < val_loss_min:
val_acc_model = eval_acc
val_loss_model = eval_loss
_exec_save_checkpoint(train_net.network, "ckpts/gat.ckpt")
val_acc_max = np.max((val_acc_max, eval_acc))
val_loss_min = np.min((val_loss_min, eval_loss))
curr_step = 0
else:
curr_step += 1
if curr_step == early_stopping:
print("Early Stop Triggered!, Min loss: {}, Max accuracy: {}".format(val_loss_min, val_acc_max))
print("Early stop model validation loss: {}, accuracy{}".format(val_loss_model, val_acc_model))
break
gat_net_test = GAT(feature,
biases,
feature_size,
num_class,
num_nodes,
hid_units,
n_heads,
attn_drop=0.0,
ftr_drop=0.0)
load_checkpoint("ckpts/gat.ckpt", net=gat_net_test)
gat_net_test.add_flags_recursive(fp16=True)
test_net = LossAccuracyWrapper(gat_net_test,
num_class,
y_test,
test_mask,
l2_coeff)
test_result = test_net()
print("Test loss={}, test acc={}".format(test_result[0], test_result[1]))
if __name__ == "__main__":
train()

View File

@ -110,4 +110,4 @@ Epoch: 0200 train_loss= 0.57948 train_acc= 0.96429 val_loss= 1.04753 val_acc= 0.
Optimization Finished!
Test set results: cost= 1.00983 accuracy= 0.81300 time= 0.39083
...
```
```