forked from mindspore-Ecosystem/mindspore
!34901 Add flatten for flat_map
Merge pull request !34901 from xiaotianci/fix_flat_map
This commit is contained in:
commit
43b5ad6f12
|
@ -732,18 +732,29 @@ class Dataset:
|
|||
Dataset, dataset applied by the function.
|
||||
|
||||
Examples:
|
||||
>>> # use NumpySlicesDataset as an example
|
||||
>>> dataset = ds.NumpySlicesDataset([[0, 1], [2, 3]])
|
||||
>>> # 1) flat_map on one column dataset
|
||||
>>> dataset = ds.NumpySlicesDataset([[0, 1], [2, 3]], shuffle=False)
|
||||
>>>
|
||||
>>> def flat_map_func(array):
|
||||
>>> def repeat(array):
|
||||
... # create a NumpySlicesDataset with the array
|
||||
... dataset = ds.NumpySlicesDataset(array)
|
||||
... data = ds.NumpySlicesDataset(array, shuffle=False)
|
||||
... # repeat the dataset twice
|
||||
... dataset = dataset.repeat(2)
|
||||
... return dataset
|
||||
... data = data.repeat(2)
|
||||
... return data
|
||||
>>>
|
||||
>>> dataset = dataset.flat_map(flat_map_func)
|
||||
>>> # [[0, 1], [0, 1], [2, 3], [2, 3]]
|
||||
>>> dataset = dataset.flat_map(repeat)
|
||||
>>> # [0, 1, 0, 1, 2, 3, 2, 3]
|
||||
>>>
|
||||
>>> # 2) flat_map on multi column dataset
|
||||
>>> dataset = ds.NumpySlicesDataset(([[0, 1], [2, 3]], [[0, -1], [-2, -3]]), shuffle=False)
|
||||
|
||||
>>> def plus_and_minus(col1, col2):
|
||||
... # apply different methods on columns
|
||||
... data = ds.NumpySlicesDataset((col1 + 1, col2 - 1), shuffle=False)
|
||||
... return data
|
||||
|
||||
>>> dataset = dataset.flat_map(plus_and_minus)
|
||||
>>> # ([1, 2, 3, 4], [-1, -2, -3, -4])
|
||||
|
||||
Raises:
|
||||
TypeError: If `func` is not a function.
|
||||
|
@ -754,11 +765,11 @@ class Dataset:
|
|||
logger.critical("func must be a function.")
|
||||
raise TypeError("func must be a function.")
|
||||
|
||||
for row_data in self.create_tuple_iterator(output_numpy=True):
|
||||
for row_data in self.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
if dataset is None:
|
||||
dataset = func(row_data)
|
||||
dataset = func(*row_data)
|
||||
else:
|
||||
dataset += func(row_data)
|
||||
dataset += func(*row_data)
|
||||
|
||||
if not isinstance(dataset, Dataset):
|
||||
logger.critical("flat_map must return a Dataset object.")
|
||||
|
@ -2382,8 +2393,7 @@ def _check_shm_usage(num_worker, queue_size, max_rowsize, num_queues=1):
|
|||
"it's recommended to reduce memory usage by following methods:\n"
|
||||
"1. reduce value of parameter max_rowsize or num_parallel_workers.\n"
|
||||
"2. reduce prefetch size by set_prefetch_size().\n"
|
||||
"3. disable shared memory by set_enable_shared_mem()."
|
||||
.format(shm_estimate_usage, shm_available))
|
||||
"3. disable shared memory by set_enable_shared_mem().".format(shm_estimate_usage, shm_available))
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError("Expected /dev/shm to exist.")
|
||||
|
||||
|
|
|
@ -20,13 +20,14 @@ DATA_FILE = "../data/dataset/test_flat_map/images1.txt"
|
|||
INDEX_FILE = "../data/dataset/test_flat_map/image_index.txt"
|
||||
|
||||
|
||||
def test_flat_map_1():
|
||||
'''
|
||||
DATA_FILE records the path of image folders, load the images from them.
|
||||
'''
|
||||
|
||||
def test_flat_map_basic():
|
||||
"""
|
||||
Feature: flat_map
|
||||
Description: test basic usage
|
||||
Expectation: the result is as expected
|
||||
"""
|
||||
def flat_map_func(x):
|
||||
data_dir = x[0].item().decode('utf8')
|
||||
data_dir = x.item().decode('utf8')
|
||||
d = ds.ImageFolderDataset(data_dir)
|
||||
return d
|
||||
|
||||
|
@ -40,18 +41,19 @@ def test_flat_map_1():
|
|||
assert count == 52
|
||||
|
||||
|
||||
def test_flat_map_2():
|
||||
'''
|
||||
Flatten 3D structure data
|
||||
'''
|
||||
|
||||
def test_flat_map_chain_call():
|
||||
"""
|
||||
Feature: flat_map
|
||||
Description: test chain call
|
||||
Expectation: the result is as expected
|
||||
"""
|
||||
def flat_map_func_1(x):
|
||||
data_dir = x[0].item().decode('utf8')
|
||||
data_dir = x.item().decode('utf8')
|
||||
d = ds.ImageFolderDataset(data_dir)
|
||||
return d
|
||||
|
||||
def flat_map_func_2(x):
|
||||
text_file = x[0].item().decode('utf8')
|
||||
text_file = x.item().decode('utf8')
|
||||
d = ds.TextFileDataset(text_file)
|
||||
d = d.flat_map(flat_map_func_1)
|
||||
return d
|
||||
|
@ -66,6 +68,68 @@ def test_flat_map_2():
|
|||
assert count == 104
|
||||
|
||||
|
||||
def test_flat_map_one_column():
|
||||
"""
|
||||
Feature: flat_map
|
||||
Description: test with one column dataset
|
||||
Expectation: the result is as expected
|
||||
"""
|
||||
dataset = ds.NumpySlicesDataset([[0, 1], [2, 3]], shuffle=False)
|
||||
|
||||
def repeat(array):
|
||||
data = ds.NumpySlicesDataset(array, shuffle=False)
|
||||
data = data.repeat(2)
|
||||
return data
|
||||
|
||||
dataset = dataset.flat_map(repeat)
|
||||
|
||||
i = 0
|
||||
expect = np.array([0, 1, 0, 1, 2, 3, 2, 3])
|
||||
for d in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_array_equal(d[0], expect[i])
|
||||
i += 1
|
||||
|
||||
dataset = ds.NumpySlicesDataset([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]], shuffle=False)
|
||||
|
||||
def plus(array):
|
||||
data = ds.NumpySlicesDataset(array + 1, shuffle=False)
|
||||
return data
|
||||
|
||||
dataset = dataset.flat_map(plus)
|
||||
|
||||
i = 0
|
||||
expect = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
|
||||
for d in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_array_equal(d[0], expect[i])
|
||||
i += 1
|
||||
|
||||
|
||||
def test_flat_map_multi_column():
|
||||
"""
|
||||
Feature: flat_map
|
||||
Description: test with multi column dataset
|
||||
Expectation: the result is as expected
|
||||
"""
|
||||
dataset = ds.NumpySlicesDataset(([[0, 1], [2, 3]], [[0, -1], [-2, -3]]), column_names=["col1", "col2"],
|
||||
shuffle=False)
|
||||
|
||||
def plus_and_minus(col1, col2):
|
||||
data = ds.NumpySlicesDataset((col1 + 1, col2 - 1), shuffle=False)
|
||||
return data
|
||||
|
||||
dataset = dataset.flat_map(plus_and_minus)
|
||||
|
||||
i = 0
|
||||
expect_col1 = np.array([1, 2, 3, 4])
|
||||
expect_col2 = np.array([-1, -2, -3, -4])
|
||||
for d in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
np.testing.assert_array_equal(d[0], expect_col1[i])
|
||||
np.testing.assert_array_equal(d[1], expect_col2[i])
|
||||
i += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_flat_map_1()
|
||||
test_flat_map_2()
|
||||
test_flat_map_basic()
|
||||
test_flat_map_chain_call()
|
||||
test_flat_map_one_column()
|
||||
test_flat_map_multi_column()
|
||||
|
|
Loading…
Reference in New Issue