diff --git a/mindspore/ccsrc/minddata/dataset/api/samplers.cc b/mindspore/ccsrc/minddata/dataset/api/samplers.cc index a56add0dc2d..9307f4dc259 100644 --- a/mindspore/ccsrc/minddata/dataset/api/samplers.cc +++ b/mindspore/ccsrc/minddata/dataset/api/samplers.cc @@ -142,6 +142,12 @@ bool DistributedSamplerObj::ValidateParams() { return false; } + if (offset_ > num_shards_) { + MS_LOG(ERROR) << "DistributedSampler: invalid offset: " << offset_ + << ", which should be no more than num_shards: " << num_shards_; + return false; + } + return true; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h index c5db9862b16..d9425b052e2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h @@ -34,11 +34,12 @@ class DistributedSampler : public Sampler { /// \param[in] shuffle Option to shuffle /// \param seed Seed parameter to shuffle, default to max unsigned int (different seed in sampler will /// result in different samples being picked - /// \param[in] offset The starting position which the elements in the dataset are send to.The application - /// scenario of this parameter is when the concatdataset is set distributedSampler + /// \param[in] offset The starting device id where the elements in the dataset are send to, which should be no more + /// than num_dev. The application scenario of this parameter is when the concatdataset is set distributedSampler /// \param even_dist The option to indicate whether or not each shard returns the same number of rows. /// This option is not exposed in the python API. Current behavior is that the remainder will always - /// be handled by the first n shards, n being the corresponding device id. + /// be handled by the first n shards, n being the corresponding device id. Please notice that when offset is set, + /// even_dist will be forcibly converted to false for sending rest datasets in concatdataset scenario. DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, uint32_t seed = std::numeric_limits::max(), int64_t offset = -1, bool even_dist = true); diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 6ff1c3413ba..fb1077eb74c 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -223,7 +223,8 @@ class DistributedSampler(BuiltinSampler): shard_id (int): Shard ID of the current shard within num_shards. shuffle (bool, optional): If True, the indices are shuffled (default=True). num_samples (int, optional): The number of samples to draw (default=None, all elements). - offset(int, optional): The starting sample ID where access to elements in the dataset begins (default=-1). + offset(int, optional): The starting shard ID where the elements in the dataset are sent to (default=-1), which + should be no more than num_shards. Examples: >>> import mindspore.dataset as ds @@ -238,6 +239,7 @@ class DistributedSampler(BuiltinSampler): ValueError: If num_shards is not positive. ValueError: If shard_id is smaller than 0 or equal to num_shards or larger than num_shards. ValueError: If shuffle is not a boolean value. + ValueError: If offset is greater than num_shards. """ def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1): @@ -255,6 +257,10 @@ class DistributedSampler(BuiltinSampler): raise ValueError("num_samples should be a positive integer " "value, but got num_samples={}".format(num_samples)) + if offset > num_shards: + raise ValueError("offset should be no more than num_shards={}, " + "but got offset={}".format(num_shards, offset)) + self.num_shards = num_shards self.shard_id = shard_id self.shuffle = shuffle diff --git a/tests/ut/cpp/dataset/c_api_samplers_test.cc b/tests/ut/cpp/dataset/c_api_samplers_test.cc index 47b8a43170c..2b246f19a3b 100644 --- a/tests/ut/cpp/dataset/c_api_samplers_test.cc +++ b/tests/ut/cpp/dataset/c_api_samplers_test.cc @@ -116,3 +116,42 @@ TEST_F(MindDataTestPipeline, TestWeightedRandomSamplerFail) { std::shared_ptr sampl3 = WeightedRandomSampler(weights3); EXPECT_EQ(sampl3, nullptr); } + +TEST_F(MindDataTestPipeline, TestDistributedSamplerSuccess) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDistributedSamplerSuccess."; + // Test basic setting of distributed_sampler + + // num_shards=4, shard_id=0, shuffle=false, num_samplers=0, seed=0, offset=-1, even_dist=true + std::shared_ptr sampler = DistributedSampler(4, 0, false, 0, 0, -1, true); + EXPECT_NE(sampler, nullptr); + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, false, sampler); + EXPECT_NE(ds, nullptr); + + // Iterate the dataset and get each row + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto label = row["label"]; + iter->GetNextRow(&row); + } + + EXPECT_EQ(i, 11); + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestDistributedSamplerFail) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDistributedSamplerFail."; + // Test invalid offset setting of distributed_sampler + + // offset=5 which is greater than num_shards=4 + std::shared_ptr sampler = DistributedSampler(4, 0, false, 0, 0, 5, false); + EXPECT_EQ(sampler, nullptr); +} diff --git a/tests/ut/python/dataset/test_sampler.py b/tests/ut/python/dataset/test_sampler.py index 2216fdacdba..ab8a3b964c7 100644 --- a/tests/ut/python/dataset/test_sampler.py +++ b/tests/ut/python/dataset/test_sampler.py @@ -236,6 +236,12 @@ def test_add_sampler_invalid_input(): assert "Conflicting arguments during sampler assignments" in str(info.value) +def test_distributed_sampler_invalid_offset(): + with pytest.raises(ValueError) as info: + sampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=5) + assert "offset should be no more than num_shards" in str(info.value) + + if __name__ == '__main__': test_sequential_sampler(True) test_random_sampler(True) @@ -245,3 +251,4 @@ if __name__ == '__main__': test_subset_sampler() test_sampler_chain() test_add_sampler_invalid_input() + test_distributed_sampler_invalid_offset()