forked from mindspore-Ecosystem/mindspore
!23827 MD Python UT: Update num_epochs to appropriate value for iterators instead of using default value
Merge pull request !23827 from hetshah/hs_python_ut
This commit is contained in:
commit
44b1dbf77f
|
@ -56,7 +56,7 @@ def test_func_allpass_biquad_pipeline():
|
|||
# Filtered waveform by allpassbiquad
|
||||
dataset = dataset.map(input_columns=["channel"], operations=allpass_biquad_op, num_parallel_workers=8)
|
||||
i = 0
|
||||
for item in dataset.create_dict_iterator(output_numpy=True):
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count_unequal_element(expect_waveform[i, :], item['channel'], 0.0001, 0.0001)
|
||||
i += 1
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ def test_func_angle_002():
|
|||
dataset = ds.NumpySlicesDataset(data, column_names=["col1", "col2"], shuffle=False)
|
||||
angle_op = a_c_trans.Angle()
|
||||
dataset = dataset.map(operations=angle_op, input_columns=["col1"])
|
||||
for item1, item2 in zip(dataset.create_dict_iterator(output_numpy=True), expected):
|
||||
for item1, item2 in zip(dataset.create_dict_iterator(num_epochs=1, output_numpy=True), expected):
|
||||
count_unequal_element(item2, item1['col1'], 0.0001, 0.0001)
|
||||
|
||||
|
||||
|
@ -74,7 +74,7 @@ def test_func_angle_003():
|
|||
dataset = dataset.map(operations=angle_op, input_columns=["col1"])
|
||||
num_itr = 0
|
||||
with pytest.raises(RuntimeError, match="input tensor type should be int, float or double"):
|
||||
for _ in dataset.create_dict_iterator(output_numpy=True):
|
||||
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_itr += 1
|
||||
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ def test_func_band_biquad_pipeline():
|
|||
dataset = dataset.map(
|
||||
input_columns=["channel"], operations=band_biquad_op, num_parallel_workers=8)
|
||||
i = 0
|
||||
for item in dataset.create_dict_iterator(output_numpy=True):
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count_unequal_element(expect_waveform[i, :],
|
||||
item['channel'], 0.0001, 0.0001)
|
||||
i += 1
|
||||
|
|
|
@ -58,7 +58,7 @@ def test_func_bandpass_biquad_pipeline():
|
|||
# Filtered waveform by bandpassbiquad
|
||||
dataset = dataset.map(input_columns=["channel"], operations=bandpass_biquad_op, num_parallel_workers=8)
|
||||
i = 0
|
||||
for item in dataset.create_dict_iterator(output_numpy=True):
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count_unequal_element(expect_waveform[i, :], item['channel'], 0.0001, 0.0001)
|
||||
i += 1
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ def test_func_bandreject_biquad_pipeline():
|
|||
dataset = dataset.map(
|
||||
input_columns=["channel"], operations=bandreject_biquad_op, num_parallel_workers=8)
|
||||
i = 0
|
||||
for item in dataset.create_dict_iterator(output_numpy=True):
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count_unequal_element(expect_waveform[i, :],
|
||||
item['channel'], 0.0001, 0.0001)
|
||||
i += 1
|
||||
|
|
|
@ -59,7 +59,7 @@ def test_func_bass_biquad_pipeline():
|
|||
dataset = dataset.map(
|
||||
input_columns=["channel"], operations=bass_biquad_op, num_parallel_workers=8)
|
||||
i = 0
|
||||
for item in dataset.create_dict_iterator(output_numpy=True):
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count_unequal_element(expect_waveform[i, :],
|
||||
item['channel'], 0.0001, 0.0001)
|
||||
i += 1
|
||||
|
|
|
@ -56,7 +56,7 @@ def test_func_biquad_pipeline():
|
|||
# Filtered waveform by biquad
|
||||
dataset = dataset.map(input_columns=["audio"], operations=biquad_op, num_parallel_workers=8)
|
||||
i = 0
|
||||
for item in dataset.create_dict_iterator(output_numpy=True):
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count_unequal_element(expect_waveform[i, :],
|
||||
item['audio'], 0.0001, 0.0001)
|
||||
i += 1
|
||||
|
|
|
@ -168,7 +168,7 @@ def test_cache_map_basic4():
|
|||
data = data.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -283,7 +283,7 @@ def test_cache_map_failure2():
|
|||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in dsz.create_dict_iterator():
|
||||
for _ in dsz.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert "ZipNode is not supported as a descendant operator under a cache" in str(e.value)
|
||||
|
||||
|
@ -322,7 +322,7 @@ def test_cache_map_failure3():
|
|||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert "BatchNode is not supported as a descendant operator under a cache" in str(e.value)
|
||||
|
||||
|
@ -364,7 +364,7 @@ def test_cache_map_failure4():
|
|||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert "FilterNode is not supported as a descendant operator under a cache" in str(e.value)
|
||||
|
||||
|
@ -405,7 +405,7 @@ def test_cache_map_failure5():
|
|||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
|
||||
|
||||
|
@ -446,7 +446,7 @@ def test_cache_map_failure7():
|
|||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert "There is currently no support for GeneratorOp under cache" in str(e.value)
|
||||
|
||||
|
@ -524,7 +524,7 @@ def test_cache_map_failure9():
|
|||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value)
|
||||
|
||||
|
@ -566,7 +566,7 @@ def test_cache_map_failure10():
|
|||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert "SkipNode is not supported as a descendant operator under a cache" in str(e.value)
|
||||
|
||||
|
@ -597,7 +597,7 @@ def test_cache_map_failure11():
|
|||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert "Unexpected error. Server is not set up with spill support" in str(e.value)
|
||||
|
||||
|
@ -646,13 +646,13 @@ def test_cache_map_split1():
|
|||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in ds2.create_dict_iterator():
|
||||
for _ in ds2.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value)
|
||||
logger.info('test_cache_split1 Ended.\n')
|
||||
|
@ -694,12 +694,12 @@ def test_cache_map_split2():
|
|||
ds2 = ds2.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert num_iter == 12
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds2.create_dict_iterator():
|
||||
for _ in ds2.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert num_iter == 24
|
||||
logger.info('test_cache_split2 Ended.\n')
|
||||
|
@ -805,13 +805,13 @@ def test_cache_map_running_twice1():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
assert num_iter == 8
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
assert num_iter == 8
|
||||
|
@ -848,7 +848,7 @@ def test_cache_map_running_twice2():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -885,7 +885,7 @@ def test_cache_map_extra_small_size1():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -922,7 +922,7 @@ def test_cache_map_extra_small_size2():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -960,7 +960,7 @@ def test_cache_map_no_image():
|
|||
|
||||
with pytest.raises(RuntimeError):
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
assert num_iter == 0
|
||||
|
@ -996,7 +996,7 @@ def test_cache_map_parallel_pipeline1(shard):
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -1033,7 +1033,7 @@ def test_cache_map_parallel_pipeline2(shard):
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -1070,7 +1070,7 @@ def test_cache_map_parallel_workers():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -1107,7 +1107,7 @@ def test_cache_map_server_workers_1():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -1144,7 +1144,7 @@ def test_cache_map_server_workers_100():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -1181,7 +1181,7 @@ def test_cache_map_num_connections_1():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -1218,7 +1218,7 @@ def test_cache_map_num_connections_100():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -1255,7 +1255,7 @@ def test_cache_map_prefetch_size_1():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -1292,7 +1292,7 @@ def test_cache_map_prefetch_size_100():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -1404,7 +1404,7 @@ def test_cache_map_epoch_ctrl2():
|
|||
|
||||
num_epoch = 5
|
||||
# iter1 will always assume there is a next epoch and never shutdown
|
||||
iter1 = ds1.create_dict_iterator()
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=-1)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
|
@ -1451,7 +1451,7 @@ def test_cache_map_epoch_ctrl3():
|
|||
|
||||
num_epoch = 5
|
||||
# iter1 will always assume there is a next epoch and never shutdown
|
||||
iter1 = ds1.create_dict_iterator()
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
|
@ -2097,7 +2097,7 @@ def test_cache_map_python_sampler1():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
assert num_iter == 8
|
||||
|
@ -2133,7 +2133,7 @@ def test_cache_map_python_sampler2():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
assert num_iter == 8
|
||||
|
@ -2200,7 +2200,7 @@ def test_cache_map_interrupt_and_rerun():
|
|||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, cache=some_cache)
|
||||
iter1 = ds1.create_dict_iterator()
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=-1)
|
||||
|
||||
num_iter = 0
|
||||
with pytest.raises(AttributeError) as e:
|
||||
|
@ -2252,7 +2252,7 @@ def test_cache_map_dataset_size1():
|
|||
assert dataset_size == 2
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -2289,7 +2289,7 @@ def test_cache_map_dataset_size2():
|
|||
assert dataset_size == 2
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
|
|
@ -629,13 +629,13 @@ def test_cache_nomap_running_twice1():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
assert num_iter == 12
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
assert num_iter == 12
|
||||
|
@ -672,7 +672,7 @@ def test_cache_nomap_running_twice2():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -708,7 +708,7 @@ def test_cache_nomap_extra_small_size1():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -886,7 +886,7 @@ def test_cache_nomap_server_workers_1():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -923,7 +923,7 @@ def test_cache_nomap_server_workers_100():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -960,7 +960,7 @@ def test_cache_nomap_num_connections_1():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -997,7 +997,7 @@ def test_cache_nomap_num_connections_100():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -1034,7 +1034,7 @@ def test_cache_nomap_prefetch_size_1():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -1071,7 +1071,7 @@ def test_cache_nomap_prefetch_size_100():
|
|||
ds1 = ds1.repeat(4)
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -1147,7 +1147,7 @@ def test_cache_nomap_session_destroy():
|
|||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert "Unexpected error" in str(e.value)
|
||||
|
||||
|
@ -1185,7 +1185,7 @@ def test_cache_nomap_server_stop():
|
|||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert "Network error. Cache server with port 50052 is unreachable. Make sure the server is running." in \
|
||||
str(e.value)
|
||||
|
@ -1218,7 +1218,7 @@ def test_cache_nomap_interrupt_and_rerun():
|
|||
|
||||
# User-created sampler here
|
||||
ds1 = ds.RandomDataset(schema=schema, total_rows=10000, num_parallel_workers=4, cache=some_cache)
|
||||
iter1 = ds1.create_dict_iterator()
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=-1)
|
||||
|
||||
num_iter = 0
|
||||
with pytest.raises(AttributeError) as e:
|
||||
|
@ -1312,7 +1312,7 @@ def test_cache_nomap_epoch_ctrl2():
|
|||
|
||||
num_epoch = 5
|
||||
# iter1 will always assume there is a next epoch and never shutdown
|
||||
iter1 = ds1.create_dict_iterator()
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=-1)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
|
@ -1359,7 +1359,7 @@ def test_cache_nomap_epoch_ctrl3():
|
|||
|
||||
num_epoch = 5
|
||||
# iter1 will always assume there is a next epoch and never shutdown
|
||||
iter1 = ds1.create_dict_iterator()
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
|
@ -2052,7 +2052,7 @@ def test_cache_nomap_failure2():
|
|||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in dsz.create_dict_iterator():
|
||||
for _ in dsz.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert "ZipNode is not supported as a descendant operator under a cache" in str(e.value)
|
||||
|
||||
|
@ -2091,7 +2091,7 @@ def test_cache_nomap_failure3():
|
|||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert "BatchNode is not supported as a descendant operator under a cache" in str(e.value)
|
||||
|
||||
|
@ -2133,7 +2133,7 @@ def test_cache_nomap_failure4():
|
|||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert "FilterNode is not supported as a descendant operator under a cache" in str(e.value)
|
||||
|
||||
|
@ -2173,7 +2173,7 @@ def test_cache_nomap_failure5():
|
|||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
|
||||
|
||||
|
@ -2217,7 +2217,7 @@ def test_cache_nomap_pyfunc_lambda():
|
|||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in ds2.create_dict_iterator():
|
||||
for _ in ds2.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
|
||||
logger.info("test_cache_nomap_pyfunc_lambda Ended.\n")
|
||||
|
@ -2259,7 +2259,7 @@ def test_cache_nomap_pyfunc_builtin():
|
|||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in ds2.create_dict_iterator():
|
||||
for _ in ds2.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
|
||||
logger.info("test_cache_nomap_pyfunc_builtin Ended.\n")
|
||||
|
@ -2308,7 +2308,7 @@ def test_cache_nomap_pyfunc_function():
|
|||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in ds2.create_dict_iterator():
|
||||
for _ in ds2.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
|
||||
logger.info("test_cache_nomap_pyfunc_function Ended.\n")
|
||||
|
@ -2341,7 +2341,7 @@ def test_cache_nomap_all_rows_cached():
|
|||
num_total_rows = 271
|
||||
# User-created sampler here
|
||||
ds1 = ds.RandomDataset(schema=schema, total_rows=num_total_rows, num_parallel_workers=4, cache=some_cache)
|
||||
iter1 = ds1.create_dict_iterator()
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=1)
|
||||
|
||||
num_iter = 0
|
||||
for _ in iter1:
|
||||
|
@ -2380,7 +2380,7 @@ def test_cache_nomap_dataset_size1():
|
|||
assert dataset_size == 2
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
@ -2417,7 +2417,7 @@ def test_cache_nomap_dataset_size2():
|
|||
assert dataset_size == 2
|
||||
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
|
|
|
@ -39,7 +39,7 @@ def test_compose():
|
|||
data = ds.NumpySlicesDataset(arr, column_names="col", shuffle=False)
|
||||
data = data.map(input_columns=["col"], operations=op_list)
|
||||
res = []
|
||||
for i in data.create_dict_iterator(output_numpy=True):
|
||||
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
res.append(i["col"].tolist())
|
||||
return res
|
||||
except (TypeError, ValueError) as e:
|
||||
|
@ -114,7 +114,7 @@ def test_lambdas():
|
|||
data = data.map(operations=op_list, input_columns=input_columns, output_columns=output_cols,
|
||||
column_order=output_cols)
|
||||
res = []
|
||||
for i in data.create_dict_iterator(output_numpy=True):
|
||||
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
for col_name in output_cols:
|
||||
res.append(i[col_name].tolist())
|
||||
return res
|
||||
|
@ -141,7 +141,7 @@ def test_c_py_compose_transforms_module():
|
|||
data = data.map(operations=op_list, input_columns=input_columns, output_columns=output_cols,
|
||||
column_order=output_cols)
|
||||
res = []
|
||||
for i in data.create_dict_iterator(output_numpy=True):
|
||||
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
for col_name in output_cols:
|
||||
res.append(i[col_name].tolist())
|
||||
return res
|
||||
|
@ -229,7 +229,7 @@ def test_py_transforms_with_c_vision():
|
|||
data = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
|
||||
data = data.map(operations=op_list)
|
||||
res = []
|
||||
for i in data.create_dict_iterator(output_numpy=True):
|
||||
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
for col_name in output_cols:
|
||||
res.append(i[col_name].tolist())
|
||||
return res
|
||||
|
@ -323,7 +323,7 @@ def test_compose_with_custom_function():
|
|||
#
|
||||
|
||||
res = []
|
||||
for i in data.create_dict_iterator(output_numpy=True):
|
||||
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
res.append(i["col0"].tolist())
|
||||
assert res == [[[3, 6], [9, 36]]]
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ def test_concat_01():
|
|||
data3 = data1 + data2
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
|
||||
for i, d in enumerate(data3.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
t = d
|
||||
logger.info("data: %i", t[0][0])
|
||||
assert i == t[0][0]
|
||||
|
@ -76,7 +76,7 @@ def test_concat_02():
|
|||
data3 = data1.concat(data2)
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
|
||||
for i, d in enumerate(data3.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
t = d
|
||||
logger.info("data: %i", t[0][0])
|
||||
assert i == t[0][0]
|
||||
|
@ -154,7 +154,7 @@ def test_concat_06():
|
|||
dataset = data1 + data2 + data3
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(dataset.create_tuple_iterator(output_numpy=True)):
|
||||
for i, d in enumerate(dataset.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
t = d
|
||||
logger.info("data: %i", t[0][0])
|
||||
assert i == t[0][0]
|
||||
|
@ -175,7 +175,7 @@ def test_concat_07():
|
|||
data4 = data1 + dataset
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data4.create_tuple_iterator(output_numpy=True)):
|
||||
for i, d in enumerate(data4.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
t = d
|
||||
logger.info("data: %i", t[0][0])
|
||||
assert i == t[0][0]
|
||||
|
@ -195,7 +195,7 @@ def test_concat_08():
|
|||
data3 = data3.repeat(2)
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
|
||||
for i, d in enumerate(data3.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
t = d
|
||||
logger.info("data: %i", t[0][0])
|
||||
assert i % 10 == t[0][0]
|
||||
|
@ -217,7 +217,7 @@ def test_concat_09():
|
|||
|
||||
res = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 4, 5, 6, 7, 8, 9]
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
|
||||
for i, d in enumerate(data3.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
t = d
|
||||
logger.info("data: %i", t[0][0])
|
||||
assert res[i] == t[0][0]
|
||||
|
@ -238,7 +238,7 @@ def test_concat_10():
|
|||
|
||||
res = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
|
||||
for i, d in enumerate(data3.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
t = d
|
||||
logger.info("data: %i", t[0][0])
|
||||
assert res[i] == t[0][0]
|
||||
|
@ -261,7 +261,7 @@ def test_concat_11():
|
|||
res = [0, 10, 15, 20]
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
|
||||
for i, d in enumerate(data3.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
t = d
|
||||
logger.info("data: %i", t[0][0])
|
||||
assert res[i] == t[0][0]
|
||||
|
@ -285,7 +285,7 @@ def test_concat_12():
|
|||
data3 = data3.shuffle(buffer_size=10)
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
|
||||
for i, d in enumerate(data3.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
t = d
|
||||
logger.info("data: %i", t[0][0])
|
||||
assert res[i] == t[0][0]
|
||||
|
@ -313,7 +313,7 @@ def test_concat_13():
|
|||
data3 = data3.shuffle(buffer_size=int(data3.get_dataset_size()))
|
||||
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
|
||||
for i, d in enumerate(data3.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
t = d
|
||||
logger.info("data: %i", t[0][0])
|
||||
assert res[i] == t[0][0]
|
||||
|
@ -341,11 +341,11 @@ def test_concat_14():
|
|||
data3 = data1 + data2
|
||||
|
||||
expected, output = [], []
|
||||
for d in data1.create_tuple_iterator(output_numpy=True):
|
||||
for d in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
expected.append(d[0])
|
||||
for d in data2.create_tuple_iterator(output_numpy=True):
|
||||
for d in data2.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
expected.append(d[0])
|
||||
for d in data3.create_tuple_iterator(output_numpy=True):
|
||||
for d in data3.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
output.append(d[0])
|
||||
|
||||
assert len(expected) == len(output)
|
||||
|
|
|
@ -34,7 +34,7 @@ def test_concatenate_op_all():
|
|||
data = data.map(operations=concatenate_op, input_columns=["col"])
|
||||
expected = np.array([1.4, 2., 3., 4., 4.5, 5., 6., 7., 8., 9., 10.3,
|
||||
11., 12.])
|
||||
for data_row in data.create_tuple_iterator(output_numpy=True):
|
||||
for data_row in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_array_equal(data_row[0], expected)
|
||||
|
||||
|
||||
|
@ -46,7 +46,7 @@ def test_concatenate_op_none():
|
|||
concatenate_op = data_trans.Concatenate()
|
||||
|
||||
data = data.map(operations=concatenate_op, input_columns=["col"])
|
||||
for data_row in data.create_tuple_iterator(output_numpy=True):
|
||||
for data_row in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_array_equal(data_row[0], np.array([5., 6., 7., 8.], dtype=np.float))
|
||||
|
||||
|
||||
|
@ -61,7 +61,7 @@ def test_concatenate_op_string():
|
|||
|
||||
data = data.map(operations=concatenate_op, input_columns=["col"])
|
||||
expected = np.array(["dw", "df", "ss", "ad", "dwsdf", "df"], dtype='S')
|
||||
for data_row in data.create_tuple_iterator(output_numpy=True):
|
||||
for data_row in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_array_equal(data_row[0], expected)
|
||||
|
||||
|
||||
|
@ -77,7 +77,7 @@ def test_concatenate_op_multi_input_string():
|
|||
data = data.map(operations=concatenate_op, input_columns=["col1", "col2"], column_order=["out1"],
|
||||
output_columns=["out1"])
|
||||
expected = np.array(["dw", "df", "1", "2", "d", "3", "4", "e", "dwsdf", "df"], dtype='S')
|
||||
for data_row in data.create_tuple_iterator(output_numpy=True):
|
||||
for data_row in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_array_equal(data_row[0], expected)
|
||||
|
||||
|
||||
|
@ -92,7 +92,7 @@ def test_concatenate_op_multi_input_numeric():
|
|||
data = data.map(operations=concatenate_op, input_columns=["col1", "col2"], column_order=["out1"],
|
||||
output_columns=["out1"])
|
||||
expected = np.array([3, 5, 1, 2, 3, 4])
|
||||
for data_row in data.create_tuple_iterator(output_numpy=True):
|
||||
for data_row in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_array_equal(data_row[0], expected)
|
||||
|
||||
|
||||
|
@ -158,7 +158,7 @@ def test_concatenate_op_negative_axis():
|
|||
data = data.map(operations=concatenate_op, input_columns=["col"])
|
||||
expected = np.array([1.4, 2., 3., 4., 4.5, 5., 6., 7., 8., 9., 10.3,
|
||||
11., 12.])
|
||||
for data_row in data.create_tuple_iterator(output_numpy=True):
|
||||
for data_row in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_array_equal(data_row[0], expected)
|
||||
|
||||
|
||||
|
|
|
@ -60,7 +60,7 @@ def test_func_contrast_pipeline():
|
|||
# Filtered waveform by contrast
|
||||
dataset = dataset.map(input_columns=["audio"], operations=contrast_op, num_parallel_workers=8)
|
||||
i = 0
|
||||
for item in dataset.create_dict_iterator(output_numpy=True):
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count_unequal_element(expect_waveform[i, :], item['audio'], 0.0001, 0.0001)
|
||||
i += 1
|
||||
|
||||
|
|
|
@ -77,7 +77,7 @@ def test_numpy_slices_list_append():
|
|||
|
||||
ds = de.NumpySlicesDataset(res, column_names=["col1"], shuffle=False)
|
||||
|
||||
for i, data in enumerate(ds.create_tuple_iterator(output_numpy=True)):
|
||||
for i, data in enumerate(ds.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
assert np.equal(data, res[i]).all()
|
||||
|
||||
|
||||
|
@ -99,7 +99,7 @@ def test_numpy_slices_tuple_1():
|
|||
np_data = [([1, 2], [3, 4]), ([11, 12], [13, 14]), ([21, 22], [23, 24])]
|
||||
ds = de.NumpySlicesDataset(np_data, shuffle=False)
|
||||
|
||||
for i, data in enumerate(ds.create_tuple_iterator(output_numpy=True)):
|
||||
for i, data in enumerate(ds.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
assert np.equal(data, np_data[i]).all()
|
||||
|
||||
assert sum([1 for _ in ds]) == 3
|
||||
|
@ -112,7 +112,7 @@ def test_numpy_slices_tuple_2():
|
|||
expected = [[1, 3, 5], [2, 4, 6]]
|
||||
ds = de.NumpySlicesDataset(np_data, shuffle=False)
|
||||
|
||||
for i, data in enumerate(ds.create_tuple_iterator(output_numpy=True)):
|
||||
for i, data in enumerate(ds.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
assert np.equal(data, expected[i]).all()
|
||||
|
||||
assert sum([1 for _ in ds]) == 2
|
||||
|
@ -156,7 +156,7 @@ def test_numpy_slices_csv_dict():
|
|||
|
||||
ds = de.NumpySlicesDataset(dict(df), shuffle=False)
|
||||
|
||||
for i, data in enumerate(ds.create_tuple_iterator(output_numpy=True)):
|
||||
for i, data in enumerate(ds.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
assert np.equal(data, res[i]).all()
|
||||
|
||||
|
||||
|
|
|
@ -142,7 +142,7 @@ def test_celeba_dataset_exception_file_path():
|
|||
try:
|
||||
data = ds.CelebADataset(DATA_DIR, shuffle=False)
|
||||
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -152,7 +152,7 @@ def test_celeba_dataset_exception_file_path():
|
|||
data = ds.CelebADataset(DATA_DIR, shuffle=False)
|
||||
data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
|
||||
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -161,7 +161,7 @@ def test_celeba_dataset_exception_file_path():
|
|||
try:
|
||||
data = ds.CelebADataset(DATA_DIR, shuffle=False)
|
||||
data = data.map(operations=exception_func, input_columns=["attr"], num_parallel_workers=1)
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -175,7 +175,7 @@ def test_celeba_sampler_exception():
|
|||
logger.info("Test CelebA with bad sampler input")
|
||||
try:
|
||||
data = ds.CelebADataset(DATA_DIR, sampler="")
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
pass
|
||||
assert False
|
||||
except TypeError as e:
|
||||
|
|
|
@ -454,7 +454,7 @@ def test_cifar_exception_file_path():
|
|||
data = ds.Cifar10Dataset(DATA_DIR_10)
|
||||
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
num_rows += 1
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -464,7 +464,7 @@ def test_cifar_exception_file_path():
|
|||
data = ds.Cifar10Dataset(DATA_DIR_10)
|
||||
data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
num_rows += 1
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -474,7 +474,7 @@ def test_cifar_exception_file_path():
|
|||
data = ds.Cifar100Dataset(DATA_DIR_100)
|
||||
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
num_rows += 1
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -484,7 +484,7 @@ def test_cifar_exception_file_path():
|
|||
data = ds.Cifar100Dataset(DATA_DIR_100)
|
||||
data = data.map(operations=exception_func, input_columns=["coarse_label"], num_parallel_workers=1)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
num_rows += 1
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -494,7 +494,7 @@ def test_cifar_exception_file_path():
|
|||
data = ds.Cifar100Dataset(DATA_DIR_100)
|
||||
data = data.map(operations=exception_func, input_columns=["fine_label"], num_parallel_workers=1)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
num_rows += 1
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
|
|
@ -222,7 +222,7 @@ def test_cityscapes_exception():
|
|||
data = ds.CityscapesDataset(DATASET_DIR, usage=usage, quality_mode=quality_mode, task=task)
|
||||
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
num_rows += 1
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -232,7 +232,7 @@ def test_cityscapes_exception():
|
|||
data = ds.CityscapesDataset(DATASET_DIR, usage=usage, quality_mode=quality_mode, task=task)
|
||||
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
num_rows += 1
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
|
|
@ -378,7 +378,7 @@ def test_clue_exception_file_path():
|
|||
try:
|
||||
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
|
||||
data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1)
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -387,7 +387,7 @@ def test_clue_exception_file_path():
|
|||
try:
|
||||
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
|
||||
data = data.map(operations=exception_func, input_columns=["sentence1"], num_parallel_workers=1)
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -396,7 +396,7 @@ def test_clue_exception_file_path():
|
|||
try:
|
||||
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
|
||||
data = data.map(operations=exception_func, input_columns=["sentence2"], num_parallel_workers=1)
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
|
|
@ -179,22 +179,22 @@ def test_coco_panoptic():
|
|||
def test_coco_meta_column():
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection",
|
||||
decode=True, shuffle=False, extra_metadata=True)
|
||||
for item in data1.create_tuple_iterator():
|
||||
for item in data1.create_tuple_iterator(num_epochs=1):
|
||||
assert len(item) == 4
|
||||
|
||||
data2 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Stuff",
|
||||
decode=True, shuffle=False, extra_metadata=True)
|
||||
for item in data2.create_tuple_iterator():
|
||||
for item in data2.create_tuple_iterator(num_epochs=1):
|
||||
assert len(item) == 3
|
||||
|
||||
data3 = ds.CocoDataset(DATA_DIR, annotation_file=KEYPOINT_FILE, task="Keypoint",
|
||||
decode=True, shuffle=False, extra_metadata=True)
|
||||
for item in data3.create_tuple_iterator():
|
||||
for item in data3.create_tuple_iterator(num_epochs=1):
|
||||
assert len(item) == 3
|
||||
|
||||
data4 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic",
|
||||
decode=True, shuffle=False, extra_metadata=True)
|
||||
for item in data4.create_tuple_iterator():
|
||||
for item in data4.create_tuple_iterator(num_epochs=1):
|
||||
assert len(item) == 5
|
||||
|
||||
|
||||
|
@ -204,7 +204,7 @@ def test_coco_detection_classindex():
|
|||
assert class_index == {'person': [1], 'bicycle': [2], 'car': [3], 'cat': [4], 'dog': [5], 'monkey': [6],
|
||||
'bag': [7], 'orange': [8]}
|
||||
num_iter = 0
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
|
@ -214,7 +214,7 @@ def test_coco_panootic_classindex():
|
|||
class_index = data1.get_class_indexing()
|
||||
assert class_index == {'person': [1, 1], 'bicycle': [2, 1], 'car': [3, 1]}
|
||||
num_iter = 0
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
assert num_iter == 2
|
||||
|
||||
|
@ -252,7 +252,7 @@ def test_coco_case_2():
|
|||
data1 = data1.map(operations=resize_op, input_columns=["image"])
|
||||
data1 = data1.repeat(4)
|
||||
num_iter = 0
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
assert num_iter == 24
|
||||
|
||||
|
@ -264,7 +264,7 @@ def test_coco_case_3():
|
|||
data1 = data1.map(operations=resize_op, input_columns=["image"])
|
||||
data1 = data1.repeat(4)
|
||||
num_iter = 0
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
assert num_iter == 24
|
||||
|
||||
|
@ -272,7 +272,7 @@ def test_coco_case_3():
|
|||
def test_coco_case_exception():
|
||||
try:
|
||||
data1 = ds.CocoDataset("path_not_exist/", annotation_file=ANNOTATION_FILE, task="Detection")
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except ValueError as e:
|
||||
|
@ -280,7 +280,7 @@ def test_coco_case_exception():
|
|||
|
||||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file="./file_not_exist", task="Detection")
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except ValueError as e:
|
||||
|
@ -288,7 +288,7 @@ def test_coco_case_exception():
|
|||
|
||||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Invalid task")
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except ValueError as e:
|
||||
|
@ -296,7 +296,7 @@ def test_coco_case_exception():
|
|||
|
||||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=LACKOFIMAGE_FILE, task="Detection")
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -304,7 +304,7 @@ def test_coco_case_exception():
|
|||
|
||||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=INVALID_CATEGORY_ID_FILE, task="Detection")
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -312,7 +312,7 @@ def test_coco_case_exception():
|
|||
|
||||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=INVALID_FILE, task="Detection")
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -321,7 +321,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
sampler = ds.PKSampler(3)
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=INVALID_FILE, task="Detection", sampler=sampler)
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except ValueError as e:
|
||||
|
@ -333,7 +333,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -343,7 +343,7 @@ def test_coco_case_exception():
|
|||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection")
|
||||
data1 = data1.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
|
||||
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -352,7 +352,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["bbox"], num_parallel_workers=1)
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -361,7 +361,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["category_id"], num_parallel_workers=1)
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -370,7 +370,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Stuff")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -380,7 +380,7 @@ def test_coco_case_exception():
|
|||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Stuff")
|
||||
data1 = data1.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
|
||||
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -389,7 +389,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Stuff")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["segmentation"], num_parallel_workers=1)
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -398,7 +398,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Stuff")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["iscrowd"], num_parallel_workers=1)
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -407,7 +407,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=KEYPOINT_FILE, task="Keypoint")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -417,7 +417,7 @@ def test_coco_case_exception():
|
|||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=KEYPOINT_FILE, task="Keypoint")
|
||||
data1 = data1.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
|
||||
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -426,7 +426,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=KEYPOINT_FILE, task="Keypoint")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["keypoints"], num_parallel_workers=1)
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -435,7 +435,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=KEYPOINT_FILE, task="Keypoint")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["num_keypoints"], num_parallel_workers=1)
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -444,7 +444,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -454,7 +454,7 @@ def test_coco_case_exception():
|
|||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic")
|
||||
data1 = data1.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
|
||||
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -463,7 +463,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["bbox"], num_parallel_workers=1)
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -472,7 +472,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["category_id"], num_parallel_workers=1)
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -481,7 +481,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["area"], num_parallel_workers=1)
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
|
|
@ -211,7 +211,7 @@ def test_div2k_exception():
|
|||
data = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale)
|
||||
data = data.map(operations=exception_func, input_columns=["hr_image"], num_parallel_workers=1)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
num_rows += 1
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -221,7 +221,7 @@ def test_div2k_exception():
|
|||
data = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale)
|
||||
data = data.map(operations=exception_func, input_columns=["hr_image"], num_parallel_workers=1)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
num_rows += 1
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
|
|
@ -130,7 +130,7 @@ def test_flickr30k_dataset_exception():
|
|||
data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True)
|
||||
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
num_rows += 1
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -140,7 +140,7 @@ def test_flickr30k_dataset_exception():
|
|||
data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True)
|
||||
data = data.map(operations=exception_func, input_columns=["annotation"], num_parallel_workers=1)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
for _ in data.create_dict_iterator(num_epochs=1):
|
||||
num_rows += 1
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
|
|
@ -898,7 +898,7 @@ def test_func_generator_dataset_005():
|
|||
column_names = ["col1", "col2"]
|
||||
dataset = ds.GeneratorDataset(MyData(result), column_names)
|
||||
i = 0
|
||||
for data in dataset.create_dict_iterator(output_numpy=True):
|
||||
for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
assert "col1" in str(data.keys())
|
||||
assert (data["col1"] == result[0]).all()
|
||||
assert (data["col2"] == result[1]).all()
|
||||
|
|
|
@ -64,14 +64,14 @@ def test_imagenet_tf_file_dataset_size():
|
|||
assert ds_shard_3_0.get_dataset_size() == 4
|
||||
|
||||
count = 0
|
||||
for _ in ds_shard_3_0.create_dict_iterator():
|
||||
for _ in ds_shard_3_0.create_dict_iterator(num_epochs=1):
|
||||
count += 1
|
||||
assert ds_shard_3_0.get_dataset_size() == count
|
||||
|
||||
# shard_equal_rows is set to False therefore, get_dataset_size must return count
|
||||
ds_shard_4_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=4, shard_id=0)
|
||||
count = 0
|
||||
for _ in ds_shard_4_0.create_dict_iterator():
|
||||
for _ in ds_shard_4_0.create_dict_iterator(num_epochs=1):
|
||||
count += 1
|
||||
assert ds_shard_4_0.get_dataset_size() == count
|
||||
|
||||
|
@ -254,7 +254,7 @@ def test_distributed_get_dataset_size():
|
|||
assert dataset1.get_dataset_size() == 2000
|
||||
|
||||
count1 = 0
|
||||
for _ in dataset1.create_dict_iterator():
|
||||
for _ in dataset1.create_dict_iterator(num_epochs=1):
|
||||
count1 += 1
|
||||
assert count1 == 2000
|
||||
|
||||
|
@ -263,7 +263,7 @@ def test_distributed_get_dataset_size():
|
|||
assert dataset2.get_dataset_size() == 2500
|
||||
|
||||
count2 = 0
|
||||
for _ in dataset2.create_dict_iterator():
|
||||
for _ in dataset2.create_dict_iterator(num_epochs=1):
|
||||
count2 += 1
|
||||
assert count2 == 2500
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ def test_voc_meta_column():
|
|||
# scenario one: output 2 columns if without rename meta column
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", decode=True, shuffle=False, extra_metadata=True)
|
||||
num = 0
|
||||
for item in data1.create_tuple_iterator():
|
||||
for item in data1.create_tuple_iterator(num_epochs=1):
|
||||
assert len(item) == 2
|
||||
num += 1
|
||||
|
||||
|
@ -102,7 +102,7 @@ def test_voc_meta_column():
|
|||
data2 = data2.map(operations=pyfunc1, input_columns=["image", "target"])
|
||||
data2 = data2.rename("_meta-filename", "filename")
|
||||
num = 0
|
||||
for item in data2.create_tuple_iterator(output_numpy=True):
|
||||
for item in data2.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
assert text.to_str(item[2]) == IMAGE_ID[num]
|
||||
num += 1
|
||||
|
||||
|
@ -115,7 +115,7 @@ def test_voc_meta_column():
|
|||
column_order=["_meta-filename", "img1", "img2", "label"])
|
||||
data3 = data3.rename("_meta-filename", "filename")
|
||||
num = 0
|
||||
for item in data3.create_tuple_iterator(output_numpy=True):
|
||||
for item in data3.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
assert text.to_str(item[0]) == IMAGE_ID[num]
|
||||
num += 1
|
||||
|
||||
|
@ -128,7 +128,7 @@ def test_voc_meta_column():
|
|||
column_order=["_meta-filename", "img1"])
|
||||
data4 = data4.rename("_meta-filename", "filename")
|
||||
num = 0
|
||||
for item in data4.create_tuple_iterator(output_numpy=True):
|
||||
for item in data4.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
assert text.to_str(item[0]) == IMAGE_ID[num]
|
||||
num += 1
|
||||
|
||||
|
@ -248,7 +248,7 @@ def test_voc_exception():
|
|||
try:
|
||||
data = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False)
|
||||
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data.create_dict_iterator(output_numpy=True):
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -258,7 +258,7 @@ def test_voc_exception():
|
|||
data = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False)
|
||||
data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
|
||||
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data.create_dict_iterator(output_numpy=True):
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -267,7 +267,7 @@ def test_voc_exception():
|
|||
try:
|
||||
data = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False)
|
||||
data = data.map(operations=exception_func, input_columns=["bbox"], num_parallel_workers=1)
|
||||
for _ in data.create_dict_iterator(output_numpy=True):
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -276,7 +276,7 @@ def test_voc_exception():
|
|||
try:
|
||||
data = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False)
|
||||
data = data.map(operations=exception_func, input_columns=["difficult"], num_parallel_workers=1)
|
||||
for _ in data.create_dict_iterator(output_numpy=True):
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -285,7 +285,7 @@ def test_voc_exception():
|
|||
try:
|
||||
data = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False)
|
||||
data = data.map(operations=exception_func, input_columns=["truncate"], num_parallel_workers=1)
|
||||
for _ in data.create_dict_iterator(output_numpy=True):
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -294,7 +294,7 @@ def test_voc_exception():
|
|||
try:
|
||||
data = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", shuffle=False)
|
||||
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data.create_dict_iterator(output_numpy=True):
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -304,7 +304,7 @@ def test_voc_exception():
|
|||
data = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", shuffle=False)
|
||||
data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
|
||||
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data.create_dict_iterator(output_numpy=True):
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -313,7 +313,7 @@ def test_voc_exception():
|
|||
try:
|
||||
data = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", shuffle=False)
|
||||
data = data.map(operations=exception_func, input_columns=["target"], num_parallel_workers=1)
|
||||
for _ in data.create_dict_iterator(output_numpy=True):
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -323,7 +323,7 @@ def test_voc_exception():
|
|||
data = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", shuffle=False)
|
||||
data = data.map(operations=vision.Decode(), input_columns=["target"], num_parallel_workers=1)
|
||||
data = data.map(operations=exception_func, input_columns=["target"], num_parallel_workers=1)
|
||||
for _ in data.create_dict_iterator(output_numpy=True):
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
|
|
@ -50,7 +50,7 @@ def test_func_dc_shift_pipeline():
|
|||
dataset = ds.NumpySlicesDataset(arr, column_names=["col1"], shuffle=False)
|
||||
dcshift_op = a_c_trans.DCShift(0.8, 0.03)
|
||||
dataset = dataset.map(operations=dcshift_op, input_columns=["col1"])
|
||||
for item1, item2 in zip(dataset.create_dict_iterator(output_numpy=True), expected):
|
||||
for item1, item2 in zip(dataset.create_dict_iterator(num_epochs=1, output_numpy=True), expected):
|
||||
count_unequal_element(item2, item1['col1'], 0.0001, 0.0001)
|
||||
|
||||
|
||||
|
@ -66,7 +66,7 @@ def test_func_dc_shift_pipeline_error():
|
|||
with pytest.raises(ValueError, match=r"Input shift is not within the required interval of \[-2.0, 2.0\]."):
|
||||
dcshift_op = a_c_trans.DCShift(2.5, 0.03)
|
||||
dataset = dataset.map(operations=dcshift_op, input_columns=["col1"])
|
||||
for _ in dataset.create_dict_iterator(output_numpy=True):
|
||||
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_itr += 1
|
||||
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ def test_func_deemph_biquad_pipeline():
|
|||
# Filtered waveform by deemphbiquad
|
||||
dataset = dataset.map(input_columns=["audio"], operations=deemph_biquad_op, num_parallel_workers=8)
|
||||
i = 0
|
||||
for data in dataset.create_dict_iterator(output_numpy=True):
|
||||
for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count_unequal_element(expect_waveform[i, :], data['audio'], 0.0001, 0.0001)
|
||||
i += 1
|
||||
|
||||
|
|
|
@ -50,8 +50,7 @@ def test_cifar10():
|
|||
data1 = data1.repeat(num_repeat)
|
||||
data1 = data1.batch(batch_size, True)
|
||||
num_epoch = 5
|
||||
# iter1 will always assume there is a next epoch and never shutdown.
|
||||
iter1 = data1.create_tuple_iterator()
|
||||
iter1 = data1.create_tuple_iterator(num_epochs=num_epoch)
|
||||
epoch_count = 0
|
||||
sample_count = 0
|
||||
for _ in range(num_epoch):
|
||||
|
@ -86,7 +85,7 @@ def test_decode_op():
|
|||
|
||||
num_epoch = 5
|
||||
# iter1 will always assume there is a next epoch and never shutdown.
|
||||
iter1 = data1.create_dict_iterator(output_numpy=True)
|
||||
iter1 = data1.create_dict_iterator(num_epochs=-1, output_numpy=True)
|
||||
# iter 2 will stop and shutdown pipeline after num_epoch
|
||||
iter2 = data2.create_dict_iterator(num_epoch, output_numpy=True)
|
||||
for _ in range(num_epoch):
|
||||
|
@ -169,7 +168,9 @@ def test_generator_dict_2():
|
|||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
iter1 = data1.create_dict_iterator()
|
||||
|
||||
# iter1 will always assume there is a next epoch and never shutdown
|
||||
iter1 = data1.create_dict_iterator(num_epochs=-1)
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
|
@ -192,7 +193,9 @@ def test_generator_dict_3():
|
|||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
iter1 = data1.create_dict_iterator()
|
||||
|
||||
# iter1 will always assume there is a next epoch and never shutdown
|
||||
iter1 = data1.create_dict_iterator(num_epochs=-1)
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
|
@ -200,7 +203,7 @@ def test_generator_dict_3():
|
|||
np.testing.assert_array_equal(item["data"].asnumpy(), golden)
|
||||
i = i + 1
|
||||
assert i == 64
|
||||
# optional
|
||||
|
||||
iter1.stop()
|
||||
# Expect a AttributeError since iter1 has been stopped.
|
||||
with pytest.raises(AttributeError) as info:
|
||||
|
@ -360,7 +363,8 @@ def test_generator_tuple_2():
|
|||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
iter1 = data1.create_tuple_iterator(output_numpy=True)
|
||||
# iter1 will always assume there is a next epoch and never shutdown
|
||||
iter1 = data1.create_tuple_iterator(num_epochs=-1, output_numpy=True)
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
|
@ -383,7 +387,8 @@ def test_generator_tuple_3():
|
|||
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
iter1 = data1.create_tuple_iterator(output_numpy=True)
|
||||
# iter1 will always assume there is a next epoch and never shutdown
|
||||
iter1 = data1.create_tuple_iterator(num_epochs=-1, output_numpy=True)
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
|
@ -391,7 +396,7 @@ def test_generator_tuple_3():
|
|||
np.testing.assert_array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
assert i == 64
|
||||
# optional
|
||||
|
||||
iter1.stop()
|
||||
# Expect a AttributeError since iter1 has been stopped.
|
||||
with pytest.raises(AttributeError) as info:
|
||||
|
@ -533,7 +538,8 @@ def test_generator_tuple_repeat_repeat_2():
|
|||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
data1 = data1.repeat(2)
|
||||
data1 = data1.repeat(3)
|
||||
iter1 = data1.create_tuple_iterator(output_numpy=True)
|
||||
# iter1 will always assume there is a next epoch and never shutdown
|
||||
iter1 = data1.create_tuple_iterator(num_epochs=-1, output_numpy=True)
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
|
@ -541,7 +547,7 @@ def test_generator_tuple_repeat_repeat_2():
|
|||
np.testing.assert_array_equal(item[0], golden)
|
||||
i = i + 1
|
||||
assert i == 64 * 2 * 3
|
||||
# optional
|
||||
|
||||
iter1.stop()
|
||||
# Expect a AttributeError since iter1 has been stopped.
|
||||
with pytest.raises(AttributeError) as info:
|
||||
|
@ -559,7 +565,7 @@ def test_generator_tuple_repeat_repeat_3():
|
|||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
data1 = data1.repeat(2)
|
||||
data1 = data1.repeat(3)
|
||||
iter1 = data1.create_tuple_iterator(output_numpy=True)
|
||||
iter1 = data1.create_tuple_iterator(num_epochs=15, output_numpy=True)
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
|
@ -658,7 +664,7 @@ def test_generator_tuple_infinite_repeat_repeat_4():
|
|||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
data1 = data1.repeat()
|
||||
data1 = data1.repeat()
|
||||
iter1 = data1.create_tuple_iterator(output_numpy=True)
|
||||
iter1 = data1.create_tuple_iterator(num_epochs=1, output_numpy=True)
|
||||
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
|
@ -680,7 +686,7 @@ def test_generator_reusedataset():
|
|||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
data1 = data1.repeat(2)
|
||||
iter1 = data1.create_tuple_iterator(output_numpy=True)
|
||||
iter1 = data1.create_tuple_iterator(num_epochs=10, output_numpy=True)
|
||||
for _ in range(10):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
|
@ -690,7 +696,7 @@ def test_generator_reusedataset():
|
|||
assert i == 64 * 2
|
||||
|
||||
data1 = data1.repeat(3)
|
||||
iter1 = data1.create_tuple_iterator(output_numpy=True)
|
||||
iter1 = data1.create_tuple_iterator(num_epochs=5, output_numpy=True)
|
||||
for _ in range(5):
|
||||
i = 0
|
||||
for item in iter1: # each data is a dictionary
|
||||
|
@ -700,7 +706,7 @@ def test_generator_reusedataset():
|
|||
assert i == 64 * 2 * 3
|
||||
|
||||
data1 = data1.batch(2)
|
||||
iter1 = data1.create_dict_iterator(output_numpy=True)
|
||||
iter1 = data1.create_dict_iterator(num_epochs=5, output_numpy=True)
|
||||
for _ in range(5):
|
||||
i = 0
|
||||
sample = 0
|
||||
|
|
|
@ -60,7 +60,7 @@ def test_equalizer_biquad_pipeline():
|
|||
# Filtered waveform by equalizer_biquad
|
||||
dataset = dataset.map(input_columns=["col1"], operations=equalizer_biquad_op, num_parallel_workers=4)
|
||||
i = 0
|
||||
for item in dataset.create_dict_iterator(output_numpy=True):
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count_unequal_element(expect_waveform[i, :],
|
||||
item["col1"], 0.0001, 0.0001)
|
||||
i += 1
|
||||
|
|
|
@ -69,7 +69,7 @@ def test_fillop_string():
|
|||
|
||||
data = data.map(operations=fill_op, input_columns=["col"])
|
||||
expected = np.array(['error', 'error'], dtype='S')
|
||||
for data_row in data.create_tuple_iterator(output_numpy=True):
|
||||
for data_row in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_array_equal(data_row[0], expected)
|
||||
|
||||
|
||||
|
@ -82,7 +82,7 @@ def test_fillop_bytes():
|
|||
|
||||
data = data.map(operations=fill_op, input_columns=["col"])
|
||||
expected = np.array([b'abc', b'abc', b'abc'], dtype='S')
|
||||
for data_row in data.create_tuple_iterator(output_numpy=True):
|
||||
for data_row in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_array_equal(data_row[0], expected)
|
||||
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ def test_flat_map_1():
|
|||
data = data.flat_map(flat_map_func)
|
||||
|
||||
count = 0
|
||||
for d in data.create_tuple_iterator(output_numpy=True):
|
||||
for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
assert isinstance(d[0], np.ndarray)
|
||||
count += 1
|
||||
assert count == 52
|
||||
|
@ -60,7 +60,7 @@ def test_flat_map_2():
|
|||
data = data.flat_map(flat_map_func_2)
|
||||
|
||||
count = 0
|
||||
for d in data.create_tuple_iterator(output_numpy=True):
|
||||
for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
assert isinstance(d[0], np.ndarray)
|
||||
count += 1
|
||||
assert count == 104
|
||||
|
|
|
@ -121,9 +121,9 @@ def test_generator_reset_3():
|
|||
concat1 = branch2 + branch3
|
||||
concat2 = branch1 + concat1.repeat(3).skip(5).take(15)
|
||||
|
||||
itr = concat2.create_dict_iterator(output_numpy=True)
|
||||
|
||||
num_epochs = 5
|
||||
itr = concat2.create_dict_iterator(num_epochs=num_epochs, output_numpy=True)
|
||||
|
||||
output = np.array([0])
|
||||
golden = np.array([0])
|
||||
expected = np.array([2, 1, 2, 1, 12, 22, 23, 10, 11, 12, 10, 11, 12, 22, 23, 10, 11, 12, 10])
|
||||
|
@ -164,7 +164,7 @@ def test_generator_reset_5():
|
|||
|
||||
num_epochs = 2
|
||||
output = np.array([0])
|
||||
itr = branch1.create_dict_iterator(output_numpy=True)
|
||||
itr = branch1.create_dict_iterator(num_epochs=num_epochs, output_numpy=True)
|
||||
|
||||
for _ in range(num_epochs):
|
||||
for item in itr:
|
||||
|
|
|
@ -63,7 +63,7 @@ def test_highpass_biquad_pipeline():
|
|||
dataset = dataset.map(
|
||||
input_columns=["col1"], operations=highpass_biquad_op, num_parallel_workers=4)
|
||||
i = 0
|
||||
for item in dataset.create_dict_iterator(output_numpy=True):
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count_unequal_element(expect_waveform[i, :],
|
||||
item["col1"], 0.0001, 0.0001)
|
||||
i += 1
|
||||
|
|
|
@ -127,25 +127,25 @@ def test_iterator_weak_ref():
|
|||
def test_iterator_exception():
|
||||
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
|
||||
try:
|
||||
_ = data.create_dict_iterator(output_numpy="123")
|
||||
_ = data.create_dict_iterator(num_epochs=1, output_numpy="123")
|
||||
assert False
|
||||
except TypeError as e:
|
||||
assert "Argument output_numpy with value 123 is not of type" in str(e)
|
||||
|
||||
try:
|
||||
_ = data.create_dict_iterator(output_numpy=123)
|
||||
_ = data.create_dict_iterator(num_epochs=1, output_numpy=123)
|
||||
assert False
|
||||
except TypeError as e:
|
||||
assert "Argument output_numpy with value 123 is not of type" in str(e)
|
||||
|
||||
try:
|
||||
_ = data.create_tuple_iterator(output_numpy="123")
|
||||
_ = data.create_tuple_iterator(num_epochs=1, output_numpy="123")
|
||||
assert False
|
||||
except TypeError as e:
|
||||
assert "Argument output_numpy with value 123 is not of type" in str(e)
|
||||
|
||||
try:
|
||||
_ = data.create_tuple_iterator(output_numpy=123)
|
||||
_ = data.create_tuple_iterator(num_epochs=1, output_numpy=123)
|
||||
assert False
|
||||
except TypeError as e:
|
||||
assert "Argument output_numpy with value 123 is not of type" in str(e)
|
||||
|
|
|
@ -61,7 +61,7 @@ def test_func_lfilter_pipeline():
|
|||
# Filtered waveform by lfilter
|
||||
dataset = dataset.map(input_columns=["channel"], operations=lfilter_op, num_parallel_workers=8)
|
||||
i = 0
|
||||
for data in dataset.create_dict_iterator(output_numpy=True):
|
||||
for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count_unequal_element(expect_waveform[i, :], data['channel'], 0.0001, 0.0001)
|
||||
i += 1
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ def test_lowpass_biquad_pipeline():
|
|||
dataset = dataset.map(
|
||||
input_columns=["col1"], operations=lowpass_biquad_op, num_parallel_workers=4)
|
||||
i = 0
|
||||
for _ in dataset.create_dict_iterator(output_numpy=True):
|
||||
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count_unequal_element(expect_waveform[i, :],
|
||||
_["col1"], 0.0001, 0.0001)
|
||||
i += 1
|
||||
|
|
|
@ -843,7 +843,7 @@ def test_cv_minddataset_reader_basic_tutorial_5_epoch(add_and_remove_cv_file):
|
|||
assert data_set.get_dataset_size() == 10
|
||||
for _ in range(5):
|
||||
num_iter = 0
|
||||
for data in data_set.create_tuple_iterator(output_numpy=True):
|
||||
for data in data_set.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
logger.info("data is {}".format(data))
|
||||
num_iter += 1
|
||||
assert num_iter == 10
|
||||
|
@ -871,7 +871,7 @@ def test_cv_minddataset_reader_basic_tutorial_5_epoch_with_batch(add_and_remove_
|
|||
assert data_set.get_dataset_size() == 5
|
||||
for _ in range(5):
|
||||
num_iter = 0
|
||||
for data in data_set.create_tuple_iterator(output_numpy=True):
|
||||
for data in data_set.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
logger.info("data is {}".format(data))
|
||||
num_iter += 1
|
||||
assert num_iter == 5
|
||||
|
|
|
@ -121,7 +121,7 @@ def test_one_hot_success():
|
|||
trans = py_trans.Compose([one_hot_encode])
|
||||
dataset = dataset.map(operations=trans, input_columns=["label"])
|
||||
|
||||
for index, item in enumerate(dataset.create_dict_iterator(output_numpy=True)):
|
||||
for index, item in enumerate(dataset.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
assert item["label"][index] == 1.0
|
||||
|
||||
def test_one_hot_success2():
|
||||
|
@ -146,7 +146,7 @@ def test_one_hot_success2():
|
|||
trans = py_trans.Compose([one_hot_encode])
|
||||
dataset = dataset.map(operations=trans, input_columns=["label"])
|
||||
|
||||
for index, item in enumerate(dataset.create_dict_iterator(output_numpy=True)):
|
||||
for index, item in enumerate(dataset.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
logger.info(item)
|
||||
assert item["label"][0][index] == 1.0
|
||||
|
||||
|
@ -175,7 +175,7 @@ def test_one_hot_success3():
|
|||
trans = py_trans.Compose([one_hot_encode])
|
||||
dataset = dataset.map(operations=trans, input_columns=["label"])
|
||||
|
||||
for item in dataset.create_dict_iterator(output_numpy=True):
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
logger.info(item)
|
||||
for i in range(10):
|
||||
assert item["label"][i][0][i] == 1.0
|
||||
|
@ -203,7 +203,7 @@ def test_one_hot_type_error():
|
|||
dataset = dataset.map(operations=trans, input_columns=["label"])
|
||||
|
||||
try:
|
||||
for index, item in enumerate(dataset.create_dict_iterator(output_numpy=True)):
|
||||
for index, item in enumerate(dataset.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
||||
assert item["label"][index] == 1.0
|
||||
except RuntimeError as e:
|
||||
assert "the input numpy type should be int" in str(e)
|
||||
|
|
|
@ -39,7 +39,7 @@ def test_case_0():
|
|||
data1 = data1.batch(2)
|
||||
|
||||
expected_data = np.array([[[1], [2]], [[3], [0]]])
|
||||
for i, data_row in enumerate(data1.create_tuple_iterator(output_numpy=True)):
|
||||
for i, data_row in enumerate(data1.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
np.testing.assert_array_equal(data_row[0], expected_data[i])
|
||||
|
||||
# Restore configuration
|
||||
|
|
|
@ -67,7 +67,8 @@ def test_shuffle():
|
|||
data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
|
||||
data2 = data2.shuffle(10000)
|
||||
|
||||
for d1, d2 in zip(data1.create_tuple_iterator(output_numpy=True), data2.create_tuple_iterator(output_numpy=True)):
|
||||
for d1, d2 in zip(data1.create_tuple_iterator(num_epochs=1, output_numpy=True),
|
||||
data2.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
for t1, t2 in zip(d1, d2):
|
||||
np.testing.assert_array_equal(t1, t2)
|
||||
|
||||
|
@ -77,7 +78,8 @@ def test_shuffle():
|
|||
data2 = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.FILES)
|
||||
data2 = data2.shuffle(10000)
|
||||
|
||||
for d1, d2 in zip(data1.create_tuple_iterator(output_numpy=True), data2.create_tuple_iterator(output_numpy=True)):
|
||||
for d1, d2 in zip(data1.create_tuple_iterator(num_epochs=1, output_numpy=True),
|
||||
data2.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
for t1, t2 in zip(d1, d2):
|
||||
np.testing.assert_array_equal(t1, t2)
|
||||
|
||||
|
@ -87,7 +89,8 @@ def test_shuffle():
|
|||
data2 = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=ds.Shuffle.FILES)
|
||||
data2 = data2.shuffle(10000)
|
||||
|
||||
for d1, d2 in zip(data1.create_tuple_iterator(output_numpy=True), data2.create_tuple_iterator(output_numpy=True)):
|
||||
for d1, d2 in zip(data1.create_tuple_iterator(num_epochs=1, output_numpy=True),
|
||||
data2.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
for t1, t2 in zip(d1, d2):
|
||||
np.testing.assert_array_equal(t1, t2)
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ def pad_compare(array, pad_shape, pad_value, res):
|
|||
data = data.map(operations=ops.PadEnd(pad_shape, pad_value))
|
||||
else:
|
||||
data = data.map(operations=ops.PadEnd(pad_shape))
|
||||
for d in data.create_tuple_iterator(output_numpy=True):
|
||||
for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_array_equal(res, d[0])
|
||||
|
||||
|
||||
|
|
|
@ -191,9 +191,9 @@ def test_profiling_inline_ops_pipeline1():
|
|||
|
||||
try:
|
||||
num_iter = 0
|
||||
# Note: Do not explicitly set num_epochs argument in create_tuple_iterator() call
|
||||
# Note: set num_epochs=2 in create_tuple_iterator(), so that EpochCtrl op is added to the pipeline
|
||||
# Here i refers to index, d refers to data element
|
||||
for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
|
||||
for i, d in enumerate(data3.create_tuple_iterator(num_epochs=2, output_numpy=True)):
|
||||
num_iter += 1
|
||||
t = d
|
||||
assert i == t[0][0]
|
||||
|
@ -326,7 +326,7 @@ def test_profiling_basic_pipeline():
|
|||
|
||||
try:
|
||||
num_iter = 0
|
||||
# Note: If create_tuple_iterator() is called with num_epochs>1, then EpochCtrlOp is added to the pipeline
|
||||
# Note: If create_dict_iterator() is called with num_epochs>1, then EpochCtrlOp is added to the pipeline
|
||||
for _ in data1.create_dict_iterator(num_epochs=2):
|
||||
num_iter += 1
|
||||
|
||||
|
@ -379,7 +379,7 @@ def test_profiling_cifar10_pipeline():
|
|||
|
||||
try:
|
||||
num_iter = 0
|
||||
# Note: If create_tuple_iterator() is called with num_epochs=1, then EpochCtrlOp is NOT added to the pipeline
|
||||
# Note: If create_dict_iterator() is called with num_epochs=1, then EpochCtrlOp is NOT added to the pipeline
|
||||
for _ in data1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
||||
|
@ -426,7 +426,7 @@ def test_profiling_seq_pipelines_epochctrl3():
|
|||
try:
|
||||
# Test A - Call create_dict_iterator with num_epochs>1
|
||||
num_iter = 0
|
||||
# Note: If create_tuple_iterator() is called with num_epochs>1, then EpochCtrlOp is added to the pipeline
|
||||
# Note: If create_dict_iterator() is called with num_epochs>1, then EpochCtrlOp is added to the pipeline
|
||||
for _ in data1.create_dict_iterator(num_epochs=2):
|
||||
num_iter += 1
|
||||
assert num_iter == 2
|
||||
|
@ -437,7 +437,7 @@ def test_profiling_seq_pipelines_epochctrl3():
|
|||
|
||||
# Test B - Call create_dict_iterator with num_epochs=1
|
||||
num_iter = 0
|
||||
# Note: If create_tuple_iterator() is called with num_epochs=1,
|
||||
# Note: If create_dict_iterator() is called with num_epochs=1,
|
||||
# then EpochCtrlOp should not be NOT added to the pipeline
|
||||
for _ in data1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
|
@ -470,7 +470,7 @@ def test_profiling_seq_pipelines_epochctrl2():
|
|||
try:
|
||||
# Test A - Call create_dict_iterator with num_epochs=1
|
||||
num_iter = 0
|
||||
# Note: If create_tuple_iterator() is called with num_epochs=1, then EpochCtrlOp is NOT added to the pipeline
|
||||
# Note: If create_dict_iterator() is called with num_epochs=1, then EpochCtrlOp is NOT added to the pipeline
|
||||
for _ in data2.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert num_iter == 4
|
||||
|
@ -481,7 +481,7 @@ def test_profiling_seq_pipelines_epochctrl2():
|
|||
|
||||
# Test B - Call create_dict_iterator with num_epochs>1
|
||||
num_iter = 0
|
||||
# Note: If create_tuple_iterator() is called with num_epochs>1,
|
||||
# Note: If create_dict_iterator() is called with num_epochs>1,
|
||||
# then EpochCtrlOp should be added to the pipeline
|
||||
for _ in data2.create_dict_iterator(num_epochs=2):
|
||||
num_iter += 1
|
||||
|
|
|
@ -340,7 +340,7 @@ def test_func_with_yield_manifest_dataset_01():
|
|||
data = data.map(operations=pass_func, input_columns=["image"], num_parallel_workers=1, python_multiprocessing=True)
|
||||
num_iter = 0
|
||||
try:
|
||||
for _ in data.create_dict_iterator(output_numpy=True):
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
except RuntimeError as e:
|
||||
assert "Can not pickle <class 'generator'> object, " in str(e)
|
||||
|
|
|
@ -47,7 +47,7 @@ def test_random_sharpness_py(degrees=(0.7, 0.7), plot=False):
|
|||
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_original.create_tuple_iterator(output_numpy=True)):
|
||||
for idx, (image, _) in enumerate(ds_original.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
if idx == 0:
|
||||
images_original = np.transpose(image, (0, 2, 3, 1))
|
||||
else:
|
||||
|
@ -71,7 +71,7 @@ def test_random_sharpness_py(degrees=(0.7, 0.7), plot=False):
|
|||
|
||||
ds_random_sharpness = ds_random_sharpness.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_random_sharpness.create_tuple_iterator(output_numpy=True)):
|
||||
for idx, (image, _) in enumerate(ds_random_sharpness.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
if idx == 0:
|
||||
images_random_sharpness = np.transpose(image, (0, 2, 3, 1))
|
||||
else:
|
||||
|
@ -136,7 +136,7 @@ def test_random_sharpness_c(degrees=(1.6, 1.6), plot=False):
|
|||
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_original.create_tuple_iterator(output_numpy=True)):
|
||||
for idx, (image, _) in enumerate(ds_original.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
if idx == 0:
|
||||
images_original = image
|
||||
else:
|
||||
|
@ -159,7 +159,7 @@ def test_random_sharpness_c(degrees=(1.6, 1.6), plot=False):
|
|||
|
||||
ds_random_sharpness = ds_random_sharpness.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_random_sharpness.create_tuple_iterator(output_numpy=True)):
|
||||
for idx, (image, _) in enumerate(ds_random_sharpness.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
if idx == 0:
|
||||
images_random_sharpness = image
|
||||
else:
|
||||
|
@ -226,7 +226,7 @@ def test_random_sharpness_c_py(degrees=(1.0, 1.0), plot=False):
|
|||
|
||||
ds_random_sharpness_py = ds_random_sharpness_py.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_random_sharpness_py.create_tuple_iterator(output_numpy=True)):
|
||||
for idx, (image, _) in enumerate(ds_random_sharpness_py.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
if idx == 0:
|
||||
images_random_sharpness_py = image
|
||||
|
||||
|
@ -242,7 +242,10 @@ def test_random_sharpness_c_py(degrees=(1.0, 1.0), plot=False):
|
|||
|
||||
ds_images_random_sharpness_c = ds_images_random_sharpness_c.batch(512)
|
||||
|
||||
for idx, (image, _) in enumerate(ds_images_random_sharpness_c.create_tuple_iterator(output_numpy=True)):
|
||||
for idx, (image, _) in enumerate(
|
||||
ds_images_random_sharpness_c.create_tuple_iterator(
|
||||
num_epochs=1,
|
||||
output_numpy=True)):
|
||||
if idx == 0:
|
||||
images_random_sharpness_c = image
|
||||
|
||||
|
|
|
@ -117,7 +117,7 @@ def test_nested_repeat1():
|
|||
data = data.repeat(2)
|
||||
data = data.repeat(3)
|
||||
|
||||
for i, d in enumerate(data.create_tuple_iterator(output_numpy=True)):
|
||||
for i, d in enumerate(data.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
assert i % 3 == d[0][0]
|
||||
|
||||
assert sum([1 for _ in data]) == 2 * 3 * 3
|
||||
|
@ -129,7 +129,7 @@ def test_nested_repeat2():
|
|||
data = data.repeat(1)
|
||||
data = data.repeat(1)
|
||||
|
||||
for i, d in enumerate(data.create_tuple_iterator(output_numpy=True)):
|
||||
for i, d in enumerate(data.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
assert i % 3 == d[0][0]
|
||||
|
||||
assert sum([1 for _ in data]) == 3
|
||||
|
@ -141,7 +141,7 @@ def test_nested_repeat3():
|
|||
data = data.repeat(1)
|
||||
data = data.repeat(2)
|
||||
|
||||
for i, d in enumerate(data.create_tuple_iterator(output_numpy=True)):
|
||||
for i, d in enumerate(data.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
assert i % 3 == d[0][0]
|
||||
|
||||
assert sum([1 for _ in data]) == 2 * 3
|
||||
|
@ -153,7 +153,7 @@ def test_nested_repeat4():
|
|||
data = data.repeat(2)
|
||||
data = data.repeat(1)
|
||||
|
||||
for i, d in enumerate(data.create_tuple_iterator(output_numpy=True)):
|
||||
for i, d in enumerate(data.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
assert i % 3 == d[0][0]
|
||||
|
||||
assert sum([1 for _ in data]) == 2 * 3
|
||||
|
|
|
@ -269,7 +269,7 @@ def test_voc_sampler_chain():
|
|||
assert data1_size == 5
|
||||
|
||||
# Verify number of rows
|
||||
assert sum([1 for _ in data1.create_dict_iterator(output_numpy=True)]) == 5
|
||||
assert sum([1 for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True)]) == 5
|
||||
|
||||
# Verify dataset contents
|
||||
res = []
|
||||
|
|
|
@ -307,8 +307,9 @@ def test_serdes_zip_dataset(remove_json_files=True):
|
|||
assert filecmp.cmp('zip_dataset_pipeline.json', 'zip_dataset_pipeline_1.json')
|
||||
|
||||
rows = 0
|
||||
for d0, d3, d4 in zip(ds0.create_tuple_iterator(output_numpy=True), data3.create_tuple_iterator(output_numpy=True),
|
||||
data4.create_tuple_iterator(output_numpy=True)):
|
||||
for d0, d3, d4 in zip(ds0.create_tuple_iterator(num_epochs=1, output_numpy=True),
|
||||
data3.create_tuple_iterator(num_epochs=1, output_numpy=True),
|
||||
data4.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
num_cols = len(d0)
|
||||
offset = 0
|
||||
for t1 in d0:
|
||||
|
|
|
@ -57,7 +57,7 @@ def test_generator_skip():
|
|||
ds1 = ds1.skip(3)
|
||||
|
||||
buf = []
|
||||
for data in ds1.create_tuple_iterator(output_numpy=True):
|
||||
for data in ds1.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 2
|
||||
assert buf == [3, 4]
|
||||
|
@ -70,7 +70,7 @@ def test_skip_1():
|
|||
ds1 = ds1.skip(7)
|
||||
|
||||
buf = []
|
||||
for data in ds1.create_tuple_iterator(output_numpy=True):
|
||||
for data in ds1.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
buf.append(data[0][0])
|
||||
assert buf == []
|
||||
|
||||
|
@ -82,7 +82,7 @@ def test_skip_2():
|
|||
ds1 = ds1.skip(0)
|
||||
|
||||
buf = []
|
||||
for data in ds1.create_tuple_iterator(output_numpy=True):
|
||||
for data in ds1.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 5
|
||||
assert buf == [0, 1, 2, 3, 4]
|
||||
|
@ -98,7 +98,7 @@ def test_skip_repeat_1():
|
|||
ds1 = ds1.skip(3)
|
||||
|
||||
buf = []
|
||||
for data in ds1.create_tuple_iterator(output_numpy=True):
|
||||
for data in ds1.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 7
|
||||
assert buf == [3, 4, 0, 1, 2, 3, 4]
|
||||
|
@ -114,7 +114,7 @@ def test_skip_repeat_2():
|
|||
ds1 = ds1.repeat(2)
|
||||
|
||||
buf = []
|
||||
for data in ds1.create_tuple_iterator(output_numpy=True):
|
||||
for data in ds1.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 4
|
||||
assert buf == [3, 4, 3, 4]
|
||||
|
@ -133,7 +133,7 @@ def test_skip_repeat_3():
|
|||
ds1 = ds1.repeat(3)
|
||||
|
||||
buf = []
|
||||
for data in ds1.create_tuple_iterator(output_numpy=True):
|
||||
for data in ds1.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 6
|
||||
assert buf == [3, 4, 3, 4, 3, 4]
|
||||
|
@ -149,7 +149,7 @@ def test_skip_take_1():
|
|||
ds1 = ds1.skip(2)
|
||||
|
||||
buf = []
|
||||
for data in ds1.create_tuple_iterator(output_numpy=True):
|
||||
for data in ds1.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 2
|
||||
assert buf == [2, 3]
|
||||
|
@ -165,7 +165,7 @@ def test_skip_take_2():
|
|||
ds1 = ds1.take(2)
|
||||
|
||||
buf = []
|
||||
for data in ds1.create_tuple_iterator(output_numpy=True):
|
||||
for data in ds1.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 2
|
||||
assert buf == [2, 3]
|
||||
|
@ -182,7 +182,7 @@ def test_skip_filter_1():
|
|||
dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4)
|
||||
|
||||
buf = []
|
||||
for item in dataset.create_tuple_iterator(output_numpy=True):
|
||||
for item in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
buf.append(item[0][0])
|
||||
assert buf == [5, 6, 7, 8, 9, 10]
|
||||
|
||||
|
@ -193,7 +193,7 @@ def test_skip_filter_2():
|
|||
dataset = dataset.skip(5)
|
||||
|
||||
buf = []
|
||||
for item in dataset.create_tuple_iterator(output_numpy=True):
|
||||
for item in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
buf.append(item[0][0])
|
||||
assert buf == [5, 6, 7, 8, 9, 10]
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ def slice_compare(array, indexing, expected_array):
|
|||
data = data.map(operations=ops.Slice(*indexing))
|
||||
else:
|
||||
data = data.map(operations=ops.Slice(indexing))
|
||||
for d in data.create_dict_iterator(output_numpy=True):
|
||||
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_array_equal(expected_array, d['column_0'])
|
||||
|
||||
|
||||
|
@ -141,7 +141,7 @@ def test_slice_multiple_rows():
|
|||
data = ds.GeneratorDataset(gen, column_names=["col"])
|
||||
indexing = slice(1, 4)
|
||||
data = data.map(operations=ops.Slice(indexing))
|
||||
for (d, exp_d) in zip(data.create_dict_iterator(output_numpy=True), exp_dataset):
|
||||
for (d, exp_d) in zip(data.create_dict_iterator(num_epochs=1, output_numpy=True), exp_dataset):
|
||||
np.testing.assert_array_equal(exp_d, d['col'])
|
||||
|
||||
|
||||
|
@ -158,12 +158,12 @@ def test_slice_none_and_ellipsis():
|
|||
|
||||
data = ds.GeneratorDataset(gen, column_names=["col"])
|
||||
data = data.map(operations=ops.Slice(None))
|
||||
for (d, exp_d) in zip(data.create_dict_iterator(output_numpy=True), exp_dataset):
|
||||
for (d, exp_d) in zip(data.create_dict_iterator(num_epochs=1, output_numpy=True), exp_dataset):
|
||||
np.testing.assert_array_equal(exp_d, d['col'])
|
||||
|
||||
data = ds.GeneratorDataset(gen, column_names=["col"])
|
||||
data = data.map(operations=ops.Slice(Ellipsis))
|
||||
for (d, exp_d) in zip(data.create_dict_iterator(output_numpy=True), exp_dataset):
|
||||
for (d, exp_d) in zip(data.create_dict_iterator(num_epochs=1, output_numpy=True), exp_dataset):
|
||||
np.testing.assert_array_equal(exp_d, d['col'])
|
||||
|
||||
|
||||
|
@ -280,7 +280,7 @@ def test_out_of_bounds_slicing_str():
|
|||
data = [b"1", b"2", b"3", b"4", b"5"]
|
||||
data = ds.NumpySlicesDataset([data])
|
||||
data = data.map(operations=ops.Slice(indexing))
|
||||
for d in data.create_dict_iterator(output_numpy=True):
|
||||
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_array_equal(expected_array, d['column_0'])
|
||||
|
||||
|
||||
|
|
|
@ -161,7 +161,7 @@ def test_slice_patches_08():
|
|||
dataset = dataset.map(input_columns=["image"], output_columns=["img0", "img1", "img2", "img3"],
|
||||
column_order=["img0", "img1", "img2", "img3"],
|
||||
operations=slice_patches_op)
|
||||
for item in dataset.create_dict_iterator(output_numpy=True):
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
patch_shape = item['img0'].shape
|
||||
assert patch_shape == (28, 41, 256)
|
||||
|
||||
|
@ -185,7 +185,7 @@ def skip_test_slice_patches_11():
|
|||
cols = ['img' + str(x) for x in range(10*13)]
|
||||
dataset = dataset.map(input_columns=["image"], output_columns=cols,
|
||||
column_order=cols, operations=slice_patches_op)
|
||||
for item in dataset.create_dict_iterator(output_numpy=True):
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
patch_shape = item['img0'].shape
|
||||
assert patch_shape == (700, 538, 256)
|
||||
|
||||
|
|
|
@ -264,7 +264,7 @@ def test_simple_sync_wait_empty_condition_name():
|
|||
dataset = dataset.batch(batch_size)
|
||||
|
||||
count = 0
|
||||
for data in dataset.create_dict_iterator(output_numpy=True):
|
||||
for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count += 1
|
||||
data = {"loss": count}
|
||||
dataset.sync_update(condition_name="", data=data)
|
||||
|
|
|
@ -25,7 +25,7 @@ def test_tensor_empty():
|
|||
|
||||
data = ds.GeneratorDataset(gen, column_names=["col1", "col2", "col3"])
|
||||
|
||||
for d in data.create_tuple_iterator(output_numpy=True):
|
||||
for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_array_equal(np.array([], dtype=np.int64), d[0])
|
||||
np.testing.assert_array_equal(np.array([], dtype='S').reshape([0, 4]), d[1])
|
||||
np.testing.assert_array_equal(np.array([1], dtype=np.float64), d[2])
|
||||
|
@ -46,7 +46,7 @@ def test_tensor_empty_map():
|
|||
|
||||
data = data.map(operations=func, input_columns=["col1", "col2", "col3"])
|
||||
|
||||
for d in data.create_tuple_iterator(output_numpy=True):
|
||||
for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_array_equal(np.array([1], dtype=np.int64), d[0])
|
||||
np.testing.assert_array_equal(np.array(["Hi"], dtype='S'), d[1])
|
||||
np.testing.assert_array_equal(np.array([], dtype=np.float64), d[2])
|
||||
|
@ -60,7 +60,7 @@ def test_tensor_empty_batch():
|
|||
|
||||
data = ds.GeneratorDataset(gen, column_names=["col1", "col2", "col3"]).batch(2)
|
||||
|
||||
for d in data.create_tuple_iterator(output_numpy=True):
|
||||
for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_array_equal(np.array([], dtype=np.int64).reshape([2, 0]), d[0])
|
||||
np.testing.assert_array_equal(np.array([], dtype='S').reshape([2, 0, 4]), d[1])
|
||||
np.testing.assert_array_equal(np.array([[1], [1]], dtype=np.float64), d[2])
|
||||
|
|
|
@ -35,7 +35,7 @@ def compare(strings, dtype='S'):
|
|||
|
||||
data = ds.GeneratorDataset(gen, column_names=["col"])
|
||||
|
||||
for d in data.create_tuple_iterator(output_numpy=True):
|
||||
for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_array_equal(d[0], arr.astype('S'))
|
||||
|
||||
|
||||
|
@ -79,7 +79,7 @@ def test_batching_strings():
|
|||
data = ds.GeneratorDataset(gen, column_names=["col"])
|
||||
data = data.batch(2, drop_remainder=True)
|
||||
|
||||
for d in data.create_tuple_iterator(output_numpy=True):
|
||||
for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_array_equal(d[0], to_bytes(chinese[0:2]))
|
||||
|
||||
|
||||
|
@ -96,7 +96,7 @@ def test_map():
|
|||
|
||||
data = data.map(operations=split, input_columns=["col"])
|
||||
expected = np.array(["ab", "cde", "121"], dtype='S')
|
||||
for d in data.create_tuple_iterator(output_numpy=True):
|
||||
for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_array_equal(d[0], expected)
|
||||
|
||||
|
||||
|
@ -112,7 +112,7 @@ def test_map2():
|
|||
|
||||
data = data.map(operations=upper, input_columns=["col"])
|
||||
expected = np.array(["AB CDE 121"], dtype='S')
|
||||
for d in data.create_tuple_iterator(output_numpy=True):
|
||||
for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_array_equal(d[0], expected)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue