diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index e3a76fe72b7..e9ea14fd82f 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -609,7 +609,20 @@ class Dataset: absolute_sizes.append(absolute_size) absolute_sizes_sum = sum(absolute_sizes) - if absolute_sizes_sum != dataset_size: + + # if we still need more rows, give them to the first split. + # if we have too many rows, remove the extras from the first split that has + # enough rows. + size_difference = dataset_size - absolute_sizes_sum + if size_difference > 0: + absolute_sizes[0] += size_difference + else: + for i, _ in enumerate(absolute_sizes): + if absolute_sizes[i] + size_difference > 0: + absolute_sizes[i] += size_difference + break + + if sum(absolute_sizes) != dataset_size: raise RuntimeError("sum of calculated split sizes {} is not equal to dataset size {}." .format(absolute_sizes_sum, dataset_size)) @@ -629,10 +642,15 @@ class Dataset: provided, the dataset will be split into n datasets of size s1, size s2, …, size sn respectively. If the sum of all sizes does not equal the original dataset size, an an error will occur. - If a list of floats [f1, f2, …, fn] is provided, the dataset will be split into n - Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size - of the original dataset. If after rounding, any size equals 0, an error will occur. - All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur. + If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1 + and must sum to 1, otherwise an error will occur. The dataset will be split into n + Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the + original dataset. + If after rounding: + -Any size equals 0, an error will occur. + -The sum of split sizes < K, the difference will be added to the first split. + -The sum of split sizes > K, the difference will be removed from the first large + enough split such that it will have atleast 1 row after removing the difference. randomize (bool, optional): determines whether or not to split the data randomly (default=True). If true, the data will be randomly split. Otherwise, each split will be created with consecutive rows from the dataset. @@ -1224,7 +1242,7 @@ class MappableDataset(SourceDataset): >>> data.use_sampler(new_sampler) """ if new_sampler is None: - raise TypeError("Input sampler could not be None.") + raise TypeError("Input sampler can not be None.") if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)): raise TypeError("Input sampler is not an instance of a sampler.") @@ -1259,10 +1277,15 @@ class MappableDataset(SourceDataset): provided, the dataset will be split into n datasets of size s1, size s2, …, size sn respectively. If the sum of all sizes does not equal the original dataset size, an an error will occur. - If a list of floats [f1, f2, …, fn] is provided, the dataset will be split into n - Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size - of the original dataset. If after rounding, any size equals 0, an error will occur. - All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur. + If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1 + and must sum to 1, otherwise an error will occur. The dataset will be split into n + Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the + original dataset. + If after rounding: + -Any size equals 0, an error will occur. + -The sum of split sizes < K, the difference will be added to the first split. + -The sum of split sizes > K, the difference will be removed from the first large + enough split such that it will have atleast 1 row after removing the difference. randomize (bool, optional): determines whether or not to split the data randomly (default=True). If true, the data will be randomly split. Otherwise, each split will be created with consecutive rows from the dataset. diff --git a/tests/ut/python/dataset/test_split.py b/tests/ut/python/dataset/test_split.py index da772ed581c..b904e2e0168 100644 --- a/tests/ut/python/dataset/test_split.py +++ b/tests/ut/python/dataset/test_split.py @@ -554,6 +554,43 @@ def test_mappable_multi_split(): assert s2_output == [2] +def test_rounding(): + d = ds.ManifestDataset(manifest_file, shuffle=False) + + # under rounding + s1, s2 = d.split([0.5, 0.5], randomize=False) + + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + assert s1_output == [0, 1, 2] + assert s2_output == [3, 4] + + # over rounding + s1, s2, s3 = d.split([0.15, 0.55, 0.3], randomize=False) + + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s3_output = [] + for item in s3.create_dict_iterator(): + s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + assert s1_output == [0] + assert s2_output == [1, 2] + assert s3_output == [3, 4] + + if __name__ == '__main__': test_unmappable_invalid_input() test_unmappable_split() @@ -569,3 +606,4 @@ if __name__ == '__main__': test_mappable_sharding() test_mappable_get_dataset_size() test_mappable_multi_split() + test_rounding()