forked from mindspore-Ecosystem/mindspore
!48221 add dataset_sink_mode for network testcase
Merge pull request !48221 from anzhengqi/modify-testcase-resnet50
This commit is contained in:
commit
050a5e7635
|
@ -173,7 +173,7 @@ def test_resnet_imagenet_8p_mpi():
|
|||
acc = 0.0
|
||||
time_cost = 0.0
|
||||
for _ in range(0, int(epoch_size / eval_interval)):
|
||||
model.train(1, dataset, callbacks=loss_cb)
|
||||
model.train(1, dataset, callbacks=loss_cb, dataset_sink_mode=True)
|
||||
output = model.eval(eval_dataset)
|
||||
acc = float(output.get('acc', 0.0))
|
||||
time_cost = loss_cb.get_per_step_time()
|
||||
|
|
|
@ -66,7 +66,7 @@ def train_and_eval(device_id, epoch_size, model, dataset, loss_cb, eval_dataset)
|
|||
acc = 0.0
|
||||
time_cost = 0.0
|
||||
for epoch_idx in range(0, int(epoch_size / eval_interval)):
|
||||
model.train(1, dataset, callbacks=loss_cb)
|
||||
model.train(1, dataset, callbacks=loss_cb, dataset_sink_mode=True)
|
||||
eval_start = time.time()
|
||||
output = model.eval(eval_dataset)
|
||||
eval_cost = (time.time() - eval_start) * 1000
|
||||
|
|
Loading…
Reference in New Issue