forked from mindspore-Ecosystem/mindspore
fixed bug for split, RandomSampler and some other cleanup
Cleanup dataset UT: restore config support
This commit is contained in:
parent
6be8929f62
commit
9228384304
|
@ -71,7 +71,7 @@ if __name__ == '__main__':
|
|||
model = Model(network, loss, opt, {'acc': Accuracy()})
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
ds_train = create_dataset(args.preprocess_path, cfg.batch_size, repeat_num=cfg.num_epochs)
|
||||
ds_train = create_dataset(args.preprocess_path, cfg.batch_size, cfg.num_epochs)
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck)
|
||||
|
|
|
@ -70,21 +70,26 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
|||
}
|
||||
|
||||
Status RandomSampler::InitSampler() {
|
||||
num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive");
|
||||
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows needs to be positive.");
|
||||
|
||||
rnd_.seed(seed_);
|
||||
|
||||
if (replacement_ == false) {
|
||||
num_samples_ = std::min(num_samples_, num_rows_);
|
||||
|
||||
shuffled_ids_.reserve(num_rows_);
|
||||
for (int64_t i = 0; i < num_rows_; i++) {
|
||||
shuffled_ids_.push_back(i);
|
||||
}
|
||||
std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_);
|
||||
} else {
|
||||
num_samples_ = std::min(num_samples_, user_num_samples_);
|
||||
dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1);
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples needs to be positive.");
|
||||
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -32,9 +32,7 @@ Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
|
|||
}
|
||||
|
||||
// Handshake and init child first.
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op));
|
||||
}
|
||||
RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op));
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n");
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -28,9 +28,9 @@ SubsetSampler::SubsetSampler(int64_t start_index, int64_t subset_size)
|
|||
: Sampler(subset_size), start_index_(start_index), subset_size_(subset_size), current_id_(0) {}
|
||||
|
||||
Status SubsetSampler::InitSampler() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size_ <= 0\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size <= 0\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows_\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ + subset_size_ - 1 < num_rows_, "Final index out of bounds.\n");
|
||||
|
||||
num_samples_ = subset_size_;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -23,7 +23,7 @@ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset
|
|||
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \
|
||||
Schema, Shuffle, zip, RandomDataset
|
||||
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
|
||||
WeightedRandomSampler, Sampler
|
||||
WeightedRandomSampler, SubsetSampler, Sampler
|
||||
from .engine.serializer_deserializer import serialize, deserialize, show
|
||||
from .engine.graphdata import GraphData
|
||||
|
||||
|
|
|
@ -633,9 +633,9 @@ class Dataset:
|
|||
Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size
|
||||
of the original dataset. If after rounding, any size equals 0, an error will occur.
|
||||
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
|
||||
randomize (bool): determines whether or not to split the data randomly. If true, the data
|
||||
will be randomly split. Otherwise, each split will be created with consecutive rows
|
||||
from the dataset.
|
||||
randomize (bool, optional): determines whether or not to split the data randomly (default=True).
|
||||
If true, the data will be randomly split. Otherwise, each split will be created with
|
||||
consecutive rows from the dataset.
|
||||
|
||||
Note:
|
||||
1. Dataset cannot be sharded if split is going to be called.
|
||||
|
@ -678,7 +678,8 @@ class Dataset:
|
|||
ds = copy.deepcopy(self)
|
||||
if randomize:
|
||||
# want to shuffle the same way every epoch before split
|
||||
ds = ds.shuffle()
|
||||
# in alter_tree, shuffle buffer is minimum 10000, so use 10000 here
|
||||
ds = ds.shuffle(10000)
|
||||
ds.reshuffle_each_epoch = False
|
||||
|
||||
if rows_to_skip > 0:
|
||||
|
@ -1209,6 +1210,9 @@ class MappableDataset(SourceDataset):
|
|||
>>> new_sampler = ds.DistributedSampler(10, 2)
|
||||
>>> data.use_sampler(new_sampler)
|
||||
"""
|
||||
if new_sampler is not None and not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
|
||||
raise TypeError("new_sampler is not an instance of a sampler.")
|
||||
|
||||
self.sampler = self.sampler.child_sampler
|
||||
self.add_sampler(new_sampler)
|
||||
|
||||
|
@ -1218,6 +1222,11 @@ class MappableDataset(SourceDataset):
|
|||
def is_sharded(self):
|
||||
raise NotImplementedError("MappableDataset must implement is_sharded.")
|
||||
|
||||
def _get_sampler_dataset_size(self):
|
||||
if self.sampler is not None:
|
||||
return self.sampler.get_dataset_size()
|
||||
|
||||
return None
|
||||
|
||||
@check_split
|
||||
def split(self, sizes, randomize=True):
|
||||
|
@ -1236,9 +1245,9 @@ class MappableDataset(SourceDataset):
|
|||
Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size
|
||||
of the original dataset. If after rounding, any size equals 0, an error will occur.
|
||||
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
|
||||
randomize (bool): determines whether or not to split the data randomly. If true, the data
|
||||
will be randomly split. Otherwise, each split will be created with consecutive rows
|
||||
from the dataset.
|
||||
randomize (bool, optional): determines whether or not to split the data randomly (default=True).
|
||||
If true, the data will be randomly split. Otherwise, each split will be created with
|
||||
consecutive rows from the dataset.
|
||||
|
||||
Note:
|
||||
1. Dataset should not be sharded if split is going to be called. Instead, create a
|
||||
|
@ -2105,7 +2114,6 @@ class TransferDataset(DatasetOp):
|
|||
self.iterator = TupleIterator(self)
|
||||
|
||||
|
||||
|
||||
class RangeDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset that reads and parses datasets stored on disk in a range.
|
||||
|
@ -2296,8 +2304,13 @@ class ImageFolderDatasetV2(MappableDataset):
|
|||
else:
|
||||
num_samples = self.num_samples
|
||||
num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[0]
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
return get_num_rows(num_rows, self.num_shards)
|
||||
if rows_from_sampler is None:
|
||||
return rows_per_shard
|
||||
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
|
||||
def num_classes(self):
|
||||
"""
|
||||
|
@ -2425,8 +2438,13 @@ class MnistDataset(MappableDataset):
|
|||
num_samples = self.num_samples
|
||||
|
||||
num_rows = MnistOp.get_num_rows(self.dataset_dir, num_samples)
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
return get_num_rows(num_rows, self.num_shards)
|
||||
if rows_from_sampler is None:
|
||||
return rows_per_shard
|
||||
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
|
@ -2926,7 +2944,12 @@ class GeneratorDataset(MappableDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
return self._dataset_size
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
if rows_from_sampler is None:
|
||||
return self._dataset_size
|
||||
|
||||
return min(rows_from_sampler, self._dataset_size)
|
||||
|
||||
# manually set dataset_size as a temporary solution.
|
||||
def set_dataset_size(self, value):
|
||||
|
@ -3220,8 +3243,13 @@ class ManifestDataset(MappableDataset):
|
|||
class_indexing = self.class_indexing
|
||||
|
||||
num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[0]
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
return get_num_rows(num_rows, self.num_shards)
|
||||
if rows_from_sampler is None:
|
||||
return rows_per_shard
|
||||
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
|
||||
def num_classes(self):
|
||||
"""
|
||||
|
@ -3379,8 +3407,13 @@ class Cifar10Dataset(MappableDataset):
|
|||
num_samples = self.num_samples
|
||||
|
||||
num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, True)
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
return get_num_rows(num_rows, self.num_shards)
|
||||
if rows_from_sampler is None:
|
||||
return rows_per_shard
|
||||
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
|
@ -3498,8 +3531,13 @@ class Cifar100Dataset(MappableDataset):
|
|||
num_samples = self.num_samples
|
||||
|
||||
num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, False)
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
return get_num_rows(num_rows, self.num_shards)
|
||||
if rows_from_sampler is None:
|
||||
return rows_per_shard
|
||||
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
|
@ -3562,7 +3600,12 @@ class RandomDataset(SourceDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
return num_samples
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
if rows_from_sampler is None:
|
||||
return self.num_samples
|
||||
|
||||
return min(rows_from_sampler, self.num_samples)
|
||||
|
||||
def is_shuffled(self):
|
||||
return True
|
||||
|
@ -3871,7 +3914,12 @@ class VOCDataset(MappableDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
return self.num_samples
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
if rows_from_sampler is None:
|
||||
return self.num_samples
|
||||
|
||||
return min(rows_from_sampler, self.num_samples)
|
||||
|
||||
def get_class_indexing(self):
|
||||
"""
|
||||
|
|
|
@ -114,6 +114,9 @@ class Sampler:
|
|||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
def get_dataset_size(self):
|
||||
return self._get_indices().size
|
||||
|
||||
|
||||
class BuiltinSampler:
|
||||
"""
|
||||
|
@ -146,6 +149,12 @@ class BuiltinSampler:
|
|||
def is_sharded(self):
|
||||
raise NotImplementedError("Sampler must implement is_sharded.")
|
||||
|
||||
def get_dataset_size(self):
|
||||
if self.child_sampler is not None:
|
||||
return self.child_sampler.get_dataset_size()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class DistributedSampler(BuiltinSampler):
|
||||
"""
|
||||
|
@ -330,6 +339,9 @@ class RandomSampler(BuiltinSampler):
|
|||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
def get_dataset_size(self):
|
||||
return self.num_samples
|
||||
|
||||
|
||||
class SequentialSampler(BuiltinSampler):
|
||||
"""
|
||||
|
@ -421,6 +433,9 @@ class SubsetSampler(BuiltinSampler):
|
|||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
def get_dataset_size(self):
|
||||
return self.subset_size
|
||||
|
||||
|
||||
class SubsetRandomSampler(BuiltinSampler):
|
||||
"""
|
||||
|
@ -467,6 +482,10 @@ class SubsetRandomSampler(BuiltinSampler):
|
|||
return cde.MindrecordSubsetRandomSampler(self.indices)
|
||||
|
||||
|
||||
def get_dataset_size(self):
|
||||
return len(indices)
|
||||
|
||||
|
||||
class WeightedRandomSampler(BuiltinSampler):
|
||||
"""
|
||||
Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).
|
||||
|
@ -522,3 +541,6 @@ class WeightedRandomSampler(BuiltinSampler):
|
|||
return False
|
||||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
def get_dataset_size(self):
|
||||
return self.num_samples
|
||||
|
|
|
@ -30,6 +30,14 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
|||
|
||||
|
||||
def test_basic():
|
||||
"""
|
||||
Test basic configuration functions
|
||||
"""
|
||||
# Save original configuration values
|
||||
num_parallel_workers_original = ds.config.get_num_parallel_workers()
|
||||
prefetch_size_original = ds.config.get_prefetch_size()
|
||||
seed_original = ds.config.get_seed()
|
||||
|
||||
ds.config.load('../data/dataset/declient.cfg')
|
||||
|
||||
# assert ds.config.get_rows_per_buffer() == 32
|
||||
|
@ -50,6 +58,11 @@ def test_basic():
|
|||
assert ds.config.get_prefetch_size() == 4
|
||||
assert ds.config.get_seed() == 5
|
||||
|
||||
# Restore original configuration values
|
||||
ds.config.set_num_parallel_workers(num_parallel_workers_original)
|
||||
ds.config.set_prefetch_size(prefetch_size_original)
|
||||
ds.config.set_seed(seed_original)
|
||||
|
||||
|
||||
def test_get_seed():
|
||||
"""
|
||||
|
@ -62,6 +75,9 @@ def test_pipeline():
|
|||
"""
|
||||
Test that our configuration pipeline works when we set parameters at different locations in dataset code
|
||||
"""
|
||||
# Save original configuration values
|
||||
num_parallel_workers_original = ds.config.get_num_parallel_workers()
|
||||
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
ds.config.set_num_parallel_workers(2)
|
||||
data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)])
|
||||
|
@ -85,6 +101,9 @@ def test_pipeline():
|
|||
except IOError:
|
||||
logger.info("Error while deleting: {}".format(f))
|
||||
|
||||
# Restore original configuration values
|
||||
ds.config.set_num_parallel_workers(num_parallel_workers_original)
|
||||
|
||||
|
||||
def test_deterministic_run_fail():
|
||||
"""
|
||||
|
@ -92,6 +111,10 @@ def test_deterministic_run_fail():
|
|||
"""
|
||||
logger.info("test_deterministic_run_fail")
|
||||
|
||||
# Save original configuration values
|
||||
num_parallel_workers_original = ds.config.get_num_parallel_workers()
|
||||
seed_original = ds.config.get_seed()
|
||||
|
||||
# when we set the seed all operations within our dataset should be deterministic
|
||||
ds.config.set_seed(0)
|
||||
ds.config.set_num_parallel_workers(1)
|
||||
|
@ -120,12 +143,21 @@ def test_deterministic_run_fail():
|
|||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Array" in str(e)
|
||||
|
||||
# Restore original configuration values
|
||||
ds.config.set_num_parallel_workers(num_parallel_workers_original)
|
||||
ds.config.set_seed(seed_original)
|
||||
|
||||
|
||||
def test_deterministic_run_pass():
|
||||
"""
|
||||
Test deterministic run with with setting the seed
|
||||
"""
|
||||
logger.info("test_deterministic_run_pass")
|
||||
|
||||
# Save original configuration values
|
||||
num_parallel_workers_original = ds.config.get_num_parallel_workers()
|
||||
seed_original = ds.config.get_seed()
|
||||
|
||||
ds.config.set_seed(0)
|
||||
ds.config.set_num_parallel_workers(1)
|
||||
|
||||
|
@ -152,13 +184,23 @@ def test_deterministic_run_pass():
|
|||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Array" in str(e)
|
||||
|
||||
# Restore original configuration values
|
||||
ds.config.set_num_parallel_workers(num_parallel_workers_original)
|
||||
ds.config.set_seed(seed_original)
|
||||
|
||||
|
||||
def test_seed_undeterministic():
|
||||
"""
|
||||
Test seed with num parallel workers in c, this test is expected to fail some of the time
|
||||
"""
|
||||
logger.info("test_seed_undeterministic")
|
||||
|
||||
# Save original configuration values
|
||||
num_parallel_workers_original = ds.config.get_num_parallel_workers()
|
||||
seed_original = ds.config.get_seed()
|
||||
|
||||
ds.config.set_seed(0)
|
||||
ds.config.set_num_parallel_workers(1)
|
||||
|
||||
# First dataset
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
|
@ -178,6 +220,10 @@ def test_seed_undeterministic():
|
|||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
np.testing.assert_equal(item1["image"], item2["image"])
|
||||
|
||||
# Restore original configuration values
|
||||
ds.config.set_num_parallel_workers(num_parallel_workers_original)
|
||||
ds.config.set_seed(seed_original)
|
||||
|
||||
|
||||
def test_deterministic_run_distribution():
|
||||
"""
|
||||
|
@ -185,6 +231,10 @@ def test_deterministic_run_distribution():
|
|||
"""
|
||||
logger.info("test_deterministic_run_distribution")
|
||||
|
||||
# Save original configuration values
|
||||
num_parallel_workers_original = ds.config.get_num_parallel_workers()
|
||||
seed_original = ds.config.get_seed()
|
||||
|
||||
# when we set the seed all operations within our dataset should be deterministic
|
||||
ds.config.set_seed(0)
|
||||
ds.config.set_num_parallel_workers(1)
|
||||
|
@ -206,12 +256,21 @@ def test_deterministic_run_distribution():
|
|||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
np.testing.assert_equal(item1["image"], item2["image"])
|
||||
|
||||
# Restore original configuration values
|
||||
ds.config.set_num_parallel_workers(num_parallel_workers_original)
|
||||
ds.config.set_seed(seed_original)
|
||||
|
||||
|
||||
def test_deterministic_python_seed():
|
||||
"""
|
||||
Test deterministic execution with seed in python
|
||||
"""
|
||||
logger.info("deterministic_random_crop_op_python_2")
|
||||
|
||||
# Save original configuration values
|
||||
num_parallel_workers_original = ds.config.get_num_parallel_workers()
|
||||
seed_original = ds.config.get_seed()
|
||||
|
||||
ds.config.set_seed(0)
|
||||
ds.config.set_num_parallel_workers(1)
|
||||
|
||||
|
@ -242,12 +301,20 @@ def test_deterministic_python_seed():
|
|||
|
||||
np.testing.assert_equal(data1_output, data2_output)
|
||||
|
||||
# Restore original configuration values
|
||||
ds.config.set_num_parallel_workers(num_parallel_workers_original)
|
||||
ds.config.set_seed(seed_original)
|
||||
|
||||
|
||||
def test_deterministic_python_seed_multi_thread():
|
||||
"""
|
||||
Test deterministic execution with seed in python, this fails with multi-thread pyfunc run
|
||||
"""
|
||||
logger.info("deterministic_random_crop_op_python_2")
|
||||
|
||||
# Save original configuration values
|
||||
seed_original = ds.config.get_seed()
|
||||
|
||||
ds.config.set_seed(0)
|
||||
# when we set the seed all operations within our dataset should be deterministic
|
||||
# First dataset
|
||||
|
@ -282,6 +349,9 @@ def test_deterministic_python_seed_multi_thread():
|
|||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Array" in str(e)
|
||||
|
||||
# Restore original configuration values
|
||||
ds.config.set_seed(seed_original)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_basic()
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
# ==============================================================================
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
from util import config_get_set_num_parallel_workers
|
||||
|
||||
|
||||
DATA_FILE = "../data/dataset/testTextFileDataset/1.txt"
|
||||
DATA_ALL_FILE = "../data/dataset/testTextFileDataset/*"
|
||||
|
@ -38,7 +40,7 @@ def test_textline_dataset_all_file():
|
|||
|
||||
|
||||
def test_textline_dataset_totext():
|
||||
ds.config.set_num_parallel_workers(4)
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
|
||||
data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False)
|
||||
count = 0
|
||||
line = ["This is a text file.", "Another file.",
|
||||
|
@ -48,6 +50,8 @@ def test_textline_dataset_totext():
|
|||
assert (str == line[count])
|
||||
count += 1
|
||||
assert (count == 5)
|
||||
# Restore configuration num_parallel_workers
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
||||
def test_textline_dataset_num_samples():
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
@ -164,6 +165,35 @@ def test_python_sampler():
|
|||
assert list(sp1.get_indices()) == [0, 1, 2, 3, 4]
|
||||
|
||||
|
||||
def test_subset_sampler():
|
||||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
|
||||
def test_config(num_samples, start_index, subset_size):
|
||||
sampler = ds.SubsetSampler(start_index, subset_size)
|
||||
d = ds.ManifestDataset(manifest_file, sampler=sampler)
|
||||
|
||||
res = []
|
||||
for item in d.create_dict_iterator():
|
||||
res.append(map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
return res
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
test_config(5, 0, 0)
|
||||
assert "subset_size <= 0" in str(info.value)
|
||||
|
||||
assert test_config(5, 0, 1) == [0]
|
||||
assert test_config(5, 0, 2) == [0, 1]
|
||||
assert test_config(5, 0, 3) == [0, 1, 2]
|
||||
assert test_config(5, 0, 4) == [0, 1, 2, 3]
|
||||
assert test_config(5, 0, 5) == [0, 1, 2, 3, 4]
|
||||
assert test_config(5, 1, 1) == [1]
|
||||
assert test_config(5, 2, 3) == [2, 3, 4]
|
||||
assert test_config(5, 3, 2) == [3, 4]
|
||||
assert test_config(5, 4, 1) == [4]
|
||||
|
||||
|
||||
def test_sampler_chain():
|
||||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
|
@ -190,10 +220,26 @@ def test_sampler_chain():
|
|||
assert test_config(5, 3) == [3]
|
||||
assert test_config(5, 4) == [4]
|
||||
|
||||
def test_add_sampler_invalid_input():
|
||||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
data1 = ds.ManifestDataset(manifest_file)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
data1.use_sampler(1)
|
||||
assert "not an instance of a sampler" in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
data1.use_sampler("sampler")
|
||||
assert "not an instance of a sampler" in str(info.value)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sequential_sampler(True)
|
||||
test_random_sampler(True)
|
||||
test_random_sampler_multi_iter(True)
|
||||
test_sampler_py_api()
|
||||
test_python_sampler()
|
||||
test_subset_sampler()
|
||||
test_sampler_chain()
|
||||
test_add_sampler_invalid_input()
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
# ==============================================================================
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
from util import config_get_set_num_parallel_workers
|
||||
|
||||
|
||||
# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
|
||||
# the label of each image is [0,0,0,1,1] each image can be uniquely identified
|
||||
|
@ -21,7 +23,11 @@ import mindspore.dataset as ds
|
|||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
manifest_map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
|
||||
def split_with_invalid_inputs(d):
|
||||
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
|
||||
text_file_data = ["This is a text file.", "Another file.", "Be happy every day.",
|
||||
"End of file.", "Good luck to everyone."]
|
||||
|
||||
def split_with_invalid_inputs(d):
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([])
|
||||
assert "sizes cannot be empty" in str(info.value)
|
||||
|
@ -66,8 +72,8 @@ def split_with_invalid_inputs(d):
|
|||
s1, s2 = d.split([0.05, 0.95])
|
||||
assert "percentage 0.05 is too small" in str(info.value)
|
||||
|
||||
|
||||
def test_unmappable_invalid_input():
|
||||
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
|
||||
d = ds.TextFileDataset(text_file_dataset_path)
|
||||
split_with_invalid_inputs(d)
|
||||
|
||||
|
@ -76,11 +82,10 @@ def test_unmappable_invalid_input():
|
|||
s1, s2 = d.split([4, 1])
|
||||
assert "dataset should not be sharded before split" in str(info.value)
|
||||
|
||||
|
||||
def test_unmappable_split():
|
||||
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
|
||||
text_file_data = ["This is a text file.", "Another file.", "Be happy every day.",
|
||||
"End of file.", "Good luck to everyone."]
|
||||
ds.config.set_num_parallel_workers(4)
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
|
||||
|
||||
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
|
||||
s1, s2 = d.split([4, 1], randomize=False)
|
||||
|
||||
|
@ -123,6 +128,145 @@ def test_unmappable_split():
|
|||
assert s1_output == text_file_data[0:2]
|
||||
assert s2_output == text_file_data[2:]
|
||||
|
||||
# Restore configuration num_parallel_workers
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
||||
def test_unmappable_randomize_deterministic():
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
|
||||
|
||||
# the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3]
|
||||
ds.config.set_seed(53)
|
||||
|
||||
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
|
||||
s1, s2 = d.split([0.8, 0.2])
|
||||
|
||||
for _ in range(10):
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
# note no overlap
|
||||
assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]]
|
||||
assert s2_output == [text_file_data[3]]
|
||||
|
||||
# Restore configuration num_parallel_workers
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
||||
def test_unmappable_randomize_repeatable():
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
|
||||
|
||||
# the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3]
|
||||
ds.config.set_seed(53)
|
||||
|
||||
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
|
||||
s1, s2 = d.split([0.8, 0.2])
|
||||
|
||||
num_epochs = 5
|
||||
s1 = s1.repeat(num_epochs)
|
||||
s2 = s2.repeat(num_epochs)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
# note no overlap
|
||||
assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]] * num_epochs
|
||||
assert s2_output == [text_file_data[3]] * num_epochs
|
||||
|
||||
# Restore configuration num_parallel_workers
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
||||
def test_unmappable_get_dataset_size():
|
||||
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
|
||||
s1, s2 = d.split([0.8, 0.2])
|
||||
|
||||
assert d.get_dataset_size() == 5
|
||||
assert s1.get_dataset_size() == 4
|
||||
assert s2.get_dataset_size() == 1
|
||||
|
||||
|
||||
def test_unmappable_multi_split():
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
|
||||
|
||||
# the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3]
|
||||
ds.config.set_seed(53)
|
||||
|
||||
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
|
||||
s1, s2 = d.split([4, 1])
|
||||
|
||||
s1_correct_output = [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]]
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(item["text"].item().decode("utf8"))
|
||||
assert s1_output == s1_correct_output
|
||||
|
||||
# no randomize in second split
|
||||
s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False)
|
||||
|
||||
s1s1_output = []
|
||||
for item in s1s1.create_dict_iterator():
|
||||
s1s1_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
s1s2_output = []
|
||||
for item in s1s2.create_dict_iterator():
|
||||
s1s2_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
s1s3_output = []
|
||||
for item in s1s3.create_dict_iterator():
|
||||
s1s3_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
assert s1s1_output == [s1_correct_output[0]]
|
||||
assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]]
|
||||
assert s1s3_output == [s1_correct_output[3]]
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(item["text"].item().decode("utf8"))
|
||||
assert s2_output == [text_file_data[3]]
|
||||
|
||||
# randomize in second split
|
||||
# the labels outputted by the ShuffleOp for seed 53 is [2, 3, 1, 0]
|
||||
shuffled_ids = [2, 3, 1, 0]
|
||||
|
||||
s1s1, s1s2, s1s3 = s1.split([1, 2, 1])
|
||||
|
||||
s1s1_output = []
|
||||
for item in s1s1.create_dict_iterator():
|
||||
s1s1_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
s1s2_output = []
|
||||
for item in s1s2.create_dict_iterator():
|
||||
s1s2_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
s1s3_output = []
|
||||
for item in s1s3.create_dict_iterator():
|
||||
s1s3_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
assert s1s1_output == [s1_correct_output[shuffled_ids[0]]]
|
||||
assert s1s2_output == [s1_correct_output[shuffled_ids[1]], s1_correct_output[shuffled_ids[2]]]
|
||||
assert s1s3_output == [s1_correct_output[shuffled_ids[3]]]
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(item["text"].item().decode("utf8"))
|
||||
assert s2_output == [text_file_data[3]]
|
||||
|
||||
# Restore configuration num_parallel_workers
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
||||
def test_mappable_invalid_input():
|
||||
d = ds.ManifestDataset(manifest_file)
|
||||
split_with_invalid_inputs(d)
|
||||
|
@ -132,6 +276,7 @@ def test_mappable_invalid_input():
|
|||
s1, s2 = d.split([4, 1])
|
||||
assert "dataset should not be sharded before split" in str(info.value)
|
||||
|
||||
|
||||
def test_mappable_split_general():
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
d = d.take(5)
|
||||
|
@ -178,6 +323,7 @@ def test_mappable_split_general():
|
|||
assert s1_output == [0, 1]
|
||||
assert s2_output == [2, 3, 4]
|
||||
|
||||
|
||||
def test_mappable_split_optimized():
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
|
||||
|
@ -223,9 +369,9 @@ def test_mappable_split_optimized():
|
|||
assert s1_output == [0, 1]
|
||||
assert s2_output == [2, 3, 4]
|
||||
|
||||
|
||||
def test_mappable_randomize_deterministic():
|
||||
# set arbitrary seed for shard after split
|
||||
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
|
||||
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
|
||||
ds.config.set_seed(53)
|
||||
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
|
@ -244,9 +390,9 @@ def test_mappable_randomize_deterministic():
|
|||
assert s1_output == [0, 1, 3, 4]
|
||||
assert s2_output == [2]
|
||||
|
||||
|
||||
def test_mappable_randomize_repeatable():
|
||||
# set arbitrary seed for shard after split
|
||||
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
|
||||
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
|
||||
ds.config.set_seed(53)
|
||||
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
|
@ -268,9 +414,10 @@ def test_mappable_randomize_repeatable():
|
|||
assert s1_output == [0, 1, 3, 4] * num_epochs
|
||||
assert s2_output == [2] * num_epochs
|
||||
|
||||
|
||||
def test_mappable_sharding():
|
||||
# set arbitrary seed for repeatability for shard after split
|
||||
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
|
||||
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
|
||||
ds.config.set_seed(53)
|
||||
|
||||
num_epochs = 5
|
||||
|
@ -331,12 +478,94 @@ def test_mappable_sharding():
|
|||
assert s2_output == [2]
|
||||
assert d2s2_output == [2]
|
||||
|
||||
|
||||
def test_mappable_get_dataset_size():
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
s1, s2 = d.split([4, 1])
|
||||
|
||||
assert d.get_dataset_size() == 5
|
||||
assert s1.get_dataset_size() == 4
|
||||
assert s2.get_dataset_size() == 1
|
||||
|
||||
|
||||
def test_mappable_multi_split():
|
||||
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
|
||||
ds.config.set_seed(53)
|
||||
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
s1, s2 = d.split([4, 1])
|
||||
|
||||
s1_correct_output = [0, 1, 3, 4]
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
assert s1_output == s1_correct_output
|
||||
|
||||
# no randomize in second split
|
||||
s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False)
|
||||
|
||||
s1s1_output = []
|
||||
for item in s1s1.create_dict_iterator():
|
||||
s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s1s2_output = []
|
||||
for item in s1s2.create_dict_iterator():
|
||||
s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s1s3_output = []
|
||||
for item in s1s3.create_dict_iterator():
|
||||
s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
assert s1s1_output == [s1_correct_output[0]]
|
||||
assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]]
|
||||
assert s1s3_output == [s1_correct_output[3]]
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
assert s2_output == [2]
|
||||
|
||||
# randomize in second split
|
||||
# the labels outputted by the RandomSampler for seed 53 is [3, 1, 2, 0]
|
||||
random_sampler_ids = [3, 1, 2, 0]
|
||||
|
||||
s1s1, s1s2, s1s3 = s1.split([1, 2, 1])
|
||||
|
||||
s1s1_output = []
|
||||
for item in s1s1.create_dict_iterator():
|
||||
s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s1s2_output = []
|
||||
for item in s1s2.create_dict_iterator():
|
||||
s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s1s3_output = []
|
||||
for item in s1s3.create_dict_iterator():
|
||||
s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
assert s1s1_output == [s1_correct_output[random_sampler_ids[0]]]
|
||||
assert s1s2_output == [s1_correct_output[random_sampler_ids[1]], s1_correct_output[random_sampler_ids[2]]]
|
||||
assert s1s3_output == [s1_correct_output[random_sampler_ids[3]]]
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
assert s2_output == [2]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_unmappable_invalid_input()
|
||||
test_unmappable_split()
|
||||
test_unmappable_randomize_deterministic()
|
||||
test_unmappable_randomize_repeatable()
|
||||
test_unmappable_get_dataset_size()
|
||||
test_unmappable_multi_split()
|
||||
test_mappable_invalid_input()
|
||||
test_mappable_split_general()
|
||||
test_mappable_split_optimized()
|
||||
test_mappable_randomize_deterministic()
|
||||
test_mappable_randomize_repeatable()
|
||||
test_mappable_sharding()
|
||||
test_mappable_get_dataset_size()
|
||||
test_mappable_multi_split()
|
||||
|
|
|
@ -15,11 +15,11 @@
|
|||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
# import jsbeautifier
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
||||
# These are the column names defined in the testTFTestAllTypes dataset
|
||||
|
@ -221,3 +221,26 @@ def visualize(image_original, image_transformed):
|
|||
plt.title("Transformed image")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def config_get_set_seed(seed_new):
|
||||
"""
|
||||
Get and return the original configuration seed value.
|
||||
Set the new configuration seed value.
|
||||
"""
|
||||
seed_original = ds.config.get_seed()
|
||||
ds.config.set_seed(seed_new)
|
||||
logger.info("seed: original = {} new = {} ".format(seed_original, seed_new))
|
||||
return seed_original
|
||||
|
||||
|
||||
def config_get_set_num_parallel_workers(num_parallel_workers_new):
|
||||
"""
|
||||
Get and return the original configuration num_parallel_workers value.
|
||||
Set the new configuration num_parallel_workers value.
|
||||
"""
|
||||
num_parallel_workers_original = ds.config.get_num_parallel_workers()
|
||||
ds.config.set_num_parallel_workers(num_parallel_workers_new)
|
||||
logger.info("num_parallel_workers: original = {} new = {} ".format(num_parallel_workers_original,
|
||||
num_parallel_workers_new))
|
||||
return num_parallel_workers_original
|
||||
|
|
Loading…
Reference in New Issue