forked from mindspore-Ecosystem/mindspore
!7922 add params check for offset in distributed_sampler
Merge pull request !7922 from xiaotianci/fix_distributed_sampler
This commit is contained in:
commit
163544795f
|
@ -142,6 +142,12 @@ bool DistributedSamplerObj::ValidateParams() {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (offset_ > num_shards_) {
|
||||
MS_LOG(ERROR) << "DistributedSampler: invalid offset: " << offset_
|
||||
<< ", which should be no more than num_shards: " << num_shards_;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -34,11 +34,12 @@ class DistributedSampler : public Sampler {
|
|||
/// \param[in] shuffle Option to shuffle
|
||||
/// \param seed Seed parameter to shuffle, default to max unsigned int (different seed in sampler will
|
||||
/// result in different samples being picked
|
||||
/// \param[in] offset The starting position which the elements in the dataset are send to.The application
|
||||
/// scenario of this parameter is when the concatdataset is set distributedSampler
|
||||
/// \param[in] offset The starting device id where the elements in the dataset are send to, which should be no more
|
||||
/// than num_dev. The application scenario of this parameter is when the concatdataset is set distributedSampler
|
||||
/// \param even_dist The option to indicate whether or not each shard returns the same number of rows.
|
||||
/// This option is not exposed in the python API. Current behavior is that the remainder will always
|
||||
/// be handled by the first n shards, n being the corresponding device id.
|
||||
/// be handled by the first n shards, n being the corresponding device id. Please notice that when offset is set,
|
||||
/// even_dist will be forcibly converted to false for sending rest datasets in concatdataset scenario.
|
||||
DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
|
||||
uint32_t seed = std::numeric_limits<uint32_t>::max(), int64_t offset = -1, bool even_dist = true);
|
||||
|
||||
|
|
|
@ -223,7 +223,8 @@ class DistributedSampler(BuiltinSampler):
|
|||
shard_id (int): Shard ID of the current shard within num_shards.
|
||||
shuffle (bool, optional): If True, the indices are shuffled (default=True).
|
||||
num_samples (int, optional): The number of samples to draw (default=None, all elements).
|
||||
offset(int, optional): The starting sample ID where access to elements in the dataset begins (default=-1).
|
||||
offset(int, optional): The starting shard ID where the elements in the dataset are sent to (default=-1), which
|
||||
should be no more than num_shards.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -238,6 +239,7 @@ class DistributedSampler(BuiltinSampler):
|
|||
ValueError: If num_shards is not positive.
|
||||
ValueError: If shard_id is smaller than 0 or equal to num_shards or larger than num_shards.
|
||||
ValueError: If shuffle is not a boolean value.
|
||||
ValueError: If offset is greater than num_shards.
|
||||
"""
|
||||
|
||||
def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1):
|
||||
|
@ -255,6 +257,10 @@ class DistributedSampler(BuiltinSampler):
|
|||
raise ValueError("num_samples should be a positive integer "
|
||||
"value, but got num_samples={}".format(num_samples))
|
||||
|
||||
if offset > num_shards:
|
||||
raise ValueError("offset should be no more than num_shards={}, "
|
||||
"but got offset={}".format(num_shards, offset))
|
||||
|
||||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
self.shuffle = shuffle
|
||||
|
|
|
@ -116,3 +116,42 @@ TEST_F(MindDataTestPipeline, TestWeightedRandomSamplerFail) {
|
|||
std::shared_ptr<SamplerObj> sampl3 = WeightedRandomSampler(weights3);
|
||||
EXPECT_EQ(sampl3, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestDistributedSamplerSuccess) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDistributedSamplerSuccess.";
|
||||
// Test basic setting of distributed_sampler
|
||||
|
||||
// num_shards=4, shard_id=0, shuffle=false, num_samplers=0, seed=0, offset=-1, even_dist=true
|
||||
std::shared_ptr<SamplerObj> sampler = DistributedSampler(4, 0, false, 0, 0, -1, true);
|
||||
EXPECT_NE(sampler, nullptr);
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, sampler);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto label = row["label"];
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 11);
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestDistributedSamplerFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDistributedSamplerFail.";
|
||||
// Test invalid offset setting of distributed_sampler
|
||||
|
||||
// offset=5 which is greater than num_shards=4
|
||||
std::shared_ptr<SamplerObj> sampler = DistributedSampler(4, 0, false, 0, 0, 5, false);
|
||||
EXPECT_EQ(sampler, nullptr);
|
||||
}
|
||||
|
|
|
@ -236,6 +236,12 @@ def test_add_sampler_invalid_input():
|
|||
assert "Conflicting arguments during sampler assignments" in str(info.value)
|
||||
|
||||
|
||||
def test_distributed_sampler_invalid_offset():
|
||||
with pytest.raises(ValueError) as info:
|
||||
sampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=5)
|
||||
assert "offset should be no more than num_shards" in str(info.value)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sequential_sampler(True)
|
||||
test_random_sampler(True)
|
||||
|
@ -245,3 +251,4 @@ if __name__ == '__main__':
|
|||
test_subset_sampler()
|
||||
test_sampler_chain()
|
||||
test_add_sampler_invalid_input()
|
||||
test_distributed_sampler_invalid_offset()
|
||||
|
|
Loading…
Reference in New Issue