change gnn input position

This commit is contained in:
zhanke 2020-10-12 11:03:50 +08:00
parent b5723dcd81
commit 4384ca214f
7 changed files with 65 additions and 63 deletions

View File

@ -19,7 +19,6 @@ from mindspore.ops import functional as F
from mindspore._extends import cell_attr_register
from mindspore import Tensor, Parameter
from mindspore.common.initializer import initializer
from mindspore._checkparam import Validator
from mindspore.nn.layer.activation import get_activation
@ -72,9 +71,9 @@ class GNNFeatureTransform(nn.Cell):
bias_init='zeros',
has_bias=True):
super(GNNFeatureTransform, self).__init__()
self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = Validator.check_positive_int(out_channels)
self.has_bias = Validator.check_bool(has_bias)
self.in_channels = in_channels
self.out_channels = out_channels
self.has_bias = has_bias
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
@ -259,8 +258,8 @@ class AttentionHead(nn.Cell):
coef_activation=nn.LeakyReLU(),
activation=nn.ELU()):
super(AttentionHead, self).__init__()
self.in_channel = Validator.check_positive_int(in_channel)
self.out_channel = Validator.check_positive_int(out_channel)
self.in_channel = in_channel
self.out_channel = out_channel
self.in_drop_ratio = in_drop_ratio
self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio)
self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio)
@ -284,7 +283,7 @@ class AttentionHead(nn.Cell):
self.matmul = P.MatMul()
self.bias_add = P.BiasAdd()
self.bias = Parameter(initializer('zeros', self.out_channel), name='bias')
self.residual = Validator.check_bool(residual)
self.residual = residual
if self.residual:
if in_channel != out_channel:
self.residual_transform_flag = True
@ -436,8 +435,6 @@ class GAT(nn.Cell):
"""
def __init__(self,
features,
biases,
ftr_dims,
num_class,
num_nodes,
@ -448,17 +445,15 @@ class GAT(nn.Cell):
activation=nn.ELU(),
residual=False):
super(GAT, self).__init__()
self.features = Tensor(features)
self.biases = Tensor(biases)
self.ftr_dims = Validator.check_positive_int(ftr_dims)
self.num_class = Validator.check_positive_int(num_class)
self.num_nodes = Validator.check_positive_int(num_nodes)
self.ftr_dims = ftr_dims
self.num_class = num_class
self.num_nodes = num_nodes
self.hidden_units = hidden_units
self.num_heads = num_heads
self.attn_drop = attn_drop
self.ftr_drop = ftr_drop
self.activation = activation
self.residual = Validator.check_bool(residual)
self.residual = residual
self.layers = []
# first layer
self.layers.append(AttentionAggregator(
@ -491,9 +486,9 @@ class GAT(nn.Cell):
output_transform='sum'))
self.layers = nn.layer.CellList(self.layers)
def construct(self, training=True):
input_data = self.features
bias_mat = self.biases
def construct(self, feature, biases, training=True):
input_data = feature
bias_mat = biases
for cell in self.layers:
input_data = cell(input_data, bias_mat, training)
return input_data/self.num_heads[-1]

View File

@ -103,8 +103,8 @@ class LossAccuracyWrapper(nn.Cell):
self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, self.network.trainable_params())
self.acc_func = MaskedAccuracy(num_class, label, mask)
def construct(self):
logits = self.network(training=False)
def construct(self, feature, biases):
logits = self.network(feature, biases, training=False)
loss = self.loss_func(logits)
accuracy = self.acc_func(logits)
return loss, accuracy
@ -120,8 +120,8 @@ class LossNetWrapper(nn.Cell):
params = list(param for param in self.network.trainable_params() if param.name[-4:] != 'bias')
self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, params)
def construct(self):
logits = self.network()
def construct(self, feature, biases):
logits = self.network(feature, biases)
loss = self.loss_func(logits)
return loss
@ -145,11 +145,11 @@ class TrainOneStepCell(nn.Cell):
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
def construct(self):
def construct(self, feature, biases):
weights = self.weights
loss = self.network()
loss = self.network(feature, biases)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(sens)
grads = self.grad(self.network, weights)(feature, biases, sens)
return F.depend(loss, self.optimizer(grads))
@ -174,7 +174,7 @@ class TrainGAT(nn.Cell):
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
self.accuracy_func = MaskedAccuracy(num_class, label, mask)
def construct(self):
loss = self.loss_train_net()
accuracy = self.accuracy_func(self.network())
def construct(self, feature, biases):
loss = self.loss_train_net(feature, biases)
accuracy = self.accuracy_func(self.network(feature, biases))
return loss, accuracy

View File

@ -20,6 +20,7 @@ import numpy as np
import mindspore.context as context
from mindspore.train.serialization import save_checkpoint, load_checkpoint
from mindspore.common import set_seed
from mindspore import Tensor
from src.config import GatConfig
from src.dataset import load_and_process
@ -56,9 +57,7 @@ def train():
num_nodes = feature.shape[1]
num_class = y_train.shape[2]
gat_net = GAT(feature,
biases,
feature_size,
gat_net = GAT(feature_size,
num_class,
num_nodes,
hid_units,
@ -67,6 +66,9 @@ def train():
ftr_drop=GatConfig.feature_dropout)
gat_net.add_flags_recursive(fp16=True)
feature = Tensor(feature)
biases = Tensor(biases)
eval_net = LossAccuracyWrapper(gat_net,
num_class,
y_val,
@ -84,11 +86,11 @@ def train():
val_acc_max = 0.0
val_loss_min = np.inf
for _epoch in range(num_epochs):
train_result = train_net()
train_result = train_net(feature, biases)
train_loss = train_result[0].asnumpy()
train_acc = train_result[1].asnumpy()
eval_result = eval_net()
eval_result = eval_net(feature, biases)
eval_loss = eval_result[0].asnumpy()
eval_acc = eval_result[1].asnumpy()
@ -110,9 +112,7 @@ def train():
print("Early Stop Triggered!, Min loss: {}, Max accuracy: {}".format(val_loss_min, val_acc_max))
print("Early stop model validation loss: {}, accuracy{}".format(val_loss_model, val_acc_model))
break
gat_net_test = GAT(feature,
biases,
feature_size,
gat_net_test = GAT(feature_size,
num_class,
num_nodes,
hid_units,
@ -127,7 +127,7 @@ def train():
y_test,
test_mask,
l2_coeff)
test_result = test_net()
test_result = test_net(feature, biases)
print("Test loss={}, test acc={}".format(test_result[0], test_result[1]))

View File

@ -92,15 +92,12 @@ class GCN(nn.Cell):
output_dim (int): The number of output channels, equal to classes num.
"""
def __init__(self, config, adj, feature, output_dim):
def __init__(self, config, input_dim, output_dim):
super(GCN, self).__init__()
self.adj = Tensor(adj)
self.feature = Tensor(feature)
input_dim = feature.shape[1]
self.layer0 = GraphConvolution(input_dim, config.hidden1, activation="relu", dropout_ratio=config.dropout)
self.layer1 = GraphConvolution(config.hidden1, output_dim, dropout_ratio=None)
def construct(self):
output0 = self.layer0(self.adj, self.feature)
output1 = self.layer1(self.adj, output0)
def construct(self, adj, feature):
output0 = self.layer0(adj, feature)
output1 = self.layer1(adj, output0)
return output1

