fix textrcnn eval net amp construct problem

This commit is contained in:
陈劢 2021-06-08 15:08:33 +08:00
parent 635c2b0adb
commit 87c9e71826
1 changed files with 2 additions and 1 deletions

View File

@ -48,12 +48,13 @@ if __name__ == '__main__':
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0],
cell=cfg.cell, batch_size=cfg.batch_size)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
eval_net = nn.WithEvalCell(network, loss, True)
loss_cb = LossMonitor()
print("============== Starting Testing ==============")
ds_eval = create_dataset(cfg.preprocess_path, cfg.batch_size, False)
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
network.set_train(False)
model = Model(network, loss, metrics={'acc': Accuracy()}, amp_level='O3')
model = Model(network, loss, metrics={'acc': Accuracy()}, eval_network=eval_net, eval_indexes=[0, 1, 2])
acc = model.eval(ds_eval, dataset_sink_mode=False)
print("============== Accuracy:{} ==============".format(acc))