add params checking of offset in distributed_sampler

add related comment to explain
add related UT
This commit is contained in:
Xiao Tianci 2020-10-28 10:12:26 +08:00
parent 9fc0218c56
commit 823e52f1dc
5 changed files with 63 additions and 4 deletions

View File

@ -142,6 +142,12 @@ bool DistributedSamplerObj::ValidateParams() {
return false;
}
if (offset_ > num_shards_) {
MS_LOG(ERROR) << "DistributedSampler: invalid offset: " << offset_
<< ", which should be no more than num_shards: " << num_shards_;
return false;
}
return true;
}

View File

@ -34,11 +34,12 @@ class DistributedSampler : public Sampler {
/// \param[in] shuffle Option to shuffle
/// \param seed Seed parameter to shuffle, default to max unsigned int (different seed in sampler will
/// result in different samples being picked
/// \param[in] offset The starting position which the elements in the dataset are send to.The application
/// scenario of this parameter is when the concatdataset is set distributedSampler
/// \param[in] offset The starting device id where the elements in the dataset are send to, which should be no more
/// than num_dev. The application scenario of this parameter is when the concatdataset is set distributedSampler
/// \param even_dist The option to indicate whether or not each shard returns the same number of rows.
/// This option is not exposed in the python API. Current behavior is that the remainder will always
/// be handled by the first n shards, n being the corresponding device id.
/// be handled by the first n shards, n being the corresponding device id. Please notice that when offset is set,
/// even_dist will be forcibly converted to false for sending rest datasets in concatdataset scenario.
DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
uint32_t seed = std::numeric_limits<uint32_t>::max(), int64_t offset = -1, bool even_dist = true);

View File

@ -223,7 +223,8 @@ class DistributedSampler(BuiltinSampler):
shard_id (int): Shard ID of the current shard within num_shards.
shuffle (bool, optional): If True, the indices are shuffled (default=True).
num_samples (int, optional): The number of samples to draw (default=None, all elements).
offset(int, optional): The starting sample ID where access to elements in the dataset begins (default=-1).
offset(int, optional): The starting shard ID where the elements in the dataset are sent to (default=-1), which
should be no more than num_shards.
Examples:
>>> import mindspore.dataset as ds
@ -238,6 +239,7 @@ class DistributedSampler(BuiltinSampler):
ValueError: If num_shards is not positive.
ValueError: If shard_id is smaller than 0 or equal to num_shards or larger than num_shards.
ValueError: If shuffle is not a boolean value.
ValueError: If offset is greater than num_shards.
"""
def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1):
@ -255,6 +257,10 @@ class DistributedSampler(BuiltinSampler):
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples))
if offset > num_shards:
raise ValueError("offset should be no more than num_shards={}, "
"but got offset={}".format(num_shards, offset))
self.num_shards = num_shards
self.shard_id = shard_id
self.shuffle = shuffle

View File

@ -116,3 +116,42 @@ TEST_F(MindDataTestPipeline, TestWeightedRandomSamplerFail) {
std::shared_ptr<SamplerObj> sampl3 = WeightedRandomSampler(weights3);
EXPECT_EQ(sampl3, nullptr);
}
TEST_F(MindDataTestPipeline, TestDistributedSamplerSuccess) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDistributedSamplerSuccess.";
// Test basic setting of distributed_sampler
// num_shards=4, shard_id=0, shuffle=false, num_samplers=0, seed=0, offset=-1, even_dist=true
std::shared_ptr<SamplerObj> sampler = DistributedSampler(4, 0, false, 0, 0, -1, true);
EXPECT_NE(sampler, nullptr);
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, sampler);
EXPECT_NE(ds, nullptr);
// Iterate the dataset and get each row
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto label = row["label"];
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 11);
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestDistributedSamplerFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDistributedSamplerFail.";
// Test invalid offset setting of distributed_sampler
// offset=5 which is greater than num_shards=4
std::shared_ptr<SamplerObj> sampler = DistributedSampler(4, 0, false, 0, 0, 5, false);
EXPECT_EQ(sampler, nullptr);
}

View File

@ -236,6 +236,12 @@ def test_add_sampler_invalid_input():
assert "Conflicting arguments during sampler assignments" in str(info.value)
def test_distributed_sampler_invalid_offset():
with pytest.raises(ValueError) as info:
sampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=5)
assert "offset should be no more than num_shards" in str(info.value)
if __name__ == '__main__':
test_sequential_sampler(True)
test_random_sampler(True)
@ -245,3 +251,4 @@ if __name__ == '__main__':
test_subset_sampler()
test_sampler_chain()
test_add_sampler_invalid_input()
test_distributed_sampler_invalid_offset()