!1638 fixed rounding edge case in split

Merge pull request !1638 from Peilin/splitOp-after-testing
This commit is contained in:
mindspore-ci-bot 2020-05-30 05:31:33 +08:00 committed by Gitee
commit 7878743400
2 changed files with 71 additions and 10 deletions

View File

@ -609,7 +609,20 @@ class Dataset:
absolute_sizes.append(absolute_size) absolute_sizes.append(absolute_size)
absolute_sizes_sum = sum(absolute_sizes) absolute_sizes_sum = sum(absolute_sizes)
if absolute_sizes_sum != dataset_size:
# if we still need more rows, give them to the first split.
# if we have too many rows, remove the extras from the first split that has
# enough rows.
size_difference = dataset_size - absolute_sizes_sum
if size_difference > 0:
absolute_sizes[0] += size_difference
else:
for i, _ in enumerate(absolute_sizes):
if absolute_sizes[i] + size_difference > 0:
absolute_sizes[i] += size_difference
break
if sum(absolute_sizes) != dataset_size:
raise RuntimeError("sum of calculated split sizes {} is not equal to dataset size {}." raise RuntimeError("sum of calculated split sizes {} is not equal to dataset size {}."
.format(absolute_sizes_sum, dataset_size)) .format(absolute_sizes_sum, dataset_size))
@ -629,10 +642,15 @@ class Dataset:
provided, the dataset will be split into n datasets of size s1, size s2, , size sn provided, the dataset will be split into n datasets of size s1, size s2, , size sn
respectively. If the sum of all sizes does not equal the original dataset size, an respectively. If the sum of all sizes does not equal the original dataset size, an
an error will occur. an error will occur.
If a list of floats [f1, f2, , fn] is provided, the dataset will be split into n If a list of floats [f1, f2, , fn] is provided, all floats must be between 0 and 1
Datasets of size f1*K, f2*K, , fn*K (rounded to nearest integer) where K is the size and must sum to 1, otherwise an error will occur. The dataset will be split into n
of the original dataset. If after rounding, any size equals 0, an error will occur. Datasets of size round(f1*K), round(f2*K), , round(fn*K) where K is the size of the
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur. original dataset.
If after rounding:
-Any size equals 0, an error will occur.
-The sum of split sizes < K, the difference will be added to the first split.
-The sum of split sizes > K, the difference will be removed from the first large
enough split such that it will have atleast 1 row after removing the difference.
randomize (bool, optional): determines whether or not to split the data randomly (default=True). randomize (bool, optional): determines whether or not to split the data randomly (default=True).
If true, the data will be randomly split. Otherwise, each split will be created with If true, the data will be randomly split. Otherwise, each split will be created with
consecutive rows from the dataset. consecutive rows from the dataset.
@ -1224,7 +1242,7 @@ class MappableDataset(SourceDataset):
>>> data.use_sampler(new_sampler) >>> data.use_sampler(new_sampler)
""" """
if new_sampler is None: if new_sampler is None:
raise TypeError("Input sampler could not be None.") raise TypeError("Input sampler can not be None.")
if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)): if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
raise TypeError("Input sampler is not an instance of a sampler.") raise TypeError("Input sampler is not an instance of a sampler.")
@ -1259,10 +1277,15 @@ class MappableDataset(SourceDataset):
provided, the dataset will be split into n datasets of size s1, size s2, , size sn provided, the dataset will be split into n datasets of size s1, size s2, , size sn
respectively. If the sum of all sizes does not equal the original dataset size, an respectively. If the sum of all sizes does not equal the original dataset size, an
an error will occur. an error will occur.
If a list of floats [f1, f2, , fn] is provided, the dataset will be split into n If a list of floats [f1, f2, , fn] is provided, all floats must be between 0 and 1
Datasets of size f1*K, f2*K, , fn*K (rounded to nearest integer) where K is the size and must sum to 1, otherwise an error will occur. The dataset will be split into n
of the original dataset. If after rounding, any size equals 0, an error will occur. Datasets of size round(f1*K), round(f2*K), , round(fn*K) where K is the size of the
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur. original dataset.
If after rounding:
-Any size equals 0, an error will occur.
-The sum of split sizes < K, the difference will be added to the first split.
-The sum of split sizes > K, the difference will be removed from the first large
enough split such that it will have atleast 1 row after removing the difference.
randomize (bool, optional): determines whether or not to split the data randomly (default=True). randomize (bool, optional): determines whether or not to split the data randomly (default=True).
If true, the data will be randomly split. Otherwise, each split will be created with If true, the data will be randomly split. Otherwise, each split will be created with
consecutive rows from the dataset. consecutive rows from the dataset.

View File

@ -554,6 +554,43 @@ def test_mappable_multi_split():
assert s2_output == [2] assert s2_output == [2]
def test_rounding():
d = ds.ManifestDataset(manifest_file, shuffle=False)
# under rounding
s1, s2 = d.split([0.5, 0.5], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1_output == [0, 1, 2]
assert s2_output == [3, 4]
# over rounding
s1, s2, s3 = d.split([0.15, 0.55, 0.3], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s3_output = []
for item in s3.create_dict_iterator():
s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1_output == [0]
assert s2_output == [1, 2]
assert s3_output == [3, 4]
if __name__ == '__main__': if __name__ == '__main__':
test_unmappable_invalid_input() test_unmappable_invalid_input()
test_unmappable_split() test_unmappable_split()
@ -569,3 +606,4 @@ if __name__ == '__main__':
test_mappable_sharding() test_mappable_sharding()
test_mappable_get_dataset_size() test_mappable_get_dataset_size()
test_mappable_multi_split() test_mappable_multi_split()
test_rounding()