forked from mindspore-Ecosystem/mindspore
!1457 fix 3 bug reports for split
Merge pull request !1457 from Peilin/splitOp-after-testing
This commit is contained in:
commit
4e8e82f24a
|
@ -71,7 +71,7 @@ if __name__ == '__main__':
|
|||
model = Model(network, loss, opt, {'acc': Accuracy()})
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
ds_train = create_dataset(args.preprocess_path, cfg.batch_size, repeat_num=cfg.num_epochs)
|
||||
ds_train = create_dataset(args.preprocess_path, cfg.batch_size, cfg.num_epochs)
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck)
|
||||
|
|
|
@ -70,21 +70,26 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
|||
}
|
||||
|
||||
Status RandomSampler::InitSampler() {
|
||||
num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive");
|
||||
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows needs to be positive.");
|
||||
|
||||
rnd_.seed(seed_);
|
||||
|
||||
if (replacement_ == false) {
|
||||
num_samples_ = std::min(num_samples_, num_rows_);
|
||||
|
||||
shuffled_ids_.reserve(num_rows_);
|
||||
for (int64_t i = 0; i < num_rows_; i++) {
|
||||
shuffled_ids_.push_back(i);
|
||||
}
|
||||
std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_);
|
||||
} else {
|
||||
num_samples_ = std::min(num_samples_, user_num_samples_);
|
||||
dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1);
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples needs to be positive.");
|
||||
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -32,9 +32,7 @@ Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
|
|||
}
|
||||
|
||||
// Handshake and init child first.
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op));
|
||||
}
|
||||
RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op));
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n");
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -28,9 +28,9 @@ SubsetSampler::SubsetSampler(int64_t start_index, int64_t subset_size)
|
|||
: Sampler(subset_size), start_index_(start_index), subset_size_(subset_size), current_id_(0) {}
|
||||
|
||||
Status SubsetSampler::InitSampler() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size_ <= 0\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size <= 0\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows_\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ + subset_size_ - 1 < num_rows_, "Final index out of bounds.\n");
|
||||
|
||||
num_samples_ = subset_size_;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -23,7 +23,7 @@ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset
|
|||
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \
|
||||
Schema, Shuffle, zip, RandomDataset
|
||||
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
|
||||
WeightedRandomSampler, Sampler
|
||||
WeightedRandomSampler, SubsetSampler, Sampler
|
||||
from .engine.serializer_deserializer import serialize, deserialize, show
|
||||
from .engine.graphdata import GraphData
|
||||
|
||||
|
|
|
@ -633,9 +633,9 @@ class Dataset:
|
|||
Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size
|
||||
of the original dataset. If after rounding, any size equals 0, an error will occur.
|
||||
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
|
||||
randomize (bool): determines whether or not to split the data randomly. If true, the data
|
||||
will be randomly split. Otherwise, each split will be created with consecutive rows
|
||||
from the dataset.
|
||||
randomize (bool, optional): determines whether or not to split the data randomly (default=True).
|
||||
If true, the data will be randomly split. Otherwise, each split will be created with
|
||||
consecutive rows from the dataset.
|
||||
|
||||
Note:
|
||||
1. Dataset cannot be sharded if split is going to be called.
|
||||
|
@ -678,7 +678,8 @@ class Dataset:
|
|||
ds = copy.deepcopy(self)
|
||||
if randomize:
|
||||
# want to shuffle the same way every epoch before split
|
||||
ds = ds.shuffle()
|
||||
# in alter_tree, shuffle buffer is minimum 10000, so use 10000 here
|
||||
ds = ds.shuffle(10000)
|
||||
ds.reshuffle_each_epoch = False
|
||||
|
||||
if rows_to_skip > 0:
|
||||
|
@ -1209,6 +1210,9 @@ class MappableDataset(SourceDataset):
|
|||
>>> new_sampler = ds.DistributedSampler(10, 2)
|
||||
>>> data.use_sampler(new_sampler)
|
||||
"""
|
||||
if new_sampler is not None and not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
|
||||
raise TypeError("new_sampler is not an instance of a sampler.")
|
||||
|
||||
self.sampler = self.sampler.child_sampler
|
||||
self.add_sampler(new_sampler)
|
||||
|
||||
|
@ -1218,6 +1222,11 @@ class MappableDataset(SourceDataset):
|
|||
def is_sharded(self):
|
||||
raise NotImplementedError("MappableDataset must implement is_sharded.")
|
||||
|
||||
def _get_sampler_dataset_size(self):
|
||||
if self.sampler is not None:
|
||||
return self.sampler.get_dataset_size()
|
||||
|
||||
return None
|
||||
|
||||
@check_split
|
||||
def split(self, sizes, randomize=True):
|
||||
|
@ -1236,9 +1245,9 @@ class MappableDataset(SourceDataset):
|
|||
Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size
|
||||
of the original dataset. If after rounding, any size equals 0, an error will occur.
|
||||
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
|
||||
randomize (bool): determines whether or not to split the data randomly. If true, the data
|
||||
will be randomly split. Otherwise, each split will be created with consecutive rows
|
||||
from the dataset.
|
||||
randomize (bool, optional): determines whether or not to split the data randomly (default=True).
|
||||
If true, the data will be randomly split. Otherwise, each split will be created with
|
||||
consecutive rows from the dataset.
|
||||
|
||||
Note:
|
||||
1. Dataset should not be sharded if split is going to be called. Instead, create a
|
||||
|
@ -2105,7 +2114,6 @@ class TransferDataset(DatasetOp):
|
|||
self.iterator = TupleIterator(self)
|
||||
|
||||
|
||||
|
||||
class RangeDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset that reads and parses datasets stored on disk in a range.
|
||||
|
@ -2296,8 +2304,13 @@ class ImageFolderDatasetV2(MappableDataset):
|
|||
else:
|
||||
num_samples = self.num_samples
|
||||
num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[0]
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
return get_num_rows(num_rows, self.num_shards)
|
||||
if rows_from_sampler is None:
|
||||
return rows_per_shard
|
||||
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
|
||||
def num_classes(self):
|
||||
"""
|
||||
|
@ -2425,8 +2438,13 @@ class MnistDataset(MappableDataset):
|
|||
num_samples = self.num_samples
|
||||
|
||||
num_rows = MnistOp.get_num_rows(self.dataset_dir, num_samples)
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
return get_num_rows(num_rows, self.num_shards)
|
||||
if rows_from_sampler is None:
|
||||
return rows_per_shard
|
||||
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
|
@ -2926,7 +2944,12 @@ class GeneratorDataset(MappableDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
return self._dataset_size
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
if rows_from_sampler is None:
|
||||
return self._dataset_size
|
||||
|
||||
return min(rows_from_sampler, self._dataset_size)
|
||||
|
||||
# manually set dataset_size as a temporary solution.
|
||||
def set_dataset_size(self, value):
|
||||
|
@ -3220,8 +3243,13 @@ class ManifestDataset(MappableDataset):
|
|||
class_indexing = self.class_indexing
|
||||
|
||||
num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[0]
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
return get_num_rows(num_rows, self.num_shards)
|
||||
if rows_from_sampler is None:
|
||||
return rows_per_shard
|
||||
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
|
||||
def num_classes(self):
|
||||
"""
|
||||
|
@ -3379,8 +3407,13 @@ class Cifar10Dataset(MappableDataset):
|
|||
num_samples = self.num_samples
|
||||
|
||||
num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, True)
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
return get_num_rows(num_rows, self.num_shards)
|
||||
if rows_from_sampler is None:
|
||||
return rows_per_shard
|
||||
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
|
@ -3498,8 +3531,13 @@ class Cifar100Dataset(MappableDataset):
|
|||
num_samples = self.num_samples
|
||||
|
||||
num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, False)
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
return get_num_rows(num_rows, self.num_shards)
|
||||
if rows_from_sampler is None:
|
||||
return rows_per_shard
|
||||
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
|
@ -3562,7 +3600,12 @@ class RandomDataset(SourceDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
return num_samples
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
if rows_from_sampler is None:
|
||||
return self.num_samples
|
||||
|
||||
return min(rows_from_sampler, self.num_samples)
|
||||
|
||||
def is_shuffled(self):
|
||||
return True
|
||||
|
@ -3871,7 +3914,12 @@ class VOCDataset(MappableDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
return self.num_samples
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
if rows_from_sampler is None:
|
||||
return self.num_samples
|
||||
|
||||
return min(rows_from_sampler, self.num_samples)
|
||||
|
||||
def get_class_indexing(self):
|
||||
"""
|
||||
|
|
|
@ -114,6 +114,9 @@ class Sampler:
|
|||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
def get_dataset_size(self):
|
||||
return self._get_indices().size
|
||||
|
||||
|
||||
class BuiltinSampler:
|
||||
"""
|
||||
|
@ -146,6 +149,12 @@ class BuiltinSampler:
|
|||
def is_sharded(self):
|
||||
raise NotImplementedError("Sampler must implement is_sharded.")
|
||||
|
||||
def get_dataset_size(self):
|
||||
if self.child_sampler is not None:
|
||||
return self.child_sampler.get_dataset_size()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class DistributedSampler(BuiltinSampler):
|
||||
"""
|
||||
|
@ -330,6 +339,9 @@ class RandomSampler(BuiltinSampler):
|
|||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
def get_dataset_size(self):
|
||||
return self.num_samples
|
||||
|
||||
|
||||
class SequentialSampler(BuiltinSampler):
|
||||
"""
|
||||
|
@ -421,6 +433,9 @@ class SubsetSampler(BuiltinSampler):
|
|||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
def get_dataset_size(self):
|
||||
return self.subset_size
|
||||
|
||||
|
||||
class SubsetRandomSampler(BuiltinSampler):
|
||||
"""
|
||||
|
@ -467,6 +482,10 @@ class SubsetRandomSampler(BuiltinSampler):
|
|||
return cde.MindrecordSubsetRandomSampler(self.indices)
|
||||
|
||||
|
||||
def get_dataset_size(self):
|
||||
return len(indices)
|
||||
|
||||
|
||||
class WeightedRandomSampler(BuiltinSampler):
|
||||
"""
|
||||
Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).
|
||||
|
@ -522,3 +541,6 @@ class WeightedRandomSampler(BuiltinSampler):
|
|||
return False
|
||||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
def get_dataset_size(self):
|
||||
return self.num_samples
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
@ -164,6 +165,35 @@ def test_python_sampler():
|
|||
assert list(sp1.get_indices()) == [0, 1, 2, 3, 4]
|
||||
|
||||
|
||||
def test_subset_sampler():
|
||||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
|
||||
def test_config(num_samples, start_index, subset_size):
|
||||
sampler = ds.SubsetSampler(start_index, subset_size)
|
||||
d = ds.ManifestDataset(manifest_file, sampler=sampler)
|
||||
|
||||
res = []
|
||||
for item in d.create_dict_iterator():
|
||||
res.append(map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
return res
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
test_config(5, 0, 0)
|
||||
assert "subset_size <= 0" in str(info.value)
|
||||
|
||||
assert test_config(5, 0, 1) == [0]
|
||||
assert test_config(5, 0, 2) == [0, 1]
|
||||
assert test_config(5, 0, 3) == [0, 1, 2]
|
||||
assert test_config(5, 0, 4) == [0, 1, 2, 3]
|
||||
assert test_config(5, 0, 5) == [0, 1, 2, 3, 4]
|
||||
assert test_config(5, 1, 1) == [1]
|
||||
assert test_config(5, 2, 3) == [2, 3, 4]
|
||||
assert test_config(5, 3, 2) == [3, 4]
|
||||
assert test_config(5, 4, 1) == [4]
|
||||
|
||||
|
||||
def test_sampler_chain():
|
||||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
|
@ -190,10 +220,26 @@ def test_sampler_chain():
|
|||
assert test_config(5, 3) == [3]
|
||||
assert test_config(5, 4) == [4]
|
||||
|
||||
def test_add_sampler_invalid_input():
|
||||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
data1 = ds.ManifestDataset(manifest_file)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
data1.use_sampler(1)
|
||||
assert "not an instance of a sampler" in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
data1.use_sampler("sampler")
|
||||
assert "not an instance of a sampler" in str(info.value)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sequential_sampler(True)
|
||||
test_random_sampler(True)
|
||||
test_random_sampler_multi_iter(True)
|
||||
test_sampler_py_api()
|
||||
test_python_sampler()
|
||||
test_subset_sampler()
|
||||
test_sampler_chain()
|
||||
test_add_sampler_invalid_input()
|
||||
|
|
|
@ -23,7 +23,11 @@ from util import config_get_set_num_parallel_workers
|
|||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
manifest_map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
|
||||
def split_with_invalid_inputs(d):
|
||||
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
|
||||
text_file_data = ["This is a text file.", "Another file.", "Be happy every day.",
|
||||
"End of file.", "Good luck to everyone."]
|
||||
|
||||
def split_with_invalid_inputs(d):
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([])
|
||||
assert "sizes cannot be empty" in str(info.value)
|
||||
|
@ -68,8 +72,8 @@ def split_with_invalid_inputs(d):
|
|||
s1, s2 = d.split([0.05, 0.95])
|
||||
assert "percentage 0.05 is too small" in str(info.value)
|
||||
|
||||
|
||||
def test_unmappable_invalid_input():
|
||||
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
|
||||
d = ds.TextFileDataset(text_file_dataset_path)
|
||||
split_with_invalid_inputs(d)
|
||||
|
||||
|
@ -78,11 +82,10 @@ def test_unmappable_invalid_input():
|
|||
s1, s2 = d.split([4, 1])
|
||||
assert "dataset should not be sharded before split" in str(info.value)
|
||||
|
||||
|
||||
def test_unmappable_split():
|
||||
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
|
||||
text_file_data = ["This is a text file.", "Another file.", "Be happy every day.",
|
||||
"End of file.", "Good luck to everyone."]
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
|
||||
|
||||
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
|
||||
s1, s2 = d.split([4, 1], randomize=False)
|
||||
|
||||
|
@ -124,6 +127,142 @@ def test_unmappable_split():
|
|||
|
||||
assert s1_output == text_file_data[0:2]
|
||||
assert s2_output == text_file_data[2:]
|
||||
|
||||
# Restore configuration num_parallel_workers
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
||||
def test_unmappable_randomize_deterministic():
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
|
||||
|
||||
# the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3]
|
||||
ds.config.set_seed(53)
|
||||
|
||||
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
|
||||
s1, s2 = d.split([0.8, 0.2])
|
||||
|
||||
for _ in range(10):
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
# note no overlap
|
||||
assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]]
|
||||
assert s2_output == [text_file_data[3]]
|
||||
|
||||
# Restore configuration num_parallel_workers
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
||||
def test_unmappable_randomize_repeatable():
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
|
||||
|
||||
# the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3]
|
||||
ds.config.set_seed(53)
|
||||
|
||||
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
|
||||
s1, s2 = d.split([0.8, 0.2])
|
||||
|
||||
num_epochs = 5
|
||||
s1 = s1.repeat(num_epochs)
|
||||
s2 = s2.repeat(num_epochs)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
# note no overlap
|
||||
assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]] * num_epochs
|
||||
assert s2_output == [text_file_data[3]] * num_epochs
|
||||
|
||||
# Restore configuration num_parallel_workers
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
||||
def test_unmappable_get_dataset_size():
|
||||
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
|
||||
s1, s2 = d.split([0.8, 0.2])
|
||||
|
||||
assert d.get_dataset_size() == 5
|
||||
assert s1.get_dataset_size() == 4
|
||||
assert s2.get_dataset_size() == 1
|
||||
|
||||
|
||||
def test_unmappable_multi_split():
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
|
||||
|
||||
# the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3]
|
||||
ds.config.set_seed(53)
|
||||
|
||||
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
|
||||
s1, s2 = d.split([4, 1])
|
||||
|
||||
s1_correct_output = [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]]
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(item["text"].item().decode("utf8"))
|
||||
assert s1_output == s1_correct_output
|
||||
|
||||
# no randomize in second split
|
||||
s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False)
|
||||
|
||||
s1s1_output = []
|
||||
for item in s1s1.create_dict_iterator():
|
||||
s1s1_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
s1s2_output = []
|
||||
for item in s1s2.create_dict_iterator():
|
||||
s1s2_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
s1s3_output = []
|
||||
for item in s1s3.create_dict_iterator():
|
||||
s1s3_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
assert s1s1_output == [s1_correct_output[0]]
|
||||
assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]]
|
||||
assert s1s3_output == [s1_correct_output[3]]
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(item["text"].item().decode("utf8"))
|
||||
assert s2_output == [text_file_data[3]]
|
||||
|
||||
# randomize in second split
|
||||
# the labels outputted by the ShuffleOp for seed 53 is [2, 3, 1, 0]
|
||||
shuffled_ids = [2, 3, 1, 0]
|
||||
|
||||
s1s1, s1s2, s1s3 = s1.split([1, 2, 1])
|
||||
|
||||
s1s1_output = []
|
||||
for item in s1s1.create_dict_iterator():
|
||||
s1s1_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
s1s2_output = []
|
||||
for item in s1s2.create_dict_iterator():
|
||||
s1s2_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
s1s3_output = []
|
||||
for item in s1s3.create_dict_iterator():
|
||||
s1s3_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
assert s1s1_output == [s1_correct_output[shuffled_ids[0]]]
|
||||
assert s1s2_output == [s1_correct_output[shuffled_ids[1]], s1_correct_output[shuffled_ids[2]]]
|
||||
assert s1s3_output == [s1_correct_output[shuffled_ids[3]]]
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(item["text"].item().decode("utf8"))
|
||||
assert s2_output == [text_file_data[3]]
|
||||
|
||||
# Restore configuration num_parallel_workers
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
@ -137,6 +276,7 @@ def test_mappable_invalid_input():
|
|||
s1, s2 = d.split([4, 1])
|
||||
assert "dataset should not be sharded before split" in str(info.value)
|
||||
|
||||
|
||||
def test_mappable_split_general():
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
d = d.take(5)
|
||||
|
@ -183,6 +323,7 @@ def test_mappable_split_general():
|
|||
assert s1_output == [0, 1]
|
||||
assert s2_output == [2, 3, 4]
|
||||
|
||||
|
||||
def test_mappable_split_optimized():
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
|
||||
|
@ -228,9 +369,9 @@ def test_mappable_split_optimized():
|
|||
assert s1_output == [0, 1]
|
||||
assert s2_output == [2, 3, 4]
|
||||
|
||||
|
||||
def test_mappable_randomize_deterministic():
|
||||
# set arbitrary seed for shard after split
|
||||
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
|
||||
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
|
||||
ds.config.set_seed(53)
|
||||
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
|
@ -249,9 +390,9 @@ def test_mappable_randomize_deterministic():
|
|||
assert s1_output == [0, 1, 3, 4]
|
||||
assert s2_output == [2]
|
||||
|
||||
|
||||
def test_mappable_randomize_repeatable():
|
||||
# set arbitrary seed for shard after split
|
||||
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
|
||||
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
|
||||
ds.config.set_seed(53)
|
||||
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
|
@ -273,9 +414,10 @@ def test_mappable_randomize_repeatable():
|
|||
assert s1_output == [0, 1, 3, 4] * num_epochs
|
||||
assert s2_output == [2] * num_epochs
|
||||
|
||||
|
||||
def test_mappable_sharding():
|
||||
# set arbitrary seed for repeatability for shard after split
|
||||
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
|
||||
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
|
||||
ds.config.set_seed(53)
|
||||
|
||||
num_epochs = 5
|
||||
|
@ -336,12 +478,94 @@ def test_mappable_sharding():
|
|||
assert s2_output == [2]
|
||||
assert d2s2_output == [2]
|
||||
|
||||
|
||||
def test_mappable_get_dataset_size():
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
s1, s2 = d.split([4, 1])
|
||||
|
||||
assert d.get_dataset_size() == 5
|
||||
assert s1.get_dataset_size() == 4
|
||||
assert s2.get_dataset_size() == 1
|
||||
|
||||
|
||||
def test_mappable_multi_split():
|
||||
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
|
||||
ds.config.set_seed(53)
|
||||
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
s1, s2 = d.split([4, 1])
|
||||
|
||||
s1_correct_output = [0, 1, 3, 4]
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
assert s1_output == s1_correct_output
|
||||
|
||||
# no randomize in second split
|
||||
s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False)
|
||||
|
||||
s1s1_output = []
|
||||
for item in s1s1.create_dict_iterator():
|
||||
s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s1s2_output = []
|
||||
for item in s1s2.create_dict_iterator():
|
||||
s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s1s3_output = []
|
||||
for item in s1s3.create_dict_iterator():
|
||||
s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
assert s1s1_output == [s1_correct_output[0]]
|
||||
assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]]
|
||||
assert s1s3_output == [s1_correct_output[3]]
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
assert s2_output == [2]
|
||||
|
||||
# randomize in second split
|
||||
# the labels outputted by the RandomSampler for seed 53 is [3, 1, 2, 0]
|
||||
random_sampler_ids = [3, 1, 2, 0]
|
||||
|
||||
s1s1, s1s2, s1s3 = s1.split([1, 2, 1])
|
||||
|
||||
s1s1_output = []
|
||||
for item in s1s1.create_dict_iterator():
|
||||
s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s1s2_output = []
|
||||
for item in s1s2.create_dict_iterator():
|
||||
s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s1s3_output = []
|
||||
for item in s1s3.create_dict_iterator():
|
||||
s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
assert s1s1_output == [s1_correct_output[random_sampler_ids[0]]]
|
||||
assert s1s2_output == [s1_correct_output[random_sampler_ids[1]], s1_correct_output[random_sampler_ids[2]]]
|
||||
assert s1s3_output == [s1_correct_output[random_sampler_ids[3]]]
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
assert s2_output == [2]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_unmappable_invalid_input()
|
||||
test_unmappable_split()
|
||||
test_unmappable_randomize_deterministic()
|
||||
test_unmappable_randomize_repeatable()
|
||||
test_unmappable_get_dataset_size()
|
||||
test_unmappable_multi_split()
|
||||
test_mappable_invalid_input()
|
||||
test_mappable_split_general()
|
||||
test_mappable_split_optimized()
|
||||
test_mappable_randomize_deterministic()
|
||||
test_mappable_randomize_repeatable()
|
||||
test_mappable_sharding()
|
||||
test_mappable_get_dataset_size()
|
||||
test_mappable_multi_split()
|
||||
|
|
Loading…
Reference in New Issue