forked from mindspore-Ecosystem/mindspore
Repair that split can be called normally when the GeneratorDataset input is an iterator and synchronize to r2.0.0-alpha
This commit is contained in:
parent
2ad878fe9b
commit
1d4d150fad
|
@ -75,6 +75,7 @@
|
|||
"mindspore/mindspore/python/mindspore/dataset/engine/__init__.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/dataset/engine/datasets.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/dataset/engine/datasets.py" "broad-except"
|
||||
"mindspore/mindspore/python/mindspore/dataset/engine/datasets_user_defined.py" "bad-super-call"
|
||||
"mindspore/mindspore/python/mindspore/dataset/engine/graphdata.py" "super-init-not-called"
|
||||
"mindspore/mindspore/python/mindspore/dataset/transforms/py_transforms_util.py" "broad-except"
|
||||
|
||||
|
|
|
@ -643,7 +643,7 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
|
|||
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
||||
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id)
|
||||
if isinstance(source, builtins.zip):
|
||||
# Although zip is iteratable, it does not have the feature of repeated iteration, so pass it to the array.
|
||||
# Although zip is iterable, it does not have the feature of repeated iteration, so pass it to the array.
|
||||
self.source = [item for item in source]
|
||||
else:
|
||||
self.source = source
|
||||
|
@ -735,6 +735,13 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
|
|||
return self.sampler.is_sharded()
|
||||
return False
|
||||
|
||||
def split(self, sizes, randomize=True):
|
||||
if hasattr(self.source, "__getitem__"):
|
||||
# If the source has __getitem__ attribute, call the split method of MappableDataset.
|
||||
# Otherwise, call the split method of Dataset.
|
||||
return super().split(sizes, randomize)
|
||||
return super(MappableDataset, self).split(sizes, randomize)
|
||||
|
||||
def parse(self, children=None):
|
||||
if self.schema is None:
|
||||
return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len,
|
||||
|
|
|
@ -2208,6 +2208,65 @@ def test_generator_with_single_numpy_with_yield():
|
|||
assert count == 20
|
||||
|
||||
|
||||
def test_generator_split_with_yield():
|
||||
"""
|
||||
Feature: GeneratorDataset
|
||||
Description: When GeneratorDataset calls split, it can be split if the input is in yield mode
|
||||
Expectation: The dataset is processed as expected
|
||||
"""
|
||||
dataset = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False)
|
||||
dataset_train, dataset_val = dataset.split([0.8, 0.2])
|
||||
assert dataset_train.get_dataset_size() == 51
|
||||
assert dataset_val.get_dataset_size() == 13
|
||||
|
||||
|
||||
def test_generator_split_with_getitem():
|
||||
"""
|
||||
Feature: GeneratorDataset
|
||||
Description: When GeneratorDataset calls split, it can be split if the input is in getitem mode
|
||||
Expectation: The dataset is processed as expected
|
||||
"""
|
||||
dataset_generator = DatasetGenerator()
|
||||
dataset = ds.GeneratorDataset(dataset_generator, ["data"], shuffle=False)
|
||||
dataset_train, dataset_val = dataset.split([0.8, 0.2])
|
||||
assert dataset_train.get_dataset_size() == 8
|
||||
assert dataset_val.get_dataset_size() == 2
|
||||
|
||||
|
||||
def test_generator_split_with_next():
|
||||
"""
|
||||
Feature: GeneratorDataset
|
||||
Description: When GeneratorDataset calls split, it can be split if the input is in next mode
|
||||
Expectation: The dataset is processed as expected
|
||||
"""
|
||||
|
||||
class GetDatasetGenerator:
|
||||
def __init__(self, data):
|
||||
self.__data = data
|
||||
self.__count = 0
|
||||
|
||||
def __next__(self):
|
||||
if self.__count >= 10:
|
||||
raise StopIteration
|
||||
|
||||
self.__count += 1
|
||||
return self.__data
|
||||
|
||||
def __iter__(self):
|
||||
self.__count = 0
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return 10
|
||||
|
||||
data_tuple = (np.array([[1, 2, 3], [2, 3, 4]]),)
|
||||
dataset_generator = GetDatasetGenerator(data_tuple)
|
||||
dataset = ds.GeneratorDataset(dataset_generator, ["data"], shuffle=False)
|
||||
dataset_train, dataset_val = dataset.split([0.8, 0.2])
|
||||
assert dataset_train.get_dataset_size() == 8
|
||||
assert dataset_val.get_dataset_size() == 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_generator_0()
|
||||
test_generator_1()
|
||||
|
@ -2261,3 +2320,6 @@ if __name__ == "__main__":
|
|||
test_generator_with_single_numpy_with_yield()
|
||||
test_generator_with_seed_5489_when_dist()
|
||||
test_generator_with_set_seed_when_dist()
|
||||
test_generator_split_with_yield()
|
||||
test_generator_split_with_getitem()
|
||||
test_generator_split_with_next()
|
||||
|
|
Loading…
Reference in New Issue