fix textrcnn eval net amp construct problem

This commit is contained in:
陈劢 2021-06-08 15:08:33 +08:00
parent 635c2b0adb
commit 87c9e71826
1 changed files with 2 additions and 1 deletions

View File

@ -48,12 +48,13 @@ if __name__ == '__main__':
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0], network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0],
cell=cfg.cell, batch_size=cfg.batch_size) cell=cfg.cell, batch_size=cfg.batch_size)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
eval_net = nn.WithEvalCell(network, loss, True)
loss_cb = LossMonitor() loss_cb = LossMonitor()
print("============== Starting Testing ==============") print("============== Starting Testing ==============")
ds_eval = create_dataset(cfg.preprocess_path, cfg.batch_size, False) ds_eval = create_dataset(cfg.preprocess_path, cfg.batch_size, False)
param_dict = load_checkpoint(args.ckpt_path) param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
network.set_train(False) network.set_train(False)
model = Model(network, loss, metrics={'acc': Accuracy()}, amp_level='O3') model = Model(network, loss, metrics={'acc': Accuracy()}, eval_network=eval_net, eval_indexes=[0, 1, 2])
acc = model.eval(ds_eval, dataset_sink_mode=False) acc = model.eval(ds_eval, dataset_sink_mode=False)
print("============== Accuracy:{} ==============".format(acc)) print("============== Accuracy:{} ==============".format(acc))