User submitted
This commit is contained in:
parent
49e10b16cc
commit
a937f6dc6c
|
@ -170,7 +170,12 @@ def test_net(args, network, model, mnist_path):
|
|||
# 请在此添加代码完成本关任务
|
||||
#********** Begin *********#
|
||||
## 提示:补全验证函数的代码
|
||||
|
||||
param_dict = load_checkpoint("checkpoint_lenet-1_1875.ckpt")
|
||||
# load parameter to the network
|
||||
load_param_into_net(network, param_dict)
|
||||
# load testing dataset
|
||||
ds_eval = create_dataset(os.path.join(mnist_path, "test"))
|
||||
acc = model.eval(ds_eval, dataset_sink_mode=False)
|
||||
#********** End **********#
|
||||
print("============== Accuracy:{} ==============".format(acc))
|
||||
|
||||
|
|
Loading…
Reference in New Issue