forked from OSSInnovation/mindspore
set dataset_size in generator when source has len
This commit is contained in:
parent
b57d4ea2f3
commit
3b42c360b6
|
@ -3157,6 +3157,9 @@ class GeneratorDataset(MappableDataset):
|
|||
self.column_names.append(col["name"])
|
||||
self.column_types.append(DataType(col["type"]))
|
||||
|
||||
if source is not None and hasattr(source, "__len__"):
|
||||
self._dataset_size = len(source)
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
args["source"] = self.source
|
||||
|
@ -3177,6 +3180,7 @@ class GeneratorDataset(MappableDataset):
|
|||
return self._dataset_size
|
||||
if self._dataset_size is None:
|
||||
return None
|
||||
|
||||
return min(rows_from_sampler, self._dataset_size)
|
||||
|
||||
# manually set dataset_size as a temporary solution.
|
||||
|
|
Loading…
Reference in New Issue