!46531 enhance dataset sampler for r2.0_alpha

Merge pull request !46531 from guozhijian/update_sampler_r2.0_alpha
This commit is contained in:
i-robot 2022-12-08 02:56:41 +00:00 committed by Gitee
commit 3d69682c7a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 176 additions and 1 deletions

View File

@ -227,7 +227,9 @@ Status GeneratorOp::operator()() {
// Restore exception to python
e.restore();
if (num_rows_sampled != -1 && num_rows_sampled != generator_counter_) {
// Check whether the number of samples is sufficient only when the first epoch
if (num_rows_sampled != -1 && num_rows_sampled != generator_counter_ && op_current_epochs_ == 0) {
if (generator_counter_ == 0) {
std::string msg =
"Unable to fetch data from GeneratorDataset, try iterate the source function of GeneratorDataset or check"

View File

@ -678,6 +678,10 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
if hasattr(self.source, "__len__"):
self.source_len = len(self.source)
# if user defined sampler, update the self.source_len
if isinstance(self.sampler, samplers.Sampler) or hasattr(self.sampler, "__iter__"):
self.source_len = len(list(sampler))
self.max_rowsize = max_rowsize
self.sample_fn = None

View File

@ -390,6 +390,173 @@ def test_sampler_list():
msg="Type of indices element must be int, but got list[0]: [1 2], type: <class 'numpy.ndarray'>.")
def check_result(expected, result):
for index, item in enumerate(result):
assert str(expected[index][0]) == item
def test_sampler_when_less_and_larger_index_ids():
"""
Feature: Sampler op
Description: Test sampler with less and larger index ids
Expectation: success
"""
# sampler with less index ids
class MySampler():
def __iter__(self):
for i in range(0, 10, 2):
yield i
np_data = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l']
dataset = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler())
count = 0
expected_data = []
for data in dataset.create_tuple_iterator(num_epochs=1):
count += 1
expected_data.append(data)
assert count == 5
check_result(expected_data, ['a', 'c', 'e', 'g', 'i'])
epochs = 3
ds_iter = dataset.create_tuple_iterator(num_epochs=epochs)
for _ in range(epochs):
count = 0
expected_data = []
for data in ds_iter:
count += 1
expected_data.append(data)
assert count == 5
check_result(expected_data, ['a', 'c', 'e', 'g', 'i'])
# sampler with larger index ids
index = [3, 4, 3, 2, 0, 11, 5, 5, 5, 9, 1, 11, 11, 11, 11, 8]
class MySampler2():
def __iter__(self):
for i in index:
yield i
dataset2 = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler2())
count = 0
expected_data = []
for data in dataset2.create_tuple_iterator(num_epochs=1):
count += 1
expected_data.append(data)
assert count == 16
check_result(expected_data, ['d', 'e', 'd', 'c', 'a', 'l', 'f', 'f', 'f', 'j', 'b', 'l', 'l', 'l', 'l', 'i'])
epochs = 3
ds_iter2 = dataset2.create_tuple_iterator(num_epochs=epochs)
for _ in range(epochs):
count = 0
expected_data = []
for data in ds_iter2:
count += 1
expected_data.append(data)
assert count == 16
check_result(expected_data, ['d', 'e', 'd', 'c', 'a', 'l', 'f', 'f', 'f', 'j', 'b', 'l', 'l', 'l', 'l', 'i'])
def test_sampler_with_getitem_method():
"""
Feature: Sampler op
Description: Test sampler with __getitem__ method
Expectation: success
"""
# sampler with equal index ids
class MySampler():
def __init__(self):
self.index_ids = [3, 8, 7, 2, 0, 9, 11, 4, 5, 1, 6, 10]
def __getitem__(self, index):
return self.index_ids[index]
def __len__(self):
return len(self.index_ids)
np_data = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l']
dataset = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler())
count = 0
expected_data = []
for data in dataset.create_tuple_iterator(num_epochs=1):
count += 1
expected_data.append(data)
assert count == 12
check_result(expected_data, ['d', 'i', 'h', 'c', 'a', 'j', 'l', 'e', 'f', 'b', 'g', 'k'])
epochs = 3
ds_iter = dataset.create_tuple_iterator(num_epochs=epochs)
for _ in range(epochs):
count = 0
expected_data = []
for data in ds_iter:
count += 1
expected_data.append(data)
assert count == 12
check_result(expected_data, ['d', 'i', 'h', 'c', 'a', 'j', 'l', 'e', 'f', 'b', 'g', 'k'])
# sampler with less index ids
class MySampler2():
def __init__(self):
self.index_ids = [0, 2, 4, 6, 8]
def __getitem__(self, index):
return self.index_ids[index]
def __len__(self):
return len(self.index_ids)
np_data = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l']
dataset2 = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler2())
count = 0
expected_data = []
for data in dataset2.create_tuple_iterator(num_epochs=1):
count += 1
expected_data.append(data)
assert count == 5
check_result(expected_data, ['a', 'c', 'e', 'g', 'i'])
epochs = 3
ds_iter2 = dataset2.create_tuple_iterator(num_epochs=epochs)
for _ in range(epochs):
count = 0
expected_data = []
for data in ds_iter2:
count += 1
expected_data.append(data)
assert count == 5
check_result(expected_data, ['a', 'c', 'e', 'g', 'i'])
# sampler with larger index ids
class MySampler3():
def __init__(self):
self.index_ids = [3, 4, 3, 2, 0, 11, 5, 5, 5, 9, 1, 11, 11, 11, 11, 8]
def __getitem__(self, index):
return self.index_ids[index]
def __len__(self):
return len(self.index_ids)
dataset3 = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler3())
count = 0
expected_data = []
for data in dataset3.create_tuple_iterator(num_epochs=1):
count += 1
expected_data.append(data)
assert count == 16
check_result(expected_data, ['d', 'e', 'd', 'c', 'a', 'l', 'f', 'f', 'f', 'j', 'b', 'l', 'l', 'l', 'l', 'i'])
epochs = 3
ds_iter3 = dataset3.create_tuple_iterator(num_epochs=epochs)
for _ in range(epochs):
count = 0
expected_data = []
for data in ds_iter3:
count += 1
expected_data.append(data)
assert count == 16
check_result(expected_data, ['d', 'e', 'd', 'c', 'a', 'l', 'f', 'f', 'f', 'j', 'b', 'l', 'l', 'l', 'l', 'i'])
if __name__ == '__main__':
test_sequential_sampler(True)
test_random_sampler(True)
@ -402,3 +569,5 @@ if __name__ == '__main__':
test_add_sampler_invalid_input()
test_distributed_sampler_invalid_offset()
test_sampler_list()
test_sampler_when_less_and_larger_index_ids()
test_sampler_with_getitem_method()