Fixing issues in distributed sampler

Added Distributed sampler option

Fix style
This commit is contained in:
Eric 2020-07-27 02:08:04 -04:00
parent 1165b27f41
commit 8c018da468
6 changed files with 175 additions and 23 deletions

View File

@ -31,8 +31,8 @@ SamplerObj::SamplerObj() {}
/// Function to create a Distributed Sampler.
std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle,
int64_t num_samples, uint32_t seed) {
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed);
int64_t num_samples, uint32_t seed, bool even_dist) {
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed, even_dist);
// Input validation
if (!sampler->ValidateParams()) {
return nullptr;
@ -95,8 +95,13 @@ std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vecto
// DistributedSampler
DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
uint32_t seed)
: num_shards_(num_shards), shard_id_(shard_id), shuffle_(shuffle), num_samples_(num_samples), seed_(seed) {}
uint32_t seed, bool even_dist)
: num_shards_(num_shards),
shard_id_(shard_id),
shuffle_(shuffle),
num_samples_(num_samples),
seed_(seed),
even_dist_(even_dist) {}
bool DistributedSamplerObj::ValidateParams() {
if (num_shards_ <= 0) {
@ -118,7 +123,8 @@ bool DistributedSamplerObj::ValidateParams() {
}
std::shared_ptr<Sampler> DistributedSamplerObj::Build() {
return std::make_shared<dataset::DistributedSampler>(num_samples_, num_shards_, shard_id_, shuffle_, seed_);
return std::make_shared<dataset::DistributedSampler>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
even_dist_);
}
// PKSampler

View File

@ -24,13 +24,14 @@
namespace mindspore {
namespace dataset {
DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
uint32_t seed)
uint32_t seed, bool even_dist)
: Sampler(num_samples, std::numeric_limits<int64_t>::max()),
cnt_(0),
seed_(seed == std::numeric_limits<uint32_t>::max() ? GetSeed() : seed),
device_id_(dev_id),
num_devices_(num_dev),
shuffle_(shuffle) {}
shuffle_(shuffle),
even_dist_(even_dist) {}
Status DistributedSampler::InitSampler() {
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
@ -43,7 +44,15 @@ Status DistributedSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0,
"fail to init DistributedSampler");
rnd_.seed(seed_++);
samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices)
if (even_dist_) {
samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices)
} else {
int64_t mod = num_rows_ % num_devices_;
samples_per_buffer_ = num_rows_ / num_devices_;
if (mod > device_id_) {
samples_per_buffer_++;
}
}
samples_per_buffer_ = num_samples_ < samples_per_buffer_ ? num_samples_ : samples_per_buffer_;
if (shuffle_ == true) {
shuffle_vec_.reserve(num_rows_);

View File

@ -27,26 +27,32 @@ namespace mindspore {
namespace dataset {
class DistributedSampler : public Sampler {
public:
// @param num_samples
// @param int64_t num_dev
// @param int64_t dev_id
// @param bool shuffle
/// \brief Constructor
/// \param[in] num_samples The total number of rows in the dataset
/// \param[in] num_dev Total number of shards for the distributed sampler
/// \param[in] dev_id Device id of the shard
/// \param[in] shuffle Option to shuffle
/// \param seed Seed parameter to shuffle, default to max unsigned int (different seed in sampler will
/// result in different samples being picked
/// \param even_dist The option to indicate whether or not each shard returns the same number of rows.
/// This option is not exposed in the python API. Current behavior is that the remainder will always
/// be handled by the first n shards, n being the corresponding device id.
DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
uint32_t seed = std::numeric_limits<uint32_t>::max());
uint32_t seed = std::numeric_limits<uint32_t>::max(), bool even_dist = true);
// default destructor
/// \brief default destructor
~DistributedSampler() = default;
// @param std::unique_ptr<DataBuffer> * pBuffer
// @param int32_t workerId
// @return - The error code return
/// \param std::unique_ptr<DataBuffer> * pBuffer
/// \param int32_t workerId
/// \return Status code
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
// Init sampler, called by base class or python
/// Init sampler, called by base class or python
Status InitSampler() override;
// for next epoch of sampleIds
// @return - The error code return
/// \brief for next epoch of sampleIds
/// \return Status code
Status ResetSampler() override;
void Print(std::ostream &out, bool show_all) const override;
@ -59,6 +65,7 @@ class DistributedSampler : public Sampler {
bool shuffle_;
std::mt19937 rnd_;
std::vector<int64_t> shuffle_vec_;
bool even_dist_;
};
} // namespace dataset
} // namespace mindspore

