change gnn input position

This commit is contained in:
zhanke 2020-10-12 11:03:50 +08:00
parent b5723dcd81
commit 4384ca214f
7 changed files with 65 additions and 63 deletions

View File

@ -19,7 +19,6 @@ from mindspore.ops import functional as F
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from mindspore import Tensor, Parameter from mindspore import Tensor, Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore._checkparam import Validator
from mindspore.nn.layer.activation import get_activation from mindspore.nn.layer.activation import get_activation
@ -72,9 +71,9 @@ class GNNFeatureTransform(nn.Cell):
bias_init='zeros', bias_init='zeros',
has_bias=True): has_bias=True):
super(GNNFeatureTransform, self).__init__() super(GNNFeatureTransform, self).__init__()
self.in_channels = Validator.check_positive_int(in_channels) self.in_channels = in_channels
self.out_channels = Validator.check_positive_int(out_channels) self.out_channels = out_channels
self.has_bias = Validator.check_bool(has_bias) self.has_bias = has_bias
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
@ -259,8 +258,8 @@ class AttentionHead(nn.Cell):
coef_activation=nn.LeakyReLU(), coef_activation=nn.LeakyReLU(),
activation=nn.ELU()): activation=nn.ELU()):
super(AttentionHead, self).__init__() super(AttentionHead, self).__init__()
self.in_channel = Validator.check_positive_int(in_channel) self.in_channel = in_channel
self.out_channel = Validator.check_positive_int(out_channel) self.out_channel = out_channel
self.in_drop_ratio = in_drop_ratio self.in_drop_ratio = in_drop_ratio
self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio) self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio)
self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio) self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio)
@ -284,7 +283,7 @@ class AttentionHead(nn.Cell):
self.matmul = P.MatMul() self.matmul = P.MatMul()
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()
self.bias = Parameter(initializer('zeros', self.out_channel), name='bias') self.bias = Parameter(initializer('zeros', self.out_channel), name='bias')
self.residual = Validator.check_bool(residual) self.residual = residual
if self.residual: if self.residual:
if in_channel != out_channel: if in_channel != out_channel:
self.residual_transform_flag = True self.residual_transform_flag = True
@ -436,8 +435,6 @@ class GAT(nn.Cell):
""" """
def __init__(self, def __init__(self,
features,
biases,
ftr_dims, ftr_dims,
num_class, num_class,
num_nodes, num_nodes,
@ -448,17 +445,15 @@ class GAT(nn.Cell):
activation=nn.ELU(), activation=nn.ELU(),
residual=False): residual=False):
super(GAT, self).__init__() super(GAT, self).__init__()
self.features = Tensor(features) self.ftr_dims = ftr_dims
self.biases = Tensor(biases) self.num_class = num_class
self.ftr_dims = Validator.check_positive_int(ftr_dims) self.num_nodes = num_nodes
self.num_class = Validator.check_positive_int(num_class)
self.num_nodes = Validator.check_positive_int(num_nodes)
self.hidden_units = hidden_units self.hidden_units = hidden_units
self.num_heads = num_heads self.num_heads = num_heads
self.attn_drop = attn_drop self.attn_drop = attn_drop
self.ftr_drop = ftr_drop self.ftr_drop = ftr_drop
self.activation = activation self.activation = activation
self.residual = Validator.check_bool(residual) self.residual = residual
self.layers = [] self.layers = []
# first layer # first layer
self.layers.append(AttentionAggregator( self.layers.append(AttentionAggregator(
@ -491,9 +486,9 @@ class GAT(nn.Cell):
output_transform='sum')) output_transform='sum'))
self.layers = nn.layer.CellList(self.layers) self.layers = nn.layer.CellList(self.layers)
def construct(self, training=True): def construct(self, feature, biases, training=True):
input_data = self.features input_data = feature
bias_mat = self.biases bias_mat = biases
for cell in self.layers: for cell in self.layers:
input_data = cell(input_data, bias_mat, training) input_data = cell(input_data, bias_mat, training)
return input_data/self.num_heads[-1] return input_data/self.num_heads[-1]

View File

