Added tests for chain samplers, fixed GetDatasetSize issue with chained samplers

This commit is contained in:
Mahdi 2020-12-16 09:56:45 -05:00
parent 1447879b9c
commit 98ea8fa6ea
6 changed files with 210 additions and 4 deletions

View File

@ -173,7 +173,7 @@ int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) {
child_num_rows = child_[0]->CalculateNumSamples(num_rows);
}
int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
int64_t num_per_shard = std::ceil(num_rows * 1.0 / num_devices_);
int64_t num_per_shard = std::ceil(child_num_rows * 1.0 / num_devices_);
return std::min(num_samples, num_per_shard);
}

View File

@ -133,12 +133,12 @@ Status SamplerRT::SetNumSamples(int64_t num_samples) {
int64_t SamplerRT::GetNumSamples() { return num_samples_; }
int64_t SamplerRT::CalculateNumSamples(int64_t num_rows) {
int64_t childs = num_rows;
int64_t child_num_rows = num_rows;
if (!child_.empty()) {
childs = child_[0]->CalculateNumSamples(num_rows);
child_num_rows = child_[0]->CalculateNumSamples(num_rows);
}
return (num_samples_ > 0) ? std::min(childs, num_samples_) : childs;
return (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
}
Status SamplerRT::SetNumRowsInDataset(int64_t num_rows) {

View File

@ -97,6 +97,26 @@ Status SequentialSamplerRT::ResetSampler() {
return Status::OK();
}
int64_t SequentialSamplerRT::CalculateNumSamples(int64_t num_rows) {
// Holds the number of rows available for Sequential sampler. It can be the rows passed from its child sampler or the
// num_rows from the dataset
int64_t child_num_rows = num_rows;
if (!child_.empty()) {
child_num_rows = child_[0]->CalculateNumSamples(num_rows);
}
int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
// For this sampler we need to take start_index into account. Because for example in the case we are given n rows
// and start_index != 0 and num_samples >= n then we can't return all the n rows.
if (child_num_rows - (start_index_ - current_id_) <= 0) {
return 0;
}
if (child_num_rows - (start_index_ - current_id_) < num_samples)
num_samples = child_num_rows - (start_index_ - current_id_) > num_samples
? num_samples
: num_samples - (start_index_ - current_id_);
return num_samples;
}
void SequentialSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
out << "\nSampler: SequentialSampler";
if (show_all) {

View File

@ -49,6 +49,13 @@ class SequentialSamplerRT : public SamplerRT {
// @return Status The status code returned
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
/// \brief Recursively calls this function on its children to get the actual number of samples on a tree of samplers
/// \note This is not a getter for num_samples_. For example, if num_samples_ is 0 or if it's smaller than num_rows,
/// then num_samples_ is not returned at all.
/// \param[in] num_rows The total number of rows in the dataset
/// \return int64_t Calculated number of samples
int64_t CalculateNumSamples(int64_t num_rows) override;
// Printer for debugging purposes.
// @param out - output stream to write to
// @param show_all - bool to show detailed vs summary

View File

@ -188,6 +188,7 @@ def test_subset_sampler():
assert test_config(2, 3) == [2, 3, 4]
assert test_config(3, 2) == [3, 4]
assert test_config(4, 1) == [4]
assert test_config(4, None) == [4]
def test_sampler_chain():

View File

@ -20,6 +20,18 @@ from util import save_and_check_md5
GENERATE_GOLDEN = False
IMAGENET_RAWDATA_DIR = "../data/dataset/testImageNetData2/train"
IMAGENET_TFFILE_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
MNIST_DATA_DIR = "../data/dataset/testMnistData"
MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest"
CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data"
COCO_DATA_DIR = "../data/dataset/testCOCO/train/"
ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json"
VOC_DATA_DIR = "../data/dataset/testVOC2012"
def test_numpyslices_sampler_no_chain():
"""
@ -107,6 +119,166 @@ def test_numpyslices_sampler_chain2():
logger.info("dataset: {}".format(res))
def test_imagefolder_sampler_chain():
"""
Test ImageFolderDataset sampler chain
"""
logger.info("test_imagefolder_sampler_chain")
sampler = ds.SequentialSampler(start_index=1, num_samples=3)
child_sampler = ds.PKSampler(2)
sampler.add_child(child_sampler)
data1 = ds.ImageFolderDataset(IMAGENET_RAWDATA_DIR, sampler=sampler)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 3
# Verify number of rows
assert sum([1 for _ in data1]) == 3
# Verify dataset contents
res = []
for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
logger.info("item: {}".format(item))
res.append(item)
logger.info("dataset: {}".format(res))
def test_mnist_sampler_chain():
"""
Test Mnist sampler chain
"""
logger.info("test_mnist_sampler_chain")
sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=3, offset=1)
child_sampler = ds.RandomSampler(replacement=True, num_samples=4)
sampler.add_child(child_sampler)
data1 = ds.MnistDataset(MNIST_DATA_DIR, sampler=sampler)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 3
# Verify number of rows
assert sum([1 for _ in data1]) == 3
# Verify dataset contents
res = []
for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
logger.info("item: {}".format(item))
res.append(item)
logger.info("dataset: {}".format(res))
def test_manifest_sampler_chain():
"""
Test Manifest sampler chain
"""
logger.info("test_manifest_sampler_chain")
sampler = ds.RandomSampler(replacement=True, num_samples=2)
child_sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=3, offset=1)
sampler.add_child(child_sampler)
data1 = ds.ManifestDataset(MANIFEST_DATA_FILE, sampler=sampler)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 2
# Verify number of rows
assert sum([1 for _ in data1]) == 2
# Verify dataset contents
res = []
for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
logger.info("item: {}".format(item))
res.append(item)
logger.info("dataset: {}".format(res))
def test_coco_sampler_chain():
"""
Test Coco sampler chain
"""
logger.info("test_coco_sampler_chain")
sampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5)
child_sampler = ds.RandomSampler(replacement=True, num_samples=2)
sampler.add_child(child_sampler)
data1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", decode=True,
sampler=sampler)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 1
# Verify number of rows
assert sum([1 for _ in data1]) == 1
# Verify dataset contents
res = []
for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
logger.info("item: {}".format(item))
res.append(item)
logger.info("dataset: {}".format(res))
def test_cifar_sampler_chain():
"""
Test Cifar sampler chain
"""
logger.info("test_cifar_sampler_chain")
sampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5)
child_sampler = ds.RandomSampler(replacement=True, num_samples=4)
child_sampler2 = ds.SequentialSampler(start_index=0, num_samples=2)
child_sampler.add_child(child_sampler2)
sampler.add_child(child_sampler)
data1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, sampler=sampler)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 1
# Verify number of rows
assert sum([1 for _ in data1]) == 1
# Verify dataset contents
res = []
for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
logger.info("item: {}".format(item))
res.append(item)
logger.info("dataset: {}".format(res))
def test_voc_sampler_chain():
"""
Test VOC sampler chain
"""
logger.info("test_voc_sampler_chain")
sampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5)
child_sampler = ds.SequentialSampler(start_index=0)
sampler.add_child(child_sampler)
data1 = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", sampler=sampler)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 5
# Verify number of rows
assert sum([1 for _ in data1]) == 5
# Verify dataset contents
res = []
for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
logger.info("item: {}".format(item))
res.append(item)
logger.info("dataset: {}".format(res))
def test_numpyslices_sampler_chain_batch():
"""
Test NumpySlicesDataset sampler chaining, with batch
@ -241,6 +413,12 @@ if __name__ == '__main__':
test_numpyslices_sampler_no_chain()
test_numpyslices_sampler_chain()
test_numpyslices_sampler_chain2()
test_imagefolder_sampler_chain()
test_mnist_sampler_chain()
test_manifest_sampler_chain()
test_coco_sampler_chain()
test_cifar_sampler_chain()
test_voc_sampler_chain()
test_numpyslices_sampler_chain_batch()
test_sampler_chain_errors()
test_manifest_sampler_chain_repeat()