diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc index c37cec60aac..16e2328e103 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc @@ -177,7 +177,7 @@ int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) { child_num_rows = child_[0]->CalculateNumSamples(num_rows); } int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; - int64_t num_per_shard = std::ceil(num_rows * 1.0 / num_devices_); + int64_t num_per_shard = std::ceil(child_num_rows * 1.0 / num_devices_); return std::min(num_samples, num_per_shard); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc index e326aa76c96..9c1e15aa6ad 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc @@ -137,12 +137,12 @@ Status SamplerRT::SetNumSamples(int64_t num_samples) { int64_t SamplerRT::GetNumSamples() { return num_samples_; } int64_t SamplerRT::CalculateNumSamples(int64_t num_rows) { - int64_t childs = num_rows; + int64_t child_num_rows = num_rows; if (!child_.empty()) { - childs = child_[0]->CalculateNumSamples(num_rows); + child_num_rows = child_[0]->CalculateNumSamples(num_rows); } - return (num_samples_ > 0) ? std::min(childs, num_samples_) : childs; + return (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; } Status SamplerRT::SetNumRowsInDataset(int64_t num_rows) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc index f57e1f6b681..4d997dbb23b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc @@ -102,6 +102,26 @@ Status SequentialSamplerRT::ResetSampler() { return Status::OK(); } +int64_t SequentialSamplerRT::CalculateNumSamples(int64_t num_rows) { + // Holds the number of rows available for Sequential sampler. It can be the rows passed from its child sampler or the + // num_rows from the dataset + int64_t child_num_rows = num_rows; + if (!child_.empty()) { + child_num_rows = child_[0]->CalculateNumSamples(num_rows); + } + int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; + // For this sampler we need to take start_index into account. Because for example in the case we are given n rows + // and start_index != 0 and num_samples >= n then we can't return all the n rows. + if (child_num_rows - (start_index_ - current_id_) <= 0) { + return 0; + } + if (child_num_rows - (start_index_ - current_id_) < num_samples) + num_samples = child_num_rows - (start_index_ - current_id_) > num_samples + ? num_samples + : num_samples - (start_index_ - current_id_); + return num_samples; +} + void SequentialSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const { out << "\nSampler: SequentialSampler"; if (show_all) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h index 250f2903d3d..04b9d1b7975 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h @@ -49,6 +49,13 @@ class SequentialSamplerRT : public SamplerRT { // @return Status The status code returned Status GetNextSample(std::unique_ptr *out_buffer) override; + /// \brief Recursively calls this function on its children to get the actual number of samples on a tree of samplers + /// \note This is not a getter for num_samples_. For example, if num_samples_ is 0 or if it's smaller than num_rows, + /// then num_samples_ is not returned at all. + /// \param[in] num_rows The total number of rows in the dataset + /// \return int64_t Calculated number of samples + int64_t CalculateNumSamples(int64_t num_rows) override; + // Printer for debugging purposes. // @param out - output stream to write to // @param show_all - bool to show detailed vs summary diff --git a/tests/ut/python/dataset/test_sampler.py b/tests/ut/python/dataset/test_sampler.py index ab8a3b964c7..80688bed5ed 100644 --- a/tests/ut/python/dataset/test_sampler.py +++ b/tests/ut/python/dataset/test_sampler.py @@ -188,6 +188,7 @@ def test_subset_sampler(): assert test_config(2, 3) == [2, 3, 4] assert test_config(3, 2) == [3, 4] assert test_config(4, 1) == [4] + assert test_config(4, None) == [4] def test_sampler_chain(): diff --git a/tests/ut/python/dataset/test_sampler_chain.py b/tests/ut/python/dataset/test_sampler_chain.py index baca36ece9d..6b68dcb0978 100644 --- a/tests/ut/python/dataset/test_sampler_chain.py +++ b/tests/ut/python/dataset/test_sampler_chain.py @@ -20,6 +20,18 @@ from util import save_and_check_md5 GENERATE_GOLDEN = False +IMAGENET_RAWDATA_DIR = "../data/dataset/testImageNetData2/train" +IMAGENET_TFFILE_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", + "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data", + "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data", + "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"] +MNIST_DATA_DIR = "../data/dataset/testMnistData" +MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest" +CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data" +COCO_DATA_DIR = "../data/dataset/testCOCO/train/" +ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json" +VOC_DATA_DIR = "../data/dataset/testVOC2012" + def test_numpyslices_sampler_no_chain(): """ @@ -107,6 +119,166 @@ def test_numpyslices_sampler_chain2(): logger.info("dataset: {}".format(res)) +def test_imagefolder_sampler_chain(): + """ + Test ImageFolderDataset sampler chain + """ + logger.info("test_imagefolder_sampler_chain") + + sampler = ds.SequentialSampler(start_index=1, num_samples=3) + child_sampler = ds.PKSampler(2) + sampler.add_child(child_sampler) + data1 = ds.ImageFolderDataset(IMAGENET_RAWDATA_DIR, sampler=sampler) + # Verify dataset size + data1_size = data1.get_dataset_size() + logger.info("dataset size is: {}".format(data1_size)) + assert data1_size == 3 + # Verify number of rows + assert sum([1 for _ in data1]) == 3 + + # Verify dataset contents + res = [] + for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): + logger.info("item: {}".format(item)) + res.append(item) + logger.info("dataset: {}".format(res)) + + +def test_mnist_sampler_chain(): + """ + Test Mnist sampler chain + """ + logger.info("test_mnist_sampler_chain") + + sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=3, offset=1) + child_sampler = ds.RandomSampler(replacement=True, num_samples=4) + sampler.add_child(child_sampler) + data1 = ds.MnistDataset(MNIST_DATA_DIR, sampler=sampler) + + # Verify dataset size + data1_size = data1.get_dataset_size() + logger.info("dataset size is: {}".format(data1_size)) + assert data1_size == 3 + # Verify number of rows + assert sum([1 for _ in data1]) == 3 + + # Verify dataset contents + res = [] + for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): + logger.info("item: {}".format(item)) + res.append(item) + logger.info("dataset: {}".format(res)) + + +def test_manifest_sampler_chain(): + """ + Test Manifest sampler chain + """ + logger.info("test_manifest_sampler_chain") + + sampler = ds.RandomSampler(replacement=True, num_samples=2) + child_sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=3, offset=1) + sampler.add_child(child_sampler) + data1 = ds.ManifestDataset(MANIFEST_DATA_FILE, sampler=sampler) + + # Verify dataset size + data1_size = data1.get_dataset_size() + logger.info("dataset size is: {}".format(data1_size)) + assert data1_size == 2 + # Verify number of rows + assert sum([1 for _ in data1]) == 2 + + # Verify dataset contents + res = [] + for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): + logger.info("item: {}".format(item)) + res.append(item) + logger.info("dataset: {}".format(res)) + + +def test_coco_sampler_chain(): + """ + Test Coco sampler chain + """ + logger.info("test_coco_sampler_chain") + + sampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5) + child_sampler = ds.RandomSampler(replacement=True, num_samples=2) + sampler.add_child(child_sampler) + data1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", decode=True, + sampler=sampler) + + # Verify dataset size + data1_size = data1.get_dataset_size() + logger.info("dataset size is: {}".format(data1_size)) + assert data1_size == 1 + + # Verify number of rows + assert sum([1 for _ in data1]) == 1 + + # Verify dataset contents + res = [] + for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): + logger.info("item: {}".format(item)) + res.append(item) + logger.info("dataset: {}".format(res)) + + +def test_cifar_sampler_chain(): + """ + Test Cifar sampler chain + """ + logger.info("test_cifar_sampler_chain") + + sampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5) + child_sampler = ds.RandomSampler(replacement=True, num_samples=4) + child_sampler2 = ds.SequentialSampler(start_index=0, num_samples=2) + child_sampler.add_child(child_sampler2) + sampler.add_child(child_sampler) + data1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, sampler=sampler) + # Verify dataset size + data1_size = data1.get_dataset_size() + logger.info("dataset size is: {}".format(data1_size)) + assert data1_size == 1 + + # Verify number of rows + assert sum([1 for _ in data1]) == 1 + + # Verify dataset contents + res = [] + for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): + logger.info("item: {}".format(item)) + res.append(item) + logger.info("dataset: {}".format(res)) + + +def test_voc_sampler_chain(): + """ + Test VOC sampler chain + """ + logger.info("test_voc_sampler_chain") + + sampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5) + child_sampler = ds.SequentialSampler(start_index=0) + sampler.add_child(child_sampler) + data1 = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", sampler=sampler) + + # Verify dataset size + data1_size = data1.get_dataset_size() + logger.info("dataset size is: {}".format(data1_size)) + assert data1_size == 5 + + # Verify number of rows + assert sum([1 for _ in data1]) == 5 + + # Verify dataset contents + res = [] + for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): + logger.info("item: {}".format(item)) + res.append(item) + logger.info("dataset: {}".format(res)) + + def test_numpyslices_sampler_chain_batch(): """ Test NumpySlicesDataset sampler chaining, with batch @@ -241,6 +413,12 @@ if __name__ == '__main__': test_numpyslices_sampler_no_chain() test_numpyslices_sampler_chain() test_numpyslices_sampler_chain2() + test_imagefolder_sampler_chain() + test_mnist_sampler_chain() + test_manifest_sampler_chain() + test_coco_sampler_chain() + test_cifar_sampler_chain() + test_voc_sampler_chain() test_numpyslices_sampler_chain_batch() test_sampler_chain_errors() test_manifest_sampler_chain_repeat()