@ -103,8 +103,8 @@ class LossAccuracyWrapper(nn.Cell):
self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, self.network.trainable_params()) self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, self.network.trainable_params())
self.acc_func = MaskedAccuracy(num_class, label, mask) self.acc_func = MaskedAccuracy(num_class, label, mask)
def construct(self): def construct(self, feature, biases):
logits = self.network(training=False) logits = self.network(feature, biases, training=False)
loss = self.loss_func(logits) loss = self.loss_func(logits)
accuracy = self.acc_func(logits) accuracy = self.acc_func(logits)
return loss, accuracy return loss, accuracy
@ -120,8 +120,8 @@ class LossNetWrapper(nn.Cell):
params = list(param for param in self.network.trainable_params() if param.name[-4:] != 'bias') params = list(param for param in self.network.trainable_params() if param.name[-4:] != 'bias')
self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, params) self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, params)
def construct(self): def construct(self, feature, biases):
logits = self.network() logits = self.network(feature, biases)
loss = self.loss_func(logits) loss = self.loss_func(logits)
return loss return loss
@ -145,11 +145,11 @@ class TrainOneStepCell(nn.Cell):
self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
def construct(self): def construct(self, feature, biases):
weights = self.weights weights = self.weights
loss = self.network() loss = self.network(feature, biases)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(sens) grads = self.grad(self.network, weights)(feature, biases, sens)
return F.depend(loss, self.optimizer(grads)) return F.depend(loss, self.optimizer(grads))
@ -174,7 +174,7 @@ class TrainGAT(nn.Cell):
self.loss_train_net = TrainOneStepCell(loss_net, optimizer) self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
self.accuracy_func = MaskedAccuracy(num_class, label, mask) self.accuracy_func = MaskedAccuracy(num_class, label, mask)
def construct(self): def construct(self, feature, biases):
loss = self.loss_train_net() loss = self.loss_train_net(feature, biases)
accuracy = self.accuracy_func(self.network()) accuracy = self.accuracy_func(self.network(feature, biases))
return loss, accuracy return loss, accuracy

View File

