Fixing issues in distributed sampler
Added Distributed sampler option Fix style
This commit is contained in:
parent
1165b27f41
commit
8c018da468
|
@ -31,8 +31,8 @@ SamplerObj::SamplerObj() {}
|
||||||
|
|
||||||
/// Function to create a Distributed Sampler.
|
/// Function to create a Distributed Sampler.
|
||||||
std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle,
|
std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle,
|
||||||
int64_t num_samples, uint32_t seed) {
|
int64_t num_samples, uint32_t seed, bool even_dist) {
|
||||||
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed);
|
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed, even_dist);
|
||||||
// Input validation
|
// Input validation
|
||||||
if (!sampler->ValidateParams()) {
|
if (!sampler->ValidateParams()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -95,8 +95,13 @@ std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vecto
|
||||||
|
|
||||||
// DistributedSampler
|
// DistributedSampler
|
||||||
DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
|
DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
|
||||||
uint32_t seed)
|
uint32_t seed, bool even_dist)
|
||||||
: num_shards_(num_shards), shard_id_(shard_id), shuffle_(shuffle), num_samples_(num_samples), seed_(seed) {}
|
: num_shards_(num_shards),
|
||||||
|
shard_id_(shard_id),
|
||||||
|
shuffle_(shuffle),
|
||||||
|
num_samples_(num_samples),
|
||||||
|
seed_(seed),
|
||||||
|
even_dist_(even_dist) {}
|
||||||
|
|
||||||
bool DistributedSamplerObj::ValidateParams() {
|
bool DistributedSamplerObj::ValidateParams() {
|
||||||
if (num_shards_ <= 0) {
|
if (num_shards_ <= 0) {
|
||||||
|
@ -118,7 +123,8 @@ bool DistributedSamplerObj::ValidateParams() {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Sampler> DistributedSamplerObj::Build() {
|
std::shared_ptr<Sampler> DistributedSamplerObj::Build() {
|
||||||
return std::make_shared<dataset::DistributedSampler>(num_samples_, num_shards_, shard_id_, shuffle_, seed_);
|
return std::make_shared<dataset::DistributedSampler>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
|
||||||
|
even_dist_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// PKSampler
|
// PKSampler
|
||||||
|
|
|
@ -24,13 +24,14 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
|
DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
|
||||||
uint32_t seed)
|
uint32_t seed, bool even_dist)
|
||||||
: Sampler(num_samples, std::numeric_limits<int64_t>::max()),
|
: Sampler(num_samples, std::numeric_limits<int64_t>::max()),
|
||||||
cnt_(0),
|
cnt_(0),
|
||||||
seed_(seed == std::numeric_limits<uint32_t>::max() ? GetSeed() : seed),
|
seed_(seed == std::numeric_limits<uint32_t>::max() ? GetSeed() : seed),
|
||||||
device_id_(dev_id),
|
device_id_(dev_id),
|
||||||
num_devices_(num_dev),
|
num_devices_(num_dev),
|
||||||
shuffle_(shuffle) {}
|
shuffle_(shuffle),
|
||||||
|
even_dist_(even_dist) {}
|
||||||
|
|
||||||
Status DistributedSampler::InitSampler() {
|
Status DistributedSampler::InitSampler() {
|
||||||
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
|
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
|
||||||
|
@ -43,7 +44,15 @@ Status DistributedSampler::InitSampler() {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0,
|
CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0,
|
||||||
"fail to init DistributedSampler");
|
"fail to init DistributedSampler");
|
||||||
rnd_.seed(seed_++);
|
rnd_.seed(seed_++);
|
||||||
samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices)
|
if (even_dist_) {
|
||||||
|
samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices)
|
||||||
|
} else {
|
||||||
|
int64_t mod = num_rows_ % num_devices_;
|
||||||
|
samples_per_buffer_ = num_rows_ / num_devices_;
|
||||||
|
if (mod > device_id_) {
|
||||||
|
samples_per_buffer_++;
|
||||||
|
}
|
||||||
|
}
|
||||||
samples_per_buffer_ = num_samples_ < samples_per_buffer_ ? num_samples_ : samples_per_buffer_;
|
samples_per_buffer_ = num_samples_ < samples_per_buffer_ ? num_samples_ : samples_per_buffer_;
|
||||||
if (shuffle_ == true) {
|
if (shuffle_ == true) {
|
||||||
shuffle_vec_.reserve(num_rows_);
|
shuffle_vec_.reserve(num_rows_);
|
||||||
|
|
|
@ -27,26 +27,32 @@ namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
class DistributedSampler : public Sampler {
|
class DistributedSampler : public Sampler {
|
||||||
public:
|
public:
|
||||||
// @param num_samples
|
/// \brief Constructor
|
||||||
// @param int64_t num_dev
|
/// \param[in] num_samples The total number of rows in the dataset
|
||||||
// @param int64_t dev_id
|
/// \param[in] num_dev Total number of shards for the distributed sampler
|
||||||
// @param bool shuffle
|
/// \param[in] dev_id Device id of the shard
|
||||||
|
/// \param[in] shuffle Option to shuffle
|
||||||
|
/// \param seed Seed parameter to shuffle, default to max unsigned int (different seed in sampler will
|
||||||
|
/// result in different samples being picked
|
||||||
|
/// \param even_dist The option to indicate whether or not each shard returns the same number of rows.
|
||||||
|
/// This option is not exposed in the python API. Current behavior is that the remainder will always
|
||||||
|
/// be handled by the first n shards, n being the corresponding device id.
|
||||||
DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
|
DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
|
||||||
uint32_t seed = std::numeric_limits<uint32_t>::max());
|
uint32_t seed = std::numeric_limits<uint32_t>::max(), bool even_dist = true);
|
||||||
|
|
||||||
// default destructor
|
/// \brief default destructor
|
||||||
~DistributedSampler() = default;
|
~DistributedSampler() = default;
|
||||||
|
|
||||||
// @param std::unique_ptr<DataBuffer> * pBuffer
|
/// \param std::unique_ptr<DataBuffer> * pBuffer
|
||||||
// @param int32_t workerId
|
/// \param int32_t workerId
|
||||||
// @return - The error code return
|
/// \return Status code
|
||||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||||
|
|
||||||
// Init sampler, called by base class or python
|
/// Init sampler, called by base class or python
|
||||||
Status InitSampler() override;
|
Status InitSampler() override;
|
||||||
|
|
||||||
// for next epoch of sampleIds
|
/// \brief for next epoch of sampleIds
|
||||||
// @return - The error code return
|
/// \return Status code
|
||||||
Status ResetSampler() override;
|
Status ResetSampler() override;
|
||||||
|
|
||||||
void Print(std::ostream &out, bool show_all) const override;
|
void Print(std::ostream &out, bool show_all) const override;
|
||||||
|
@ -59,6 +65,7 @@ class DistributedSampler : public Sampler {
|
||||||
bool shuffle_;
|
bool shuffle_;
|
||||||
std::mt19937 rnd_;
|
std::mt19937 rnd_;
|
||||||
std::vector<int64_t> shuffle_vec_;
|
std::vector<int64_t> shuffle_vec_;
|
||||||
|
bool even_dist_;
|
||||||
};
|
};
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -52,9 +52,12 @@ class WeightedRandomSamplerObj;
|
||||||
/// \param[in] shuffle - If true, the indices are shuffled.
|
/// \param[in] shuffle - If true, the indices are shuffled.
|
||||||
/// \param[in] num_samples - The number of samples to draw (default to all elements).
|
/// \param[in] num_samples - The number of samples to draw (default to all elements).
|
||||||
/// \param[in] seed - The seed in use when shuffle is true.
|
/// \param[in] seed - The seed in use when shuffle is true.
|
||||||
|
/// \param[in] even_dist - If true, each shard would return the same number of rows (default to true).
|
||||||
|
/// If false the total rows returned by all the shards would not have overlap.
|
||||||
/// \return Shared pointer to the current Sampler.
|
/// \return Shared pointer to the current Sampler.
|
||||||
std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true,
|
std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true,
|
||||||
int64_t num_samples = 0, uint32_t seed = 1);
|
int64_t num_samples = 0, uint32_t seed = 1,
|
||||||
|
bool even_dist = true);
|
||||||
|
|
||||||
/// Function to create a PK Sampler.
|
/// Function to create a PK Sampler.
|
||||||
/// \notes Samples K elements for each P class in the dataset.
|
/// \notes Samples K elements for each P class in the dataset.
|
||||||
|
@ -100,7 +103,8 @@ std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vecto
|
||||||
/* ####################################### Derived Sampler classes ################################# */
|
/* ####################################### Derived Sampler classes ################################# */
|
||||||
class DistributedSamplerObj : public SamplerObj {
|
class DistributedSamplerObj : public SamplerObj {
|
||||||
public:
|
public:
|
||||||
DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed);
|
DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed,
|
||||||
|
bool even_dist);
|
||||||
|
|
||||||
~DistributedSamplerObj() = default;
|
~DistributedSamplerObj() = default;
|
||||||
|
|
||||||
|
@ -114,6 +118,7 @@ class DistributedSamplerObj : public SamplerObj {
|
||||||
bool shuffle_;
|
bool shuffle_;
|
||||||
int64_t num_samples_;
|
int64_t num_samples_;
|
||||||
uint32_t seed_;
|
uint32_t seed_;
|
||||||
|
bool even_dist_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class PKSamplerObj : public SamplerObj {
|
class PKSamplerObj : public SamplerObj {
|
||||||
|
|
|
@ -92,7 +92,9 @@ SET(DE_UT_SRCS
|
||||||
tensor_op_fusion_pass_test.cc
|
tensor_op_fusion_pass_test.cc
|
||||||
sliding_window_op_test.cc
|
sliding_window_op_test.cc
|
||||||
epoch_ctrl_op_test.cc
|
epoch_ctrl_op_test.cc
|
||||||
swap_red_blue_test.cc
|
sentence_piece_vocab_op_test.cc
|
||||||
|
swap_red_blue_test.cc
|
||||||
|
distributed_sampler_test.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
if (ENABLE_PYTHON)
|
if (ENABLE_PYTHON)
|
||||||
|
|
|
@ -0,0 +1,123 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "common/common.h"
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/constants.h"
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#include "minddata/dataset/engine/data_buffer.h"
|
||||||
|
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||||
|
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
|
using namespace mindspore::dataset;
|
||||||
|
using mindspore::MsLogLevel::INFO;
|
||||||
|
using mindspore::ExceptionType::NoExceptionType;
|
||||||
|
using mindspore::LogStream;
|
||||||
|
|
||||||
|
class MindDataTestDistributedSampler : public UT::Common {
|
||||||
|
public:
|
||||||
|
class DummyRandomAccessOp : public RandomAccessOp {
|
||||||
|
public:
|
||||||
|
DummyRandomAccessOp(uint64_t num_rows) {
|
||||||
|
// row count is in base class as protected member
|
||||||
|
// GetNumRowsInDataset does not need an override, the default from base class is fine.
|
||||||
|
num_rows_ = num_rows;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(MindDataTestDistributedSampler, TestTwoShardsOne) {
|
||||||
|
// num samples to draw.
|
||||||
|
uint64_t num_samples = 7;
|
||||||
|
|
||||||
|
// create sampler with replacement = true
|
||||||
|
DistributedSampler m_sampler(num_samples, 2, 0, false, 0, false);
|
||||||
|
DummyRandomAccessOp dummyRandomAccessOp(num_samples);
|
||||||
|
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||||
|
|
||||||
|
std::unique_ptr<DataBuffer> db;
|
||||||
|
TensorRow row;
|
||||||
|
std::vector<uint64_t> out;
|
||||||
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||||
|
db->PopRow(&row);
|
||||||
|
for (const auto &t : row) {
|
||||||
|
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
||||||
|
out.push_back(*it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_EQ(4, out.size());
|
||||||
|
|
||||||
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||||
|
ASSERT_EQ(db->eoe(), true);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestDistributedSampler, TestTwoShardsTwo) {
|
||||||
|
// num samples to draw.
|
||||||
|
uint64_t num_samples = 7;
|
||||||
|
|
||||||
|
// create sampler with replacement = true
|
||||||
|
DistributedSampler m_sampler(num_samples, 2, 1, false, 0, false);
|
||||||
|
DummyRandomAccessOp dummyRandomAccessOp(num_samples);
|
||||||
|
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||||
|
|
||||||
|
std::unique_ptr<DataBuffer> db;
|
||||||
|
TensorRow row;
|
||||||
|
std::vector<uint64_t> out;
|
||||||
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||||
|
db->PopRow(&row);
|
||||||
|
for (const auto &t : row) {
|
||||||
|
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
||||||
|
out.push_back(*it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_EQ(3, out.size());
|
||||||
|
|
||||||
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||||
|
ASSERT_EQ(db->eoe(), true);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestDistributedSampler, TestThreeShards) {
|
||||||
|
// num samples to draw.
|
||||||
|
uint64_t num_samples = 2;
|
||||||
|
|
||||||
|
// create sampler with replacement = true
|
||||||
|
DistributedSampler m_sampler(num_samples, 3, 2, false, 0, false);
|
||||||
|
DummyRandomAccessOp dummyRandomAccessOp(num_samples);
|
||||||
|
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||||
|
|
||||||
|
std::unique_ptr<DataBuffer> db;
|
||||||
|
TensorRow row;
|
||||||
|
std::vector<uint64_t> out;
|
||||||
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||||
|
db->PopRow(&row);
|
||||||
|
for (const auto &t : row) {
|
||||||
|
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
||||||
|
out.push_back(*it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_EQ(0, out.size());
|
||||||
|
|
||||||
|
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||||
|
ASSERT_EQ(db->eoe(), true);
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue