forked from mindspore-Ecosystem/mindspore
!48221 add dataset_sink_mode for network testcase
Merge pull request !48221 from anzhengqi/modify-testcase-resnet50
This commit is contained in:
commit
050a5e7635
|
@ -173,7 +173,7 @@ def test_resnet_imagenet_8p_mpi():
|
||||||
acc = 0.0
|
acc = 0.0
|
||||||
time_cost = 0.0
|
time_cost = 0.0
|
||||||
for _ in range(0, int(epoch_size / eval_interval)):
|
for _ in range(0, int(epoch_size / eval_interval)):
|
||||||
model.train(1, dataset, callbacks=loss_cb)
|
model.train(1, dataset, callbacks=loss_cb, dataset_sink_mode=True)
|
||||||
output = model.eval(eval_dataset)
|
output = model.eval(eval_dataset)
|
||||||
acc = float(output.get('acc', 0.0))
|
acc = float(output.get('acc', 0.0))
|
||||||
time_cost = loss_cb.get_per_step_time()
|
time_cost = loss_cb.get_per_step_time()
|
||||||
|
|
|
@ -66,7 +66,7 @@ def train_and_eval(device_id, epoch_size, model, dataset, loss_cb, eval_dataset)
|
||||||
acc = 0.0
|
acc = 0.0
|
||||||
time_cost = 0.0
|
time_cost = 0.0
|
||||||
for epoch_idx in range(0, int(epoch_size / eval_interval)):
|
for epoch_idx in range(0, int(epoch_size / eval_interval)):
|
||||||
model.train(1, dataset, callbacks=loss_cb)
|
model.train(1, dataset, callbacks=loss_cb, dataset_sink_mode=True)
|
||||||
eval_start = time.time()
|
eval_start = time.time()
|
||||||
output = model.eval(eval_dataset)
|
output = model.eval(eval_dataset)
|
||||||
eval_cost = (time.time() - eval_start) * 1000
|
eval_cost = (time.time() - eval_start) * 1000
|
||||||
|
|
Loading…
Reference in New Issue