forked from mindspore-Ecosystem/mindspore
!46531 enhance dataset sampler for r2.0_alpha
Merge pull request !46531 from guozhijian/update_sampler_r2.0_alpha
This commit is contained in:
commit
3d69682c7a
|
@ -227,7 +227,9 @@ Status GeneratorOp::operator()() {
|
|||
|
||||
// Restore exception to python
|
||||
e.restore();
|
||||
if (num_rows_sampled != -1 && num_rows_sampled != generator_counter_) {
|
||||
|
||||
// Check whether the number of samples is sufficient only when the first epoch
|
||||
if (num_rows_sampled != -1 && num_rows_sampled != generator_counter_ && op_current_epochs_ == 0) {
|
||||
if (generator_counter_ == 0) {
|
||||
std::string msg =
|
||||
"Unable to fetch data from GeneratorDataset, try iterate the source function of GeneratorDataset or check"
|
||||
|
|
|
@ -678,6 +678,10 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
|
|||
if hasattr(self.source, "__len__"):
|
||||
self.source_len = len(self.source)
|
||||
|
||||
# if user defined sampler, update the self.source_len
|
||||
if isinstance(self.sampler, samplers.Sampler) or hasattr(self.sampler, "__iter__"):
|
||||
self.source_len = len(list(sampler))
|
||||
|
||||
self.max_rowsize = max_rowsize
|
||||
self.sample_fn = None
|
||||
|
||||
|
|
|
@ -390,6 +390,173 @@ def test_sampler_list():
|
|||
msg="Type of indices element must be int, but got list[0]: [1 2], type: <class 'numpy.ndarray'>.")
|
||||
|
||||
|
||||
def check_result(expected, result):
|
||||
for index, item in enumerate(result):
|
||||
assert str(expected[index][0]) == item
|
||||
|
||||
|
||||
def test_sampler_when_less_and_larger_index_ids():
|
||||
"""
|
||||
Feature: Sampler op
|
||||
Description: Test sampler with less and larger index ids
|
||||
Expectation: success
|
||||
"""
|
||||
|
||||
# sampler with less index ids
|
||||
class MySampler():
|
||||
def __iter__(self):
|
||||
for i in range(0, 10, 2):
|
||||
yield i
|
||||
|
||||
np_data = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l']
|
||||
|
||||
dataset = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler())
|
||||
count = 0
|
||||
expected_data = []
|
||||
for data in dataset.create_tuple_iterator(num_epochs=1):
|
||||
count += 1
|
||||
expected_data.append(data)
|
||||
assert count == 5
|
||||
check_result(expected_data, ['a', 'c', 'e', 'g', 'i'])
|
||||
|
||||
epochs = 3
|
||||
ds_iter = dataset.create_tuple_iterator(num_epochs=epochs)
|
||||
for _ in range(epochs):
|
||||
count = 0
|
||||
expected_data = []
|
||||
for data in ds_iter:
|
||||
count += 1
|
||||
expected_data.append(data)
|
||||
assert count == 5
|
||||
check_result(expected_data, ['a', 'c', 'e', 'g', 'i'])
|
||||
|
||||
# sampler with larger index ids
|
||||
index = [3, 4, 3, 2, 0, 11, 5, 5, 5, 9, 1, 11, 11, 11, 11, 8]
|
||||
class MySampler2():
|
||||
def __iter__(self):
|
||||
for i in index:
|
||||
yield i
|
||||
|
||||
dataset2 = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler2())
|
||||
count = 0
|
||||
expected_data = []
|
||||
for data in dataset2.create_tuple_iterator(num_epochs=1):
|
||||
count += 1
|
||||
expected_data.append(data)
|
||||
assert count == 16
|
||||
check_result(expected_data, ['d', 'e', 'd', 'c', 'a', 'l', 'f', 'f', 'f', 'j', 'b', 'l', 'l', 'l', 'l', 'i'])
|
||||
|
||||
epochs = 3
|
||||
ds_iter2 = dataset2.create_tuple_iterator(num_epochs=epochs)
|
||||
for _ in range(epochs):
|
||||
count = 0
|
||||
expected_data = []
|
||||
for data in ds_iter2:
|
||||
count += 1
|
||||
expected_data.append(data)
|
||||
assert count == 16
|
||||
check_result(expected_data, ['d', 'e', 'd', 'c', 'a', 'l', 'f', 'f', 'f', 'j', 'b', 'l', 'l', 'l', 'l', 'i'])
|
||||
|
||||
|
||||
def test_sampler_with_getitem_method():
|
||||
"""
|
||||
Feature: Sampler op
|
||||
Description: Test sampler with __getitem__ method
|
||||
Expectation: success
|
||||
"""
|
||||
|
||||
# sampler with equal index ids
|
||||
class MySampler():
|
||||
def __init__(self):
|
||||
self.index_ids = [3, 8, 7, 2, 0, 9, 11, 4, 5, 1, 6, 10]
|
||||
def __getitem__(self, index):
|
||||
return self.index_ids[index]
|
||||
def __len__(self):
|
||||
return len(self.index_ids)
|
||||
|
||||
np_data = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l']
|
||||
|
||||
dataset = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler())
|
||||
count = 0
|
||||
expected_data = []
|
||||
for data in dataset.create_tuple_iterator(num_epochs=1):
|
||||
count += 1
|
||||
expected_data.append(data)
|
||||
assert count == 12
|
||||
check_result(expected_data, ['d', 'i', 'h', 'c', 'a', 'j', 'l', 'e', 'f', 'b', 'g', 'k'])
|
||||
|
||||
epochs = 3
|
||||
ds_iter = dataset.create_tuple_iterator(num_epochs=epochs)
|
||||
for _ in range(epochs):
|
||||
count = 0
|
||||
expected_data = []
|
||||
for data in ds_iter:
|
||||
count += 1
|
||||
expected_data.append(data)
|
||||
assert count == 12
|
||||
check_result(expected_data, ['d', 'i', 'h', 'c', 'a', 'j', 'l', 'e', 'f', 'b', 'g', 'k'])
|
||||
|
||||
# sampler with less index ids
|
||||
class MySampler2():
|
||||
def __init__(self):
|
||||
self.index_ids = [0, 2, 4, 6, 8]
|
||||
def __getitem__(self, index):
|
||||
return self.index_ids[index]
|
||||
def __len__(self):
|
||||
return len(self.index_ids)
|
||||
|
||||
np_data = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l']
|
||||
|
||||
dataset2 = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler2())
|
||||
count = 0
|
||||
expected_data = []
|
||||
for data in dataset2.create_tuple_iterator(num_epochs=1):
|
||||
count += 1
|
||||
expected_data.append(data)
|
||||
assert count == 5
|
||||
check_result(expected_data, ['a', 'c', 'e', 'g', 'i'])
|
||||
|
||||
epochs = 3
|
||||
ds_iter2 = dataset2.create_tuple_iterator(num_epochs=epochs)
|
||||
for _ in range(epochs):
|
||||
count = 0
|
||||
expected_data = []
|
||||
for data in ds_iter2:
|
||||
count += 1
|
||||
expected_data.append(data)
|
||||
assert count == 5
|
||||
check_result(expected_data, ['a', 'c', 'e', 'g', 'i'])
|
||||
|
||||
# sampler with larger index ids
|
||||
class MySampler3():
|
||||
def __init__(self):
|
||||
self.index_ids = [3, 4, 3, 2, 0, 11, 5, 5, 5, 9, 1, 11, 11, 11, 11, 8]
|
||||
def __getitem__(self, index):
|
||||
return self.index_ids[index]
|
||||
def __len__(self):
|
||||
return len(self.index_ids)
|
||||
|
||||
dataset3 = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler3())
|
||||
count = 0
|
||||
expected_data = []
|
||||
for data in dataset3.create_tuple_iterator(num_epochs=1):
|
||||
count += 1
|
||||
expected_data.append(data)
|
||||
assert count == 16
|
||||
check_result(expected_data, ['d', 'e', 'd', 'c', 'a', 'l', 'f', 'f', 'f', 'j', 'b', 'l', 'l', 'l', 'l', 'i'])
|
||||
|
||||
epochs = 3
|
||||
ds_iter3 = dataset3.create_tuple_iterator(num_epochs=epochs)
|
||||
for _ in range(epochs):
|
||||
count = 0
|
||||
expected_data = []
|
||||
for data in ds_iter3:
|
||||
count += 1
|
||||
expected_data.append(data)
|
||||
assert count == 16
|
||||
check_result(expected_data, ['d', 'e', 'd', 'c', 'a', 'l', 'f', 'f', 'f', 'j', 'b', 'l', 'l', 'l', 'l', 'i'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sequential_sampler(True)
|
||||
test_random_sampler(True)
|
||||
|
@ -402,3 +569,5 @@ if __name__ == '__main__':
|
|||
test_add_sampler_invalid_input()
|
||||
test_distributed_sampler_invalid_offset()
|
||||
test_sampler_list()
|
||||
test_sampler_when_less_and_larger_index_ids()
|
||||
test_sampler_with_getitem_method()
|
||||
|
|
Loading…
Reference in New Issue