modify api comment and repair deepcopy resulted total_batch attribute missing problem when sink_size is specified

This commit is contained in:
ms_yan 2020-11-04 16:18:15 +08:00
parent 5887107264
commit ccaa601d25
1 changed files with 8 additions and 3 deletions

View File

@ -2166,6 +2166,8 @@ class MapDataset(DatasetOp):
new_op.operations = self.operations
new_op.dataset_size = self.dataset_size
new_op.callbacks = self.callbacks
if hasattr(self, "__total_batch__"):
new_op.__total_batch__ = self.__total_batch__
return new_op
# Iterator bootstrap will be called on iterator construction.
@ -3640,6 +3642,8 @@ class GeneratorDataset(MappableDataset):
new_op.num_samples = copy.deepcopy(self.num_samples, memodict)
new_op.dataset_size = self.dataset_size
new_op.sampler = copy.deepcopy(self.sampler)
if hasattr(self, "__total_batch__"):
new_op.__total_batch__ = self.__total_batch__
if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
if isinstance(new_op.sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
samplers.RandomSampler, samplers.SubsetRandomSampler,
@ -5705,10 +5709,11 @@ class NumpySlicesDataset(GeneratorDataset):
Args:
data (Union[list, tuple, dict]) Input of given data. Supported data types include: list, tuple, dict and other
NumPy formats. Input data will be sliced along the first dimension and generate additional rows.
Large data is not recommended to be loaded in this way as data is loading into memory.
NumPy formats. Input data will be sliced along the first dimension and generate additional rows, if input is
list, there will be one column in each row, otherwise there tends to be multi columns. Large data is not
recommended to be loaded in this way as data is loading into memory.
column_names (list[str], optional): List of column names of the dataset (default=None). If column_names is not
provided, when data is dict, column_names will be its keys, otherwise it will be like column_1, column_2 ...
provided, when data is dict, column_names will be its keys, otherwise it will be like column_0, column_1 ...
num_samples (int, optional): The number of samples to be included in the dataset (default=None, all images).
num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.