View File

@ -91,8 +91,8 @@ class LossAccuracyWrapper(nn.Cell):
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
self.accuracy = Accuracy(label, mask)
def construct(self):
preds = self.network()
def construct(self, adj, feature):
preds = self.network(adj, feature)
loss = self.loss(preds)
accuracy = self.accuracy(preds)
return loss, accuracy
@ -114,8 +114,8 @@ class LossWrapper(nn.Cell):
self.network = network
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
def construct(self):
preds = self.network()
def construct(self, adj, feature):
preds = self.network(adj, feature)
loss = self.loss(preds)
return loss
@ -154,11 +154,11 @@ class TrainOneStepCell(nn.Cell):
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
def construct(self):
def construct(self, adj, feature):
weights = self.weights
loss = self.network()
loss = self.network(adj, feature)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(sens)
grads = self.grad(self.network, weights)(adj, feature, sens)
return F.depend(loss, self.optimizer(grads))
@ -182,7 +182,7 @@ class TrainNetWrapper(nn.Cell):
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
self.accuracy = Accuracy(label, mask)
def construct(self):
loss = self.loss_train_net()
accuracy = self.accuracy(self.network())
def construct(self, adj, feature):
loss = self.loss_train_net(adj, feature)
accuracy = self.accuracy(self.network(adj, feature))
return loss, accuracy

