forked from mindspore-Ecosystem/mindspore
!7922 add params check for offset in distributed_sampler
Merge pull request !7922 from xiaotianci/fix_distributed_sampler
This commit is contained in:
commit
163544795f
|
@ -142,6 +142,12 @@ bool DistributedSamplerObj::ValidateParams() {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (offset_ > num_shards_) {
|
||||||
|
MS_LOG(ERROR) << "DistributedSampler: invalid offset: " << offset_
|
||||||
|
<< ", which should be no more than num_shards: " << num_shards_;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -34,11 +34,12 @@ class DistributedSampler : public Sampler {
|
||||||
/// \param[in] shuffle Option to shuffle
|
/// \param[in] shuffle Option to shuffle
|
||||||
/// \param seed Seed parameter to shuffle, default to max unsigned int (different seed in sampler will
|
/// \param seed Seed parameter to shuffle, default to max unsigned int (different seed in sampler will
|
||||||
/// result in different samples being picked
|
/// result in different samples being picked
|
||||||
/// \param[in] offset The starting position which the elements in the dataset are send to.The application
|
/// \param[in] offset The starting device id where the elements in the dataset are send to, which should be no more
|
||||||
/// scenario of this parameter is when the concatdataset is set distributedSampler
|
/// than num_dev. The application scenario of this parameter is when the concatdataset is set distributedSampler
|
||||||
/// \param even_dist The option to indicate whether or not each shard returns the same number of rows.
|
/// \param even_dist The option to indicate whether or not each shard returns the same number of rows.
|
||||||
/// This option is not exposed in the python API. Current behavior is that the remainder will always
|
/// This option is not exposed in the python API. Current behavior is that the remainder will always
|
||||||
/// be handled by the first n shards, n being the corresponding device id.
|
/// be handled by the first n shards, n being the corresponding device id. Please notice that when offset is set,
|
||||||
|
/// even_dist will be forcibly converted to false for sending rest datasets in concatdataset scenario.
|
||||||
DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
|
DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
|
||||||
uint32_t seed = std::numeric_limits<uint32_t>::max(), int64_t offset = -1, bool even_dist = true);
|
uint32_t seed = std::numeric_limits<uint32_t>::max(), int64_t offset = -1, bool even_dist = true);
|
||||||
|
|
||||||
|
|
|
@ -223,7 +223,8 @@ class DistributedSampler(BuiltinSampler):
|
||||||
shard_id (int): Shard ID of the current shard within num_shards.
|
shard_id (int): Shard ID of the current shard within num_shards.
|
||||||
shuffle (bool, optional): If True, the indices are shuffled (default=True).
|
shuffle (bool, optional): If True, the indices are shuffled (default=True).
|
||||||
num_samples (int, optional): The number of samples to draw (default=None, all elements).
|
num_samples (int, optional): The number of samples to draw (default=None, all elements).
|
||||||
offset(int, optional): The starting sample ID where access to elements in the dataset begins (default=-1).
|
offset(int, optional): The starting shard ID where the elements in the dataset are sent to (default=-1), which
|
||||||
|
should be no more than num_shards.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore.dataset as ds
|
>>> import mindspore.dataset as ds
|
||||||
|
@ -238,6 +239,7 @@ class DistributedSampler(BuiltinSampler):
|
||||||
ValueError: If num_shards is not positive.
|
ValueError: If num_shards is not positive.
|
||||||
ValueError: If shard_id is smaller than 0 or equal to num_shards or larger than num_shards.
|
ValueError: If shard_id is smaller than 0 or equal to num_shards or larger than num_shards.
|
||||||
ValueError: If shuffle is not a boolean value.
|
ValueError: If shuffle is not a boolean value.
|
||||||
|
ValueError: If offset is greater than num_shards.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1):
|
def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1):
|
||||||
|
@ -255,6 +257,10 @@ class DistributedSampler(BuiltinSampler):
|
||||||
raise ValueError("num_samples should be a positive integer "
|
raise ValueError("num_samples should be a positive integer "
|
||||||
"value, but got num_samples={}".format(num_samples))
|
"value, but got num_samples={}".format(num_samples))
|
||||||
|
|
||||||
|
if offset > num_shards:
|
||||||
|
raise ValueError("offset should be no more than num_shards={}, "
|
||||||
|
"but got offset={}".format(num_shards, offset))
|
||||||
|
|
||||||
self.num_shards = num_shards
|
self.num_shards = num_shards
|
||||||
self.shard_id = shard_id
|
self.shard_id = shard_id
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
|
|
|
@ -116,3 +116,42 @@ TEST_F(MindDataTestPipeline, TestWeightedRandomSamplerFail) {
|
||||||
std::shared_ptr<SamplerObj> sampl3 = WeightedRandomSampler(weights3);
|
std::shared_ptr<SamplerObj> sampl3 = WeightedRandomSampler(weights3);
|
||||||
EXPECT_EQ(sampl3, nullptr);
|
EXPECT_EQ(sampl3, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestPipeline, TestDistributedSamplerSuccess) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDistributedSamplerSuccess.";
|
||||||
|
// Test basic setting of distributed_sampler
|
||||||
|
|
||||||
|
// num_shards=4, shard_id=0, shuffle=false, num_samplers=0, seed=0, offset=-1, even_dist=true
|
||||||
|
std::shared_ptr<SamplerObj> sampler = DistributedSampler(4, 0, false, 0, 0, -1, true);
|
||||||
|
EXPECT_NE(sampler, nullptr);
|
||||||
|
|
||||||
|
// Create an ImageFolder Dataset
|
||||||
|
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||||
|
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, sampler);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Iterate the dataset and get each row
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||||
|
iter->GetNextRow(&row);
|
||||||
|
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
i++;
|
||||||
|
auto label = row["label"];
|
||||||
|
iter->GetNextRow(&row);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(i, 11);
|
||||||
|
iter->Stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestPipeline, TestDistributedSamplerFail) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDistributedSamplerFail.";
|
||||||
|
// Test invalid offset setting of distributed_sampler
|
||||||
|
|
||||||
|
// offset=5 which is greater than num_shards=4
|
||||||
|
std::shared_ptr<SamplerObj> sampler = DistributedSampler(4, 0, false, 0, 0, 5, false);
|
||||||
|
EXPECT_EQ(sampler, nullptr);
|
||||||
|
}
|
||||||
|
|
|
@ -236,6 +236,12 @@ def test_add_sampler_invalid_input():
|
||||||
assert "Conflicting arguments during sampler assignments" in str(info.value)
|
assert "Conflicting arguments during sampler assignments" in str(info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_distributed_sampler_invalid_offset():
|
||||||
|
with pytest.raises(ValueError) as info:
|
||||||
|
sampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=5)
|
||||||
|
assert "offset should be no more than num_shards" in str(info.value)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_sequential_sampler(True)
|
test_sequential_sampler(True)
|
||||||
test_random_sampler(True)
|
test_random_sampler(True)
|
||||||
|
@ -245,3 +251,4 @@ if __name__ == '__main__':
|
||||||
test_subset_sampler()
|
test_subset_sampler()
|
||||||
test_sampler_chain()
|
test_sampler_chain()
|
||||||
test_add_sampler_invalid_input()
|
test_add_sampler_invalid_input()
|
||||||
|
test_distributed_sampler_invalid_offset()
|
||||||
|
|
Loading…
Reference in New Issue