fix num_samples in concatDataset

This commit is contained in:
liyong 2020-09-12 15:02:54 +08:00
parent a0e3fd6bf3
commit 16147669a6
3 changed files with 19 additions and 1 deletions

View File

@ -2438,6 +2438,8 @@ class ConcatDataset(DatasetOp):
self._sampler = _select_sampler(None, sampler, None, None, None)
cumulative_samples_nums = 0
for index, child in enumerate(self.children):
if hasattr(child, 'sampler') and child.sampler.get_num_samples() is not None:
raise ValueError("The parameter NumSamples of %s is not support to be set!" % (child))
if isinstance(child, BatchDataset):
raise TypeError("The parameter %s of concat should't be BatchDataset!" % (child))

View File

@ -1,6 +1,5 @@
{
"datasetType": "TF",
"numRows": 3,
"columns": {
"image": {
"type": "uint8",

View File

@ -213,6 +213,23 @@ def test_raise_error():
ds3.use_sampler(testsampler)
assert excinfo.type == 'ValueError'
def test_imagefolder_error():
DATA_DIR = "../data/dataset/testPK/data"
data = ds.ImageFolderDataset(DATA_DIR, num_samples=14)
data1 = [{'image': np.zeros(1, np.uint8), 'label': np.array(0, np.int32)},
{'image': np.zeros(2, np.uint8), 'label': np.array(1, np.int32)},
{'image': np.zeros(3, np.uint8), 'label': np.array(0, np.int32)},
{'image': np.zeros(4, np.uint8), 'label': np.array(1, np.int32)},
{'image': np.zeros(5, np.uint8), 'label': np.array(0, np.int32)},
{'image': np.zeros(6, np.uint8), 'label': np.array(1, np.int32)}]
data2 = ds.PaddedDataset(data1)
data3 = data + data2
with pytest.raises(ValueError) as excinfo:
testsampler = ds.DistributedSampler(num_shards=5, shard_id=4, shuffle=False, num_samples=None)
data3.use_sampler(testsampler)
assert excinfo.type == 'ValueError'
def test_imagefolder_padded():
DATA_DIR = "../data/dataset/testPK/data"