Added tests for chain samplers, fixed GetDatasetSize issue with chained samplers
This commit is contained in:
parent
1447879b9c
commit
98ea8fa6ea
|
@ -173,7 +173,7 @@ int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) {
|
|||
child_num_rows = child_[0]->CalculateNumSamples(num_rows);
|
||||
}
|
||||
int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
|
||||
int64_t num_per_shard = std::ceil(num_rows * 1.0 / num_devices_);
|
||||
int64_t num_per_shard = std::ceil(child_num_rows * 1.0 / num_devices_);
|
||||
return std::min(num_samples, num_per_shard);
|
||||
}
|
||||
|
||||
|
|
|
@ -133,12 +133,12 @@ Status SamplerRT::SetNumSamples(int64_t num_samples) {
|
|||
int64_t SamplerRT::GetNumSamples() { return num_samples_; }
|
||||
|
||||
int64_t SamplerRT::CalculateNumSamples(int64_t num_rows) {
|
||||
int64_t childs = num_rows;
|
||||
int64_t child_num_rows = num_rows;
|
||||
if (!child_.empty()) {
|
||||
childs = child_[0]->CalculateNumSamples(num_rows);
|
||||
child_num_rows = child_[0]->CalculateNumSamples(num_rows);
|
||||
}
|
||||
|
||||
return (num_samples_ > 0) ? std::min(childs, num_samples_) : childs;
|
||||
return (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
|
||||
}
|
||||
|
||||
Status SamplerRT::SetNumRowsInDataset(int64_t num_rows) {
|
||||
|
|
|
@ -97,6 +97,26 @@ Status SequentialSamplerRT::ResetSampler() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
int64_t SequentialSamplerRT::CalculateNumSamples(int64_t num_rows) {
|
||||
// Holds the number of rows available for Sequential sampler. It can be the rows passed from its child sampler or the
|
||||
// num_rows from the dataset
|
||||
int64_t child_num_rows = num_rows;
|
||||
if (!child_.empty()) {
|
||||
child_num_rows = child_[0]->CalculateNumSamples(num_rows);
|
||||
}
|
||||
int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
|
||||
// For this sampler we need to take start_index into account. Because for example in the case we are given n rows
|
||||
// and start_index != 0 and num_samples >= n then we can't return all the n rows.
|
||||
if (child_num_rows - (start_index_ - current_id_) <= 0) {
|
||||
return 0;
|
||||
}
|
||||
if (child_num_rows - (start_index_ - current_id_) < num_samples)
|
||||
num_samples = child_num_rows - (start_index_ - current_id_) > num_samples
|
||||
? num_samples
|
||||
: num_samples - (start_index_ - current_id_);
|
||||
return num_samples;
|
||||
}
|
||||
|
||||
void SequentialSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
|
||||
out << "\nSampler: SequentialSampler";
|
||||
if (show_all) {
|
||||
|
|
|
@ -49,6 +49,13 @@ class SequentialSamplerRT : public SamplerRT {
|
|||
// @return Status The status code returned
|
||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
/// \brief Recursively calls this function on its children to get the actual number of samples on a tree of samplers
|
||||
/// \note This is not a getter for num_samples_. For example, if num_samples_ is 0 or if it's smaller than num_rows,
|
||||
/// then num_samples_ is not returned at all.
|
||||
/// \param[in] num_rows The total number of rows in the dataset
|
||||
/// \return int64_t Calculated number of samples
|
||||
int64_t CalculateNumSamples(int64_t num_rows) override;
|
||||
|
||||
// Printer for debugging purposes.
|
||||
// @param out - output stream to write to
|
||||
// @param show_all - bool to show detailed vs summary
|
||||
|
|
|
@ -188,6 +188,7 @@ def test_subset_sampler():
|
|||
assert test_config(2, 3) == [2, 3, 4]
|
||||
assert test_config(3, 2) == [3, 4]
|
||||
assert test_config(4, 1) == [4]
|
||||
assert test_config(4, None) == [4]
|
||||
|
||||
|
||||
def test_sampler_chain():
|
||||
|
|
|
@ -20,6 +20,18 @@ from util import save_and_check_md5
|
|||
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
IMAGENET_RAWDATA_DIR = "../data/dataset/testImageNetData2/train"
|
||||
IMAGENET_TFFILE_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
|
||||
"../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
|
||||
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
|
||||
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
|
||||
MNIST_DATA_DIR = "../data/dataset/testMnistData"
|
||||
MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest"
|
||||
CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data"
|
||||
COCO_DATA_DIR = "../data/dataset/testCOCO/train/"
|
||||
ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json"
|
||||
VOC_DATA_DIR = "../data/dataset/testVOC2012"
|
||||
|
||||
|
||||
def test_numpyslices_sampler_no_chain():
|
||||
"""
|
||||
|
@ -107,6 +119,166 @@ def test_numpyslices_sampler_chain2():
|
|||
logger.info("dataset: {}".format(res))
|
||||
|
||||
|
||||
def test_imagefolder_sampler_chain():
|
||||
"""
|
||||
Test ImageFolderDataset sampler chain
|
||||
"""
|
||||
logger.info("test_imagefolder_sampler_chain")
|
||||
|
||||
sampler = ds.SequentialSampler(start_index=1, num_samples=3)
|
||||
child_sampler = ds.PKSampler(2)
|
||||
sampler.add_child(child_sampler)
|
||||
data1 = ds.ImageFolderDataset(IMAGENET_RAWDATA_DIR, sampler=sampler)
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 3
|
||||
# Verify number of rows
|
||||
assert sum([1 for _ in data1]) == 3
|
||||
|
||||
# Verify dataset contents
|
||||
res = []
|
||||
for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
logger.info("item: {}".format(item))
|
||||
res.append(item)
|
||||
logger.info("dataset: {}".format(res))
|
||||
|
||||
|
||||
def test_mnist_sampler_chain():
|
||||
"""
|
||||
Test Mnist sampler chain
|
||||
"""
|
||||
logger.info("test_mnist_sampler_chain")
|
||||
|
||||
sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=3, offset=1)
|
||||
child_sampler = ds.RandomSampler(replacement=True, num_samples=4)
|
||||
sampler.add_child(child_sampler)
|
||||
data1 = ds.MnistDataset(MNIST_DATA_DIR, sampler=sampler)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 3
|
||||
# Verify number of rows
|
||||
assert sum([1 for _ in data1]) == 3
|
||||
|
||||
# Verify dataset contents
|
||||
res = []
|
||||
for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
logger.info("item: {}".format(item))
|
||||
res.append(item)
|
||||
logger.info("dataset: {}".format(res))
|
||||
|
||||
|
||||
def test_manifest_sampler_chain():
|
||||
"""
|
||||
Test Manifest sampler chain
|
||||
"""
|
||||
logger.info("test_manifest_sampler_chain")
|
||||
|
||||
sampler = ds.RandomSampler(replacement=True, num_samples=2)
|
||||
child_sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=3, offset=1)
|
||||
sampler.add_child(child_sampler)
|
||||
data1 = ds.ManifestDataset(MANIFEST_DATA_FILE, sampler=sampler)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 2
|
||||
# Verify number of rows
|
||||
assert sum([1 for _ in data1]) == 2
|
||||
|
||||
# Verify dataset contents
|
||||
res = []
|
||||
for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
logger.info("item: {}".format(item))
|
||||
res.append(item)
|
||||
logger.info("dataset: {}".format(res))
|
||||
|
||||
|
||||
def test_coco_sampler_chain():
|
||||
"""
|
||||
Test Coco sampler chain
|
||||
"""
|
||||
logger.info("test_coco_sampler_chain")
|
||||
|
||||
sampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5)
|
||||
child_sampler = ds.RandomSampler(replacement=True, num_samples=2)
|
||||
sampler.add_child(child_sampler)
|
||||
data1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", decode=True,
|
||||
sampler=sampler)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 1
|
||||
|
||||
# Verify number of rows
|
||||
assert sum([1 for _ in data1]) == 1
|
||||
|
||||
# Verify dataset contents
|
||||
res = []
|
||||
for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
logger.info("item: {}".format(item))
|
||||
res.append(item)
|
||||
logger.info("dataset: {}".format(res))
|
||||
|
||||
|
||||
def test_cifar_sampler_chain():
|
||||
"""
|
||||
Test Cifar sampler chain
|
||||
"""
|
||||
logger.info("test_cifar_sampler_chain")
|
||||
|
||||
sampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5)
|
||||
child_sampler = ds.RandomSampler(replacement=True, num_samples=4)
|
||||
child_sampler2 = ds.SequentialSampler(start_index=0, num_samples=2)
|
||||
child_sampler.add_child(child_sampler2)
|
||||
sampler.add_child(child_sampler)
|
||||
data1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, sampler=sampler)
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 1
|
||||
|
||||
# Verify number of rows
|
||||
assert sum([1 for _ in data1]) == 1
|
||||
|
||||
# Verify dataset contents
|
||||
res = []
|
||||
for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
logger.info("item: {}".format(item))
|
||||
res.append(item)
|
||||
logger.info("dataset: {}".format(res))
|
||||
|
||||
|
||||
def test_voc_sampler_chain():
|
||||
"""
|
||||
Test VOC sampler chain
|
||||
"""
|
||||
logger.info("test_voc_sampler_chain")
|
||||
|
||||
sampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5)
|
||||
child_sampler = ds.SequentialSampler(start_index=0)
|
||||
sampler.add_child(child_sampler)
|
||||
data1 = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", sampler=sampler)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 5
|
||||
|
||||
# Verify number of rows
|
||||
assert sum([1 for _ in data1]) == 5
|
||||
|
||||
# Verify dataset contents
|
||||
res = []
|
||||
for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
logger.info("item: {}".format(item))
|
||||
res.append(item)
|
||||
logger.info("dataset: {}".format(res))
|
||||
|
||||
|
||||
def test_numpyslices_sampler_chain_batch():
|
||||
"""
|
||||
Test NumpySlicesDataset sampler chaining, with batch
|
||||
|
@ -241,6 +413,12 @@ if __name__ == '__main__':
|
|||
test_numpyslices_sampler_no_chain()
|
||||
test_numpyslices_sampler_chain()
|
||||
test_numpyslices_sampler_chain2()
|
||||
test_imagefolder_sampler_chain()
|
||||
test_mnist_sampler_chain()
|
||||
test_manifest_sampler_chain()
|
||||
test_coco_sampler_chain()
|
||||
test_cifar_sampler_chain()
|
||||
test_voc_sampler_chain()
|
||||
test_numpyslices_sampler_chain_batch()
|
||||
test_sampler_chain_errors()
|
||||
test_manifest_sampler_chain_repeat()
|
||||
|
|
Loading…
Reference in New Issue