Correct shuffle UT buffer_size > #dataset-row as valid

This commit is contained in:
Cathy Wong 2020-04-09 15:22:33 -04:00
parent 7f8c9ebf10
commit 59a714c654
2 changed files with 20 additions and 19 deletions

Binary file not shown.

View File

@ -98,6 +98,25 @@ def test_shuffle_04():
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_shuffle_05():
"""
Test shuffle: buffer_size > number-of-rows-in-dataset
"""
logger.info("test_shuffle_05")
# define parameters
buffer_size = 13
seed = 1
parameters = {"params": {'buffer_size': buffer_size, "seed": seed}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
ds.config.set_seed(seed)
data1 = data1.shuffle(buffer_size=buffer_size)
filename = "shuffle_05_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_shuffle_exception_01(): def test_shuffle_exception_01():
""" """
Test shuffle exception: buffer_size<0 Test shuffle exception: buffer_size<0
@ -152,24 +171,6 @@ def test_shuffle_exception_03():
assert "buffer_size" in str(e) assert "buffer_size" in str(e)
def test_shuffle_exception_04():
"""
Test shuffle exception: buffer_size > number-of-rows-in-dataset
"""
logger.info("test_shuffle_exception_04")
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR)
ds.config.set_seed(1)
try:
data1 = data1.shuffle(buffer_size=13)
sum([1 for _ in data1])
except BaseException as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "buffer_size" in str(e)
def test_shuffle_exception_05(): def test_shuffle_exception_05():
""" """
Test shuffle exception: Missing mandatory buffer_size input parameter Test shuffle exception: Missing mandatory buffer_size input parameter
@ -229,10 +230,10 @@ if __name__ == '__main__':
test_shuffle_02() test_shuffle_02()
test_shuffle_03() test_shuffle_03()
test_shuffle_04() test_shuffle_04()
test_shuffle_05()
test_shuffle_exception_01() test_shuffle_exception_01()
test_shuffle_exception_02() test_shuffle_exception_02()
test_shuffle_exception_03() test_shuffle_exception_03()
test_shuffle_exception_04()
test_shuffle_exception_05() test_shuffle_exception_05()
test_shuffle_exception_06() test_shuffle_exception_06()
test_shuffle_exception_07() test_shuffle_exception_07()