From e3e78204136c1895b34d7deae9df97a37a2d89d1 Mon Sep 17 00:00:00 2001 From: xiefangqi Date: Fri, 21 Aug 2020 16:55:58 +0800 Subject: [PATCH] fix cifar stuck problem --- .../dataset/engine/datasetops/source/cifar_op.cc | 3 +++ tests/ut/python/dataset/test_datasets_cifarop.py | 13 +++++++++++++ 2 files changed, 16 insertions(+) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc index fceec890b2f..edace86ee8e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc @@ -336,6 +336,9 @@ Status CifarOp::GetCifarFiles() { std::string err_msg = "Unable to open directory " + dataset_directory.toString(); RETURN_STATUS_UNEXPECTED(err_msg); } + if (cifar_files_.size() == 0) { + RETURN_STATUS_UNEXPECTED("No .bin files found under " + folder_path_); + } std::sort(cifar_files_.begin(), cifar_files_.end()); return Status::OK(); } diff --git a/tests/ut/python/dataset/test_datasets_cifarop.py b/tests/ut/python/dataset/test_datasets_cifarop.py index 2b66f326657..9b485cac8c2 100644 --- a/tests/ut/python/dataset/test_datasets_cifarop.py +++ b/tests/ut/python/dataset/test_datasets_cifarop.py @@ -24,6 +24,7 @@ from mindspore import log as logger DATA_DIR_10 = "../data/dataset/testCifar10Data" DATA_DIR_100 = "../data/dataset/testCifar100Data" +NO_BIN_DIR = "../data/dataset/testMnistData" def load_cifar(path, kind="cifar10"): @@ -208,6 +209,12 @@ def test_cifar10_exception(): with pytest.raises(ValueError, match=error_msg_6): ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_parallel_workers=88) + error_msg_7 = "No .bin files found" + with pytest.raises(RuntimeError, match=error_msg_7): + ds1 = ds.Cifar10Dataset(NO_BIN_DIR) + for _ in ds1.__iter__(): + pass + def test_cifar10_visualize(plot=False): """ @@ -352,6 +359,12 @@ def test_cifar100_exception(): with pytest.raises(ValueError, match=error_msg_6): ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, num_parallel_workers=88) + error_msg_7 = "No .bin files found" + with pytest.raises(RuntimeError, match=error_msg_7): + ds1 = ds.Cifar100Dataset(NO_BIN_DIR) + for _ in ds1.__iter__(): + pass + def test_cifar100_visualize(plot=False): """