fix test_gcn

This commit is contained in:
sophie2020 2022-05-19 20:29:37 +08:00
parent f47f97d24b
commit bfe97a2a33
1 changed files with 9 additions and 7 deletions

View File

@ -61,6 +61,7 @@ def test_gcn():
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
loss_list = []
best_acc = 0.
for epoch in range(config.epochs):
t = time.time()
@ -78,15 +79,16 @@ def test_gcn():
print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(train_loss),
"train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss),
"val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t))
if epoch % 5 == 0:
test_net.set_train(False)
test_result = test_net(adj, feature)
test_loss = test_result[0].asnumpy()
test_accuracy = test_result[1].asnumpy()
print("Test set results:", "loss=", "{:.5f}".format(test_loss), "accuracy=", "{:.5f}".format(test_accuracy))
best_acc = test_accuracy if test_accuracy > best_acc else best_acc
if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]):
print("Early stopping...")
break
test_net.set_train(False)
test_result = test_net(adj, feature)
test_loss = test_result[0].asnumpy()
test_accuracy = test_result[1].asnumpy()
print("Test set results:", "loss=", "{:.5f}".format(test_loss),
"accuracy=", "{:.5f}".format(test_accuracy))
assert test_accuracy > 0.812
assert best_acc > 0.812