Repair that split can be called normally when the GeneratorDataset input is an iterator and synchronize to r2.0.0-alpha

This commit is contained in:
liu-yongqi-63 2023-01-16 15:33:19 +08:00
parent 2ad878fe9b
commit 1d4d150fad
3 changed files with 71 additions and 1 deletions

View File

@ -75,6 +75,7 @@
"mindspore/mindspore/python/mindspore/dataset/engine/__init__.py" "redefined-builtin"
"mindspore/mindspore/python/mindspore/dataset/engine/datasets.py" "redefined-builtin"
"mindspore/mindspore/python/mindspore/dataset/engine/datasets.py" "broad-except"
"mindspore/mindspore/python/mindspore/dataset/engine/datasets_user_defined.py" "bad-super-call"
"mindspore/mindspore/python/mindspore/dataset/engine/graphdata.py" "super-init-not-called"
"mindspore/mindspore/python/mindspore/dataset/transforms/py_transforms_util.py" "broad-except"

View File

@ -643,7 +643,7 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id)
if isinstance(source, builtins.zip):
# Although zip is iteratable, it does not have the feature of repeated iteration, so pass it to the array.
# Although zip is iterable, it does not have the feature of repeated iteration, so pass it to the array.
self.source = [item for item in source]
else:
self.source = source
@ -735,6 +735,13 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
return self.sampler.is_sharded()
return False
def split(self, sizes, randomize=True):
if hasattr(self.source, "__getitem__"):
# If the source has __getitem__ attribute, call the split method of MappableDataset.
# Otherwise, call the split method of Dataset.
return super().split(sizes, randomize)
return super(MappableDataset, self).split(sizes, randomize)
def parse(self, children=None):
if self.schema is None:
return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len,

View File

@ -2208,6 +2208,65 @@ def test_generator_with_single_numpy_with_yield():
assert count == 20
def test_generator_split_with_yield():
"""
Feature: GeneratorDataset
Description: When GeneratorDataset calls split, it can be split if the input is in yield mode
Expectation: The dataset is processed as expected
"""
dataset = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False)
dataset_train, dataset_val = dataset.split([0.8, 0.2])
assert dataset_train.get_dataset_size() == 51
assert dataset_val.get_dataset_size() == 13
def test_generator_split_with_getitem():
"""
Feature: GeneratorDataset
Description: When GeneratorDataset calls split, it can be split if the input is in getitem mode
Expectation: The dataset is processed as expected
"""
dataset_generator = DatasetGenerator()
dataset = ds.GeneratorDataset(dataset_generator, ["data"], shuffle=False)
dataset_train, dataset_val = dataset.split([0.8, 0.2])
assert dataset_train.get_dataset_size() == 8
assert dataset_val.get_dataset_size() == 2
def test_generator_split_with_next():
"""
Feature: GeneratorDataset
Description: When GeneratorDataset calls split, it can be split if the input is in next mode
Expectation: The dataset is processed as expected
"""
class GetDatasetGenerator:
def __init__(self, data):
self.__data = data
self.__count = 0
def __next__(self):
if self.__count >= 10:
raise StopIteration
self.__count += 1
return self.__data
def __iter__(self):
self.__count = 0
return self
def __len__(self):
return 10
data_tuple = (np.array([[1, 2, 3], [2, 3, 4]]),)
dataset_generator = GetDatasetGenerator(data_tuple)
dataset = ds.GeneratorDataset(dataset_generator, ["data"], shuffle=False)
dataset_train, dataset_val = dataset.split([0.8, 0.2])
assert dataset_train.get_dataset_size() == 8
assert dataset_val.get_dataset_size() == 2
if __name__ == "__main__":
test_generator_0()
test_generator_1()
@ -2261,3 +2320,6 @@ if __name__ == "__main__":
test_generator_with_single_numpy_with_yield()
test_generator_with_seed_5489_when_dist()
test_generator_with_set_seed_when_dist()
test_generator_split_with_yield()
test_generator_split_with_getitem()
test_generator_split_with_next()