forked from mindspore-Ecosystem/mindspore
fix test_gcn
This commit is contained in:
parent
f47f97d24b
commit
bfe97a2a33
|
@ -61,6 +61,7 @@ def test_gcn():
|
|||
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
|
||||
|
||||
loss_list = []
|
||||
best_acc = 0.
|
||||
for epoch in range(config.epochs):
|
||||
t = time.time()
|
||||
|
||||
|
@ -78,15 +79,16 @@ def test_gcn():
|
|||
print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(train_loss),
|
||||
"train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss),
|
||||
"val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t))
|
||||
if epoch % 5 == 0:
|
||||
test_net.set_train(False)
|
||||
test_result = test_net(adj, feature)
|
||||
test_loss = test_result[0].asnumpy()
|
||||
test_accuracy = test_result[1].asnumpy()
|
||||
print("Test set results:", "loss=", "{:.5f}".format(test_loss), "accuracy=", "{:.5f}".format(test_accuracy))
|
||||
best_acc = test_accuracy if test_accuracy > best_acc else best_acc
|
||||
|
||||
if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]):
|
||||
print("Early stopping...")
|
||||
break
|
||||
|
||||
test_net.set_train(False)
|
||||
test_result = test_net(adj, feature)
|
||||
test_loss = test_result[0].asnumpy()
|
||||
test_accuracy = test_result[1].asnumpy()
|
||||
print("Test set results:", "loss=", "{:.5f}".format(test_loss),
|
||||
"accuracy=", "{:.5f}".format(test_accuracy))
|
||||
assert test_accuracy > 0.812
|
||||
assert best_acc > 0.812
|
||||
|
|
Loading…
Reference in New Issue