@ -20,6 +20,7 @@ import numpy as np
import mindspore.context as context import mindspore.context as context
from mindspore.train.serialization import save_checkpoint, load_checkpoint from mindspore.train.serialization import save_checkpoint, load_checkpoint
from mindspore.common import set_seed from mindspore.common import set_seed
from mindspore import Tensor
from src.config import GatConfig from src.config import GatConfig
from src.dataset import load_and_process from src.dataset import load_and_process
@ -56,9 +57,7 @@ def train():
num_nodes = feature.shape[1] num_nodes = feature.shape[1]
num_class = y_train.shape[2] num_class = y_train.shape[2]
gat_net = GAT(feature, gat_net = GAT(feature_size,
biases,
feature_size,
num_class, num_class,
num_nodes, num_nodes,
hid_units, hid_units,
@ -67,6 +66,9 @@ def train():
ftr_drop=GatConfig.feature_dropout) ftr_drop=GatConfig.feature_dropout)
gat_net.add_flags_recursive(fp16=True) gat_net.add_flags_recursive(fp16=True)
feature = Tensor(feature)
biases = Tensor(biases)
eval_net = LossAccuracyWrapper(gat_net, eval_net = LossAccuracyWrapper(gat_net,
num_class, num_class,
y_val, y_val,
@ -84,11 +86,11 @@ def train():
val_acc_max = 0.0 val_acc_max = 0.0
val_loss_min = np.inf val_loss_min = np.inf
for _epoch in range(num_epochs): for _epoch in range(num_epochs):
train_result = train_net() train_result = train_net(feature, biases)
train_loss = train_result[0].asnumpy() train_loss = train_result[0].asnumpy()
train_acc = train_result[1].asnumpy() train_acc = train_result[1].asnumpy()
eval_result = eval_net() eval_result = eval_net(feature, biases)
eval_loss = eval_result[0].asnumpy() eval_loss = eval_result[0].asnumpy()
eval_acc = eval_result[1].asnumpy() eval_acc = eval_result[1].asnumpy()
@ -110,9 +112,7 @@ def train():
print("Early Stop Triggered!, Min loss: {}, Max accuracy: {}".format(val_loss_min, val_acc_max)) print("Early Stop Triggered!, Min loss: {}, Max accuracy: {}".format(val_loss_min, val_acc_max))
print("Early stop model validation loss: {}, accuracy{}".format(val_loss_model, val_acc_model)) print("Early stop model validation loss: {}, accuracy{}".format(val_loss_model, val_acc_model))
break break
gat_net_test = GAT(feature, gat_net_test = GAT(feature_size,
biases,
feature_size,
num_class, num_class,
num_nodes, num_nodes,
hid_units, hid_units,
@ -127,7 +127,7 @@ def train():
y_test, y_test,
test_mask, test_mask,
l2_coeff) l2_coeff)
test_result = test_net() test_result = test_net(feature, biases)
print("Test loss={}, test acc={}".format(test_result[0], test_result[1])) print("Test loss={}, test acc={}".format(test_result[0], test_result[1]))

View File

@ -92,15 +92,12 @@ class GCN(nn.Cell):
output_dim (int): The number of output channels, equal to classes num. output_dim (int): The number of output channels, equal to classes num.
""" """
def __init__(self, config, adj, feature, output_dim): def __init__(self, config, input_dim, output_dim):
super(GCN, self).__init__() super(GCN, self).__init__()
self.adj = Tensor(adj)
self.feature = Tensor(feature)
input_dim = feature.shape[1]
self.layer0 = GraphConvolution(input_dim, config.hidden1, activation="relu", dropout_ratio=config.dropout) self.layer0 = GraphConvolution(input_dim, config.hidden1, activation="relu", dropout_ratio=config.dropout)
self.layer1 = GraphConvolution(config.hidden1, output_dim, dropout_ratio=None) self.layer1 = GraphConvolution(config.hidden1, output_dim, dropout_ratio=None)
def construct(self): def construct(self, adj, feature):
output0 = self.layer0(self.adj, self.feature) output0 = self.layer0(adj, feature)
output1 = self.layer1(self.adj, output0) output1 = self.layer1(adj, output0)
return output1 return output1

View File

@ -91,8 +91,8 @@ class LossAccuracyWrapper(nn.Cell):
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0]) self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
self.accuracy = Accuracy(label, mask) self.accuracy = Accuracy(label, mask)
def construct(self): def construct(self, adj, feature):
preds = self.network() preds = self.network(adj, feature)
loss = self.loss(preds) loss = self.loss(preds)
accuracy = self.accuracy(preds) accuracy = self.accuracy(preds)
return loss, accuracy return loss, accuracy
@ -114,8 +114,8 @@ class LossWrapper(nn.Cell):
self.network = network self.network = network
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0]) self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
def construct(self): def construct(self, adj, feature):
preds = self.network() preds = self.network(adj, feature)
loss = self.loss(preds) loss = self.loss(preds)
return loss return loss
@ -154,11 +154,11 @@ class TrainOneStepCell(nn.Cell):
self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
def construct(self): def construct(self, adj, feature):
weights = self.weights weights = self.weights
loss = self.network() loss = self.network(adj, feature)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(sens) grads = self.grad(self.network, weights)(adj, feature, sens)
return F.depend(loss, self.optimizer(grads)) return F.depend(loss, self.optimizer(grads))
@ -182,7 +182,7 @@ class TrainNetWrapper(nn.Cell):
self.loss_train_net = TrainOneStepCell(loss_net, optimizer) self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
self.accuracy = Accuracy(label, mask) self.accuracy = Accuracy(label, mask)
def construct(self): def construct(self, adj, feature):
loss = self.loss_train_net() loss = self.loss_train_net(adj, feature)
accuracy = self.accuracy(self.network()) accuracy = self.accuracy(self.network(adj, feature))
return loss, accuracy return loss, accuracy

View File

