forked from mindspore-Ecosystem/mindspore
Add gnn aggregator and its ut
This commit is contained in:
parent
1f2ca74cd1
commit
64f824e4fc
|
@ -0,0 +1,222 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Aggregator."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore._extends import cell_attr_register
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore._checkparam import check_int_positive, check_bool
|
||||
from mindspore.nn.layer.activation import get_activation
|
||||
|
||||
|
||||
class GNNFeatureTransform(nn.Cell):
|
||||
r"""
|
||||
The GNN featuren transform layer for input.
|
||||
|
||||
Applies linear transformation for the input feature. This layer implements the operation as:
|
||||
|
||||
.. math::
|
||||
\text{outputs} = \text{inputs} * \text{kernel} + \text{bias},
|
||||
|
||||
where :math:`\text{activation}` is the activation function passed as the activation
|
||||
argument (if passed in),:math:`\text{activation}` is a weight matrix with the same
|
||||
data type as the inputs created by the layer, and :math:`\text{bias}` is a bias vector
|
||||
with the same data type as the inputs created by the layer (only if has_bias is True).
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels in the input space.
|
||||
out_channels (int): The number of channels in the output space.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
||||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||
|
||||
Raises:
|
||||
ValueError: If weight_init or bias_init shape is incorrect.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(*B, N, C)`,
|
||||
where :math:`*B` represents the batch size which can be multidimensional, :math:`N` and :math:`C` are the
|
||||
size of the last two dimensions. If `transpose_a` is True, its shape should be :math:`(*B, C, N)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape of the output tensor is :math:`(*B, N, M)`.
|
||||
|
||||
Examples:
|
||||
>>> net = nn.Dense(3, 4)
|
||||
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
||||
>>> net(input)
|
||||
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
|
||||
[ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
|
||||
"""
|
||||
@cell_attr_register(attrs=['has_bias', 'activation'])
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
has_bias=True):
|
||||
super(GNNFeatureTransform, self).__init__()
|
||||
self.in_channels = check_int_positive(in_channels)
|
||||
self.out_channels = check_int_positive(out_channels)
|
||||
self.has_bias = check_bool(has_bias)
|
||||
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
|
||||
weight_init.shape()[1] != in_channels:
|
||||
raise ValueError("weight_init shape error")
|
||||
|
||||
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
|
||||
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
||||
|
||||
self.matmul = P.MatMul(transpose_b=True)
|
||||
self.bias_add = P.BiasAdd()
|
||||
|
||||
def construct(self, x):
|
||||
tensor_shape = F.shape(x)
|
||||
input_feature = F.reshape(x, (tensor_shape[0] * tensor_shape[1], tensor_shape[2]))
|
||||
output = self.matmul(input_feature, self.weight)
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
output = F.reshape(output, (tensor_shape[0], tensor_shape[1], self.out_channels))
|
||||
return output
|
||||
|
||||
def extend_repr(self):
|
||||
str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \
|
||||
.format(self.in_channels, self.out_channels, self.weight, self.has_bias)
|
||||
if self.has_bias:
|
||||
str_info = str_info + ', bias={}'.format(self.bias)
|
||||
|
||||
return str_info
|
||||
|
||||
|
||||
class _BaseAggregator(nn.Cell):
|
||||
"""
|
||||
Base Aggregator of GNN
|
||||
|
||||
Args:
|
||||
feature_in_dim (int): Node or edge input feature dim.
|
||||
feature_out_dim (int): Node or edge outpout feature dim.
|
||||
use_fc (bool): Specifies whether a linear transformation before message is aggregated. Default: True
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
||||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||
dropout_ratio (float): The keep rate of dropout layer, greater than 0 and less equal than 1. Default: None.
|
||||
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||
|
||||
Examples:
|
||||
>>> class MyAggregator(_BaseAggregator):
|
||||
>>> def __init__(self):
|
||||
>>> super(MyAggregator, self).__init__(self, feature_in_dim, feature_out_dim)
|
||||
>>> self.reduce_mean = P.ReduceSum()
|
||||
>>>
|
||||
>>> def construct(self, x):
|
||||
>>> return self.reduce_mean(x, 1)
|
||||
"""
|
||||
def __init__(self,
|
||||
feature_in_dim,
|
||||
feature_out_dim,
|
||||
use_fc=True,
|
||||
weight_init="normal",
|
||||
bias_init="zeros",
|
||||
has_bias=True,
|
||||
dropout_ratio=None,
|
||||
activation=None):
|
||||
super(_BaseAggregator, self).__init__()
|
||||
self.in_dim = feature_in_dim
|
||||
self.out_dim = feature_out_dim
|
||||
self.use_fc = use_fc
|
||||
if self.use_fc:
|
||||
self.weight_init = weight_init
|
||||
self.bias_init = bias_init
|
||||
self.has_bias = has_bias
|
||||
self.fc = GNNFeatureTransform(self.in_dim,
|
||||
self.out_dim,
|
||||
weight_init=self.weight_init,
|
||||
bias_init=self.bias_init,
|
||||
has_bias=self.has_bias)
|
||||
self.dropout_ratio = dropout_ratio
|
||||
if self.dropout_ratio is not None:
|
||||
self.dropout = nn.Dropout(keep_prob=self.dropout_ratio)
|
||||
self.dropout_flag = self.dropout_ratio is not None
|
||||
self.activation = get_activation(activation)
|
||||
self.activation_flag = self.activation is not None
|
||||
|
||||
def construct(self, **kward):
|
||||
"""Must be overridden by all subclasses."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MeanAggregator(_BaseAggregator):
|
||||
"""
|
||||
Mean Aggregator of GNN
|
||||
|
||||
Args:
|
||||
feature_in_dim (int): Node or edge input feature dim.
|
||||
feature_out_dim (int): Node or edge outpout feature dim.
|
||||
use_fc (bool): Specifies whether a linear transformation before message is aggregated. Default: True
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
||||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||
dropout_ratio (float): The keep rate of dropout layer, greater than 0 and less equal than 1. Default: None.
|
||||
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||
|
||||
Examples:
|
||||
>>> net = MeanAggregator(32, 64, activation="relu", dropout=0.5)
|
||||
>>> input_data = Tensor(np.array(np.random.rand(32, 3, 32), dtypy=np.float32))
|
||||
>>> output = net(input_data)
|
||||
"""
|
||||
def __init__(self,
|
||||
feature_in_dim,
|
||||
feature_out_dim,
|
||||
use_fc=True,
|
||||
weight_init="normal",
|
||||
bias_init="zeros",
|
||||
has_bias=True,
|
||||
dropout_ratio=None,
|
||||
activation=None):
|
||||
super(MeanAggregator, self).__init__(
|
||||
feature_in_dim,
|
||||
feature_out_dim,
|
||||
use_fc=True,
|
||||
weight_init="normal",
|
||||
bias_init="zeros",
|
||||
has_bias=True,
|
||||
dropout_ratio=None,
|
||||
activation=None)
|
||||
self.reduce_mean = P.ReduceMean(keep_dims=False)
|
||||
|
||||
def construct(self, input_feature):
|
||||
if self.use_fc:
|
||||
input_feature = self.fc(input_feature)
|
||||
if self.dropout_flag:
|
||||
input_feature = self.dropout(input_feature)
|
||||
if self.activation_flag:
|
||||
input_feature = self.activation(input_feature)
|
||||
output_feature = self.reduce_mean(input_feature, 1)
|
||||
return output_feature
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test gnn aggregator."""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import _executor
|
||||
import mindspore.ops.composite as C
|
||||
from aggregator import MeanAggregator
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class MeanAggregatorGrad(nn.Cell):
|
||||
"""Backward of MeanAggregator"""
|
||||
def __init__(self, network):
|
||||
super(MeanAggregatorGrad, self).__init__()
|
||||
self.grad_op = C.grad_all_with_sens
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, sens):
|
||||
grad_op = self.grad_op(self.network)(x, sens)
|
||||
return grad_op
|
||||
|
||||
|
||||
def test_MeanAggregator():
|
||||
"""Compile MeanAggregator forward graph"""
|
||||
aggregator = MeanAggregator(32, 64, activation="relu", dropout_ratio=0.5)
|
||||
input_data = Tensor(np.array(np.random.rand(32, 3, 32), dtype=np.float32))
|
||||
_executor.compile(aggregator, input_data)
|
||||
|
||||
|
||||
def test_MeanAggregator_grad():
|
||||
"""Compile MeanAggregator backward graph"""
|
||||
aggregator = MeanAggregator(32, 64, activation="relu", dropout_ratio=0.5)
|
||||
input_data = Tensor(np.array(np.random.rand(32, 3, 32), dtype=np.float32))
|
||||
sens = Tensor(np.ones([32, 64]).astype(np.float32))
|
||||
grad_op = MeanAggregatorGrad(aggregator)
|
||||
_executor.compile(grad_op, input_data, sens)
|
Loading…
Reference in New Issue