forked from mindspore-Ecosystem/mindspore
!1638 fixed rounding edge case in split
Merge pull request !1638 from Peilin/splitOp-after-testing
This commit is contained in:
commit
7878743400
|
@ -609,7 +609,20 @@ class Dataset:
|
||||||
absolute_sizes.append(absolute_size)
|
absolute_sizes.append(absolute_size)
|
||||||
|
|
||||||
absolute_sizes_sum = sum(absolute_sizes)
|
absolute_sizes_sum = sum(absolute_sizes)
|
||||||
if absolute_sizes_sum != dataset_size:
|
|
||||||
|
# if we still need more rows, give them to the first split.
|
||||||
|
# if we have too many rows, remove the extras from the first split that has
|
||||||
|
# enough rows.
|
||||||
|
size_difference = dataset_size - absolute_sizes_sum
|
||||||
|
if size_difference > 0:
|
||||||
|
absolute_sizes[0] += size_difference
|
||||||
|
else:
|
||||||
|
for i, _ in enumerate(absolute_sizes):
|
||||||
|
if absolute_sizes[i] + size_difference > 0:
|
||||||
|
absolute_sizes[i] += size_difference
|
||||||
|
break
|
||||||
|
|
||||||
|
if sum(absolute_sizes) != dataset_size:
|
||||||
raise RuntimeError("sum of calculated split sizes {} is not equal to dataset size {}."
|
raise RuntimeError("sum of calculated split sizes {} is not equal to dataset size {}."
|
||||||
.format(absolute_sizes_sum, dataset_size))
|
.format(absolute_sizes_sum, dataset_size))
|
||||||
|
|
||||||
|
@ -629,10 +642,15 @@ class Dataset:
|
||||||
provided, the dataset will be split into n datasets of size s1, size s2, …, size sn
|
provided, the dataset will be split into n datasets of size s1, size s2, …, size sn
|
||||||
respectively. If the sum of all sizes does not equal the original dataset size, an
|
respectively. If the sum of all sizes does not equal the original dataset size, an
|
||||||
an error will occur.
|
an error will occur.
|
||||||
If a list of floats [f1, f2, …, fn] is provided, the dataset will be split into n
|
If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1
|
||||||
Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size
|
and must sum to 1, otherwise an error will occur. The dataset will be split into n
|
||||||
of the original dataset. If after rounding, any size equals 0, an error will occur.
|
Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the
|
||||||
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
|
original dataset.
|
||||||
|
If after rounding:
|
||||||
|
-Any size equals 0, an error will occur.
|
||||||
|
-The sum of split sizes < K, the difference will be added to the first split.
|
||||||
|
-The sum of split sizes > K, the difference will be removed from the first large
|
||||||
|
enough split such that it will have atleast 1 row after removing the difference.
|
||||||
randomize (bool, optional): determines whether or not to split the data randomly (default=True).
|
randomize (bool, optional): determines whether or not to split the data randomly (default=True).
|
||||||
If true, the data will be randomly split. Otherwise, each split will be created with
|
If true, the data will be randomly split. Otherwise, each split will be created with
|
||||||
consecutive rows from the dataset.
|
consecutive rows from the dataset.
|
||||||
|
@ -1224,7 +1242,7 @@ class MappableDataset(SourceDataset):
|
||||||
>>> data.use_sampler(new_sampler)
|
>>> data.use_sampler(new_sampler)
|
||||||
"""
|
"""
|
||||||
if new_sampler is None:
|
if new_sampler is None:
|
||||||
raise TypeError("Input sampler could not be None.")
|
raise TypeError("Input sampler can not be None.")
|
||||||
if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
|
if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
|
||||||
raise TypeError("Input sampler is not an instance of a sampler.")
|
raise TypeError("Input sampler is not an instance of a sampler.")
|
||||||
|
|
||||||
|
@ -1259,10 +1277,15 @@ class MappableDataset(SourceDataset):
|
||||||
provided, the dataset will be split into n datasets of size s1, size s2, …, size sn
|
provided, the dataset will be split into n datasets of size s1, size s2, …, size sn
|
||||||
respectively. If the sum of all sizes does not equal the original dataset size, an
|
respectively. If the sum of all sizes does not equal the original dataset size, an
|
||||||
an error will occur.
|
an error will occur.
|
||||||
If a list of floats [f1, f2, …, fn] is provided, the dataset will be split into n
|
If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1
|
||||||
Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size
|
and must sum to 1, otherwise an error will occur. The dataset will be split into n
|
||||||
of the original dataset. If after rounding, any size equals 0, an error will occur.
|
Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the
|
||||||
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
|
original dataset.
|
||||||
|
If after rounding:
|
||||||
|
-Any size equals 0, an error will occur.
|
||||||
|
-The sum of split sizes < K, the difference will be added to the first split.
|
||||||
|
-The sum of split sizes > K, the difference will be removed from the first large
|
||||||
|
enough split such that it will have atleast 1 row after removing the difference.
|
||||||
randomize (bool, optional): determines whether or not to split the data randomly (default=True).
|
randomize (bool, optional): determines whether or not to split the data randomly (default=True).
|
||||||
If true, the data will be randomly split. Otherwise, each split will be created with
|
If true, the data will be randomly split. Otherwise, each split will be created with
|
||||||
consecutive rows from the dataset.
|
consecutive rows from the dataset.
|
||||||
|
|
|
@ -554,6 +554,43 @@ def test_mappable_multi_split():
|
||||||
assert s2_output == [2]
|
assert s2_output == [2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_rounding():
|
||||||
|
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||||
|
|
||||||
|
# under rounding
|
||||||
|
s1, s2 = d.split([0.5, 0.5], randomize=False)
|
||||||
|
|
||||||
|
s1_output = []
|
||||||
|
for item in s1.create_dict_iterator():
|
||||||
|
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||||
|
|
||||||
|
s2_output = []
|
||||||
|
for item in s2.create_dict_iterator():
|
||||||
|
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||||
|
|
||||||
|
assert s1_output == [0, 1, 2]
|
||||||
|
assert s2_output == [3, 4]
|
||||||
|
|
||||||
|
# over rounding
|
||||||
|
s1, s2, s3 = d.split([0.15, 0.55, 0.3], randomize=False)
|
||||||
|
|
||||||
|
s1_output = []
|
||||||
|
for item in s1.create_dict_iterator():
|
||||||
|
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||||
|
|
||||||
|
s2_output = []
|
||||||
|
for item in s2.create_dict_iterator():
|
||||||
|
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||||
|
|
||||||
|
s3_output = []
|
||||||
|
for item in s3.create_dict_iterator():
|
||||||
|
s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||||
|
|
||||||
|
assert s1_output == [0]
|
||||||
|
assert s2_output == [1, 2]
|
||||||
|
assert s3_output == [3, 4]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_unmappable_invalid_input()
|
test_unmappable_invalid_input()
|
||||||
test_unmappable_split()
|
test_unmappable_split()
|
||||||
|
@ -569,3 +606,4 @@ if __name__ == '__main__':
|
||||||
test_mappable_sharding()
|
test_mappable_sharding()
|
||||||
test_mappable_get_dataset_size()
|
test_mappable_get_dataset_size()
|
||||||
test_mappable_multi_split()
|
test_mappable_multi_split()
|
||||||
|
test_rounding()
|
||||||
|
|
Loading…
Reference in New Issue