@ -26,6 +26,7 @@ from matplotlib import pyplot as plt
from matplotlib import animation from matplotlib import animation
from sklearn import manifold from sklearn import manifold
from mindspore import context from mindspore import context
from mindspore import Tensor
from mindspore.common import set_seed from mindspore.common import set_seed
from mindspore.train.serialization import save_checkpoint, load_checkpoint from mindspore.train.serialization import save_checkpoint, load_checkpoint
@ -71,9 +72,13 @@ def train():
test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num) test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num)
class_num = label_onehot.shape[1] class_num = label_onehot.shape[1]
gcn_net = GCN(config, adj, feature, class_num) input_dim = feature.shape[1]
gcn_net = GCN(config, input_dim, class_num)
gcn_net.add_flags_recursive(fp16=True) gcn_net.add_flags_recursive(fp16=True)
adj = Tensor(adj)
feature = Tensor(feature)
eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay) eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config) train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
@ -92,12 +97,12 @@ def train():
t = time.time() t = time.time()
train_net.set_train() train_net.set_train()
train_result = train_net() train_result = train_net(adj, feature)
train_loss = train_result[0].asnumpy() train_loss = train_result[0].asnumpy()
train_accuracy = train_result[1].asnumpy() train_accuracy = train_result[1].asnumpy()
eval_net.set_train(False) eval_net.set_train(False)
eval_result = eval_net() eval_result = eval_net(adj, feature)
eval_loss = eval_result[0].asnumpy() eval_loss = eval_result[0].asnumpy()
eval_accuracy = eval_result[1].asnumpy() eval_accuracy = eval_result[1].asnumpy()
@ -115,14 +120,14 @@ def train():
print("Early stopping...") print("Early stopping...")
break break
save_checkpoint(gcn_net, "ckpts/gcn.ckpt") save_checkpoint(gcn_net, "ckpts/gcn.ckpt")
gcn_net_test = GCN(config, adj, feature, class_num) gcn_net_test = GCN(config, input_dim, class_num)
load_checkpoint("ckpts/gcn.ckpt", net=gcn_net_test) load_checkpoint("ckpts/gcn.ckpt", net=gcn_net_test)
gcn_net_test.add_flags_recursive(fp16=True) gcn_net_test.add_flags_recursive(fp16=True)
test_net = LossAccuracyWrapper(gcn_net_test, label_onehot, test_mask, config.weight_decay) test_net = LossAccuracyWrapper(gcn_net_test, label_onehot, test_mask, config.weight_decay)
t_test = time.time() t_test = time.time()
test_net.set_train(False) test_net.set_train(False)
test_result = test_net() test_result = test_net(adj, feature)
test_loss = test_result[0].asnumpy() test_loss = test_result[0].asnumpy()
test_accuracy = test_result[1].asnumpy() test_accuracy = test_result[1].asnumpy()
print("Test set results:", "loss=", "{:.5f}".format(test_loss), print("Test set results:", "loss=", "{:.5f}".format(test_loss),

View File

@ -17,6 +17,7 @@ import time
import pytest import pytest
import numpy as np import numpy as np
from mindspore import context from mindspore import context
from mindspore import Tensor
from model_zoo.official.gnn.gcn.src.gcn import GCN from model_zoo.official.gnn.gcn.src.gcn import GCN
from model_zoo.official.gnn.gcn.src.metrics import LossAccuracyWrapper, TrainNetWrapper from model_zoo.official.gnn.gcn.src.metrics import LossAccuracyWrapper, TrainNetWrapper
from model_zoo.official.gnn.gcn.src.config import ConfigGCN from model_zoo.official.gnn.gcn.src.config import ConfigGCN
@ -49,9 +50,13 @@ def test_gcn():
test_mask = get_mask(nodes_num, nodes_num - TEST_NODE_NUM, nodes_num) test_mask = get_mask(nodes_num, nodes_num - TEST_NODE_NUM, nodes_num)
class_num = label_onehot.shape[1] class_num = label_onehot.shape[1]
gcn_net = GCN(config, adj, feature, class_num) input_dim = feature.shape[1]
gcn_net = GCN(config, input_dim, class_num)
gcn_net.add_flags_recursive(fp16=True) gcn_net.add_flags_recursive(fp16=True)
adj = Tensor(adj)
feature = Tensor(feature)
eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay) eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
test_net = LossAccuracyWrapper(gcn_net, label_onehot, test_mask, config.weight_decay) test_net = LossAccuracyWrapper(gcn_net, label_onehot, test_mask, config.weight_decay)
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config) train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
@ -61,12 +66,12 @@ def test_gcn():
t = time.time() t = time.time()
train_net.set_train() train_net.set_train()
train_result = train_net() train_result = train_net(adj, feature)
train_loss = train_result[0].asnumpy() train_loss = train_result[0].asnumpy()
train_accuracy = train_result[1].asnumpy() train_accuracy = train_result[1].asnumpy()
eval_net.set_train(False) eval_net.set_train(False)
eval_result = eval_net() eval_result = eval_net(adj, feature)
eval_loss = eval_result[0].asnumpy() eval_loss = eval_result[0].asnumpy()
eval_accuracy = eval_result[1].asnumpy() eval_accuracy = eval_result[1].asnumpy()
@ -80,7 +85,7 @@ def test_gcn():
break break
test_net.set_train(False) test_net.set_train(False)
test_result = test_net() test_result = test_net(adj, feature)
test_loss = test_result[0].asnumpy() test_loss = test_result[0].asnumpy()
test_accuracy = test_result[1].asnumpy() test_accuracy = test_result[1].asnumpy()
print("Test set results:", "loss=", "{:.5f}".format(test_loss), print("Test set results:", "loss=", "{:.5f}".format(test_loss),