View File

@ -52,9 +52,12 @@ class WeightedRandomSamplerObj;
/// \param[in] shuffle - If true, the indices are shuffled.
/// \param[in] num_samples - The number of samples to draw (default to all elements).
/// \param[in] seed - The seed in use when shuffle is true.
/// \param[in] even_dist - If true, each shard would return the same number of rows (default to true).
/// If false the total rows returned by all the shards would not have overlap.
/// \return Shared pointer to the current Sampler.
std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true,
int64_t num_samples = 0, uint32_t seed = 1);
int64_t num_samples = 0, uint32_t seed = 1,
bool even_dist = true);
/// Function to create a PK Sampler.
/// \notes Samples K elements for each P class in the dataset.
@ -100,7 +103,8 @@ std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vecto
/* ####################################### Derived Sampler classes ################################# */
class DistributedSamplerObj : public SamplerObj {
public:
DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed);
DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed,
bool even_dist);
~DistributedSamplerObj() = default;
@ -114,6 +118,7 @@ class DistributedSamplerObj : public SamplerObj {
bool shuffle_;
int64_t num_samples_;
uint32_t seed_;
bool even_dist_;
};
class PKSamplerObj : public SamplerObj {

View File

@ -92,7 +92,9 @@ SET(DE_UT_SRCS
tensor_op_fusion_pass_test.cc
sliding_window_op_test.cc
epoch_ctrl_op_test.cc
swap_red_blue_test.cc
sentence_piece_vocab_op_test.cc
swap_red_blue_test.cc
distributed_sampler_test.cc
)
if (ENABLE_PYTHON)

View File

@ -0,0 +1,123 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "gtest/gtest.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "utils/log_adapter.h"
#include <vector>
#include <unordered_set>
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
class MindDataTestDistributedSampler : public UT::Common {
public:
class DummyRandomAccessOp : public RandomAccessOp {
public:
DummyRandomAccessOp(uint64_t num_rows) {
// row count is in base class as protected member
// GetNumRowsInDataset does not need an override, the default from base class is fine.
num_rows_ = num_rows;
}
};
};
TEST_F(MindDataTestDistributedSampler, TestTwoShardsOne) {
// num samples to draw.
uint64_t num_samples = 7;
// create sampler with replacement = true
DistributedSampler m_sampler(num_samples, 2, 0, false, 0, false);
DummyRandomAccessOp dummyRandomAccessOp(num_samples);
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
std::unique_ptr<DataBuffer> db;
TensorRow row;
std::vector<uint64_t> out;
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
db->PopRow(&row);
for (const auto &t : row) {
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
out.push_back(*it);
}
}
ASSERT_EQ(4, out.size());
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
ASSERT_EQ(db->eoe(), true);
}
TEST_F(MindDataTestDistributedSampler, TestTwoShardsTwo) {
// num samples to draw.
uint64_t num_samples = 7;
// create sampler with replacement = true
DistributedSampler m_sampler(num_samples, 2, 1, false, 0, false);
DummyRandomAccessOp dummyRandomAccessOp(num_samples);
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
std::unique_ptr<DataBuffer> db;
TensorRow row;
std::vector<uint64_t> out;
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
db->PopRow(&row);
for (const auto &t : row) {
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
out.push_back(*it);
}
}
ASSERT_EQ(3, out.size());
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
ASSERT_EQ(db->eoe(), true);
}
TEST_F(MindDataTestDistributedSampler, TestThreeShards) {
// num samples to draw.
uint64_t num_samples = 2;
// create sampler with replacement = true
DistributedSampler m_sampler(num_samples, 3, 2, false, 0, false);
DummyRandomAccessOp dummyRandomAccessOp(num_samples);
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
std::unique_ptr<DataBuffer> db;
TensorRow row;
std::vector<uint64_t> out;
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
db->PopRow(&row);
for (const auto &t : row) {
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
out.push_back(*it);
}
}
ASSERT_EQ(0, out.size());
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
ASSERT_EQ(db->eoe(), true);
}