diff --git a/tests/ut/data/dataset/golden/shuffle_05_result.npz b/tests/ut/data/dataset/golden/shuffle_05_result.npz new file mode 100644 index 00000000000..27eb0a470d3 Binary files /dev/null and b/tests/ut/data/dataset/golden/shuffle_05_result.npz differ diff --git a/tests/ut/python/dataset/test_shuffle.py b/tests/ut/python/dataset/test_shuffle.py index 2b7a251d2cb..4a823c5fb7f 100644 --- a/tests/ut/python/dataset/test_shuffle.py +++ b/tests/ut/python/dataset/test_shuffle.py @@ -98,6 +98,25 @@ def test_shuffle_04(): save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) +def test_shuffle_05(): + """ + Test shuffle: buffer_size > number-of-rows-in-dataset + """ + logger.info("test_shuffle_05") + # define parameters + buffer_size = 13 + seed = 1 + parameters = {"params": {'buffer_size': buffer_size, "seed": seed}} + + # apply dataset operations + data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) + ds.config.set_seed(seed) + data1 = data1.shuffle(buffer_size=buffer_size) + + filename = "shuffle_05_result.npz" + save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + + def test_shuffle_exception_01(): """ Test shuffle exception: buffer_size<0 @@ -152,24 +171,6 @@ def test_shuffle_exception_03(): assert "buffer_size" in str(e) -def test_shuffle_exception_04(): - """ - Test shuffle exception: buffer_size > number-of-rows-in-dataset - """ - logger.info("test_shuffle_exception_04") - - # apply dataset operations - data1 = ds.TFRecordDataset(DATA_DIR) - ds.config.set_seed(1) - try: - data1 = data1.shuffle(buffer_size=13) - sum([1 for _ in data1]) - - except BaseException as e: - logger.info("Got an exception in DE: {}".format(str(e))) - assert "buffer_size" in str(e) - - def test_shuffle_exception_05(): """ Test shuffle exception: Missing mandatory buffer_size input parameter @@ -229,10 +230,10 @@ if __name__ == '__main__': test_shuffle_02() test_shuffle_03() test_shuffle_04() + test_shuffle_05() test_shuffle_exception_01() test_shuffle_exception_02() test_shuffle_exception_03() - test_shuffle_exception_04() test_shuffle_exception_05() test_shuffle_exception_06() test_shuffle_exception_07()