forked from mindspore-Ecosystem/mindspore
Correct shuffle UT buffer_size > #dataset-row as valid
This commit is contained in:
parent
7f8c9ebf10
commit
59a714c654
Binary file not shown.
|
@ -98,6 +98,25 @@ def test_shuffle_04():
|
||||||
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
||||||
|
|
||||||
|
|
||||||
|
def test_shuffle_05():
|
||||||
|
"""
|
||||||
|
Test shuffle: buffer_size > number-of-rows-in-dataset
|
||||||
|
"""
|
||||||
|
logger.info("test_shuffle_05")
|
||||||
|
# define parameters
|
||||||
|
buffer_size = 13
|
||||||
|
seed = 1
|
||||||
|
parameters = {"params": {'buffer_size': buffer_size, "seed": seed}}
|
||||||
|
|
||||||
|
# apply dataset operations
|
||||||
|
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
|
||||||
|
ds.config.set_seed(seed)
|
||||||
|
data1 = data1.shuffle(buffer_size=buffer_size)
|
||||||
|
|
||||||
|
filename = "shuffle_05_result.npz"
|
||||||
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
||||||
|
|
||||||
|
|
||||||
def test_shuffle_exception_01():
|
def test_shuffle_exception_01():
|
||||||
"""
|
"""
|
||||||
Test shuffle exception: buffer_size<0
|
Test shuffle exception: buffer_size<0
|
||||||
|
@ -152,24 +171,6 @@ def test_shuffle_exception_03():
|
||||||
assert "buffer_size" in str(e)
|
assert "buffer_size" in str(e)
|
||||||
|
|
||||||
|
|
||||||
def test_shuffle_exception_04():
|
|
||||||
"""
|
|
||||||
Test shuffle exception: buffer_size > number-of-rows-in-dataset
|
|
||||||
"""
|
|
||||||
logger.info("test_shuffle_exception_04")
|
|
||||||
|
|
||||||
# apply dataset operations
|
|
||||||
data1 = ds.TFRecordDataset(DATA_DIR)
|
|
||||||
ds.config.set_seed(1)
|
|
||||||
try:
|
|
||||||
data1 = data1.shuffle(buffer_size=13)
|
|
||||||
sum([1 for _ in data1])
|
|
||||||
|
|
||||||
except BaseException as e:
|
|
||||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
|
||||||
assert "buffer_size" in str(e)
|
|
||||||
|
|
||||||
|
|
||||||
def test_shuffle_exception_05():
|
def test_shuffle_exception_05():
|
||||||
"""
|
"""
|
||||||
Test shuffle exception: Missing mandatory buffer_size input parameter
|
Test shuffle exception: Missing mandatory buffer_size input parameter
|
||||||
|
@ -229,10 +230,10 @@ if __name__ == '__main__':
|
||||||
test_shuffle_02()
|
test_shuffle_02()
|
||||||
test_shuffle_03()
|
test_shuffle_03()
|
||||||
test_shuffle_04()
|
test_shuffle_04()
|
||||||
|
test_shuffle_05()
|
||||||
test_shuffle_exception_01()
|
test_shuffle_exception_01()
|
||||||
test_shuffle_exception_02()
|
test_shuffle_exception_02()
|
||||||
test_shuffle_exception_03()
|
test_shuffle_exception_03()
|
||||||
test_shuffle_exception_04()
|
|
||||||
test_shuffle_exception_05()
|
test_shuffle_exception_05()
|
||||||
test_shuffle_exception_06()
|
test_shuffle_exception_06()
|
||||||
test_shuffle_exception_07()
|
test_shuffle_exception_07()
|
||||||
|
|
Loading…
Reference in New Issue