mindspore/tests/st/gnn/gat.py

119 lines
4.3 KiB
Python

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Graph Attention Networks."""
import mindspore.nn as nn
from mindspore._checkparam import Validator
from aggregator import AttentionAggregator
class GAT(nn.Cell):
"""
Graph Attention Network
Args:
ftr_dims (int): Initial feature dimensions.
num_class (int): Num of class to identify.
num_nodes (int): Num of nodes in this graph.
hidden_units (list[int]): Num of hidden units at each layer.
num_heads (list[int]): Num of heads at each layer.
attn_drop (float): Drop out ratio of attention coefficient,
default 0.0.
ftr_drop (float): Drop out ratio of feature, default 0.0.
activation (Cell): Activation Function for output layer, default
nn.Elu().
residual (bool): Whether to use residual connection between
intermediate layers, default False.
Examples:
>>> ft_sizes = 1433
>>> num_class = 7
>>> num_nodes = 2708
>>> hid_units = [8]
>>> n_heads = [8, 1]
>>> activation = nn.ELU()
>>> residual = False
>>> input_data = Tensor(
np.array(np.random.rand(1, 2708, 1433), dtype=np.float32))
>>> biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32))
>>> net = GAT(ft_sizes,
num_class,
num_nodes,
hidden_units=hid_units,
num_heads=n_heads,
attn_drop=0.6,
ftr_drop=0.6,
activation=activation,
residual=residual)
>>> output = net(input_data, biases)
"""
def __init__(self,
ftr_dims,
num_class,
num_nodes,
hidden_units,
num_heads,
attn_drop=0.0,
ftr_drop=0.0,
activation=nn.ELU(),
residual=False):
super(GAT, self).__init__()
self.ftr_dims = Validator.check_positive_int(ftr_dims)
self.num_class = Validator.check_positive_int(num_class)
self.num_nodes = Validator.check_positive_int(num_nodes)
self.hidden_units = hidden_units
self.num_heads = num_heads
self.attn_drop = attn_drop
self.ftr_drop = ftr_drop
self.activation = activation
self.residual = Validator.check_bool(residual)
self.layers = []
# first layer
self.layers.append(AttentionAggregator(
self.ftr_dims,
self.hidden_units[0],
self.num_heads[0],
self.ftr_drop,
self.attn_drop,
self.activation,
residual=False))
# intermediate layer
for i in range(1, len(self.hidden_units)):
self.layers.append(AttentionAggregator(
self.hidden_units[i-1]*self.num_heads[i-1],
self.hidden_units[i],
self.num_heads[i],
self.ftr_drop,
self.attn_drop,
self.activation,
residual=self.residual))
# output layer
self.layers.append(AttentionAggregator(
self.hidden_units[-1]*self.num_heads[-2],
self.num_class,
self.num_heads[-1],
self.ftr_drop,
self.attn_drop,
activation=None,
residual=False,
output_transform='sum'))
self.layers = nn.layer.CellList(self.layers)
def construct(self, input_data, bias_mat):
for cell in self.layers:
input_data = cell(input_data, bias_mat)
return input_data/self.num_heads[-1]