forked from mindspore-Ecosystem/mindspore
!11854 Support list of IDs as a sampler
From: @hfarahat Reviewed-by: Signed-off-by:
This commit is contained in:
commit
112b5829e7
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -21,7 +21,6 @@ import os
|
|||
import numpy as np
|
||||
|
||||
import mindspore._c_dataengine as cde
|
||||
from ..engine import samplers
|
||||
|
||||
# POS_INT_MIN is used to limit values from starting from 0
|
||||
POS_INT_MIN = 1
|
||||
|
@ -290,8 +289,6 @@ def check_sampler_shuffle_shard_options(param_dict):
|
|||
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
|
||||
num_samples = param_dict.get('num_samples')
|
||||
|
||||
type_check(sampler, (type(None), samplers.BuiltinSampler, samplers.Sampler), "sampler")
|
||||
|
||||
if sampler is not None:
|
||||
if shuffle is not None:
|
||||
raise RuntimeError("sampler and shuffle cannot be specified at the same time.")
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
# Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -2708,7 +2708,7 @@ class ConcatDataset(Dataset):
|
|||
|
||||
self.dataset_size = None
|
||||
|
||||
self._sampler = _select_sampler(None, sampler, None, None, None)
|
||||
self._sampler = samplers.select_sampler(None, sampler, None, None, None)
|
||||
cumulative_samples_nums = 0
|
||||
for index, child in enumerate(self.children):
|
||||
if hasattr(child, 'sampler') and child.sampler.get_num_samples() is not None:
|
||||
|
@ -2990,65 +2990,6 @@ class RangeDataset(MappableDataset):
|
|||
return self.dataset_size
|
||||
|
||||
|
||||
def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, non_mappable=False):
|
||||
"""
|
||||
Create sampler based on user input.
|
||||
|
||||
Args:
|
||||
num_samples (int): Number of samples.
|
||||
input_sampler (Union[Iterable, Sampler]): Sampler from user.
|
||||
shuffle (bool): Shuffle.
|
||||
num_shards (int): Number of shard for sharding.
|
||||
shard_id (int): Shard ID.
|
||||
non_mappable (bool, optional): Indicate if caller is non-mappable dataset for special handling (default=False).
|
||||
|
||||
Returns:
|
||||
Sampler, sampler selected based on user input.
|
||||
"""
|
||||
if non_mappable is True and all(arg is None for arg in [num_samples, shuffle, num_shards, shard_id, input_sampler]):
|
||||
return None
|
||||
|
||||
if input_sampler is not None:
|
||||
# If the user provided a sampler, then it doesn't matter what the other args are because
|
||||
# we are being asked specifically to use the given sampler.
|
||||
# That means the following arguments: num_shards, shard_id, shuffle, num_samples should all
|
||||
# be None. Consider this example:
|
||||
# sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle)
|
||||
# data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1)
|
||||
# In this case, the user has given different sample-related arguments that contradict each other.
|
||||
# To prevent this, only allow the user to manually specify the sampler if those arguments are all None
|
||||
if (isinstance(input_sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
|
||||
samplers.RandomSampler, samplers.SubsetRandomSampler,
|
||||
samplers.WeightedRandomSampler, samplers.Sampler)) and
|
||||
(any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))):
|
||||
raise ValueError(
|
||||
'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'
|
||||
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
|
||||
return input_sampler
|
||||
if shuffle is None:
|
||||
if num_shards is not None:
|
||||
# If shuffle is not specified, sharding enabled, use distributed random sampler
|
||||
shuffle = True
|
||||
return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
|
||||
# If shuffle is not specified, sharding disabled, use random sampler
|
||||
if num_samples is not None:
|
||||
return samplers.RandomSampler(replacement=True, num_samples=num_samples)
|
||||
return samplers.RandomSampler(num_samples=num_samples)
|
||||
if shuffle is True:
|
||||
if num_shards is not None:
|
||||
# If shuffle enabled, sharding enabled, use distributed random sampler
|
||||
return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
|
||||
# If shuffle enabled, sharding disabled, use random sampler
|
||||
if num_samples is not None:
|
||||
return samplers.RandomSampler(replacement=True, num_samples=num_samples)
|
||||
return samplers.RandomSampler(num_samples=num_samples)
|
||||
if num_shards is not None:
|
||||
# If shuffle disabled, sharding enabled, use distributed sequential sampler
|
||||
return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
|
||||
# If shuffle disabled, sharding disabled, use sequential sampler
|
||||
return samplers.SequentialSampler(num_samples=num_samples)
|
||||
|
||||
|
||||
class ImageFolderDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset that reads images from a tree of directories.
|
||||
|
@ -3144,7 +3085,7 @@ class ImageFolderDataset(MappableDataset):
|
|||
super().__init__(num_parallel_workers=num_parallel_workers)
|
||||
|
||||
self.dataset_dir = dataset_dir
|
||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.num_samples = num_samples
|
||||
self.shuffle_level = shuffle
|
||||
self.extensions = replace_none(extensions, [])
|
||||
|
@ -3293,7 +3234,7 @@ class MnistDataset(MappableDataset):
|
|||
|
||||
self.dataset_dir = dataset_dir
|
||||
self.usage = replace_none(usage, "all")
|
||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.num_samples = num_samples
|
||||
self.shuffle_level = shuffle
|
||||
self.num_shards = num_shards
|
||||
|
@ -3386,7 +3327,7 @@ class MindDataset(MappableDataset):
|
|||
samplers.SequentialSampler)) is False:
|
||||
raise ValueError("The sampler is not supported yet.")
|
||||
|
||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.num_samples = num_samples
|
||||
|
||||
self.padded_sample = padded_sample
|
||||
|
@ -3470,27 +3411,6 @@ def _generator_fn(generator, num_samples):
|
|||
yield val
|
||||
|
||||
|
||||
def _py_sampler_fn(sampler, num_samples, dataset):
|
||||
"""
|
||||
Generator function wrapper for mappable dataset with Python sampler.
|
||||
"""
|
||||
if num_samples is not None:
|
||||
sampler_iter = iter(sampler)
|
||||
for _ in range(num_samples):
|
||||
try:
|
||||
idx = next(sampler_iter)
|
||||
except StopIteration:
|
||||
return
|
||||
val = dataset[idx]
|
||||
# convert output tensors to ndarrays
|
||||
yield tuple([np.array(x, copy=False) for x in val])
|
||||
else:
|
||||
for i in sampler:
|
||||
val = dataset[i]
|
||||
# convert output tensors to ndarrays
|
||||
yield tuple([np.array(x, copy=False) for x in val])
|
||||
|
||||
|
||||
def _cpp_sampler_fn(sample_ids, dataset):
|
||||
"""
|
||||
Generator function wrapper for mappable dataset with cpp sampler.
|
||||
|
@ -3518,31 +3438,6 @@ def _cpp_sampler_fn_mp(sample_ids, sample_fn):
|
|||
return sample_fn.process(sample_ids)
|
||||
|
||||
|
||||
def _py_sampler_fn_mp(sampler, num_samples, sample_fn):
|
||||
"""
|
||||
Multiprocessing generator function wrapper for mappable dataset with Python sampler.
|
||||
"""
|
||||
indices = _fetch_py_sampler_indices(sampler, num_samples)
|
||||
return sample_fn.process(indices)
|
||||
|
||||
|
||||
def _fetch_py_sampler_indices(sampler, num_samples):
|
||||
"""
|
||||
Indice fetcher for Python sampler.
|
||||
"""
|
||||
if num_samples is not None:
|
||||
sampler_iter = iter(sampler)
|
||||
ret = []
|
||||
for _ in range(num_samples):
|
||||
try:
|
||||
val = next(sampler_iter)
|
||||
ret.append(val)
|
||||
except StopIteration:
|
||||
break
|
||||
return ret
|
||||
return [i for i in sampler]
|
||||
|
||||
|
||||
def _fill_worker_indices(workers, indices, idx):
|
||||
"""
|
||||
Worker index queue filler, fill worker index queue in round robin order.
|
||||
|
@ -3865,7 +3760,7 @@ class GeneratorDataset(MappableDataset):
|
|||
python_multiprocessing=True):
|
||||
super().__init__(num_parallel_workers=num_parallel_workers)
|
||||
self.source = source
|
||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.num_samples = num_samples
|
||||
self.num_shards = num_shards
|
||||
self.python_multiprocessing = python_multiprocessing
|
||||
|
@ -3912,26 +3807,11 @@ class GeneratorDataset(MappableDataset):
|
|||
if hasattr(self, "__total_batch__"):
|
||||
new_op.__total_batch__ = self.__total_batch__
|
||||
if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
|
||||
if isinstance(new_op.sampler, samplers.BuiltinSampler):
|
||||
if new_op.num_parallel_workers > 1:
|
||||
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
|
||||
new_op.source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn))
|
||||
else:
|
||||
new_op.source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source))
|
||||
if new_op.num_parallel_workers > 1:
|
||||
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
|
||||
new_op.source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn))
|
||||
else:
|
||||
# the sampler provided is not a built-in sampler, it is a list of sample_ids
|
||||
new_op.sample_ids = new_op.sampler
|
||||
# since list of sample_ids are not passed to c++, we need to find the proper len here
|
||||
new_op.source_len = min(self.source_len, len(new_op.sample_ids)) if self.source_len != -1 else len(
|
||||
new_op.sample_ids)
|
||||
new_op.source_len = min(self.source_len,
|
||||
new_op.num_samples) if new_op.num_samples is not None else new_op.source_len
|
||||
new_op.sampler = None
|
||||
if new_op.num_parallel_workers > 1:
|
||||
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
|
||||
new_op.source = (lambda: _py_sampler_fn_mp(new_op.sample_ids, new_op.num_samples, sample_fn))
|
||||
else:
|
||||
new_op.source = (lambda: _py_sampler_fn(new_op.sample_ids, new_op.num_samples, self.source))
|
||||
new_op.source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source))
|
||||
new_op.sample_fn = sample_fn
|
||||
else:
|
||||
try:
|
||||
|
@ -4089,13 +3969,6 @@ class TFRecordDataset(SourceDataset):
|
|||
self.shuffle_level = shuffle
|
||||
self.shuffle_files = True
|
||||
|
||||
# The TF record dataset does not directly support a sampler. It has provided sampling arguments
|
||||
# (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in
|
||||
# the pipeline contains a cache. If there is no cache above it, then this sampler is not used.
|
||||
sampler_shuffle = self.shuffle_files
|
||||
sampler = None
|
||||
self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id,
|
||||
non_mappable=True)
|
||||
self.shard_equal_rows = replace_none(shard_equal_rows, False)
|
||||
|
||||
def get_args(self):
|
||||
|
@ -4231,7 +4104,7 @@ class ManifestDataset(MappableDataset):
|
|||
super().__init__(num_parallel_workers=num_parallel_workers)
|
||||
|
||||
self.dataset_file = dataset_file
|
||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
|
||||
if class_indexing is not None and not isinstance(class_indexing, dict):
|
||||
raise RuntimeError("class_indexing must be a dictionary.")
|
||||
|
@ -4396,7 +4269,7 @@ class Cifar10Dataset(MappableDataset):
|
|||
|
||||
self.dataset_dir = dataset_dir
|
||||
self.usage = replace_none(usage, "all")
|
||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.num_samples = num_samples
|
||||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
|
@ -4535,7 +4408,7 @@ class Cifar100Dataset(MappableDataset):
|
|||
|
||||
self.dataset_dir = dataset_dir
|
||||
self.usage = replace_none(usage, "all")
|
||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.num_samples = num_samples
|
||||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
|
@ -4607,8 +4480,6 @@ class RandomDataset(SourceDataset):
|
|||
super().__init__(num_parallel_workers=num_parallel_workers)
|
||||
self.schema = schema
|
||||
self.columns_list = replace_none(columns_list, [])
|
||||
sampler = None
|
||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id, non_mappable=True)
|
||||
|
||||
self.num_samples = num_samples
|
||||
self.total_rows = total_rows
|
||||
|
@ -4900,7 +4771,7 @@ class VOCDataset(MappableDataset):
|
|||
self.task = replace_none(task, "Segmentation")
|
||||
self.usage = replace_none(usage, "train")
|
||||
self.class_indexing = class_indexing
|
||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.num_samples = num_samples
|
||||
self.decode = replace_none(decode, False)
|
||||
self.shuffle_level = shuffle
|
||||
|
@ -5092,7 +4963,7 @@ class CocoDataset(MappableDataset):
|
|||
self.dataset_dir = dataset_dir
|
||||
self.annotation_file = annotation_file
|
||||
self.task = replace_none(task, "Detection")
|
||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.num_samples = num_samples
|
||||
self.decode = replace_none(decode, False)
|
||||
self.shuffle_level = shuffle
|
||||
|
@ -5224,7 +5095,7 @@ class CelebADataset(MappableDataset):
|
|||
extensions=None, num_samples=None, num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers=num_parallel_workers)
|
||||
self.dataset_dir = dataset_dir
|
||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.num_parallel_workers = num_parallel_workers
|
||||
self.decode = replace_none(decode, False)
|
||||
self.extensions = replace_none(extensions, [])
|
||||
|
@ -5596,12 +5467,7 @@ class CSVDataset(SourceDataset):
|
|||
self.shuffle_files = True
|
||||
|
||||
self.cache = cache
|
||||
# The CSV dataset does not directly support a sampler. It has provided sampling arguments
|
||||
# (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in
|
||||
# the pipeline contains a cache. If there is no cache above it, then this sampler is not used.
|
||||
sampler = None
|
||||
self.sampler = _select_sampler(num_samples, sampler, self.shuffle_files, num_shards, shard_id,
|
||||
non_mappable=True)
|
||||
|
||||
self.num_shards = replace_none(num_shards, 1)
|
||||
self.shard_id = replace_none(shard_id, 0)
|
||||
self.num_samples = replace_none(num_samples, 0)
|
||||
|
@ -5715,13 +5581,6 @@ class TextFileDataset(SourceDataset):
|
|||
self.shard_id = replace_none(shard_id, 0)
|
||||
|
||||
self.cache = cache
|
||||
# The text file dataset does not directly support a sampler. It has provided sampling arguments
|
||||
# (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in
|
||||
# the pipeline contains a cache. If there is no cache above it, then this sampler is not used.
|
||||
sampler_shuffle = self.shuffle_files
|
||||
sampler = None
|
||||
self.sampler = _select_sampler(num_samples, sampler, sampler_shuffle, num_shards, shard_id,
|
||||
non_mappable=True)
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
|
|
|
@ -25,6 +25,82 @@ import mindspore._c_dataengine as cde
|
|||
import mindspore.dataset as ds
|
||||
|
||||
|
||||
def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
||||
"""
|
||||
Create sampler based on user input.
|
||||
|
||||
Args:
|
||||
num_samples (int): Number of samples.
|
||||
input_sampler (Union[Iterable, Sampler]): Sampler from user.
|
||||
shuffle (bool): Shuffle.
|
||||
num_shards (int): Number of shard for sharding.
|
||||
shard_id (int): Shard ID.
|
||||
|
||||
Returns:
|
||||
Sampler, sampler selected based on user input.
|
||||
"""
|
||||
|
||||
def _is_iterable(obj):
|
||||
try:
|
||||
iter(obj)
|
||||
except TypeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_sample_ids_as_list(sampler, number_of_samples=None):
|
||||
if number_of_samples is None:
|
||||
return list(sampler)
|
||||
|
||||
if isinstance(sampler, list):
|
||||
return sampler[:number_of_samples]
|
||||
|
||||
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]
|
||||
|
||||
if input_sampler is not None:
|
||||
# If the user provided a sampler, then it doesn't matter what the other args are because
|
||||
# we are being asked specifically to use the given sampler.
|
||||
# That means the following arguments: num_shards, shard_id, shuffle, num_samples should all
|
||||
# be None. Consider this example:
|
||||
# sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle)
|
||||
# data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1)
|
||||
# In this case, the user has given different sample-related arguments that contradict each other.
|
||||
# To prevent this, only allow the user to manually specify the sampler if those arguments are all None
|
||||
if (isinstance(input_sampler, BuiltinSampler) and
|
||||
(any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))):
|
||||
raise ValueError(
|
||||
'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'
|
||||
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
|
||||
if isinstance(input_sampler, BuiltinSampler):
|
||||
return input_sampler
|
||||
if _is_iterable(input_sampler):
|
||||
return SubsetSampler(_get_sample_ids_as_list(input_sampler, num_samples))
|
||||
if isinstance(input_sampler, int):
|
||||
return [input_sampler]
|
||||
raise ValueError('Unsupported sampler object ({})'.format(input_sampler))
|
||||
if shuffle is None:
|
||||
if num_shards is not None:
|
||||
# If shuffle is not specified, sharding enabled, use distributed random sampler
|
||||
shuffle = True
|
||||
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
|
||||
# If shuffle is not specified, sharding disabled, use random sampler
|
||||
if num_samples is not None:
|
||||
return RandomSampler(replacement=True, num_samples=num_samples)
|
||||
return RandomSampler(num_samples=num_samples)
|
||||
if shuffle is True:
|
||||
if num_shards is not None:
|
||||
# If shuffle enabled, sharding enabled, use distributed random sampler
|
||||
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
|
||||
# If shuffle enabled, sharding disabled, use random sampler
|
||||
if num_samples is not None:
|
||||
return RandomSampler(replacement=True, num_samples=num_samples)
|
||||
return RandomSampler(num_samples=num_samples)
|
||||
if num_shards is not None:
|
||||
# If shuffle disabled, sharding enabled, use distributed sequential sampler
|
||||
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
|
||||
# If shuffle disabled, sharding disabled, use sequential sampler
|
||||
return SequentialSampler(num_samples=num_samples)
|
||||
|
||||
|
||||
class BuiltinSampler:
|
||||
"""
|
||||
Base class for BuiltinSampler.
|
||||
|
|
|
@ -17,6 +17,7 @@ import pytest
|
|||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
from util import dataset_equal
|
||||
|
||||
|
||||
# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
|
||||
|
@ -265,6 +266,15 @@ def test_distributed_sampler_invalid_offset():
|
|||
assert "DistributedSampler: invalid offset: 5, which should be no more than num_shards: 4" in str(info.value)
|
||||
|
||||
|
||||
def test_sampler_list():
|
||||
data1 = ds.ImageFolderDataset("../data/dataset/testPK/data", sampler=[1, 3, 5])
|
||||
data21 = ds.ImageFolderDataset("../data/dataset/testPK/data", shuffle=False).take(2).skip(1)
|
||||
data22 = ds.ImageFolderDataset("../data/dataset/testPK/data", shuffle=False).take(4).skip(3)
|
||||
data23 = ds.ImageFolderDataset("../data/dataset/testPK/data", shuffle=False).take(6).skip(5)
|
||||
|
||||
dataset_equal(data1, data21 + data22 + data23, 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sequential_sampler(True)
|
||||
test_random_sampler(True)
|
||||
|
@ -276,3 +286,4 @@ if __name__ == '__main__':
|
|||
test_sampler_chain()
|
||||
test_add_sampler_invalid_input()
|
||||
test_distributed_sampler_invalid_offset()
|
||||
test_sampler_list()
|
||||
|
|
Loading…
Reference in New Issue