forked from mindspore-Ecosystem/mindspore
change gnn input position
This commit is contained in:
parent
b5723dcd81
commit
4384ca214f
|
@ -19,7 +19,6 @@ from mindspore.ops import functional as F
|
|||
from mindspore._extends import cell_attr_register
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.nn.layer.activation import get_activation
|
||||
|
||||
|
||||
|
@ -72,9 +71,9 @@ class GNNFeatureTransform(nn.Cell):
|
|||
bias_init='zeros',
|
||||
has_bias=True):
|
||||
super(GNNFeatureTransform, self).__init__()
|
||||
self.in_channels = Validator.check_positive_int(in_channels)
|
||||
self.out_channels = Validator.check_positive_int(out_channels)
|
||||
self.has_bias = Validator.check_bool(has_bias)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.has_bias = has_bias
|
||||
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||
|
@ -259,8 +258,8 @@ class AttentionHead(nn.Cell):
|
|||
coef_activation=nn.LeakyReLU(),
|
||||
activation=nn.ELU()):
|
||||
super(AttentionHead, self).__init__()
|
||||
self.in_channel = Validator.check_positive_int(in_channel)
|
||||
self.out_channel = Validator.check_positive_int(out_channel)
|
||||
self.in_channel = in_channel
|
||||
self.out_channel = out_channel
|
||||
self.in_drop_ratio = in_drop_ratio
|
||||
self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio)
|
||||
self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio)
|
||||
|
@ -284,7 +283,7 @@ class AttentionHead(nn.Cell):
|
|||
self.matmul = P.MatMul()
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.bias = Parameter(initializer('zeros', self.out_channel), name='bias')
|
||||
self.residual = Validator.check_bool(residual)
|
||||
self.residual = residual
|
||||
if self.residual:
|
||||
if in_channel != out_channel:
|
||||
self.residual_transform_flag = True
|
||||
|
@ -436,8 +435,6 @@ class GAT(nn.Cell):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
features,
|
||||
biases,
|
||||
ftr_dims,
|
||||
num_class,
|
||||
num_nodes,
|
||||
|
@ -448,17 +445,15 @@ class GAT(nn.Cell):
|
|||
activation=nn.ELU(),
|
||||
residual=False):
|
||||
super(GAT, self).__init__()
|
||||
self.features = Tensor(features)
|
||||
self.biases = Tensor(biases)
|
||||
self.ftr_dims = Validator.check_positive_int(ftr_dims)
|
||||
self.num_class = Validator.check_positive_int(num_class)
|
||||
self.num_nodes = Validator.check_positive_int(num_nodes)
|
||||
self.ftr_dims = ftr_dims
|
||||
self.num_class = num_class
|
||||
self.num_nodes = num_nodes
|
||||
self.hidden_units = hidden_units
|
||||
self.num_heads = num_heads
|
||||
self.attn_drop = attn_drop
|
||||
self.ftr_drop = ftr_drop
|
||||
self.activation = activation
|
||||
self.residual = Validator.check_bool(residual)
|
||||
self.residual = residual
|
||||
self.layers = []
|
||||
# first layer
|
||||
self.layers.append(AttentionAggregator(
|
||||
|
@ -491,9 +486,9 @@ class GAT(nn.Cell):
|
|||
output_transform='sum'))
|
||||
self.layers = nn.layer.CellList(self.layers)
|
||||
|
||||
def construct(self, training=True):
|
||||
input_data = self.features
|
||||
bias_mat = self.biases
|
||||
def construct(self, feature, biases, training=True):
|
||||
input_data = feature
|
||||
bias_mat = biases
|
||||
for cell in self.layers:
|
||||
input_data = cell(input_data, bias_mat, training)
|
||||
return input_data/self.num_heads[-1]
|
||||
|
|
|
@ -103,8 +103,8 @@ class LossAccuracyWrapper(nn.Cell):
|
|||
self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, self.network.trainable_params())
|
||||
self.acc_func = MaskedAccuracy(num_class, label, mask)
|
||||
|
||||
def construct(self):
|
||||
logits = self.network(training=False)
|
||||
def construct(self, feature, biases):
|
||||
logits = self.network(feature, biases, training=False)
|
||||
loss = self.loss_func(logits)
|
||||
accuracy = self.acc_func(logits)
|
||||
return loss, accuracy
|
||||
|
@ -120,8 +120,8 @@ class LossNetWrapper(nn.Cell):
|
|||
params = list(param for param in self.network.trainable_params() if param.name[-4:] != 'bias')
|
||||
self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, params)
|
||||
|
||||
def construct(self):
|
||||
logits = self.network()
|
||||
def construct(self, feature, biases):
|
||||
logits = self.network(feature, biases)
|
||||
loss = self.loss_func(logits)
|
||||
return loss
|
||||
|
||||
|
@ -145,11 +145,11 @@ class TrainOneStepCell(nn.Cell):
|
|||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
|
||||
def construct(self):
|
||||
def construct(self, feature, biases):
|
||||
weights = self.weights
|
||||
loss = self.network()
|
||||
loss = self.network(feature, biases)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(sens)
|
||||
grads = self.grad(self.network, weights)(feature, biases, sens)
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
|
||||
|
@ -174,7 +174,7 @@ class TrainGAT(nn.Cell):
|
|||
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
|
||||
self.accuracy_func = MaskedAccuracy(num_class, label, mask)
|
||||
|
||||
def construct(self):
|
||||
loss = self.loss_train_net()
|
||||
accuracy = self.accuracy_func(self.network())
|
||||
def construct(self, feature, biases):
|
||||
loss = self.loss_train_net(feature, biases)
|
||||
accuracy = self.accuracy_func(self.network(feature, biases))
|
||||
return loss, accuracy
|
||||
|
|
|
@ -20,6 +20,7 @@ import numpy as np
|
|||
import mindspore.context as context
|
||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
||||
from mindspore.common import set_seed
|
||||
from mindspore import Tensor
|
||||
|
||||
from src.config import GatConfig
|
||||
from src.dataset import load_and_process
|
||||
|
@ -56,9 +57,7 @@ def train():
|
|||
num_nodes = feature.shape[1]
|
||||
num_class = y_train.shape[2]
|
||||
|
||||
gat_net = GAT(feature,
|
||||
biases,
|
||||
feature_size,
|
||||
gat_net = GAT(feature_size,
|
||||
num_class,
|
||||
num_nodes,
|
||||
hid_units,
|
||||
|
@ -67,6 +66,9 @@ def train():
|
|||
ftr_drop=GatConfig.feature_dropout)
|
||||
gat_net.add_flags_recursive(fp16=True)
|
||||
|
||||
feature = Tensor(feature)
|
||||
biases = Tensor(biases)
|
||||
|
||||
eval_net = LossAccuracyWrapper(gat_net,
|
||||
num_class,
|
||||
y_val,
|
||||
|
@ -84,11 +86,11 @@ def train():
|
|||
val_acc_max = 0.0
|
||||
val_loss_min = np.inf
|
||||
for _epoch in range(num_epochs):
|
||||
train_result = train_net()
|
||||
train_result = train_net(feature, biases)
|
||||
train_loss = train_result[0].asnumpy()
|
||||
train_acc = train_result[1].asnumpy()
|
||||
|
||||
eval_result = eval_net()
|
||||
eval_result = eval_net(feature, biases)
|
||||
eval_loss = eval_result[0].asnumpy()
|
||||
eval_acc = eval_result[1].asnumpy()
|
||||
|
||||
|
@ -110,9 +112,7 @@ def train():
|
|||
print("Early Stop Triggered!, Min loss: {}, Max accuracy: {}".format(val_loss_min, val_acc_max))
|
||||
print("Early stop model validation loss: {}, accuracy{}".format(val_loss_model, val_acc_model))
|
||||
break
|
||||
gat_net_test = GAT(feature,
|
||||
biases,
|
||||
feature_size,
|
||||
gat_net_test = GAT(feature_size,
|
||||
num_class,
|
||||
num_nodes,
|
||||
hid_units,
|
||||
|
@ -127,7 +127,7 @@ def train():
|
|||
y_test,
|
||||
test_mask,
|
||||
l2_coeff)
|
||||
test_result = test_net()
|
||||
test_result = test_net(feature, biases)
|
||||
print("Test loss={}, test acc={}".format(test_result[0], test_result[1]))
|
||||
|
||||
|
||||
|
|
|
@ -92,15 +92,12 @@ class GCN(nn.Cell):
|
|||
output_dim (int): The number of output channels, equal to classes num.
|
||||
"""
|
||||
|
||||
def __init__(self, config, adj, feature, output_dim):
|
||||
def __init__(self, config, input_dim, output_dim):
|
||||
super(GCN, self).__init__()
|
||||
self.adj = Tensor(adj)
|
||||
self.feature = Tensor(feature)
|
||||
input_dim = feature.shape[1]
|
||||
self.layer0 = GraphConvolution(input_dim, config.hidden1, activation="relu", dropout_ratio=config.dropout)
|
||||
self.layer1 = GraphConvolution(config.hidden1, output_dim, dropout_ratio=None)
|
||||
|
||||
def construct(self):
|
||||
output0 = self.layer0(self.adj, self.feature)
|
||||
output1 = self.layer1(self.adj, output0)
|
||||
def construct(self, adj, feature):
|
||||
output0 = self.layer0(adj, feature)
|
||||
output1 = self.layer1(adj, output0)
|
||||
return output1
|
||||
|
|
|
@ -91,8 +91,8 @@ class LossAccuracyWrapper(nn.Cell):
|
|||
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
||||
self.accuracy = Accuracy(label, mask)
|
||||
|
||||
def construct(self):
|
||||
preds = self.network()
|
||||
def construct(self, adj, feature):
|
||||
preds = self.network(adj, feature)
|
||||
loss = self.loss(preds)
|
||||
accuracy = self.accuracy(preds)
|
||||
return loss, accuracy
|
||||
|
@ -114,8 +114,8 @@ class LossWrapper(nn.Cell):
|
|||
self.network = network
|
||||
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
||||
|
||||
def construct(self):
|
||||
preds = self.network()
|
||||
def construct(self, adj, feature):
|
||||
preds = self.network(adj, feature)
|
||||
loss = self.loss(preds)
|
||||
return loss
|
||||
|
||||
|
@ -154,11 +154,11 @@ class TrainOneStepCell(nn.Cell):
|
|||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
|
||||
def construct(self):
|
||||
def construct(self, adj, feature):
|
||||
weights = self.weights
|
||||
loss = self.network()
|
||||
loss = self.network(adj, feature)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(sens)
|
||||
grads = self.grad(self.network, weights)(adj, feature, sens)
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
|
||||
|
@ -182,7 +182,7 @@ class TrainNetWrapper(nn.Cell):
|
|||
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
|
||||
self.accuracy = Accuracy(label, mask)
|
||||
|
||||
def construct(self):
|
||||
loss = self.loss_train_net()
|
||||
accuracy = self.accuracy(self.network())
|
||||
def construct(self, adj, feature):
|
||||
loss = self.loss_train_net(adj, feature)
|
||||
accuracy = self.accuracy(self.network(adj, feature))
|
||||
return loss, accuracy
|
||||
|
|
|
@ -26,6 +26,7 @@ from matplotlib import pyplot as plt
|
|||
from matplotlib import animation
|
||||
from sklearn import manifold
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
||||
|
||||
|
@ -71,9 +72,13 @@ def train():
|
|||
test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num)
|
||||
|
||||
class_num = label_onehot.shape[1]
|
||||
gcn_net = GCN(config, adj, feature, class_num)
|
||||
input_dim = feature.shape[1]
|
||||
gcn_net = GCN(config, input_dim, class_num)
|
||||
gcn_net.add_flags_recursive(fp16=True)
|
||||
|
||||
adj = Tensor(adj)
|
||||
feature = Tensor(feature)
|
||||
|
||||
eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
|
||||
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
|
||||
|
||||
|
@ -92,12 +97,12 @@ def train():
|
|||
t = time.time()
|
||||
|
||||
train_net.set_train()
|
||||
train_result = train_net()
|
||||
train_result = train_net(adj, feature)
|
||||
train_loss = train_result[0].asnumpy()
|
||||
train_accuracy = train_result[1].asnumpy()
|
||||
|
||||
eval_net.set_train(False)
|
||||
eval_result = eval_net()
|
||||
eval_result = eval_net(adj, feature)
|
||||
eval_loss = eval_result[0].asnumpy()
|
||||
eval_accuracy = eval_result[1].asnumpy()
|
||||
|
||||
|
@ -115,14 +120,14 @@ def train():
|
|||
print("Early stopping...")
|
||||
break
|
||||
save_checkpoint(gcn_net, "ckpts/gcn.ckpt")
|
||||
gcn_net_test = GCN(config, adj, feature, class_num)
|
||||
gcn_net_test = GCN(config, input_dim, class_num)
|
||||
load_checkpoint("ckpts/gcn.ckpt", net=gcn_net_test)
|
||||
gcn_net_test.add_flags_recursive(fp16=True)
|
||||
|
||||
test_net = LossAccuracyWrapper(gcn_net_test, label_onehot, test_mask, config.weight_decay)
|
||||
t_test = time.time()
|
||||
test_net.set_train(False)
|
||||
test_result = test_net()
|
||||
test_result = test_net(adj, feature)
|
||||
test_loss = test_result[0].asnumpy()
|
||||
test_accuracy = test_result[1].asnumpy()
|
||||
print("Test set results:", "loss=", "{:.5f}".format(test_loss),
|
||||
|
|
|
@ -17,6 +17,7 @@ import time
|
|||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from model_zoo.official.gnn.gcn.src.gcn import GCN
|
||||
from model_zoo.official.gnn.gcn.src.metrics import LossAccuracyWrapper, TrainNetWrapper
|
||||
from model_zoo.official.gnn.gcn.src.config import ConfigGCN
|
||||
|
@ -49,9 +50,13 @@ def test_gcn():
|
|||
test_mask = get_mask(nodes_num, nodes_num - TEST_NODE_NUM, nodes_num)
|
||||
|
||||
class_num = label_onehot.shape[1]
|
||||
gcn_net = GCN(config, adj, feature, class_num)
|
||||
input_dim = feature.shape[1]
|
||||
gcn_net = GCN(config, input_dim, class_num)
|
||||
gcn_net.add_flags_recursive(fp16=True)
|
||||
|
||||
adj = Tensor(adj)
|
||||
feature = Tensor(feature)
|
||||
|
||||
eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
|
||||
test_net = LossAccuracyWrapper(gcn_net, label_onehot, test_mask, config.weight_decay)
|
||||
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
|
||||
|
@ -61,12 +66,12 @@ def test_gcn():
|
|||
t = time.time()
|
||||
|
||||
train_net.set_train()
|
||||
train_result = train_net()
|
||||
train_result = train_net(adj, feature)
|
||||
train_loss = train_result[0].asnumpy()
|
||||
train_accuracy = train_result[1].asnumpy()
|
||||
|
||||
eval_net.set_train(False)
|
||||
eval_result = eval_net()
|
||||
eval_result = eval_net(adj, feature)
|
||||
eval_loss = eval_result[0].asnumpy()
|
||||
eval_accuracy = eval_result[1].asnumpy()
|
||||
|
||||
|
@ -80,7 +85,7 @@ def test_gcn():
|
|||
break
|
||||
|
||||
test_net.set_train(False)
|
||||
test_result = test_net()
|
||||
test_result = test_net(adj, feature)
|
||||
test_loss = test_result[0].asnumpy()
|
||||
test_accuracy = test_result[1].asnumpy()
|
||||
print("Test set results:", "loss=", "{:.5f}".format(test_loss),
|
||||
|
|
Loading…
Reference in New Issue