From 825d9740fa35abac3fece867ecc3dea1af6e7004 Mon Sep 17 00:00:00 2001 From: zhangdengcheng Date: Thu, 30 Apr 2020 03:23:58 +0000 Subject: [PATCH] Fixed the bug that mean aggregator argument can not pass to base class and add attention head for GAT --- tests/st/gnn/aggregator.py | 170 ++++++++++++++++++++++++++-- tests/st/gnn/test_gnn_aggregator.py | 21 +++- 2 files changed, 182 insertions(+), 9 deletions(-) diff --git a/tests/st/gnn/aggregator.py b/tests/st/gnn/aggregator.py index 18f189d979f..5e208a2329a 100644 --- a/tests/st/gnn/aggregator.py +++ b/tests/st/gnn/aggregator.py @@ -64,7 +64,7 @@ class GNNFeatureTransform(nn.Cell): [[ 2.5246444 2.2738023 0.5711005 -3.9399147 ] [ 1.0739875 4.0155234 0.94188046 -5.459526 ]] """ - @cell_attr_register(attrs=['has_bias', 'activation']) + @cell_attr_register def __init__(self, in_channels, out_channels, @@ -125,7 +125,7 @@ class _BaseAggregator(nn.Cell): same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. dropout_ratio (float): The keep rate of dropout layer, greater than 0 and less equal than 1. Default: None. - activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. + activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. Examples: >>> class MyAggregator(_BaseAggregator): @@ -203,12 +203,12 @@ class MeanAggregator(_BaseAggregator): super(MeanAggregator, self).__init__( feature_in_dim, feature_out_dim, - use_fc=True, - weight_init="normal", - bias_init="zeros", - has_bias=True, - dropout_ratio=None, - activation=None) + use_fc, + weight_init, + bias_init, + has_bias, + dropout_ratio, + activation) self.reduce_mean = P.ReduceMean(keep_dims=False) def construct(self, input_feature): @@ -220,3 +220,157 @@ class MeanAggregator(_BaseAggregator): input_feature = self.activation(input_feature) output_feature = self.reduce_mean(input_feature, 1) return output_feature + + +class AttentionHead(nn.Cell): + """ + Attention Head for Graph Attention Networks. + + Args: + in_channel (int): The number of input channel, input feature dim. + out_channel (int): The number of output channel, output feature dim. + in_drop_ratio (float): Input feature dropout ratio, default 0.0. + coef_drop_ratio (float): Coefficient dropout ratio, default 0.0. + residual (bool): Whether to use residual connection, default False. + coef_activation (Cell): The attention coefficient activation function, + default nn.LeakyReLU(). + activation (Cell): The output activation function, default nn.ELU(). + + Inputs: + - **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim). + - **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes). + + Examples: + >>> head = AttentionHead(1433, + 8, + in_drop_ratio=0.6, + coef_drop_ratio=0.6, + residual=False) + >>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtypy=np.float32)) + >>> output = net(input_data) + """ + + def __init__(self, + in_channel, + out_channel, + in_drop_ratio=0.0, + coef_drop_ratio=0.0, + residual=False, + coef_activation=nn.LeakyReLU(), + activation=nn.ELU()): + super(AttentionHead, self).__init__() + self.in_channel = check_int_positive(in_channel) + self.out_channel = check_int_positive(out_channel) + self.in_drop_ratio = in_drop_ratio + self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio) + self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio) + self.feature_transform = GNNFeatureTransform( + in_channels=self.in_channel, + out_channels=self.out_channel, + has_bias=False) + + self.f_1_transform = GNNFeatureTransform( + in_channels=self.out_channel, + out_channels=1) + self.f_2_transform = GNNFeatureTransform( + in_channels=self.out_channel, + out_channels=1) + self.softmax = nn.Softmax() + + self.coef_drop = nn.Dropout(keep_prob=1 - coef_drop_ratio) + self.batch_matmul = P.BatchMatMul() + self.bias_add = P.BiasAdd() + self.bias = Parameter(initializer('zeros', self.out_channel), name='bias') + self.residual = check_bool(residual) + if self.residual: + if in_channel != out_channel: + self.residual_transform_flag = True + self.residual_transform = GNNFeatureTransform( + in_channels=self.in_channel, + out_channels=self.out_channel) + else: + self.residual_transform = None + self.coef_activation = coef_activation + self.activation = activation + + def construct(self, input_feature, bias_mat): + input_feature = self.in_drop(input_feature) + + feature = self.feature_transform(input_feature) + # self attention following the author + f_1 = self.f_1_transform(feature) + f_2 = self.f_2_transform(feature) + logits = f_1 + P.Transpose()(f_2, (0, 2, 1)) + logits = self.coef_activation(logits) + bias_mat + coefs = self.softmax(logits) + + coefs = self.coef_drop(coefs) + feature = self.in_drop_2(feature) + + ret = self.batch_matmul(coefs, feature) + ret = P.Squeeze(0)(ret) + ret = self.bias_add(ret, self.bias) + ret = P.ExpandDims()(ret, 0) + # residual connection + if self.residual: + if self.residual_transform_flag: + res = self.residual_transform(input_feature) + ret = ret + res + else: + ret = ret + input_feature + # activation + ret = self.activation(ret) + return ret + + +class AttentionAggregator(nn.Cell): + """ + Attention Head for Graph Attention Networks,can be regarded as one + GAT layer. + + Args: + in_channel (int): Input channel. + out_channel (int): Output channel. + num_heads (int): Number of attention heads for this layer, default 1. + in_drop_ratio (float): Input feature dropout ratio, default 0.0. + coef_drop_ratio (float): Coefficient dropout ratio, default 0.0. + activation (Cell): The output activation function, default nn.ELU(). + residual (bool): Whether to use residual connection, default False. + + Inputs: + - **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim). + - **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes). + + Examples: + >>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtype=np.float32)) + >>> biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32)) + >>> net = AttentionAggregator(1433, + 8, + 8) + >>> net(input_data, biases) + """ + def __init__(self, + in_channels, + out_channels, + num_heads=1, + in_drop=0.0, + coef_drop=0.0, + activation=nn.ELU(), + residual=False): + super(AttentionAggregator, self).__init__() + self.num_heads = num_heads + self.attns = [] + for _ in range(num_heads): + self.attns.append(AttentionHead(in_channels, + out_channels, + in_drop_ratio=in_drop, + coef_drop_ratio=coef_drop, + activation=activation, + residual=residual)) + self.attns = nn.layer.CellList(self.attns) + + def construct(self, input_data, bias_mat): + res = () + for i in range(self.num_heads): + res += (self.attns[i](input_data, bias_mat),) + return P.Concat(-1)(res) diff --git a/tests/st/gnn/test_gnn_aggregator.py b/tests/st/gnn/test_gnn_aggregator.py index bba7c09c311..6335b4c8327 100644 --- a/tests/st/gnn/test_gnn_aggregator.py +++ b/tests/st/gnn/test_gnn_aggregator.py @@ -20,7 +20,7 @@ import mindspore.context as context from mindspore import Tensor from mindspore.common.api import _executor import mindspore.ops.composite as C -from aggregator import MeanAggregator +from aggregator import MeanAggregator, AttentionHead, AttentionAggregator context.set_context(mode=context.GRAPH_MODE) @@ -51,3 +51,22 @@ def test_MeanAggregator_grad(): sens = Tensor(np.ones([32, 64]).astype(np.float32)) grad_op = MeanAggregatorGrad(aggregator) _executor.compile(grad_op, input_data, sens) + + +def test_AttentionHead(): + """Compile AttentionHead forward graph""" + head = AttentionHead(1433, + 8, + in_drop_ratio=0.6, + coef_drop_ratio=0.6, + residual=False) + input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtype=np.float32)) + biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32)) + _executor.compile(head, input_data, biases) + + +def test_AttentionAggregator(): + input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtype=np.float32)) + biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32)) + net = AttentionAggregator(1433, 8, 8) + _executor.compile(net, input_data, biases)