diff --git a/model_zoo/research/nlp/textrcnn/eval.py b/model_zoo/research/nlp/textrcnn/eval.py index f36a5473bb0..cb36f70891a 100644 --- a/model_zoo/research/nlp/textrcnn/eval.py +++ b/model_zoo/research/nlp/textrcnn/eval.py @@ -48,12 +48,13 @@ if __name__ == '__main__': network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0], cell=cfg.cell, batch_size=cfg.batch_size) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) + eval_net = nn.WithEvalCell(network, loss, True) loss_cb = LossMonitor() print("============== Starting Testing ==============") ds_eval = create_dataset(cfg.preprocess_path, cfg.batch_size, False) param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(network, param_dict) network.set_train(False) - model = Model(network, loss, metrics={'acc': Accuracy()}, amp_level='O3') + model = Model(network, loss, metrics={'acc': Accuracy()}, eval_network=eval_net, eval_indexes=[0, 1, 2]) acc = model.eval(ds_eval, dataset_sink_mode=False) print("============== Accuracy:{} ==============".format(acc))