View File

@ -26,6 +26,7 @@ from matplotlib import pyplot as plt
from matplotlib import animation
from sklearn import manifold
from mindspore import context
from mindspore import Tensor
from mindspore.common import set_seed
from mindspore.train.serialization import save_checkpoint, load_checkpoint
@ -71,9 +72,13 @@ def train():
test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num)
class_num = label_onehot.shape[1]
gcn_net = GCN(config, adj, feature, class_num)
input_dim = feature.shape[1]
gcn_net = GCN(config, input_dim, class_num)
gcn_net.add_flags_recursive(fp16=True)
adj = Tensor(adj)
feature = Tensor(feature)
eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
@ -92,12 +97,12 @@ def train():
t = time.time()
train_net.set_train()
train_result = train_net()
train_result = train_net(adj, feature)
train_loss = train_result[0].asnumpy()
train_accuracy = train_result[1].asnumpy()
eval_net.set_train(False)
eval_result = eval_net()
eval_result = eval_net(adj, feature)
eval_loss = eval_result[0].asnumpy()
eval_accuracy = eval_result[1].asnumpy()
@ -115,14 +120,14 @@ def train():
print("Early stopping...")
break
save_checkpoint(gcn_net, "ckpts/gcn.ckpt")
gcn_net_test = GCN(config, adj, feature, class_num)
gcn_net_test = GCN(config, input_dim, class_num)
load_checkpoint("ckpts/gcn.ckpt", net=gcn_net_test)
gcn_net_test.add_flags_recursive(fp16=True)
test_net = LossAccuracyWrapper(gcn_net_test, label_onehot, test_mask, config.weight_decay)
t_test = time.time()
test_net.set_train(False)
test_result = test_net()
test_result = test_net(adj, feature)
test_loss = test_result[0].asnumpy()
test_accuracy = test_result[1].asnumpy()
print("Test set results:", "loss=", "{:.5f}".format(test_loss),

View File

@ -17,6 +17,7 @@ import time
import pytest
import numpy as np
from mindspore import context
from mindspore import Tensor
from model_zoo.official.gnn.gcn.src.gcn import GCN
from model_zoo.official.gnn.gcn.src.metrics import LossAccuracyWrapper, TrainNetWrapper
from model_zoo.official.gnn.gcn.src.config import ConfigGCN
@ -49,9 +50,13 @@ def test_gcn():
test_mask = get_mask(nodes_num, nodes_num - TEST_NODE_NUM, nodes_num)
class_num = label_onehot.shape[1]
gcn_net = GCN(config, adj, feature, class_num)
input_dim = feature.shape[1]
gcn_net = GCN(config, input_dim, class_num)
gcn_net.add_flags_recursive(fp16=True)
adj = Tensor(adj)
feature = Tensor(feature)
eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
test_net = LossAccuracyWrapper(gcn_net, label_onehot, test_mask, config.weight_decay)
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
@ -61,12 +66,12 @@ def test_gcn():
t = time.time()
train_net.set_train()
train_result = train_net()
train_result = train_net(adj, feature)
train_loss = train_result[0].asnumpy()
train_accuracy = train_result[1].asnumpy()
eval_net.set_train(False)
eval_result = eval_net()
eval_result = eval_net(adj, feature)
eval_loss = eval_result[0].asnumpy()
eval_accuracy = eval_result[1].asnumpy()
@ -80,7 +85,7 @@ def test_gcn():
break
test_net.set_train(False)
test_result = test_net()
test_result = test_net(adj, feature)
test_loss = test_result[0].asnumpy()
test_accuracy = test_result[1].asnumpy()
print("Test set results:", "loss=", "{:.5f}".format(test_loss),