From 5469be2a97f10406660dd9521cf474d17263542d Mon Sep 17 00:00:00 2001 From: Peilin Wang Date: Mon, 25 May 2020 17:19:08 -0400 Subject: [PATCH] fixed bug for split, RandomSampler and some other cleanup add another test case typo merge conflict another PR changed testing behavior, updated test cases in this commit added input check for use_sampler addressed code review comments fixed pylint, not related to my changes fixed edge case of rounding in getting split sizes fix pylint --- mindspore/dataset/engine/datasets.py | 43 ++++++++++++++++++++------- tests/ut/python/dataset/test_split.py | 38 +++++++++++++++++++++++ 2 files changed, 71 insertions(+), 10 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 580e32c4021..3d76d775a02 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -609,7 +609,20 @@ class Dataset: absolute_sizes.append(absolute_size) absolute_sizes_sum = sum(absolute_sizes) - if absolute_sizes_sum != dataset_size: + + # if we still need more rows, give them to the first split. + # if we have too many rows, remove the extras from the first split that has + # enough rows. + size_difference = dataset_size - absolute_sizes_sum + if size_difference > 0: + absolute_sizes[0] += size_difference + else: + for i, _ in enumerate(absolute_sizes): + if absolute_sizes[i] + size_difference > 0: + absolute_sizes[i] += size_difference + break + + if sum(absolute_sizes) != dataset_size: raise RuntimeError("sum of calculated split sizes {} is not equal to dataset size {}." .format(absolute_sizes_sum, dataset_size)) @@ -629,10 +642,15 @@ class Dataset: provided, the dataset will be split into n datasets of size s1, size s2, …, size sn respectively. If the sum of all sizes does not equal the original dataset size, an an error will occur. - If a list of floats [f1, f2, …, fn] is provided, the dataset will be split into n - Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size - of the original dataset. If after rounding, any size equals 0, an error will occur. - All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur. + If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1 + and must sum to 1, otherwise an error will occur. The dataset will be split into n + Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the + original dataset. + If after rounding: + -Any size equals 0, an error will occur. + -The sum of split sizes < K, the difference will be added to the first split. + -The sum of split sizes > K, the difference will be removed from the first large + enough split such that it will have atleast 1 row after removing the difference. randomize (bool, optional): determines whether or not to split the data randomly (default=True). If true, the data will be randomly split. Otherwise, each split will be created with consecutive rows from the dataset. @@ -1212,7 +1230,7 @@ class MappableDataset(SourceDataset): >>> data.use_sampler(new_sampler) """ if new_sampler is None: - raise TypeError("Input sampler could not be None.") + raise TypeError("Input sampler can not be None.") if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)): raise TypeError("Input sampler is not an instance of a sampler.") @@ -1247,10 +1265,15 @@ class MappableDataset(SourceDataset): provided, the dataset will be split into n datasets of size s1, size s2, …, size sn respectively. If the sum of all sizes does not equal the original dataset size, an an error will occur. - If a list of floats [f1, f2, …, fn] is provided, the dataset will be split into n - Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size - of the original dataset. If after rounding, any size equals 0, an error will occur. - All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur. + If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1 + and must sum to 1, otherwise an error will occur. The dataset will be split into n + Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the + original dataset. + If after rounding: + -Any size equals 0, an error will occur. + -The sum of split sizes < K, the difference will be added to the first split. + -The sum of split sizes > K, the difference will be removed from the first large + enough split such that it will have atleast 1 row after removing the difference. randomize (bool, optional): determines whether or not to split the data randomly (default=True). If true, the data will be randomly split. Otherwise, each split will be created with consecutive rows from the dataset. diff --git a/tests/ut/python/dataset/test_split.py b/tests/ut/python/dataset/test_split.py index da772ed581c..b904e2e0168 100644 --- a/tests/ut/python/dataset/test_split.py +++ b/tests/ut/python/dataset/test_split.py @@ -554,6 +554,43 @@ def test_mappable_multi_split(): assert s2_output == [2] +def test_rounding(): + d = ds.ManifestDataset(manifest_file, shuffle=False) + + # under rounding + s1, s2 = d.split([0.5, 0.5], randomize=False) + + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + assert s1_output == [0, 1, 2] + assert s2_output == [3, 4] + + # over rounding + s1, s2, s3 = d.split([0.15, 0.55, 0.3], randomize=False) + + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s3_output = [] + for item in s3.create_dict_iterator(): + s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + assert s1_output == [0] + assert s2_output == [1, 2] + assert s3_output == [3, 4] + + if __name__ == '__main__': test_unmappable_invalid_input() test_unmappable_split() @@ -569,3 +606,4 @@ if __name__ == '__main__': test_mappable_sharding() test_mappable_get_dataset_size() test_mappable_multi_split() + test_rounding()