!34901 Add flatten for flat_map

Merge pull request !34901 from xiaotianci/fix_flat_map
This commit is contained in:
i-robot 2022-05-27 03:37:01 +00:00 committed by Gitee
commit 43b5ad6f12
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 102 additions and 28 deletions

View File

@ -732,18 +732,29 @@ class Dataset:
Dataset, dataset applied by the function.
Examples:
>>> # use NumpySlicesDataset as an example
>>> dataset = ds.NumpySlicesDataset([[0, 1], [2, 3]])
>>> # 1) flat_map on one column dataset
>>> dataset = ds.NumpySlicesDataset([[0, 1], [2, 3]], shuffle=False)
>>>
>>> def flat_map_func(array):
>>> def repeat(array):
... # create a NumpySlicesDataset with the array
... dataset = ds.NumpySlicesDataset(array)
... data = ds.NumpySlicesDataset(array, shuffle=False)
... # repeat the dataset twice
... dataset = dataset.repeat(2)
... return dataset
... data = data.repeat(2)
... return data
>>>
>>> dataset = dataset.flat_map(flat_map_func)
>>> # [[0, 1], [0, 1], [2, 3], [2, 3]]
>>> dataset = dataset.flat_map(repeat)
>>> # [0, 1, 0, 1, 2, 3, 2, 3]
>>>
>>> # 2) flat_map on multi column dataset
>>> dataset = ds.NumpySlicesDataset(([[0, 1], [2, 3]], [[0, -1], [-2, -3]]), shuffle=False)
>>> def plus_and_minus(col1, col2):
... # apply different methods on columns
... data = ds.NumpySlicesDataset((col1 + 1, col2 - 1), shuffle=False)
... return data
>>> dataset = dataset.flat_map(plus_and_minus)
>>> # ([1, 2, 3, 4], [-1, -2, -3, -4])
Raises:
TypeError: If `func` is not a function.
@ -754,11 +765,11 @@ class Dataset:
logger.critical("func must be a function.")
raise TypeError("func must be a function.")
for row_data in self.create_tuple_iterator(output_numpy=True):
for row_data in self.create_tuple_iterator(num_epochs=1, output_numpy=True):
if dataset is None:
dataset = func(row_data)
dataset = func(*row_data)
else:
dataset += func(row_data)
dataset += func(*row_data)
if not isinstance(dataset, Dataset):
logger.critical("flat_map must return a Dataset object.")
@ -2382,8 +2393,7 @@ def _check_shm_usage(num_worker, queue_size, max_rowsize, num_queues=1):
"it's recommended to reduce memory usage by following methods:\n"
"1. reduce value of parameter max_rowsize or num_parallel_workers.\n"
"2. reduce prefetch size by set_prefetch_size().\n"
"3. disable shared memory by set_enable_shared_mem()."
.format(shm_estimate_usage, shm_available))
"3. disable shared memory by set_enable_shared_mem().".format(shm_estimate_usage, shm_available))
except FileNotFoundError:
raise RuntimeError("Expected /dev/shm to exist.")

View File

@ -20,13 +20,14 @@ DATA_FILE = "../data/dataset/test_flat_map/images1.txt"
INDEX_FILE = "../data/dataset/test_flat_map/image_index.txt"
def test_flat_map_1():
'''
DATA_FILE records the path of image folders, load the images from them.
'''
def test_flat_map_basic():
"""
Feature: flat_map
Description: test basic usage
Expectation: the result is as expected
"""
def flat_map_func(x):
data_dir = x[0].item().decode('utf8')
data_dir = x.item().decode('utf8')
d = ds.ImageFolderDataset(data_dir)
return d
@ -40,18 +41,19 @@ def test_flat_map_1():
assert count == 52
def test_flat_map_2():
'''
Flatten 3D structure data
'''
def test_flat_map_chain_call():
"""
Feature: flat_map
Description: test chain call
Expectation: the result is as expected
"""
def flat_map_func_1(x):
data_dir = x[0].item().decode('utf8')
data_dir = x.item().decode('utf8')
d = ds.ImageFolderDataset(data_dir)
return d
def flat_map_func_2(x):
text_file = x[0].item().decode('utf8')
text_file = x.item().decode('utf8')
d = ds.TextFileDataset(text_file)
d = d.flat_map(flat_map_func_1)
return d
@ -66,6 +68,68 @@ def test_flat_map_2():
assert count == 104
def test_flat_map_one_column():
"""
Feature: flat_map
Description: test with one column dataset
Expectation: the result is as expected
"""
dataset = ds.NumpySlicesDataset([[0, 1], [2, 3]], shuffle=False)
def repeat(array):
data = ds.NumpySlicesDataset(array, shuffle=False)
data = data.repeat(2)
return data
dataset = dataset.flat_map(repeat)
i = 0
expect = np.array([0, 1, 0, 1, 2, 3, 2, 3])
for d in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True):
np.testing.assert_array_equal(d[0], expect[i])
i += 1
dataset = ds.NumpySlicesDataset([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]], shuffle=False)
def plus(array):
data = ds.NumpySlicesDataset(array + 1, shuffle=False)
return data
dataset = dataset.flat_map(plus)
i = 0
expect = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
for d in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True):
np.testing.assert_array_equal(d[0], expect[i])
i += 1
def test_flat_map_multi_column():
"""
Feature: flat_map
Description: test with multi column dataset
Expectation: the result is as expected
"""
dataset = ds.NumpySlicesDataset(([[0, 1], [2, 3]], [[0, -1], [-2, -3]]), column_names=["col1", "col2"],
shuffle=False)
def plus_and_minus(col1, col2):
data = ds.NumpySlicesDataset((col1 + 1, col2 - 1), shuffle=False)
return data
dataset = dataset.flat_map(plus_and_minus)
i = 0
expect_col1 = np.array([1, 2, 3, 4])
expect_col2 = np.array([-1, -2, -3, -4])
for d in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True):
np.testing.assert_array_equal(d[0], expect_col1[i])
np.testing.assert_array_equal(d[1], expect_col2[i])
i += 1
if __name__ == "__main__":
test_flat_map_1()
test_flat_map_2()
test_flat_map_basic()
test_flat_map_chain_call()
test_flat_map_one_column()
test_flat_map_multi_column()