forked from mindspore-Ecosystem/mindspore
!2009 Add GCN to model zoo
Merge pull request !2009 from chentingting/gcn_modelzoo
This commit is contained in:
commit
c26cb9b15b
|
@ -0,0 +1,113 @@
|
|||
# GCN Example
|
||||
|
||||
## Description
|
||||
|
||||
This is an example of training GCN with Cora and Citeseer dataset in MindSpore.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||
|
||||
- Download the dataset Cora or Citeseer provided by /kimiyoung/planetoid from github.
|
||||
|
||||
> Place the dataset to any path you want, the folder should include files as follows(we use Cora dataset as an example):
|
||||
|
||||
```
|
||||
.
|
||||
└─data
|
||||
├─ind.cora.allx
|
||||
├─ind.cora.ally
|
||||
├─ind.cora.graph
|
||||
├─ind.cora.test.index
|
||||
├─ind.cora.tx
|
||||
├─ind.cora.ty
|
||||
├─ind.cora.x
|
||||
└─ind.cora.y
|
||||
```
|
||||
|
||||
> Generate dataset in mindrecord format for cora or citeseer.
|
||||
>> Usage
|
||||
```buildoutcfg
|
||||
cd ./scripts
|
||||
# SRC_PATH is the dataset file path you downloaded, DATASET_NAME is cora or citeseer
|
||||
sh run_process_data.sh [SRC_PATH] [DATASET_NAME]
|
||||
```
|
||||
|
||||
>> Launch
|
||||
```
|
||||
#Generate dataset in mindrecord format for cora
|
||||
sh run_process_data.sh cora
|
||||
#Generate dataset in mindrecord format for citeseer
|
||||
sh run_process_data.sh citeseer
|
||||
```
|
||||
|
||||
## Structure
|
||||
|
||||
```shell
|
||||
.
|
||||
└─gcn
|
||||
├─README.md
|
||||
├─scripts
|
||||
| ├─run_process_data.sh # Generate dataset in mindrecord format
|
||||
| └─run_train.sh # Launch training
|
||||
|
|
||||
├─src
|
||||
| ├─config.py # Parameter configuration
|
||||
| ├─dataset.py # Data preprocessin
|
||||
| ├─gcn.py # GCN backbone
|
||||
| └─metrics.py # Loss and accuracy
|
||||
|
|
||||
└─train.py # Train net
|
||||
```
|
||||
|
||||
## Parameter configuration
|
||||
|
||||
Parameters for training can be set in config.py.
|
||||
|
||||
```
|
||||
"learning_rate": 0.01, # Learning rate
|
||||
"epochs": 200, # Epoch sizes for training
|
||||
"hidden1": 16, # Hidden size for the first graph convolution layer
|
||||
"dropout": 0.5, # Dropout ratio for the first graph convolution layer
|
||||
"weight_decay": 5e-4, # Weight decay for the parameter of the first graph convolution layer
|
||||
"early_stopping": 10, # Tolerance for early stopping
|
||||
```
|
||||
|
||||
## Running the example
|
||||
|
||||
### Train
|
||||
|
||||
#### Usage
|
||||
|
||||
```
|
||||
# run train with cora or citeseer dataset, DATASET_NAME is cora or citeseer
|
||||
sh run_train.sh [DATASET_NAME]
|
||||
```
|
||||
|
||||
#### Launch
|
||||
|
||||
```bash
|
||||
sh run_train.sh cora
|
||||
```
|
||||
|
||||
#### Result
|
||||
|
||||
Training result will be stored in the scripts path, whose folder name begins with "train". You can find the result like the followings in log.
|
||||
|
||||
|
||||
```
|
||||
Epoch: 0001 train_loss= 1.95373 train_acc= 0.09286 val_loss= 1.95075 val_acc= 0.20200 time= 7.25737
|
||||
Epoch: 0002 train_loss= 1.94812 train_acc= 0.32857 val_loss= 1.94717 val_acc= 0.34000 time= 0.00438
|
||||
Epoch: 0003 train_loss= 1.94249 train_acc= 0.47857 val_loss= 1.94337 val_acc= 0.43000 time= 0.00428
|
||||
Epoch: 0004 train_loss= 1.93550 train_acc= 0.55000 val_loss= 1.93957 val_acc= 0.46400 time= 0.00421
|
||||
Epoch: 0005 train_loss= 1.92617 train_acc= 0.67143 val_loss= 1.93558 val_acc= 0.45400 time= 0.00430
|
||||
...
|
||||
Epoch: 0196 train_loss= 0.60326 train_acc= 0.97857 val_loss= 1.05155 val_acc= 0.78200 time= 0.00418
|
||||
Epoch: 0197 train_loss= 0.60377 train_acc= 0.97143 val_loss= 1.04940 val_acc= 0.78000 time= 0.00418
|
||||
Epoch: 0198 train_loss= 0.60680 train_acc= 0.95000 val_loss= 1.04847 val_acc= 0.78000 time= 0.00414
|
||||
Epoch: 0199 train_loss= 0.61920 train_acc= 0.96429 val_loss= 1.04797 val_acc= 0.78400 time= 0.00413
|
||||
Epoch: 0200 train_loss= 0.57948 train_acc= 0.96429 val_loss= 1.04753 val_acc= 0.78600 time= 0.00415
|
||||
Optimization Finished!
|
||||
Test set results: cost= 1.00983 accuracy= 0.81300 time= 0.39083
|
||||
...
|
||||
```
|
|
@ -0,0 +1,54 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_train.sh [SRC_PATH] [DATASET_NAME]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
SRC_PATH=$(get_real_path $1)
|
||||
echo $SRC_PATH
|
||||
|
||||
DATASET_NAME=$2
|
||||
echo $DATASET_NAME
|
||||
|
||||
if [ ! -d data_mr ]; then
|
||||
mkdir data_mr
|
||||
else
|
||||
echo data_mr exist
|
||||
fi
|
||||
MINDRECORD_PATH=`pwd`/data_mr
|
||||
|
||||
rm -f $MINDRECORD_PATH/*
|
||||
|
||||
cd ../../../example/graph_to_mindrecord || exit
|
||||
|
||||
python writer.py --mindrecord_script $DATASET_NAME \
|
||||
--mindrecord_file "$MINDRECORD_PATH/$DATASET_NAME" \
|
||||
--mindrecord_partitions 1 \
|
||||
--mindrecord_header_size_by_bit 18 \
|
||||
--mindrecord_page_size_by_bit 20 \
|
||||
--graph_api_args "$SRC_PATH"
|
||||
|
||||
cd - || exit
|
|
@ -0,0 +1,55 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]
|
||||
then
|
||||
echo "Usage: sh run_train.sh [DATASET_NAME]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DATASET_NAME=$1
|
||||
echo $DATASET_NAME
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
env > env.log
|
||||
echo "start training for device $DEVICE_ID"
|
||||
|
||||
|
||||
if [ $DATASET_NAME == cora ]
|
||||
then
|
||||
python train.py --data_dir=../data_mr/$DATASET_NAME --train_nodes_num=140 &> log &
|
||||
fi
|
||||
|
||||
if [ $DATASET_NAME == citeseer ]
|
||||
then
|
||||
python train.py --data_dir=../data_mr/$DATASET_NAME --train_nodes_num=120 &> log &
|
||||
fi
|
||||
cd ..
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py
|
||||
"""
|
||||
|
||||
|
||||
class ConfigGCN():
|
||||
learning_rate = 0.01
|
||||
epochs = 200
|
||||
hidden1 = 16
|
||||
dropout = 0.5
|
||||
weight_decay = 5e-4
|
||||
early_stopping = 10
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
create adjacency matrix, node features, labels, and mask for training.
|
||||
"""
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
import mindspore.dataset as ds
|
||||
|
||||
|
||||
def normalize_adj(adj):
|
||||
"""Symmetrically normalize adjacency matrix."""
|
||||
rowsum = np.array(adj.sum(1))
|
||||
d_inv_sqrt = np.power(rowsum, -0.5).flatten()
|
||||
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
|
||||
d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
|
||||
return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
|
||||
|
||||
|
||||
def get_adj_features_labels(data_dir):
|
||||
"""Get adjacency matrix, node features and labels from dataset."""
|
||||
g = ds.GraphData(data_dir)
|
||||
nodes = g.get_all_nodes(0)
|
||||
nodes_list = nodes.tolist()
|
||||
row_tensor = g.get_node_feature(nodes_list, [1, 2])
|
||||
features = row_tensor[0]
|
||||
labels = row_tensor[1]
|
||||
|
||||
nodes_num = labels.shape[0]
|
||||
class_num = labels.max() + 1
|
||||
labels_onehot = np.eye(nodes_num, class_num)[labels].astype(np.float32)
|
||||
|
||||
neighbor = g.get_all_neighbors(nodes_list, 0)
|
||||
node_map = {node_id: index for index, node_id in enumerate(nodes_list)}
|
||||
adj = np.zeros([nodes_num, nodes_num], dtype=np.float32)
|
||||
for index, value in np.ndenumerate(neighbor):
|
||||
# The first column of neighbor is node_id, second column to last column are neighbors of the first column.
|
||||
# So we only care index[1] > 1.
|
||||
# If the node does not have that many neighbors, -1 is padded. So if value < 0, we will not deal with it.
|
||||
if value >= 0 and index[1] > 0:
|
||||
adj[node_map[neighbor[index[0], 0]], node_map[value]] = 1
|
||||
adj = sp.coo_matrix(adj)
|
||||
adj = adj + adj.T.multiply(adj.T > adj) + sp.eye(nodes_num)
|
||||
nor_adj = normalize_adj(adj)
|
||||
nor_adj = np.array(nor_adj.todense())
|
||||
return nor_adj, features, labels_onehot
|
||||
|
||||
|
||||
def get_mask(total, begin, end):
|
||||
"""Generate mask."""
|
||||
mask = np.zeros([total]).astype(np.float32)
|
||||
mask[begin:end] = 1
|
||||
return mask
|
|
@ -0,0 +1,220 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""GCN."""
|
||||
import numpy as np
|
||||
from mindspore import nn
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.layer.activation import get_activation
|
||||
from src.metrics import Loss, Accuracy
|
||||
|
||||
|
||||
def glorot(shape):
|
||||
init_range = np.sqrt(6.0/(shape[0]+shape[1]))
|
||||
initial = np.random.uniform(-init_range, init_range, shape).astype(np.float32)
|
||||
return Tensor(initial)
|
||||
|
||||
|
||||
class GraphConvolution(nn.Cell):
|
||||
"""
|
||||
GCN graph convolution layer.
|
||||
|
||||
Args:
|
||||
feature_in_dim (int): The input feature dimension.
|
||||
feature_out_dim (int): The output feature dimension.
|
||||
dropout_ratio (float): Dropout ratio for the dropout layer. Default: None.
|
||||
activation (str): Activation function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **adj** (Tensor) - Tensor of shape :math:`(N, N)`.
|
||||
- **input_feature** (Tensor) - Tensor of shape :math:`(N, C)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, output tensor.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
feature_in_dim,
|
||||
feature_out_dim,
|
||||
dropout_ratio=None,
|
||||
activation=None):
|
||||
super(GraphConvolution, self).__init__()
|
||||
self.in_dim = feature_in_dim
|
||||
self.out_dim = feature_out_dim
|
||||
self.weight_init = glorot([self.out_dim, self.in_dim])
|
||||
self.fc = nn.Dense(self.in_dim,
|
||||
self.out_dim,
|
||||
weight_init=self.weight_init,
|
||||
has_bias=False)
|
||||
self.dropout_ratio = dropout_ratio
|
||||
if self.dropout_ratio is not None:
|
||||
self.dropout = nn.Dropout(keep_prob=1-self.dropout_ratio)
|
||||
self.dropout_flag = self.dropout_ratio is not None
|
||||
self.activation = get_activation(activation)
|
||||
self.activation_flag = self.activation is not None
|
||||
self.matmul = P.MatMul()
|
||||
|
||||
def construct(self, adj, input_feature):
|
||||
dropout = input_feature
|
||||
if self.dropout_flag:
|
||||
dropout = self.dropout(dropout)
|
||||
|
||||
fc = self.fc(dropout)
|
||||
output_feature = self.matmul(adj, fc)
|
||||
|
||||
if self.activation_flag:
|
||||
output_feature = self.activation(output_feature)
|
||||
return output_feature
|
||||
|
||||
|
||||
class GCN(nn.Cell):
|
||||
"""
|
||||
GCN architecture.
|
||||
|
||||
Args:
|
||||
config (ConfigGCN): Configuration for GCN.
|
||||
adj (numpy.ndarray): Numbers of block in different layers.
|
||||
feature (numpy.ndarray): Input channel in each layer.
|
||||
output_dim (int): The number of output channels, equal to classes num.
|
||||
"""
|
||||
|
||||
def __init__(self, config, adj, feature, output_dim):
|
||||
super(GCN, self).__init__()
|
||||
self.adj = Tensor(adj)
|
||||
self.feature = Tensor(feature)
|
||||
input_dim = feature.shape[1]
|
||||
self.layer0 = GraphConvolution(input_dim, config.hidden1, activation="relu", dropout_ratio=config.dropout)
|
||||
self.layer1 = GraphConvolution(config.hidden1, output_dim, dropout_ratio=None)
|
||||
|
||||
def construct(self):
|
||||
output0 = self.layer0(self.adj, self.feature)
|
||||
output1 = self.layer1(self.adj, output0)
|
||||
return output1
|
||||
|
||||
|
||||
class LossAccuracyWrapper(nn.Cell):
|
||||
"""
|
||||
Wraps the GCN model with loss and accuracy cell.
|
||||
|
||||
Args:
|
||||
network (Cell): GCN network.
|
||||
label (numpy.ndarray): Dataset labels.
|
||||
mask (numpy.ndarray): Mask for training, evaluation or test.
|
||||
weight_decay (float): Weight decay parameter for weight of the first convolution layer.
|
||||
"""
|
||||
|
||||
def __init__(self, network, label, mask, weight_decay):
|
||||
super(LossAccuracyWrapper, self).__init__()
|
||||
self.network = network
|
||||
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
||||
self.accuracy = Accuracy(label, mask)
|
||||
|
||||
def construct(self):
|
||||
preds = self.network()
|
||||
loss = self.loss(preds)
|
||||
accuracy = self.accuracy(preds)
|
||||
return loss, accuracy
|
||||
|
||||
|
||||
class LossWrapper(nn.Cell):
|
||||
"""
|
||||
Wraps the GCN model with loss.
|
||||
|
||||
Args:
|
||||
network (Cell): GCN network.
|
||||
label (numpy.ndarray): Dataset labels.
|
||||
mask (numpy.ndarray): Mask for training.
|
||||
weight_decay (float): Weight decay parameter for weight of the first convolution layer.
|
||||
"""
|
||||
|
||||
def __init__(self, network, label, mask, weight_decay):
|
||||
super(LossWrapper, self).__init__()
|
||||
self.network = network
|
||||
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
||||
|
||||
def construct(self):
|
||||
preds = self.network()
|
||||
loss = self.loss(preds)
|
||||
return loss
|
||||
|
||||
|
||||
class TrainOneStepCell(nn.Cell):
|
||||
r"""
|
||||
Network training package class.
|
||||
|
||||
Wraps the network with an optimizer. The resulting Cell be trained without inputs.
|
||||
Backward graph will be created in the construct function to do parameter updating. Different
|
||||
parallel modes are available to run the training.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network.
|
||||
optimizer (Cell): Optimizer for updating the weights.
|
||||
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
||||
|
||||
Outputs:
|
||||
Tensor, a scalar Tensor with shape :math:`()`.
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> loss_net = nn.WithLossCell(net, loss_fn)
|
||||
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
|
||||
def construct(self):
|
||||
weights = self.weights
|
||||
loss = self.network()
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(sens)
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
|
||||
class TrainNetWrapper(nn.Cell):
|
||||
"""
|
||||
Wraps the GCN model with optimizer.
|
||||
|
||||
Args:
|
||||
network (Cell): GCN network.
|
||||
label (numpy.ndarray): Dataset labels.
|
||||
mask (numpy.ndarray): Mask for training, evaluation or test.
|
||||
config (ConfigGCN): Configuration for GCN.
|
||||
"""
|
||||
|
||||
def __init__(self, network, label, mask, config):
|
||||
super(TrainNetWrapper, self).__init__(auto_prefix=True)
|
||||
self.network = network
|
||||
loss_net = LossWrapper(network, label, mask, config.weight_decay)
|
||||
optimizer = nn.Adam(loss_net.trainable_params(),
|
||||
learning_rate=config.learning_rate)
|
||||
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
|
||||
self.accuracy = Accuracy(label, mask)
|
||||
|
||||
def construct(self):
|
||||
loss = self.loss_train_net()
|
||||
accuracy = self.accuracy(self.network())
|
||||
return loss, accuracy
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Loss and accuracy."""
|
||||
from mindspore import nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Loss(nn.Cell):
|
||||
"""Softmax cross-entropy loss with masking."""
|
||||
def __init__(self, label, mask, weight_decay, param):
|
||||
super(Loss, self).__init__()
|
||||
self.label = Tensor(label)
|
||||
self.mask = Tensor(mask)
|
||||
self.loss = P.SoftmaxCrossEntropyWithLogits()
|
||||
self.one = Tensor(1.0, mstype.float32)
|
||||
self.zero = Tensor(0.0, mstype.float32)
|
||||
self.mean = P.ReduceMean()
|
||||
self.cast = P.Cast()
|
||||
self.l2_loss = P.L2Loss()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.weight_decay = weight_decay
|
||||
self.param = param
|
||||
|
||||
def construct(self, preds):
|
||||
param = self.l2_loss(self.param)
|
||||
loss = self.weight_decay * param
|
||||
preds = self.cast(preds, mstype.float32)
|
||||
loss = loss + self.loss(preds, self.label)[0]
|
||||
mask = self.cast(self.mask, mstype.float32)
|
||||
mask_reduce = self.mean(mask)
|
||||
mask = mask / mask_reduce
|
||||
loss = loss * mask
|
||||
loss = self.mean(loss)
|
||||
return loss
|
||||
|
||||
|
||||
class Accuracy(nn.Cell):
|
||||
"""Accuracy with masking."""
|
||||
def __init__(self, label, mask):
|
||||
super(Accuracy, self).__init__()
|
||||
self.label = Tensor(label)
|
||||
self.mask = Tensor(mask)
|
||||
self.equal = P.Equal()
|
||||
self.argmax = P.Argmax()
|
||||
self.cast = P.Cast()
|
||||
self.mean = P.ReduceMean()
|
||||
|
||||
def construct(self, preds):
|
||||
preds = self.cast(preds, mstype.float32)
|
||||
correct_prediction = self.equal(self.argmax(preds), self.argmax(self.label))
|
||||
accuracy_all = self.cast(correct_prediction, mstype.float32)
|
||||
mask = self.cast(self.mask, mstype.float32)
|
||||
mask_reduce = self.mean(mask)
|
||||
mask = mask / mask_reduce
|
||||
accuracy_all *= mask
|
||||
return self.mean(accuracy_all)
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
GCN training script.
|
||||
"""
|
||||
|
||||
import time
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
|
||||
from src.gcn import GCN, LossAccuracyWrapper, TrainNetWrapper
|
||||
from src.config import ConfigGCN
|
||||
from src.dataset import get_adj_features_labels, get_mask
|
||||
|
||||
|
||||
def train():
|
||||
"""Train model."""
|
||||
parser = argparse.ArgumentParser(description='GCN')
|
||||
parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Dataset directory')
|
||||
parser.add_argument('--seed', type=int, default=123, help='Random seed')
|
||||
parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training')
|
||||
parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation')
|
||||
parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
np.random.seed(args_opt.seed)
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend", save_graphs=False)
|
||||
config = ConfigGCN()
|
||||
adj, feature, label = get_adj_features_labels(args_opt.data_dir)
|
||||
|
||||
nodes_num = label.shape[0]
|
||||
train_mask = get_mask(nodes_num, 0, args_opt.train_nodes_num)
|
||||
eval_mask = get_mask(nodes_num, args_opt.train_nodes_num, args_opt.train_nodes_num + args_opt.eval_nodes_num)
|
||||
test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num)
|
||||
|
||||
class_num = label.shape[1]
|
||||
gcn_net = GCN(config, adj, feature, class_num)
|
||||
gcn_net.add_flags_recursive(fp16=True)
|
||||
|
||||
eval_net = LossAccuracyWrapper(gcn_net, label, eval_mask, config.weight_decay)
|
||||
test_net = LossAccuracyWrapper(gcn_net, label, test_mask, config.weight_decay)
|
||||
train_net = TrainNetWrapper(gcn_net, label, train_mask, config)
|
||||
|
||||
loss_list = []
|
||||
for epoch in range(config.epochs):
|
||||
t = time.time()
|
||||
|
||||
train_net.set_train()
|
||||
train_result = train_net()
|
||||
train_loss = train_result[0].asnumpy()
|
||||
train_accuracy = train_result[1].asnumpy()
|
||||
|
||||
eval_net.set_train(False)
|
||||
eval_result = eval_net()
|
||||
eval_loss = eval_result[0].asnumpy()
|
||||
eval_accuracy = eval_result[1].asnumpy()
|
||||
|
||||
loss_list.append(eval_loss)
|
||||
print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(train_loss),
|
||||
"train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss),
|
||||
"val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t))
|
||||
|
||||
if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]):
|
||||
print("Early stopping...")
|
||||
break
|
||||
|
||||
t_test = time.time()
|
||||
test_net.set_train(False)
|
||||
test_result = test_net()
|
||||
test_loss = test_result[0].asnumpy()
|
||||
test_accuracy = test_result[1].asnumpy()
|
||||
print("Test set results:", "loss=", "{:.5f}".format(test_loss),
|
||||
"accuracy=", "{:.5f}".format(test_accuracy), "time=", "{:.5f}".format(time.time() - t_test))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ConfigGCN():
|
||||
learning_rate = 0.01
|
||||
epochs = 200
|
||||
|
|
|
@ -58,10 +58,12 @@ def test_gcn():
|
|||
for epoch in range(config.epochs):
|
||||
t = time.time()
|
||||
|
||||
train_net.set_train()
|
||||
train_result = train_net()
|
||||
train_loss = train_result[0].asnumpy()
|
||||
train_accuracy = train_result[1].asnumpy()
|
||||
|
||||
eval_net.set_train(False)
|
||||
eval_result = eval_net()
|
||||
eval_loss = eval_result[0].asnumpy()
|
||||
eval_accuracy = eval_result[1].asnumpy()
|
||||
|
@ -75,6 +77,7 @@ def test_gcn():
|
|||
print("Early stopping...")
|
||||
break
|
||||
|
||||
test_net.set_train(False)
|
||||
test_result = test_net()
|
||||
test_loss = test_result[0].asnumpy()
|
||||
test_accuracy = test_result[1].asnumpy()
|
||||
|
|
Loading…
Reference in New Issue