Fixing issues in distributed sampler
Added Distributed sampler option Fix style
This commit is contained in:
parent
1165b27f41
commit
8c018da468
|
@ -31,8 +31,8 @@ SamplerObj::SamplerObj() {}
|
|||
|
||||
/// Function to create a Distributed Sampler.
|
||||
std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle,
|
||||
int64_t num_samples, uint32_t seed) {
|
||||
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed);
|
||||
int64_t num_samples, uint32_t seed, bool even_dist) {
|
||||
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed, even_dist);
|
||||
// Input validation
|
||||
if (!sampler->ValidateParams()) {
|
||||
return nullptr;
|
||||
|
@ -95,8 +95,13 @@ std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vecto
|
|||
|
||||
// DistributedSampler
|
||||
DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
|
||||
uint32_t seed)
|
||||
: num_shards_(num_shards), shard_id_(shard_id), shuffle_(shuffle), num_samples_(num_samples), seed_(seed) {}
|
||||
uint32_t seed, bool even_dist)
|
||||
: num_shards_(num_shards),
|
||||
shard_id_(shard_id),
|
||||
shuffle_(shuffle),
|
||||
num_samples_(num_samples),
|
||||
seed_(seed),
|
||||
even_dist_(even_dist) {}
|
||||
|
||||
bool DistributedSamplerObj::ValidateParams() {
|
||||
if (num_shards_ <= 0) {
|
||||
|
@ -118,7 +123,8 @@ bool DistributedSamplerObj::ValidateParams() {
|
|||
}
|
||||
|
||||
std::shared_ptr<Sampler> DistributedSamplerObj::Build() {
|
||||
return std::make_shared<dataset::DistributedSampler>(num_samples_, num_shards_, shard_id_, shuffle_, seed_);
|
||||
return std::make_shared<dataset::DistributedSampler>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
|
||||
even_dist_);
|
||||
}
|
||||
|
||||
// PKSampler
|
||||
|
|
|
@ -24,13 +24,14 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
|
||||
uint32_t seed)
|
||||
uint32_t seed, bool even_dist)
|
||||
: Sampler(num_samples, std::numeric_limits<int64_t>::max()),
|
||||
cnt_(0),
|
||||
seed_(seed == std::numeric_limits<uint32_t>::max() ? GetSeed() : seed),
|
||||
device_id_(dev_id),
|
||||
num_devices_(num_dev),
|
||||
shuffle_(shuffle) {}
|
||||
shuffle_(shuffle),
|
||||
even_dist_(even_dist) {}
|
||||
|
||||
Status DistributedSampler::InitSampler() {
|
||||
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
|
||||
|
@ -43,7 +44,15 @@ Status DistributedSampler::InitSampler() {
|
|||
CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0,
|
||||
"fail to init DistributedSampler");
|
||||
rnd_.seed(seed_++);
|
||||
samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices)
|
||||
if (even_dist_) {
|
||||
samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices)
|
||||
} else {
|
||||
int64_t mod = num_rows_ % num_devices_;
|
||||
samples_per_buffer_ = num_rows_ / num_devices_;
|
||||
if (mod > device_id_) {
|
||||
samples_per_buffer_++;
|
||||
}
|
||||
}
|
||||
samples_per_buffer_ = num_samples_ < samples_per_buffer_ ? num_samples_ : samples_per_buffer_;
|
||||
if (shuffle_ == true) {
|
||||
shuffle_vec_.reserve(num_rows_);
|
||||
|
|
|
@ -27,26 +27,32 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
class DistributedSampler : public Sampler {
|
||||
public:
|
||||
// @param num_samples
|
||||
// @param int64_t num_dev
|
||||
// @param int64_t dev_id
|
||||
// @param bool shuffle
|
||||
/// \brief Constructor
|
||||
/// \param[in] num_samples The total number of rows in the dataset
|
||||
/// \param[in] num_dev Total number of shards for the distributed sampler
|
||||
/// \param[in] dev_id Device id of the shard
|
||||
/// \param[in] shuffle Option to shuffle
|
||||
/// \param seed Seed parameter to shuffle, default to max unsigned int (different seed in sampler will
|
||||
/// result in different samples being picked
|
||||
/// \param even_dist The option to indicate whether or not each shard returns the same number of rows.
|
||||
/// This option is not exposed in the python API. Current behavior is that the remainder will always
|
||||
/// be handled by the first n shards, n being the corresponding device id.
|
||||
DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
|
||||
uint32_t seed = std::numeric_limits<uint32_t>::max());
|
||||
uint32_t seed = std::numeric_limits<uint32_t>::max(), bool even_dist = true);
|
||||
|
||||
// default destructor
|
||||
/// \brief default destructor
|
||||
~DistributedSampler() = default;
|
||||
|
||||
// @param std::unique_ptr<DataBuffer> * pBuffer
|
||||
// @param int32_t workerId
|
||||
// @return - The error code return
|
||||
/// \param std::unique_ptr<DataBuffer> * pBuffer
|
||||
/// \param int32_t workerId
|
||||
/// \return Status code
|
||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
// Init sampler, called by base class or python
|
||||
/// Init sampler, called by base class or python
|
||||
Status InitSampler() override;
|
||||
|
||||
// for next epoch of sampleIds
|
||||
// @return - The error code return
|
||||
/// \brief for next epoch of sampleIds
|
||||
/// \return Status code
|
||||
Status ResetSampler() override;
|
||||
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
@ -59,6 +65,7 @@ class DistributedSampler : public Sampler {
|
|||
bool shuffle_;
|
||||
std::mt19937 rnd_;
|
||||
std::vector<int64_t> shuffle_vec_;
|
||||
bool even_dist_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -52,9 +52,12 @@ class WeightedRandomSamplerObj;
|
|||
/// \param[in] shuffle - If true, the indices are shuffled.
|
||||
/// \param[in] num_samples - The number of samples to draw (default to all elements).
|
||||
/// \param[in] seed - The seed in use when shuffle is true.
|
||||
/// \param[in] even_dist - If true, each shard would return the same number of rows (default to true).
|
||||
/// If false the total rows returned by all the shards would not have overlap.
|
||||
/// \return Shared pointer to the current Sampler.
|
||||
std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true,
|
||||
int64_t num_samples = 0, uint32_t seed = 1);
|
||||
int64_t num_samples = 0, uint32_t seed = 1,
|
||||
bool even_dist = true);
|
||||
|
||||
/// Function to create a PK Sampler.
|
||||
/// \notes Samples K elements for each P class in the dataset.
|
||||
|
@ -100,7 +103,8 @@ std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vecto
|
|||
/* ####################################### Derived Sampler classes ################################# */
|
||||
class DistributedSamplerObj : public SamplerObj {
|
||||
public:
|
||||
DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed);
|
||||
DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed,
|
||||
bool even_dist);
|
||||
|
||||
~DistributedSamplerObj() = default;
|
||||
|
||||
|
@ -114,6 +118,7 @@ class DistributedSamplerObj : public SamplerObj {
|
|||
bool shuffle_;
|
||||
int64_t num_samples_;
|
||||
uint32_t seed_;
|
||||
bool even_dist_;
|
||||
};
|
||||
|
||||
class PKSamplerObj : public SamplerObj {
|
||||
|
|
|
@ -92,7 +92,9 @@ SET(DE_UT_SRCS
|
|||
tensor_op_fusion_pass_test.cc
|
||||
sliding_window_op_test.cc
|
||||
epoch_ctrl_op_test.cc
|
||||
swap_red_blue_test.cc
|
||||
sentence_piece_vocab_op_test.cc
|
||||
swap_red_blue_test.cc
|
||||
distributed_sampler_test.cc
|
||||
)
|
||||
|
||||
if (ENABLE_PYTHON)
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/data_buffer.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_set>
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::LogStream;
|
||||
|
||||
class MindDataTestDistributedSampler : public UT::Common {
|
||||
public:
|
||||
class DummyRandomAccessOp : public RandomAccessOp {
|
||||
public:
|
||||
DummyRandomAccessOp(uint64_t num_rows) {
|
||||
// row count is in base class as protected member
|
||||
// GetNumRowsInDataset does not need an override, the default from base class is fine.
|
||||
num_rows_ = num_rows;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestDistributedSampler, TestTwoShardsOne) {
|
||||
// num samples to draw.
|
||||
uint64_t num_samples = 7;
|
||||
|
||||
// create sampler with replacement = true
|
||||
DistributedSampler m_sampler(num_samples, 2, 0, false, 0, false);
|
||||
DummyRandomAccessOp dummyRandomAccessOp(num_samples);
|
||||
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||
|
||||
std::unique_ptr<DataBuffer> db;
|
||||
TensorRow row;
|
||||
std::vector<uint64_t> out;
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
db->PopRow(&row);
|
||||
for (const auto &t : row) {
|
||||
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
||||
out.push_back(*it);
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_EQ(4, out.size());
|
||||
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
ASSERT_EQ(db->eoe(), true);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestDistributedSampler, TestTwoShardsTwo) {
|
||||
// num samples to draw.
|
||||
uint64_t num_samples = 7;
|
||||
|
||||
// create sampler with replacement = true
|
||||
DistributedSampler m_sampler(num_samples, 2, 1, false, 0, false);
|
||||
DummyRandomAccessOp dummyRandomAccessOp(num_samples);
|
||||
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||
|
||||
std::unique_ptr<DataBuffer> db;
|
||||
TensorRow row;
|
||||
std::vector<uint64_t> out;
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
db->PopRow(&row);
|
||||
for (const auto &t : row) {
|
||||
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
||||
out.push_back(*it);
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_EQ(3, out.size());
|
||||
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
ASSERT_EQ(db->eoe(), true);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestDistributedSampler, TestThreeShards) {
|
||||
// num samples to draw.
|
||||
uint64_t num_samples = 2;
|
||||
|
||||
// create sampler with replacement = true
|
||||
DistributedSampler m_sampler(num_samples, 3, 2, false, 0, false);
|
||||
DummyRandomAccessOp dummyRandomAccessOp(num_samples);
|
||||
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||
|
||||
std::unique_ptr<DataBuffer> db;
|
||||
TensorRow row;
|
||||
std::vector<uint64_t> out;
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
db->PopRow(&row);
|
||||
for (const auto &t : row) {
|
||||
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
||||
out.push_back(*it);
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_EQ(0, out.size());
|
||||
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
ASSERT_EQ(db->eoe(), true);
|
||||
}
|
||||
|
Loading…
Reference in New Issue