diff --git a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py index 1111386f516..92c17c06940 100644 --- a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py +++ b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py @@ -173,7 +173,7 @@ def test_resnet_imagenet_8p_mpi(): acc = 0.0 time_cost = 0.0 for _ in range(0, int(epoch_size / eval_interval)): - model.train(1, dataset, callbacks=loss_cb) + model.train(1, dataset, callbacks=loss_cb, dataset_sink_mode=True) output = model.eval(eval_dataset) acc = float(output.get('acc', 0.0)) time_cost = loss_cb.get_per_step_time() diff --git a/tests/st/networks/models/resnet50/train_resnet50.py b/tests/st/networks/models/resnet50/train_resnet50.py index 27c0614f9c4..7d22dd39358 100644 --- a/tests/st/networks/models/resnet50/train_resnet50.py +++ b/tests/st/networks/models/resnet50/train_resnet50.py @@ -66,7 +66,7 @@ def train_and_eval(device_id, epoch_size, model, dataset, loss_cb, eval_dataset) acc = 0.0 time_cost = 0.0 for epoch_idx in range(0, int(epoch_size / eval_interval)): - model.train(1, dataset, callbacks=loss_cb) + model.train(1, dataset, callbacks=loss_cb, dataset_sink_mode=True) eval_start = time.time() output = model.eval(eval_dataset) eval_cost = (time.time() - eval_start) * 1000