From 87c9e71826af0a9c0ecb8209ba33e867b196cf2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=8A=A2?= Date: Tue, 8 Jun 2021 15:08:33 +0800 Subject: [PATCH] fix textrcnn eval net amp construct problem --- model_zoo/research/nlp/textrcnn/eval.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/model_zoo/research/nlp/textrcnn/eval.py b/model_zoo/research/nlp/textrcnn/eval.py index f36a5473bb0..cb36f70891a 100644 --- a/model_zoo/research/nlp/textrcnn/eval.py +++ b/model_zoo/research/nlp/textrcnn/eval.py @@ -48,12 +48,13 @@ if __name__ == '__main__': network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0], cell=cfg.cell, batch_size=cfg.batch_size) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) + eval_net = nn.WithEvalCell(network, loss, True) loss_cb = LossMonitor() print("============== Starting Testing ==============") ds_eval = create_dataset(cfg.preprocess_path, cfg.batch_size, False) param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(network, param_dict) network.set_train(False) - model = Model(network, loss, metrics={'acc': Accuracy()}, amp_level='O3') + model = Model(network, loss, metrics={'acc': Accuracy()}, eval_network=eval_net, eval_indexes=[0, 1, 2]) acc = model.eval(ds_eval, dataset_sink_mode=False) print("============== Accuracy:{} ==============".format(acc))