forked from mindspore-Ecosystem/mindspore
fix num_samples in concatDataset
This commit is contained in:
parent
a0e3fd6bf3
commit
16147669a6
|
@ -2438,6 +2438,8 @@ class ConcatDataset(DatasetOp):
|
|||
self._sampler = _select_sampler(None, sampler, None, None, None)
|
||||
cumulative_samples_nums = 0
|
||||
for index, child in enumerate(self.children):
|
||||
if hasattr(child, 'sampler') and child.sampler.get_num_samples() is not None:
|
||||
raise ValueError("The parameter NumSamples of %s is not support to be set!" % (child))
|
||||
|
||||
if isinstance(child, BatchDataset):
|
||||
raise TypeError("The parameter %s of concat should't be BatchDataset!" % (child))
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
{
|
||||
"datasetType": "TF",
|
||||
"numRows": 3,
|
||||
"columns": {
|
||||
"image": {
|
||||
"type": "uint8",
|
||||
|
|
|
@ -213,6 +213,23 @@ def test_raise_error():
|
|||
ds3.use_sampler(testsampler)
|
||||
assert excinfo.type == 'ValueError'
|
||||
|
||||
def test_imagefolder_error():
|
||||
DATA_DIR = "../data/dataset/testPK/data"
|
||||
data = ds.ImageFolderDataset(DATA_DIR, num_samples=14)
|
||||
|
||||
data1 = [{'image': np.zeros(1, np.uint8), 'label': np.array(0, np.int32)},
|
||||
{'image': np.zeros(2, np.uint8), 'label': np.array(1, np.int32)},
|
||||
{'image': np.zeros(3, np.uint8), 'label': np.array(0, np.int32)},
|
||||
{'image': np.zeros(4, np.uint8), 'label': np.array(1, np.int32)},
|
||||
{'image': np.zeros(5, np.uint8), 'label': np.array(0, np.int32)},
|
||||
{'image': np.zeros(6, np.uint8), 'label': np.array(1, np.int32)}]
|
||||
|
||||
data2 = ds.PaddedDataset(data1)
|
||||
data3 = data + data2
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
testsampler = ds.DistributedSampler(num_shards=5, shard_id=4, shuffle=False, num_samples=None)
|
||||
data3.use_sampler(testsampler)
|
||||
assert excinfo.type == 'ValueError'
|
||||
|
||||
def test_imagefolder_padded():
|
||||
DATA_DIR = "../data/dataset/testPK/data"
|
||||
|
|
Loading…
Reference in New Issue