forked from mindspore-Ecosystem/mindspore
fix textrcnn eval net amp construct problem
This commit is contained in:
parent
635c2b0adb
commit
87c9e71826
|
@ -48,12 +48,13 @@ if __name__ == '__main__':
|
|||
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0],
|
||||
cell=cfg.cell, batch_size=cfg.batch_size)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
eval_net = nn.WithEvalCell(network, loss, True)
|
||||
loss_cb = LossMonitor()
|
||||
print("============== Starting Testing ==============")
|
||||
ds_eval = create_dataset(cfg.preprocess_path, cfg.batch_size, False)
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
network.set_train(False)
|
||||
model = Model(network, loss, metrics={'acc': Accuracy()}, amp_level='O3')
|
||||
model = Model(network, loss, metrics={'acc': Accuracy()}, eval_network=eval_net, eval_indexes=[0, 1, 2])
|
||||
acc = model.eval(ds_eval, dataset_sink_mode=False)
|
||||
print("============== Accuracy:{} ==============".format(acc))
|
||||
|
|
Loading…
Reference in New Issue