!48221 add dataset_sink_mode for network testcase

Merge pull request !48221 from anzhengqi/modify-testcase-resnet50
This commit is contained in:
i-robot 2023-02-02 02:50:05 +00:00 committed by Gitee
commit 050a5e7635
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 2 additions and 2 deletions

View File

@ -173,7 +173,7 @@ def test_resnet_imagenet_8p_mpi():
acc = 0.0 acc = 0.0
time_cost = 0.0 time_cost = 0.0
for _ in range(0, int(epoch_size / eval_interval)): for _ in range(0, int(epoch_size / eval_interval)):
model.train(1, dataset, callbacks=loss_cb) model.train(1, dataset, callbacks=loss_cb, dataset_sink_mode=True)
output = model.eval(eval_dataset) output = model.eval(eval_dataset)
acc = float(output.get('acc', 0.0)) acc = float(output.get('acc', 0.0))
time_cost = loss_cb.get_per_step_time() time_cost = loss_cb.get_per_step_time()

View File

@ -66,7 +66,7 @@ def train_and_eval(device_id, epoch_size, model, dataset, loss_cb, eval_dataset)
acc = 0.0 acc = 0.0
time_cost = 0.0 time_cost = 0.0
for epoch_idx in range(0, int(epoch_size / eval_interval)): for epoch_idx in range(0, int(epoch_size / eval_interval)):
model.train(1, dataset, callbacks=loss_cb) model.train(1, dataset, callbacks=loss_cb, dataset_sink_mode=True)
eval_start = time.time() eval_start = time.time()
output = model.eval(eval_dataset) output = model.eval(eval_dataset)
eval_cost = (time.time() - eval_start) * 1000 eval_cost = (time.time() - eval_start) * 1000