forked from mindspore-Ecosystem/mindspore
!4868 fix: concat with none sample dataset
Merge pull request !4868 from guozhijian/fix_concat_with_zero_dataset
This commit is contained in:
commit
bccb92adf7
|
@ -2310,6 +2310,7 @@ class ConcatDataset(DatasetOp):
|
|||
|
||||
Raises:
|
||||
TypeError: If dataset is not an instance of Dataset.
|
||||
ValueError: If there is no samples in the one of the datasets.
|
||||
"""
|
||||
|
||||
def __init__(self, datasets):
|
||||
|
@ -2324,15 +2325,19 @@ class ConcatDataset(DatasetOp):
|
|||
data.parent.append(self)
|
||||
|
||||
self.children_sizes_ = [c.get_dataset_size() for c in self.children]
|
||||
"""
|
||||
_children_flag_and_nums: A list of pair<int ,int>.The first element of pair is flag that characterizes
|
||||
whether the data set is mappable. The second element of pair is length of the dataset
|
||||
"""
|
||||
child_index = 0
|
||||
for item in self.children_sizes_:
|
||||
if item == 0:
|
||||
raise ValueError("There is no samples in the %dth dataset. Please make sure there are "
|
||||
"valid samples in the dataset" % child_index)
|
||||
child_index += 1
|
||||
|
||||
# _children_flag_and_nums: A list of pair<int ,int>.The first element of pair is flag that characterizes
|
||||
# whether the data set is mappable. The second element of pair is length of the dataset
|
||||
self._children_flag_and_nums = []
|
||||
"""
|
||||
_children_start_end_index_: A list of pair<int ,int>.The elements of pair are used to characterize
|
||||
the valid position of the dataset corresponding to the subscript when sampling
|
||||
"""
|
||||
|
||||
# _children_start_end_index_: A list of pair<int ,int>.The elements of pair are used to characterize
|
||||
# the valid position of the dataset corresponding to the subscript when sampling
|
||||
self._children_start_end_index_ = []
|
||||
for index, child in enumerate(self.children):
|
||||
tem_list = [-1, -1]
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from io import BytesIO
|
||||
import copy
|
||||
import os
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -412,6 +413,46 @@ def test_Mindrecord_Padded(remove_mindrecord_file):
|
|||
result_list.append(tem_list)
|
||||
assert result_list == verify_list
|
||||
|
||||
def test_clue_padded_and_skip_with_0_samples():
|
||||
"""
|
||||
Test num_samples param of CLUE dataset
|
||||
"""
|
||||
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
|
||||
|
||||
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
|
||||
count = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
count += 1
|
||||
assert count == 3
|
||||
|
||||
data_copy1 = copy.deepcopy(data)
|
||||
|
||||
sample = {"label": np.array(1, np.string_),
|
||||
"sentence1": np.array(1, np.string_),
|
||||
"sentence2": np.array(1, np.string_)}
|
||||
samples = [sample]
|
||||
padded_ds = ds.PaddedDataset(samples)
|
||||
dataset = data + padded_ds
|
||||
testsampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None)
|
||||
dataset.use_sampler(testsampler)
|
||||
assert dataset.get_dataset_size() == 2
|
||||
count = 0
|
||||
for data in dataset.create_dict_iterator():
|
||||
count += 1
|
||||
assert count == 2
|
||||
|
||||
dataset = dataset.skip(count=2) # dataset2 has none samples
|
||||
count = 0
|
||||
for data in dataset.create_dict_iterator():
|
||||
count += 1
|
||||
assert count == 0
|
||||
|
||||
with pytest.raises(ValueError, match="There is no samples in the "):
|
||||
dataset = dataset.concat(data_copy1)
|
||||
count = 0
|
||||
for data in dataset.create_dict_iterator():
|
||||
count += 1
|
||||
assert count == 2
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_TFRecord_Padded()
|
||||
|
|
Loading…
Reference in New Issue