From 9fbc519ebb3e5a9245381506081cd7c3bd9a7698 Mon Sep 17 00:00:00 2001 From: zhangdengcheng Date: Thu, 14 May 2020 08:12:54 +0000 Subject: [PATCH] Add graph attention networks model and test file --- tests/st/gnn/aggregator.py | 16 ++++- tests/st/gnn/gat.py | 118 +++++++++++++++++++++++++++++++++ tests/st/gnn/test_gat_model.py | 47 +++++++++++++ 3 files changed, 178 insertions(+), 3 deletions(-) create mode 100644 tests/st/gnn/gat.py create mode 100644 tests/st/gnn/test_gat_model.py diff --git a/tests/st/gnn/aggregator.py b/tests/st/gnn/aggregator.py index 5e208a2329a..d04cf1b57b6 100644 --- a/tests/st/gnn/aggregator.py +++ b/tests/st/gnn/aggregator.py @@ -319,7 +319,8 @@ class AttentionHead(nn.Cell): else: ret = ret + input_feature # activation - ret = self.activation(ret) + if self.activation is not None: + ret = self.activation(ret) return ret @@ -336,6 +337,8 @@ class AttentionAggregator(nn.Cell): coef_drop_ratio (float): Coefficient dropout ratio, default 0.0. activation (Cell): The output activation function, default nn.ELU(). residual (bool): Whether to use residual connection, default False. + output_transform (str['concat', 'sum']): output transform for a layer, + default 'concat' Inputs: - **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim). @@ -356,7 +359,8 @@ class AttentionAggregator(nn.Cell): in_drop=0.0, coef_drop=0.0, activation=nn.ELU(), - residual=False): + residual=False, + output_transform='concat'): super(AttentionAggregator, self).__init__() self.num_heads = num_heads self.attns = [] @@ -368,9 +372,15 @@ class AttentionAggregator(nn.Cell): activation=activation, residual=residual)) self.attns = nn.layer.CellList(self.attns) + if output_transform == 'concat': + self.out_trans = P.Concat(-1) + elif output_transform == 'sum': + self.out_trans = P.AddN() + else: + raise ValueError def construct(self, input_data, bias_mat): res = () for i in range(self.num_heads): res += (self.attns[i](input_data, bias_mat),) - return P.Concat(-1)(res) + return self.out_trans(res) diff --git a/tests/st/gnn/gat.py b/tests/st/gnn/gat.py new file mode 100644 index 00000000000..a386f562c78 --- /dev/null +++ b/tests/st/gnn/gat.py @@ -0,0 +1,118 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Graph Attention Networks.""" +import mindspore.nn as nn +from mindspore._checkparam import check_bool, check_int_positive + +from aggregator import AttentionAggregator + + +class GAT(nn.Cell): + """ + Graph Attention Network + + Args: + ftr_dims (int): Initial feature dimensions. + num_class (int): Num of class to identify. + num_nodes (int): Num of nodes in this graph. + hidden_units (list[int]): Num of hidden units at each layer. + num_heads (list[int]): Num of heads at each layer. + attn_drop (float): Drop out ratio of attention coefficient, + default 0.0. + ftr_drop (float): Drop out ratio of feature, default 0.0. + activation (Cell): Activation Function for output layer, default + nn.Elu(). + residual (bool): Whether to use residual connection between + intermediate layers, default False. + + Examples: + >>> ft_sizes = 1433 + >>> num_class = 7 + >>> num_nodes = 2708 + >>> hid_units = [8] + >>> n_heads = [8, 1] + >>> activation = nn.ELU() + >>> residual = False + >>> input_data = Tensor( + np.array(np.random.rand(1, 2708, 1433), dtype=np.float32)) + >>> biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32)) + >>> net = GAT(ft_sizes, + num_class, + num_nodes, + hidden_units=hid_units, + num_heads=n_heads, + attn_drop=0.6, + ftr_drop=0.6, + activation=activation, + residual=residual) + >>> output = net(input_data, biases) + """ + + def __init__(self, + ftr_dims, + num_class, + num_nodes, + hidden_units, + num_heads, + attn_drop=0.0, + ftr_drop=0.0, + activation=nn.ELU(), + residual=False): + super(GAT, self).__init__() + self.ftr_dims = check_int_positive(ftr_dims) + self.num_class = check_int_positive(num_class) + self.num_nodes = check_int_positive(num_nodes) + self.hidden_units = hidden_units + self.num_heads = num_heads + self.attn_drop = attn_drop + self.ftr_drop = ftr_drop + self.activation = activation + self.residual = check_bool(residual) + self.layers = [] + # first layer + self.layers.append(AttentionAggregator( + self.ftr_dims, + self.hidden_units[0], + self.num_heads[0], + self.ftr_drop, + self.attn_drop, + self.activation, + residual=False)) + # intermediate layer + for i in range(1, len(self.hidden_units)): + self.layers.append(AttentionAggregator( + self.hidden_units[i-1]*self.num_heads[i-1], + self.hidden_units[i], + self.num_heads[i], + self.ftr_drop, + self.attn_drop, + self.activation, + residual=self.residual)) + # output layer + self.layers.append(AttentionAggregator( + self.hidden_units[-1]*self.num_heads[-2], + self.num_class, + self.num_heads[-1], + self.ftr_drop, + self.attn_drop, + activation=None, + residual=False, + output_transform='sum')) + self.layers = nn.layer.CellList(self.layers) + + def construct(self, input_data, bias_mat): + for cell in self.layers: + input_data = cell(input_data, bias_mat) + return input_data/self.num_heads[-1] diff --git a/tests/st/gnn/test_gat_model.py b/tests/st/gnn/test_gat_model.py new file mode 100644 index 00000000000..ed511481ccc --- /dev/null +++ b/tests/st/gnn/test_gat_model.py @@ -0,0 +1,47 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test gat model.""" +import numpy as np + +import mindspore.nn as nn +import mindspore.context as context +from mindspore import Tensor +from mindspore.common.api import _executor +from gat import GAT + +context.set_context(mode=context.GRAPH_MODE) + + +def test_GAT(): + ft_sizes = 1433 + num_class = 7 + num_nodes = 2708 + hid_units = [8] + n_heads = [8, 1] + activation = nn.ELU() + residual = False + input_data = Tensor( + np.array(np.random.rand(1, 2708, 1433), dtype=np.float32)) + biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32)) + net = GAT(ft_sizes, + num_class, + num_nodes, + hidden_units=hid_units, + num_heads=n_heads, + attn_drop=0.6, + ftr_drop=0.6, + activation=activation, + residual=residual) + _executor.compile(net, input_data, biases)