forked from mindspore-Ecosystem/mindspore
fix textrcnn eval net amp construct problem
This commit is contained in:
parent
635c2b0adb
commit
87c9e71826
|
@ -48,12 +48,13 @@ if __name__ == '__main__':
|
||||||
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0],
|
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0],
|
||||||
cell=cfg.cell, batch_size=cfg.batch_size)
|
cell=cfg.cell, batch_size=cfg.batch_size)
|
||||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||||
|
eval_net = nn.WithEvalCell(network, loss, True)
|
||||||
loss_cb = LossMonitor()
|
loss_cb = LossMonitor()
|
||||||
print("============== Starting Testing ==============")
|
print("============== Starting Testing ==============")
|
||||||
ds_eval = create_dataset(cfg.preprocess_path, cfg.batch_size, False)
|
ds_eval = create_dataset(cfg.preprocess_path, cfg.batch_size, False)
|
||||||
param_dict = load_checkpoint(args.ckpt_path)
|
param_dict = load_checkpoint(args.ckpt_path)
|
||||||
load_param_into_net(network, param_dict)
|
load_param_into_net(network, param_dict)
|
||||||
network.set_train(False)
|
network.set_train(False)
|
||||||
model = Model(network, loss, metrics={'acc': Accuracy()}, amp_level='O3')
|
model = Model(network, loss, metrics={'acc': Accuracy()}, eval_network=eval_net, eval_indexes=[0, 1, 2])
|
||||||
acc = model.eval(ds_eval, dataset_sink_mode=False)
|
acc = model.eval(ds_eval, dataset_sink_mode=False)
|
||||||
print("============== Accuracy:{} ==============".format(acc))
|
print("============== Accuracy:{} ==============".format(acc))
|
||||||
|
|
Loading…
Reference in New Issue