forked from mindspore-Ecosystem/mindspore
change gnn input position
This commit is contained in:
parent
b5723dcd81
commit
4384ca214f
|
@ -19,7 +19,6 @@ from mindspore.ops import functional as F
|
||||||
from mindspore._extends import cell_attr_register
|
from mindspore._extends import cell_attr_register
|
||||||
from mindspore import Tensor, Parameter
|
from mindspore import Tensor, Parameter
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
from mindspore._checkparam import Validator
|
|
||||||
from mindspore.nn.layer.activation import get_activation
|
from mindspore.nn.layer.activation import get_activation
|
||||||
|
|
||||||
|
|
||||||
|
@ -72,9 +71,9 @@ class GNNFeatureTransform(nn.Cell):
|
||||||
bias_init='zeros',
|
bias_init='zeros',
|
||||||
has_bias=True):
|
has_bias=True):
|
||||||
super(GNNFeatureTransform, self).__init__()
|
super(GNNFeatureTransform, self).__init__()
|
||||||
self.in_channels = Validator.check_positive_int(in_channels)
|
self.in_channels = in_channels
|
||||||
self.out_channels = Validator.check_positive_int(out_channels)
|
self.out_channels = out_channels
|
||||||
self.has_bias = Validator.check_bool(has_bias)
|
self.has_bias = has_bias
|
||||||
|
|
||||||
if isinstance(weight_init, Tensor):
|
if isinstance(weight_init, Tensor):
|
||||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||||
|
@ -259,8 +258,8 @@ class AttentionHead(nn.Cell):
|
||||||
coef_activation=nn.LeakyReLU(),
|
coef_activation=nn.LeakyReLU(),
|
||||||
activation=nn.ELU()):
|
activation=nn.ELU()):
|
||||||
super(AttentionHead, self).__init__()
|
super(AttentionHead, self).__init__()
|
||||||
self.in_channel = Validator.check_positive_int(in_channel)
|
self.in_channel = in_channel
|
||||||
self.out_channel = Validator.check_positive_int(out_channel)
|
self.out_channel = out_channel
|
||||||
self.in_drop_ratio = in_drop_ratio
|
self.in_drop_ratio = in_drop_ratio
|
||||||
self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio)
|
self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio)
|
||||||
self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio)
|
self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio)
|
||||||
|
@ -284,7 +283,7 @@ class AttentionHead(nn.Cell):
|
||||||
self.matmul = P.MatMul()
|
self.matmul = P.MatMul()
|
||||||
self.bias_add = P.BiasAdd()
|
self.bias_add = P.BiasAdd()
|
||||||
self.bias = Parameter(initializer('zeros', self.out_channel), name='bias')
|
self.bias = Parameter(initializer('zeros', self.out_channel), name='bias')
|
||||||
self.residual = Validator.check_bool(residual)
|
self.residual = residual
|
||||||
if self.residual:
|
if self.residual:
|
||||||
if in_channel != out_channel:
|
if in_channel != out_channel:
|
||||||
self.residual_transform_flag = True
|
self.residual_transform_flag = True
|
||||||
|
@ -436,8 +435,6 @@ class GAT(nn.Cell):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
features,
|
|
||||||
biases,
|
|
||||||
ftr_dims,
|
ftr_dims,
|
||||||
num_class,
|
num_class,
|
||||||
num_nodes,
|
num_nodes,
|
||||||
|
@ -448,17 +445,15 @@ class GAT(nn.Cell):
|
||||||
activation=nn.ELU(),
|
activation=nn.ELU(),
|
||||||
residual=False):
|
residual=False):
|
||||||
super(GAT, self).__init__()
|
super(GAT, self).__init__()
|
||||||
self.features = Tensor(features)
|
self.ftr_dims = ftr_dims
|
||||||
self.biases = Tensor(biases)
|
self.num_class = num_class
|
||||||
self.ftr_dims = Validator.check_positive_int(ftr_dims)
|
self.num_nodes = num_nodes
|
||||||
self.num_class = Validator.check_positive_int(num_class)
|
|
||||||
self.num_nodes = Validator.check_positive_int(num_nodes)
|
|
||||||
self.hidden_units = hidden_units
|
self.hidden_units = hidden_units
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.attn_drop = attn_drop
|
self.attn_drop = attn_drop
|
||||||
self.ftr_drop = ftr_drop
|
self.ftr_drop = ftr_drop
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
self.residual = Validator.check_bool(residual)
|
self.residual = residual
|
||||||
self.layers = []
|
self.layers = []
|
||||||
# first layer
|
# first layer
|
||||||
self.layers.append(AttentionAggregator(
|
self.layers.append(AttentionAggregator(
|
||||||
|
@ -491,9 +486,9 @@ class GAT(nn.Cell):
|
||||||
output_transform='sum'))
|
output_transform='sum'))
|
||||||
self.layers = nn.layer.CellList(self.layers)
|
self.layers = nn.layer.CellList(self.layers)
|
||||||
|
|
||||||
def construct(self, training=True):
|
def construct(self, feature, biases, training=True):
|
||||||
input_data = self.features
|
input_data = feature
|
||||||
bias_mat = self.biases
|
bias_mat = biases
|
||||||
for cell in self.layers:
|
for cell in self.layers:
|
||||||
input_data = cell(input_data, bias_mat, training)
|
input_data = cell(input_data, bias_mat, training)
|
||||||
return input_data/self.num_heads[-1]
|
return input_data/self.num_heads[-1]
|
||||||
|
|
|
@ -103,8 +103,8 @@ class LossAccuracyWrapper(nn.Cell):
|
||||||
self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, self.network.trainable_params())
|
self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, self.network.trainable_params())
|
||||||
self.acc_func = MaskedAccuracy(num_class, label, mask)
|
self.acc_func = MaskedAccuracy(num_class, label, mask)
|
||||||
|
|
||||||
def construct(self):
|
def construct(self, feature, biases):
|
||||||
logits = self.network(training=False)
|
logits = self.network(feature, biases, training=False)
|
||||||
loss = self.loss_func(logits)
|
loss = self.loss_func(logits)
|
||||||
accuracy = self.acc_func(logits)
|
accuracy = self.acc_func(logits)
|
||||||
return loss, accuracy
|
return loss, accuracy
|
||||||
|
@ -120,8 +120,8 @@ class LossNetWrapper(nn.Cell):
|
||||||
params = list(param for param in self.network.trainable_params() if param.name[-4:] != 'bias')
|
params = list(param for param in self.network.trainable_params() if param.name[-4:] != 'bias')
|
||||||
self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, params)
|
self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, params)
|
||||||
|
|
||||||
def construct(self):
|
def construct(self, feature, biases):
|
||||||
logits = self.network()
|
logits = self.network(feature, biases)
|
||||||
loss = self.loss_func(logits)
|
loss = self.loss_func(logits)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
@ -145,11 +145,11 @@ class TrainOneStepCell(nn.Cell):
|
||||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||||
self.sens = sens
|
self.sens = sens
|
||||||
|
|
||||||
def construct(self):
|
def construct(self, feature, biases):
|
||||||
weights = self.weights
|
weights = self.weights
|
||||||
loss = self.network()
|
loss = self.network(feature, biases)
|
||||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||||
grads = self.grad(self.network, weights)(sens)
|
grads = self.grad(self.network, weights)(feature, biases, sens)
|
||||||
return F.depend(loss, self.optimizer(grads))
|
return F.depend(loss, self.optimizer(grads))
|
||||||
|
|
||||||
|
|
||||||
|
@ -174,7 +174,7 @@ class TrainGAT(nn.Cell):
|
||||||
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
|
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
|
||||||
self.accuracy_func = MaskedAccuracy(num_class, label, mask)
|
self.accuracy_func = MaskedAccuracy(num_class, label, mask)
|
||||||
|
|
||||||
def construct(self):
|
def construct(self, feature, biases):
|
||||||
loss = self.loss_train_net()
|
loss = self.loss_train_net(feature, biases)
|
||||||
accuracy = self.accuracy_func(self.network())
|
accuracy = self.accuracy_func(self.network(feature, biases))
|
||||||
return loss, accuracy
|
return loss, accuracy
|
||||||
|
|
|
@ -20,6 +20,7 @@ import numpy as np
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
||||||
from mindspore.common import set_seed
|
from mindspore.common import set_seed
|
||||||
|
from mindspore import Tensor
|
||||||
|
|
||||||
from src.config import GatConfig
|
from src.config import GatConfig
|
||||||
from src.dataset import load_and_process
|
from src.dataset import load_and_process
|
||||||
|
@ -56,9 +57,7 @@ def train():
|
||||||
num_nodes = feature.shape[1]
|
num_nodes = feature.shape[1]
|
||||||
num_class = y_train.shape[2]
|
num_class = y_train.shape[2]
|
||||||
|
|
||||||
gat_net = GAT(feature,
|
gat_net = GAT(feature_size,
|
||||||
biases,
|
|
||||||
feature_size,
|
|
||||||
num_class,
|
num_class,
|
||||||
num_nodes,
|
num_nodes,
|
||||||
hid_units,
|
hid_units,
|
||||||
|
@ -67,6 +66,9 @@ def train():
|
||||||
ftr_drop=GatConfig.feature_dropout)
|
ftr_drop=GatConfig.feature_dropout)
|
||||||
gat_net.add_flags_recursive(fp16=True)
|
gat_net.add_flags_recursive(fp16=True)
|
||||||
|
|
||||||
|
feature = Tensor(feature)
|
||||||
|
biases = Tensor(biases)
|
||||||
|
|
||||||
eval_net = LossAccuracyWrapper(gat_net,
|
eval_net = LossAccuracyWrapper(gat_net,
|
||||||
num_class,
|
num_class,
|
||||||
y_val,
|
y_val,
|
||||||
|
@ -84,11 +86,11 @@ def train():
|
||||||
val_acc_max = 0.0
|
val_acc_max = 0.0
|
||||||
val_loss_min = np.inf
|
val_loss_min = np.inf
|
||||||
for _epoch in range(num_epochs):
|
for _epoch in range(num_epochs):
|
||||||
train_result = train_net()
|
train_result = train_net(feature, biases)
|
||||||
train_loss = train_result[0].asnumpy()
|
train_loss = train_result[0].asnumpy()
|
||||||
train_acc = train_result[1].asnumpy()
|
train_acc = train_result[1].asnumpy()
|
||||||
|
|
||||||
eval_result = eval_net()
|
eval_result = eval_net(feature, biases)
|
||||||
eval_loss = eval_result[0].asnumpy()
|
eval_loss = eval_result[0].asnumpy()
|
||||||
eval_acc = eval_result[1].asnumpy()
|
eval_acc = eval_result[1].asnumpy()
|
||||||
|
|
||||||
|
@ -110,9 +112,7 @@ def train():
|
||||||
print("Early Stop Triggered!, Min loss: {}, Max accuracy: {}".format(val_loss_min, val_acc_max))
|
print("Early Stop Triggered!, Min loss: {}, Max accuracy: {}".format(val_loss_min, val_acc_max))
|
||||||
print("Early stop model validation loss: {}, accuracy{}".format(val_loss_model, val_acc_model))
|
print("Early stop model validation loss: {}, accuracy{}".format(val_loss_model, val_acc_model))
|
||||||
break
|
break
|
||||||
gat_net_test = GAT(feature,
|
gat_net_test = GAT(feature_size,
|
||||||
biases,
|
|
||||||
feature_size,
|
|
||||||
num_class,
|
num_class,
|
||||||
num_nodes,
|
num_nodes,
|
||||||
hid_units,
|
hid_units,
|
||||||
|
@ -127,7 +127,7 @@ def train():
|
||||||
y_test,
|
y_test,
|
||||||
test_mask,
|
test_mask,
|
||||||
l2_coeff)
|
l2_coeff)
|
||||||
test_result = test_net()
|
test_result = test_net(feature, biases)
|
||||||
print("Test loss={}, test acc={}".format(test_result[0], test_result[1]))
|
print("Test loss={}, test acc={}".format(test_result[0], test_result[1]))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -92,15 +92,12 @@ class GCN(nn.Cell):
|
||||||
output_dim (int): The number of output channels, equal to classes num.
|
output_dim (int): The number of output channels, equal to classes num.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, adj, feature, output_dim):
|
def __init__(self, config, input_dim, output_dim):
|
||||||
super(GCN, self).__init__()
|
super(GCN, self).__init__()
|
||||||
self.adj = Tensor(adj)
|
|
||||||
self.feature = Tensor(feature)
|
|
||||||
input_dim = feature.shape[1]
|
|
||||||
self.layer0 = GraphConvolution(input_dim, config.hidden1, activation="relu", dropout_ratio=config.dropout)
|
self.layer0 = GraphConvolution(input_dim, config.hidden1, activation="relu", dropout_ratio=config.dropout)
|
||||||
self.layer1 = GraphConvolution(config.hidden1, output_dim, dropout_ratio=None)
|
self.layer1 = GraphConvolution(config.hidden1, output_dim, dropout_ratio=None)
|
||||||
|
|
||||||
def construct(self):
|
def construct(self, adj, feature):
|
||||||
output0 = self.layer0(self.adj, self.feature)
|
output0 = self.layer0(adj, feature)
|
||||||
output1 = self.layer1(self.adj, output0)
|
output1 = self.layer1(adj, output0)
|
||||||
return output1
|
return output1
|
||||||
|
|
|
@ -91,8 +91,8 @@ class LossAccuracyWrapper(nn.Cell):
|
||||||
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
||||||
self.accuracy = Accuracy(label, mask)
|
self.accuracy = Accuracy(label, mask)
|
||||||
|
|
||||||
def construct(self):
|
def construct(self, adj, feature):
|
||||||
preds = self.network()
|
preds = self.network(adj, feature)
|
||||||
loss = self.loss(preds)
|
loss = self.loss(preds)
|
||||||
accuracy = self.accuracy(preds)
|
accuracy = self.accuracy(preds)
|
||||||
return loss, accuracy
|
return loss, accuracy
|
||||||
|
@ -114,8 +114,8 @@ class LossWrapper(nn.Cell):
|
||||||
self.network = network
|
self.network = network
|
||||||
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
|
||||||
|
|
||||||
def construct(self):
|
def construct(self, adj, feature):
|
||||||
preds = self.network()
|
preds = self.network(adj, feature)
|
||||||
loss = self.loss(preds)
|
loss = self.loss(preds)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
@ -154,11 +154,11 @@ class TrainOneStepCell(nn.Cell):
|
||||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||||
self.sens = sens
|
self.sens = sens
|
||||||
|
|
||||||
def construct(self):
|
def construct(self, adj, feature):
|
||||||
weights = self.weights
|
weights = self.weights
|
||||||
loss = self.network()
|
loss = self.network(adj, feature)
|
||||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||||
grads = self.grad(self.network, weights)(sens)
|
grads = self.grad(self.network, weights)(adj, feature, sens)
|
||||||
return F.depend(loss, self.optimizer(grads))
|
return F.depend(loss, self.optimizer(grads))
|
||||||
|
|
||||||
|
|
||||||
|
@ -182,7 +182,7 @@ class TrainNetWrapper(nn.Cell):
|
||||||
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
|
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
|
||||||
self.accuracy = Accuracy(label, mask)
|
self.accuracy = Accuracy(label, mask)
|
||||||
|
|
||||||
def construct(self):
|
def construct(self, adj, feature):
|
||||||
loss = self.loss_train_net()
|
loss = self.loss_train_net(adj, feature)
|
||||||
accuracy = self.accuracy(self.network())
|
accuracy = self.accuracy(self.network(adj, feature))
|
||||||
return loss, accuracy
|
return loss, accuracy
|
||||||
|
|
|
@ -26,6 +26,7 @@ from matplotlib import pyplot as plt
|
||||||
from matplotlib import animation
|
from matplotlib import animation
|
||||||
from sklearn import manifold
|
from sklearn import manifold
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
from mindspore import Tensor
|
||||||
from mindspore.common import set_seed
|
from mindspore.common import set_seed
|
||||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
||||||
|
|
||||||
|
@ -71,9 +72,13 @@ def train():
|
||||||
test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num)
|
test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num)
|
||||||
|
|
||||||
class_num = label_onehot.shape[1]
|
class_num = label_onehot.shape[1]
|
||||||
gcn_net = GCN(config, adj, feature, class_num)
|
input_dim = feature.shape[1]
|
||||||
|
gcn_net = GCN(config, input_dim, class_num)
|
||||||
gcn_net.add_flags_recursive(fp16=True)
|
gcn_net.add_flags_recursive(fp16=True)
|
||||||
|
|
||||||
|
adj = Tensor(adj)
|
||||||
|
feature = Tensor(feature)
|
||||||
|
|
||||||
eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
|
eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
|
||||||
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
|
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
|
||||||
|
|
||||||
|
@ -92,12 +97,12 @@ def train():
|
||||||
t = time.time()
|
t = time.time()
|
||||||
|
|
||||||
train_net.set_train()
|
train_net.set_train()
|
||||||
train_result = train_net()
|
train_result = train_net(adj, feature)
|
||||||
train_loss = train_result[0].asnumpy()
|
train_loss = train_result[0].asnumpy()
|
||||||
train_accuracy = train_result[1].asnumpy()
|
train_accuracy = train_result[1].asnumpy()
|
||||||
|
|
||||||
eval_net.set_train(False)
|
eval_net.set_train(False)
|
||||||
eval_result = eval_net()
|
eval_result = eval_net(adj, feature)
|
||||||
eval_loss = eval_result[0].asnumpy()
|
eval_loss = eval_result[0].asnumpy()
|
||||||
eval_accuracy = eval_result[1].asnumpy()
|
eval_accuracy = eval_result[1].asnumpy()
|
||||||
|
|
||||||
|
@ -115,14 +120,14 @@ def train():
|
||||||
print("Early stopping...")
|
print("Early stopping...")
|
||||||
break
|
break
|
||||||
save_checkpoint(gcn_net, "ckpts/gcn.ckpt")
|
save_checkpoint(gcn_net, "ckpts/gcn.ckpt")
|
||||||
gcn_net_test = GCN(config, adj, feature, class_num)
|
gcn_net_test = GCN(config, input_dim, class_num)
|
||||||
load_checkpoint("ckpts/gcn.ckpt", net=gcn_net_test)
|
load_checkpoint("ckpts/gcn.ckpt", net=gcn_net_test)
|
||||||
gcn_net_test.add_flags_recursive(fp16=True)
|
gcn_net_test.add_flags_recursive(fp16=True)
|
||||||
|
|
||||||
test_net = LossAccuracyWrapper(gcn_net_test, label_onehot, test_mask, config.weight_decay)
|
test_net = LossAccuracyWrapper(gcn_net_test, label_onehot, test_mask, config.weight_decay)
|
||||||
t_test = time.time()
|
t_test = time.time()
|
||||||
test_net.set_train(False)
|
test_net.set_train(False)
|
||||||
test_result = test_net()
|
test_result = test_net(adj, feature)
|
||||||
test_loss = test_result[0].asnumpy()
|
test_loss = test_result[0].asnumpy()
|
||||||
test_accuracy = test_result[1].asnumpy()
|
test_accuracy = test_result[1].asnumpy()
|
||||||
print("Test set results:", "loss=", "{:.5f}".format(test_loss),
|
print("Test set results:", "loss=", "{:.5f}".format(test_loss),
|
||||||
|
|
|
@ -17,6 +17,7 @@ import time
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
from mindspore import Tensor
|
||||||
from model_zoo.official.gnn.gcn.src.gcn import GCN
|
from model_zoo.official.gnn.gcn.src.gcn import GCN
|
||||||
from model_zoo.official.gnn.gcn.src.metrics import LossAccuracyWrapper, TrainNetWrapper
|
from model_zoo.official.gnn.gcn.src.metrics import LossAccuracyWrapper, TrainNetWrapper
|
||||||
from model_zoo.official.gnn.gcn.src.config import ConfigGCN
|
from model_zoo.official.gnn.gcn.src.config import ConfigGCN
|
||||||
|
@ -49,9 +50,13 @@ def test_gcn():
|
||||||
test_mask = get_mask(nodes_num, nodes_num - TEST_NODE_NUM, nodes_num)
|
test_mask = get_mask(nodes_num, nodes_num - TEST_NODE_NUM, nodes_num)
|
||||||
|
|
||||||
class_num = label_onehot.shape[1]
|
class_num = label_onehot.shape[1]
|
||||||
gcn_net = GCN(config, adj, feature, class_num)
|
input_dim = feature.shape[1]
|
||||||
|
gcn_net = GCN(config, input_dim, class_num)
|
||||||
gcn_net.add_flags_recursive(fp16=True)
|
gcn_net.add_flags_recursive(fp16=True)
|
||||||
|
|
||||||
|
adj = Tensor(adj)
|
||||||
|
feature = Tensor(feature)
|
||||||
|
|
||||||
eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
|
eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
|
||||||
test_net = LossAccuracyWrapper(gcn_net, label_onehot, test_mask, config.weight_decay)
|
test_net = LossAccuracyWrapper(gcn_net, label_onehot, test_mask, config.weight_decay)
|
||||||
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
|
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
|
||||||
|
@ -61,12 +66,12 @@ def test_gcn():
|
||||||
t = time.time()
|
t = time.time()
|
||||||
|
|
||||||
train_net.set_train()
|
train_net.set_train()
|
||||||
train_result = train_net()
|
train_result = train_net(adj, feature)
|
||||||
train_loss = train_result[0].asnumpy()
|
train_loss = train_result[0].asnumpy()
|
||||||
train_accuracy = train_result[1].asnumpy()
|
train_accuracy = train_result[1].asnumpy()
|
||||||
|
|
||||||
eval_net.set_train(False)
|
eval_net.set_train(False)
|
||||||
eval_result = eval_net()
|
eval_result = eval_net(adj, feature)
|
||||||
eval_loss = eval_result[0].asnumpy()
|
eval_loss = eval_result[0].asnumpy()
|
||||||
eval_accuracy = eval_result[1].asnumpy()
|
eval_accuracy = eval_result[1].asnumpy()
|
||||||
|
|
||||||
|
@ -80,7 +85,7 @@ def test_gcn():
|
||||||
break
|
break
|
||||||
|
|
||||||
test_net.set_train(False)
|
test_net.set_train(False)
|
||||||
test_result = test_net()
|
test_result = test_net(adj, feature)
|
||||||
test_loss = test_result[0].asnumpy()
|
test_loss = test_result[0].asnumpy()
|
||||||
test_accuracy = test_result[1].asnumpy()
|
test_accuracy = test_result[1].asnumpy()
|
||||||
print("Test set results:", "loss=", "{:.5f}".format(test_loss),
|
print("Test set results:", "loss=", "{:.5f}".format(test_loss),
|
||||||
|
|
Loading…
Reference in New Issue