diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc index 978d480aa5e..35b03726cda 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc @@ -227,7 +227,9 @@ Status GeneratorOp::operator()() { // Restore exception to python e.restore(); - if (num_rows_sampled != -1 && num_rows_sampled != generator_counter_) { + + // Check whether the number of samples is sufficient only when the first epoch + if (num_rows_sampled != -1 && num_rows_sampled != generator_counter_ && op_current_epochs_ == 0) { if (generator_counter_ == 0) { std::string msg = "Unable to fetch data from GeneratorDataset, try iterate the source function of GeneratorDataset or check" diff --git a/mindspore/python/mindspore/dataset/engine/datasets_user_defined.py b/mindspore/python/mindspore/dataset/engine/datasets_user_defined.py index 23f78679f55..badedaad6f5 100644 --- a/mindspore/python/mindspore/dataset/engine/datasets_user_defined.py +++ b/mindspore/python/mindspore/dataset/engine/datasets_user_defined.py @@ -678,6 +678,10 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset): if hasattr(self.source, "__len__"): self.source_len = len(self.source) + # if user defined sampler, update the self.source_len + if isinstance(self.sampler, samplers.Sampler) or hasattr(self.sampler, "__iter__"): + self.source_len = len(list(sampler)) + self.max_rowsize = max_rowsize self.sample_fn = None diff --git a/tests/ut/python/dataset/test_sampler.py b/tests/ut/python/dataset/test_sampler.py index 3214593b385..11ecb870786 100644 --- a/tests/ut/python/dataset/test_sampler.py +++ b/tests/ut/python/dataset/test_sampler.py @@ -390,6 +390,173 @@ def test_sampler_list(): msg="Type of indices element must be int, but got list[0]: [1 2], type: .") +def check_result(expected, result): + for index, item in enumerate(result): + assert str(expected[index][0]) == item + + +def test_sampler_when_less_and_larger_index_ids(): + """ + Feature: Sampler op + Description: Test sampler with less and larger index ids + Expectation: success + """ + + # sampler with less index ids + class MySampler(): + def __iter__(self): + for i in range(0, 10, 2): + yield i + + np_data = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l'] + + dataset = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler()) + count = 0 + expected_data = [] + for data in dataset.create_tuple_iterator(num_epochs=1): + count += 1 + expected_data.append(data) + assert count == 5 + check_result(expected_data, ['a', 'c', 'e', 'g', 'i']) + + epochs = 3 + ds_iter = dataset.create_tuple_iterator(num_epochs=epochs) + for _ in range(epochs): + count = 0 + expected_data = [] + for data in ds_iter: + count += 1 + expected_data.append(data) + assert count == 5 + check_result(expected_data, ['a', 'c', 'e', 'g', 'i']) + + # sampler with larger index ids + index = [3, 4, 3, 2, 0, 11, 5, 5, 5, 9, 1, 11, 11, 11, 11, 8] + class MySampler2(): + def __iter__(self): + for i in index: + yield i + + dataset2 = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler2()) + count = 0 + expected_data = [] + for data in dataset2.create_tuple_iterator(num_epochs=1): + count += 1 + expected_data.append(data) + assert count == 16 + check_result(expected_data, ['d', 'e', 'd', 'c', 'a', 'l', 'f', 'f', 'f', 'j', 'b', 'l', 'l', 'l', 'l', 'i']) + + epochs = 3 + ds_iter2 = dataset2.create_tuple_iterator(num_epochs=epochs) + for _ in range(epochs): + count = 0 + expected_data = [] + for data in ds_iter2: + count += 1 + expected_data.append(data) + assert count == 16 + check_result(expected_data, ['d', 'e', 'd', 'c', 'a', 'l', 'f', 'f', 'f', 'j', 'b', 'l', 'l', 'l', 'l', 'i']) + + +def test_sampler_with_getitem_method(): + """ + Feature: Sampler op + Description: Test sampler with __getitem__ method + Expectation: success + """ + + # sampler with equal index ids + class MySampler(): + def __init__(self): + self.index_ids = [3, 8, 7, 2, 0, 9, 11, 4, 5, 1, 6, 10] + def __getitem__(self, index): + return self.index_ids[index] + def __len__(self): + return len(self.index_ids) + + np_data = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l'] + + dataset = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler()) + count = 0 + expected_data = [] + for data in dataset.create_tuple_iterator(num_epochs=1): + count += 1 + expected_data.append(data) + assert count == 12 + check_result(expected_data, ['d', 'i', 'h', 'c', 'a', 'j', 'l', 'e', 'f', 'b', 'g', 'k']) + + epochs = 3 + ds_iter = dataset.create_tuple_iterator(num_epochs=epochs) + for _ in range(epochs): + count = 0 + expected_data = [] + for data in ds_iter: + count += 1 + expected_data.append(data) + assert count == 12 + check_result(expected_data, ['d', 'i', 'h', 'c', 'a', 'j', 'l', 'e', 'f', 'b', 'g', 'k']) + + # sampler with less index ids + class MySampler2(): + def __init__(self): + self.index_ids = [0, 2, 4, 6, 8] + def __getitem__(self, index): + return self.index_ids[index] + def __len__(self): + return len(self.index_ids) + + np_data = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l'] + + dataset2 = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler2()) + count = 0 + expected_data = [] + for data in dataset2.create_tuple_iterator(num_epochs=1): + count += 1 + expected_data.append(data) + assert count == 5 + check_result(expected_data, ['a', 'c', 'e', 'g', 'i']) + + epochs = 3 + ds_iter2 = dataset2.create_tuple_iterator(num_epochs=epochs) + for _ in range(epochs): + count = 0 + expected_data = [] + for data in ds_iter2: + count += 1 + expected_data.append(data) + assert count == 5 + check_result(expected_data, ['a', 'c', 'e', 'g', 'i']) + + # sampler with larger index ids + class MySampler3(): + def __init__(self): + self.index_ids = [3, 4, 3, 2, 0, 11, 5, 5, 5, 9, 1, 11, 11, 11, 11, 8] + def __getitem__(self, index): + return self.index_ids[index] + def __len__(self): + return len(self.index_ids) + + dataset3 = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler3()) + count = 0 + expected_data = [] + for data in dataset3.create_tuple_iterator(num_epochs=1): + count += 1 + expected_data.append(data) + assert count == 16 + check_result(expected_data, ['d', 'e', 'd', 'c', 'a', 'l', 'f', 'f', 'f', 'j', 'b', 'l', 'l', 'l', 'l', 'i']) + + epochs = 3 + ds_iter3 = dataset3.create_tuple_iterator(num_epochs=epochs) + for _ in range(epochs): + count = 0 + expected_data = [] + for data in ds_iter3: + count += 1 + expected_data.append(data) + assert count == 16 + check_result(expected_data, ['d', 'e', 'd', 'c', 'a', 'l', 'f', 'f', 'f', 'j', 'b', 'l', 'l', 'l', 'l', 'i']) + + if __name__ == '__main__': test_sequential_sampler(True) test_random_sampler(True) @@ -402,3 +569,5 @@ if __name__ == '__main__': test_add_sampler_invalid_input() test_distributed_sampler_invalid_offset() test_sampler_list() + test_sampler_when_less_and_larger_index_ids() + test_sampler_with_getitem_method()