!183 Mindspore.dataset CPP sampler for GeneratorDataset

Merge pull request !183 from JunhanHu/cpp_sampler
This commit is contained in:
mindspore-ci-bot 2020-04-16 22:30:59 +08:00 committed by Gitee
commit cf026096a6
31 changed files with 432 additions and 127 deletions

View File

@ -517,7 +517,7 @@ Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr<Datase
std::string key = py::str(arg.first); std::string key = py::str(arg.first);
py::handle value = arg.second; py::handle value = arg.second;
if (!value.is_none()) { if (!value.is_none()) {
if (key == "generator_function") { if (key == "source") {
py::object obj = py::cast(&value); py::object obj = py::cast(&value);
if (!py::isinstance<py::function>(obj)) { if (!py::isinstance<py::function>(obj)) {
std::string err_msg = "Error: generator is invalid or not set."; std::string err_msg = "Error: generator is invalid or not set.";

View File

@ -388,7 +388,16 @@ void bindTensorOps4(py::module *m) {
} }
void bindSamplerOps(py::module *m) { void bindSamplerOps(py::module *m) {
(void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler"); (void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler")
.def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
.def("set_num_samples", [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
.def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); })
.def("get_indices", [](Sampler &self) {
py::array ret;
THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
return ret;
});
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator"); (void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator");
(void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler") (void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler")

View File

@ -491,6 +491,8 @@ Status Tensor::GetItemAt(T *o, const std::vector<dsize_t> &index) const {
// return data as numpy, should return status // return data as numpy, should return status
Status Tensor::GetDataAsNumpy(py::array *data) { Status Tensor::GetDataAsNumpy(py::array *data) {
RETURN_UNEXPECTED_IF_NULL(data_);
RETURN_UNEXPECTED_IF_NULL(data);
if (type_ == DataType::DE_BOOL) { if (type_ == DataType::DE_BOOL) {
*data = py::array_t<bool>(shape_.AsVector(), reinterpret_cast<bool *>(data_)); *data = py::array_t<bool>(shape_.AsVector(), reinterpret_cast<bool *>(data_));
} else if (type_ == DataType::DE_INT8) { } else if (type_ == DataType::DE_INT8) {

View File

@ -100,7 +100,7 @@ Status CelebAOp::LaunchThreadsAndInitOp() {
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CelebAOp::WorkerEntry, this, std::placeholders::_1))); RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CelebAOp::WorkerEntry, this, std::placeholders::_1)));
TaskManager::FindMe()->Post(); TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(ParseImageAttrInfo()); RETURN_IF_NOT_OK(ParseImageAttrInfo());
RETURN_IF_NOT_OK(sampler_->Init(this)); RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));
return Status::OK(); return Status::OK();
} }

View File

@ -240,7 +240,7 @@ Status CifarOp::Reset() {
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
Status CifarOp::InitSampler() { Status CifarOp::InitSampler() {
RETURN_IF_NOT_OK(sampler_->Init(this)); RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));
return Status::OK(); return Status::OK();
} }

View File

@ -258,7 +258,7 @@ Status ImageFolderOp::Reset() {
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
Status ImageFolderOp::InitSampler() { Status ImageFolderOp::InitSampler() {
RETURN_IF_NOT_OK(sampler_->Init(this)); RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));
return Status::OK(); return Status::OK();
} }

View File

@ -254,7 +254,7 @@ Status ManifestOp::Reset() {
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
Status ManifestOp::InitSampler() { Status ManifestOp::InitSampler() {
RETURN_IF_NOT_OK(sampler_->Init(this)); RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));
return Status::OK(); return Status::OK();
} }

View File

@ -205,7 +205,7 @@ Status MnistOp::Reset() {
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
Status MnistOp::InitSampler() { Status MnistOp::InitSampler() {
RETURN_IF_NOT_OK(sampler_->Init(this)); RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));
return Status::OK(); return Status::OK();
} }

View File

@ -31,8 +31,9 @@ DistributedSampler::DistributedSampler(int64_t num_dev, int64_t dev_id, bool shu
num_devices_(num_dev), num_devices_(num_dev),
shuffle_(shuffle) {} shuffle_(shuffle) {}
Status DistributedSampler::Init(const RandomAccessOp *op) { Status DistributedSampler::InitSampler() {
RETURN_IF_NOT_OK(Sampler::Init(op)); CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples <= 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0, CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0,
"fail to init DistributedSampler"); "fail to init DistributedSampler");
rnd_.seed(seed_++); rnd_.seed(seed_++);

View File

@ -41,10 +41,8 @@ class DistributedSampler : public Sampler {
// @return - The error code return // @return - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
// first handshake between StorageOp and Sampler // Init sampler, called by base class or python
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds() Status InitSampler() override;
// @return
Status Init(const RandomAccessOp *) override;
// for next epoch of sampleIds // for next epoch of sampleIds
// @return - The error code return // @return - The error code return

View File

@ -28,9 +28,7 @@ PKSampler::PKSampler(int64_t val, bool shuffle, int64_t samples_per_buffer)
num_pk_samples_(0), num_pk_samples_(0),
samples_per_class_(val) {} samples_per_class_(val) {}
Status PKSampler::Init(const RandomAccessOp *op) { Status PKSampler::InitSampler() {
RETURN_UNEXPECTED_IF_NULL(op);
RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_));
labels_.reserve(label_to_ids_.size()); labels_.reserve(label_to_ids_.size());
for (const auto &pair : label_to_ids_) { for (const auto &pair : label_to_ids_) {
if (pair.second.empty() == false) { if (pair.second.empty() == false) {
@ -79,5 +77,13 @@ Status PKSampler::Reset() {
rnd_.seed(seed_++); rnd_.seed(seed_++);
return Status::OK(); return Status::OK();
} }
Status PKSampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
RETURN_UNEXPECTED_IF_NULL(op);
RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_));
RETURN_IF_NOT_OK(InitSampler());
return Status::OK();
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -45,7 +45,10 @@ class PKSampler : public Sampler { // NOT YET FINISHED
// first handshake between StorageOp and Sampler // first handshake between StorageOp and Sampler
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds() // @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds()
// @return // @return
Status Init(const RandomAccessOp *op) override; Status HandshakeRandomAccessOp(const RandomAccessOp *op) override;
// init sampler, to be called by python or Handshake
Status InitSampler() override;
// for next epoch of sampleIds // for next epoch of sampleIds
// @return - The error code return // @return - The error code return

View File

@ -49,10 +49,9 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
return Status::OK(); return Status::OK();
} }
Status RandomSampler::Init(const RandomAccessOp *op) { Status RandomSampler::InitSampler() {
RETURN_IF_NOT_OK(Sampler::Init(op));
num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_; num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_;
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "Fail to init RandomSampler"); CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
if (replacement_ == false) { if (replacement_ == false) {
shuffled_ids_.reserve(num_rows_); shuffled_ids_.reserve(num_rows_);

View File

@ -42,10 +42,8 @@ class RandomSampler : public Sampler {
// @return - The error code return // @return - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
// first handshake between StorageOp and Sampler // meant to be called by base class or python
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds() Status InitSampler() override;
// @return
Status Init(const RandomAccessOp *op) override;
// for next epoch of sampleIds // for next epoch of sampleIds
// @return - The error code return // @return - The error code return

View File

@ -20,12 +20,13 @@ namespace dataset {
Sampler::Sampler(int64_t samples_per_buffer) Sampler::Sampler(int64_t samples_per_buffer)
: DatasetOp(0), num_rows_(0), num_samples_(0), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} : DatasetOp(0), num_rows_(0), num_samples_(0), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {}
Status Sampler::Init(const RandomAccessOp *op) { Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr && samples_per_buffer_ > 0, "Fail to init Sampler()\n"); CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n");
RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_)); RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_));
RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_)); RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_));
// It's up to the derived class to check the validity of the two args // It's up to the derived class to check the validity of the two args
// Because some sampler only needs one of the arg (weighted_random_sampler) // Because some sampler only needs one of the arg (weighted_random_sampler)
RETURN_IF_NOT_OK(InitSampler()); // init sampler after callback
return Status::OK(); return Status::OK();
} }
@ -42,5 +43,49 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t
(void)(*sample_ids)->StartAddr(); // allocate memory in case user forgets! (void)(*sample_ids)->StartAddr(); // allocate memory in case user forgets!
return Status::OK(); return Status::OK();
} }
Status Sampler::GetAllIdsThenReset(py::array *data) {
std::unique_ptr<DataBuffer> db;
std::shared_ptr<Tensor> sample_ids;
// check samples_per_buffer is properly set and doesn't overflow
CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ + 1 > 1, "samples_per_buffer invalid");
// A call to derived class to get sample ids wrapped inside a buffer
RETURN_IF_NOT_OK(GetNextBuffer(&db));
// Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch
RETURN_IF_NOT_OK(db->GetTensor(&sample_ids, 0, 0));
// check this buffer is not a ctrl buffer
CHECK_FAIL_RETURN_UNEXPECTED(db->buffer_flags() == DataBuffer::kDeBFlagNone, "ERROR ctrl buffer received");
{
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
}
try {
RETURN_IF_NOT_OK(sample_ids->GetDataAsNumpy(data));
} catch (const std::runtime_error &e) {
return Status(StatusCode::kPyFuncException, e.what());
}
}
// perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch
RETURN_IF_NOT_OK(GetNextBuffer(&db));
CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received");
// Reset Sampler since this is the end of the epoch
RETURN_IF_NOT_OK(Reset());
return Status::OK();
}
Status Sampler::SetNumSamples(int64_t num_samples) {
CHECK_FAIL_RETURN_UNEXPECTED(num_samples > 0, "num_samples is negative or 0");
num_samples_ = num_samples;
return Status::OK();
}
Status Sampler::SetNumRowsInDataset(int64_t num_rows) {
CHECK_FAIL_RETURN_UNEXPECTED(num_rows > 0, "num_rows is negative or 0");
num_rows_ = num_rows;
return Status::OK();
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -78,14 +78,26 @@ class Sampler : public DatasetOp {
// @return - The error code return // @return - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override = 0; Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override = 0;
// return all ids in one epoch as a numpy array, then call reset
Status GetAllIdsThenReset(py::array *data);
// for next epoch of sampleIds // for next epoch of sampleIds
// @return - The error code return // @return - The error code return
Status Reset() override = 0; Status Reset() override = 0;
// first handshake between StorageOp and Sampler. Base class init will call both GetNumRows and GetNumSamples // setter function for num_rows_
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds() Status SetNumRowsInDataset(int64_t num_rows);
// setter function for num_samples_
Status SetNumSamples(int64_t num_samples);
// first handshake between StorageOp and Sampler. This func will call getNumRows and getNumSamples
// @param op - StorageOp pointer, pass in so Sampler can call getNumSamples() and get ClassIds()
// @return // @return
virtual Status Init(const RandomAccessOp *op); virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op);
// initialize sampler and perform checks on certain vars
virtual Status InitSampler() { return Status::OK(); }
// Not meant to be called // Not meant to be called
// @return // @return

View File

@ -41,9 +41,7 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
return Status::OK(); return Status::OK();
} }
Status SequentialSampler::Init(const RandomAccessOp *op) { Status SequentialSampler::InitSampler() {
RETURN_UNEXPECTED_IF_NULL(op);
RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_));
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler"); CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
return Status::OK(); return Status::OK();

View File

@ -32,10 +32,8 @@ class SequentialSampler : public Sampler {
// Destructor. // Destructor.
~SequentialSampler() = default; ~SequentialSampler() = default;
// Initialize the sampler. // init sampler, called by python
// @param op Status InitSampler() override;
// @return Status
Status Init(const RandomAccessOp *op) override;
// for next epoch of sampleIds // for next epoch of sampleIds
// @return - The error code return // @return - The error code return

View File

@ -31,9 +31,8 @@ SubsetRandomSampler::SubsetRandomSampler(const std::vector<int64_t> &indices, in
: Sampler(samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {} : Sampler(samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {}
// Initialized this Sampler. // Initialized this Sampler.
Status SubsetRandomSampler::Init(const RandomAccessOp *op) { Status SubsetRandomSampler::InitSampler() {
// Calling base class init. CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n");
RETURN_IF_NOT_OK(Sampler::Init(op));
// Initialize random generator with seed from config manager // Initialize random generator with seed from config manager
rand_gen_.seed(GetSeed()); rand_gen_.seed(GetSeed());

View File

@ -38,9 +38,8 @@ class SubsetRandomSampler : public Sampler {
~SubsetRandomSampler() = default; ~SubsetRandomSampler() = default;
// Initialize the sampler. // Initialize the sampler.
// @param op (Not used in this sampler)
// @return Status // @return Status
Status Init(const RandomAccessOp *op) override; Status InitSampler() override;
// Reset the internal variable to the initial state and reshuffle the indices. // Reset the internal variable to the initial state and reshuffle the indices.
// @return Status // @return Status

View File

@ -29,21 +29,21 @@ namespace dataset {
// Constructor. // Constructor.
WeightedRandomSampler::WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples, bool replacement, WeightedRandomSampler::WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples, bool replacement,
int64_t samples_per_buffer) int64_t samples_per_buffer)
: Sampler(samples_per_buffer), weights_(weights), replacement_(replacement), sample_id_(0), buffer_id_(0) { : Sampler(samples_per_buffer),
num_samples_ = num_samples; // this variable is defined in base class sampler weights_(weights),
} replacement_(replacement),
sample_id_(0),
buffer_id_(0),
user_num_samples_(num_samples) {}
// Initialized this Sampler. // Initialized this Sampler.
Status WeightedRandomSampler::Init(const RandomAccessOp *op) { Status WeightedRandomSampler::InitSampler() {
RETURN_UNEXPECTED_IF_NULL(op); CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && user_num_samples_, "num_samples & num_rows need to be positive");
RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_)); CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n");
// Initialize random generator with seed from config manager // Initialize random generator with seed from config manager
rand_gen_.seed(GetSeed()); rand_gen_.seed(GetSeed());
samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_; samples_per_buffer_ = (samples_per_buffer_ > user_num_samples_) ? user_num_samples_ : samples_per_buffer_;
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init WeightedRandomSampler");
if (!replacement_) { if (!replacement_) {
exp_dist_ = std::make_unique<std::exponential_distribution<>>(1); exp_dist_ = std::make_unique<std::exponential_distribution<>>(1);
@ -65,8 +65,8 @@ void WeightedRandomSampler::InitOnePassSampling() {
} }
// Partial sort the first `numSamples` elements. // Partial sort the first `numSamples` elements.
std::partial_sort(val_idx.begin(), val_idx.begin() + num_samples_, val_idx.end()); std::partial_sort(val_idx.begin(), val_idx.begin() + user_num_samples_, val_idx.end());
for (int64_t i = 0; i < num_samples_; i++) { for (int64_t i = 0; i < user_num_samples_; i++) {
onepass_ids_.push_back(val_idx[i].second); onepass_ids_.push_back(val_idx[i].second);
} }
} }
@ -91,11 +91,11 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
"number of samples weights is more than num of rows. Might generate id out of bound OR other errors"); "number of samples weights is more than num of rows. Might generate id out of bound OR other errors");
} }
if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) { if (!replacement_ && (weights_.size() < static_cast<size_t>(user_num_samples_))) {
RETURN_STATUS_UNEXPECTED("Without replacement, sample weights less than numSamples"); RETURN_STATUS_UNEXPECTED("Without replacement, sample weights less than numSamples");
} }
if (sample_id_ == num_samples_) { if (sample_id_ == user_num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE); (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
} else { } else {
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone); (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone);
@ -103,8 +103,8 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
int64_t last_id = sample_id_ + samples_per_buffer_; int64_t last_id = sample_id_ + samples_per_buffer_;
// Handling the return all samples at once, and when last draw is not a full batch. // Handling the return all samples at once, and when last draw is not a full batch.
if (last_id > num_samples_) { if (last_id > user_num_samples_) {
last_id = num_samples_; last_id = user_num_samples_;
} }
// Allocate tensor. // Allocate tensor.

View File

@ -43,7 +43,7 @@ class WeightedRandomSampler : public Sampler {
// Initialize the sampler. // Initialize the sampler.
// @param op (Not used in this sampler) // @param op (Not used in this sampler)
// @return Status // @return Status
Status Init(const RandomAccessOp *op) override; Status InitSampler() override;
// Reset the internal variable to the initial state and reshuffle the indices. // Reset the internal variable to the initial state and reshuffle the indices.
Status Reset() override; Status Reset() override;
@ -69,6 +69,9 @@ class WeightedRandomSampler : public Sampler {
// Random engine and device // Random engine and device
std::mt19937 rand_gen_; std::mt19937 rand_gen_;
// num_samples from user
int64_t user_num_samples_;
// Discrete distribution for generating weighted random numbers with replacement. // Discrete distribution for generating weighted random numbers with replacement.
std::unique_ptr<std::discrete_distribution<int64_t>> discrete_dist_; std::unique_ptr<std::discrete_distribution<int64_t>> discrete_dist_;

View File

@ -220,7 +220,7 @@ Status VOCOp::ParseImageIds() {
} }
Status VOCOp::InitSampler() { Status VOCOp::InitSampler() {
RETURN_IF_NOT_OK(sampler_->Init(this)); RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));
return Status::OK(); return Status::OK();
} }

View File

@ -1758,14 +1758,70 @@ class MindDataset(SourceDataset):
return num_rows return num_rows
def ds_fn(dataset): def _iter_fn(dataset, num_samples):
for val in dataset: """
# convert output tensors to ndarrays Generator function wrapper for iterable dataset
yield tuple([np.array(x) for x in val]) """
if num_samples is not None:
ds_iter = iter(dataset)
for _ in range(num_samples):
try:
val = next(ds_iter)
except StopIteration:
return
# convert output tensors to ndarrays
yield tuple([np.array(x) for x in val])
else:
for val in dataset:
# convert output tensors to ndarrays
yield tuple([np.array(x) for x in val])
def sampler_fn(sampler, dataset): def _generator_fn(generator, num_samples):
for i in sampler: """
Generator function wrapper for generator function dataset
"""
if num_samples is not None:
gen_iter = generator()
for _ in range(num_samples):
try:
val = next(gen_iter)
except StopIteration:
return
yield val
else:
gen_iter = generator()
for val in gen_iter:
yield val
def _py_sampler_fn(sampler, num_samples, dataset):
"""
Generator function wrapper for mappable dataset with python sampler
"""
if num_samples is not None:
sampler_iter = iter(sampler)
for _ in range(num_samples):
try:
idx = next(sampler_iter)
except StopIteration:
return
val = dataset[idx]
# convert output tensors to ndarrays
yield tuple([np.array(x) for x in val])
else:
for i in sampler:
val = dataset[i]
# convert output tensors to ndarrays
yield tuple([np.array(x) for x in val])
def _cpp_sampler_fn(sampler, dataset):
"""
Generator function wrapper for mappable dataset with cpp sampler
"""
indices = sampler.get_indices()
for i in indices:
val = dataset[i] val = dataset[i]
# convert output tensors to ndarrays # convert output tensors to ndarrays
yield tuple([np.array(x) for x in val]) yield tuple([np.array(x) for x in val])
@ -1773,49 +1829,122 @@ def sampler_fn(sampler, dataset):
class GeneratorDataset(SourceDataset): class GeneratorDataset(SourceDataset):
""" """
A source dataset that generate data from calling generator function each epoch. A source dataset that generate data from python by invoking python data source each epoch.
This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
below shows what input args are allowed and their expected behavior.
.. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
:widths: 25 25 50
:header-rows: 1
* - Parameter 'sampler'
- Parameter 'shuffle'
- Expected Order Behavior
* - None
- None
- random order
* - None
- True
- random order
* - None
- False
- sequential order
* - Sampler object
- None
- order defined by sampler
* - Sampler object
- True
- not allowed
* - Sampler object
- False
- not allowed
Args: Args:
generator_function (callable): source (Callable/Iterable/Random Accessible):
A callable object that returns an Generator object that supports the iter() protocol. A generator callable object, an iterable python object or a random accessible python object.
Generator object is required to return a tuple of numpy array as a row of the dataset on next(). Callable source is required to return a tuple of numpy array as a row of the dataset on source().next().
Iterable source is required to return a tuple of numpy array as a row of the dataset on iter(source).next().
Random accessible source is required to return a tuple of numpy array as a row of the dataset on
source[idx].
column_names (list[str]): List of column names of the dataset. column_names (list[str]): List of column names of the dataset.
column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None). column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None).
If provided, sanity check will be performed on generator output. If provided, sanity check will be performed on generator output.
prefetch_size (int, optional): Prefetch number of records ahead of the user's request (default=None). schema (Schema/String, optional): Path to the json schema file or schema object (default=None).
sampler (Sampler, optional): Object used to choose samples from the dataset (default=None). If the schema is not provided, the meta data from column_names and column_types is considered the schema.
num_samples (int, optional): The number of samples to be included in the dataset
(default=None, all images).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
(default=None, expected order behavior shown in the table).
sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
required.
(default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
This argument should be specified only when 'num_samples' is "None". Random accessible input is required.
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
when num_shards is also specified. Random accessible input is required.
Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataengine as de
>>> # 1) generator function that generates multi-dimensional data >>> # 1) Multidimensional generator function as callable input
>>> def generator_md(): >>> def generator_md():
>>> for i in range(64): >>> for i in range(64):
>>> yield (np.array([[i, i + 1], [i + 2, i + 3]]),) >>> yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
>>> # create multi_dimension_generator_dataset with GeneratorMD() and column name "multi_dimensional_data" >>> # create multi_dimension_generator_dataset with GeneratorMD and column name "multi_dimensional_data"
>>> multi_dimension_generator_dataset = ds.GeneratorDataset(generator_md, ["multi_dimensional_data"]) >>> multi_dimension_generator_dataset = de.GeneratorDataset(generator_md, ["multi_dimensional_data"])
>>> # 2) generator function that generates multi-columns data >>> # 2) Multi-column generator function as callable input
>>> def generator_mc(maxid = 64): >>> def generator_mc(maxid = 64):
>>> for i in range(maxid): >>> for i in range(maxid):
>>> yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])) >>> yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
>>> # create multi_column_generator_dataset with GeneratorMC() and column names "col1" and "col2" >>> # create multi_column_generator_dataset with GeneratorMC and column names "col1" and "col2"
>>> multi_column_generator_dataset = ds.GeneratorDataset(generator_mc, ["col1, col2"]) >>> multi_column_generator_dataset = de.GeneratorDataset(generator_mc, ["col1, col2"])
>>> # 3) Iterable dataset as iterable input
>>> class MyIterable():
>>> def __iter__(self):
>>> return # User implementation
>>> # create iterable_generator_dataset with MyIterable object
>>> iterable_generator_dataset = de.GeneratorDataset(MyIterable(), ["col1"])
>>> # 4) Random accessible dataset as Random accessible input
>>> class MyRA():
>>> def __getitem__(self, index):
>>> return # User implementation
>>> # create ra_generator_dataset with MyRA object
>>> ra_generator_dataset = de.GeneratorDataset(MyRA(), ["col1"])
>>> # List/Dict/Tuple is also random accessible
>>> list_generator = de.GeneratorDataset([(np.array(0),), (np.array(1)), (np.array(2))], ["col1"])
>>> # 5) Built-in Sampler
>>> my_generator = de.GeneratorDataset(my_ds, ["img", "label"], sampler=samplers.RandomSampler())
>>>
""" """
@check_generatordataset @check_generatordataset
def __init__(self, generator_function, column_names, column_types=None, prefetch_size=None, sampler=None): def __init__(self, source, column_names, column_types=None, schema=None, num_samples=None, num_parallel_workers=1,
super().__init__(1) shuffle=None, sampler=None, num_shards=None, shard_id=None):
if sampler is not None: super().__init__(num_parallel_workers)
self.generator_function = (lambda: sampler_fn(sampler, generator_function)) self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
if self.sampler is not None and hasattr(source, "__getitem__"):
if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
samplers.RandomSampler, samplers.SubsetRandomSampler,
samplers.WeightedRandomSampler)):
if num_samples is None:
num_samples = len(source)
sampler_instance = self.sampler.create()
sampler_instance.set_num_rows(len(source))
sampler_instance.set_num_samples(num_samples)
sampler_instance.initialize()
self.source = (lambda: _cpp_sampler_fn(sampler_instance, source))
else:
self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source))
else: else:
try: try:
# test to see if generator_function is iterable iter(source)
iter(generator_function)
except TypeError: except TypeError:
# generator_function was not iterable, assume it is a function # Use generator function if input callable
self.generator_function = generator_function self.source = (lambda: _generator_fn(source, num_samples))
else: else:
# generator_function was iterable, build a function around it # Use iterator function if input is iterable
self.generator_function = (lambda: ds_fn(generator_function)) # Random accessible input is also iterable
self.source = (lambda: _iter_fn(source, num_samples))
self.column_names = column_names self.column_names = column_names
@ -1823,17 +1952,12 @@ class GeneratorDataset(SourceDataset):
self.column_types = mstypelist_to_detypelist(column_types) self.column_types = mstypelist_to_detypelist(column_types)
else: else:
self.column_types = column_types self.column_types = column_types
self.distribution = ""
self.prefetch_size = prefetch_size
self.sampler = sampler
def get_args(self): def get_args(self):
args = super().get_args() args = super().get_args()
args["generator_function"] = self.generator_function args["source"] = self.source
args["column_names"] = self.column_names args["column_names"] = self.column_names
args["column_types"] = self.column_types args["column_types"] = self.column_types
args["prefetch_size"] = self.prefetch_size
args["sampler"] = self.sampler
return args return args
def get_dataset_size(self): def get_dataset_size(self):

View File

@ -20,7 +20,6 @@ SequentialSampler, SubsetRandomSampler, WeightedRandomSampler.
import mindspore._c_dataengine as cde import mindspore._c_dataengine as cde
class DistributedSampler(): class DistributedSampler():
""" """
Sampler that access a shard of the dataset. Sampler that access a shard of the dataset.

View File

@ -543,28 +543,48 @@ def check_generatordataset(method):
def new_method(*args, **kwargs): def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs) param_dict = make_param_dict(method, args, kwargs)
nreq_param_int = ['prefetch_size']
nreq_param_list = ['column_names', 'column_types']
# check generator_function; required argument # check generator_function; required argument
generator_function = param_dict.get('generator_function') source = param_dict.get('source')
if generator_function is None: if source is None:
raise ValueError("generator_function is not provided.") raise ValueError("source is not provided.")
if not callable(source):
try:
iter(source)
except TypeError:
raise TypeError("source should be callable, iterable or random accessible")
# check column_names; required argument # check column_names; required argument
column_names = param_dict.get('column_names') column_names = param_dict.get('column_names')
if column_names is None: if column_names is None:
raise ValueError("column_names is not provided.") raise ValueError("column_names is not provided.")
# check prefetch_size range # check optional argument
prefetch_size = param_dict.get('prefetch_size') nreq_param_int = ["num_samples", "num_parallel_workers", "num_shards", "shard_id"]
if prefetch_size is not None and (prefetch_size <= 0 or prefetch_size > 1024):
raise ValueError("prefetch_size exceeds the boundary.")
check_param_type(nreq_param_int, param_dict, int) check_param_type(nreq_param_int, param_dict, int)
nreq_param_list = ["column_types"]
check_param_type(nreq_param_list, param_dict, list) check_param_type(nreq_param_list, param_dict, list)
num_shards = param_dict.get("num_shards")
shard_id = param_dict.get("shard_id")
if (num_shards is None) != (shard_id is None):
# These two parameters appear together.
raise ValueError("num_shards and shard_id need to be passed in together")
if num_shards is not None:
if shard_id >= num_shards:
raise ValueError("shard_id should be less than num_shards")
sampler = param_dict.get("sampler")
if sampler is not None:
if isinstance(sampler, samplers.PKSampler):
raise ValueError("PKSampler is not supported by GeneratorDataset")
if not isinstance(sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
samplers.RandomSampler, samplers.SubsetRandomSampler,
samplers.WeightedRandomSampler)):
try:
iter(sampler)
except TypeError:
raise TypeError("sampler should be either iterable or from dataset.samplers.py")
return method(*args, **kwargs) return method(*args, **kwargs)
return new_method return new_method

View File

@ -75,7 +75,7 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) {
std::shared_ptr<Tensor> tensor; std::shared_ptr<Tensor> tensor;
for (int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
std::unique_ptr<Sampler> sampler = std::make_unique<DistributedSampler>(3, i % 3, (i < 3 ? false : true)); std::unique_ptr<Sampler> sampler = std::make_unique<DistributedSampler>(3, i % 3, (i < 3 ? false : true));
sampler->Init(&mock); sampler->HandshakeRandomAccessOp(&mock);
sampler->GetNextBuffer(&db); sampler->GetNextBuffer(&db);
db->GetTensor(&tensor, 0, 0); db->GetTensor(&tensor, 0, 0);
MS_LOG(DEBUG) << (*tensor); MS_LOG(DEBUG) << (*tensor);
@ -95,7 +95,7 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) {
std::shared_ptr<Sampler> sampler = std::make_shared<SequentialSampler>(3); std::shared_ptr<Sampler> sampler = std::make_shared<SequentialSampler>(3);
std::unique_ptr<DataBuffer> db; std::unique_ptr<DataBuffer> db;
std::shared_ptr<Tensor> tensor; std::shared_ptr<Tensor> tensor;
sampler->Init(&mock); sampler->HandshakeRandomAccessOp(&mock);
sampler->GetNextBuffer(&db); sampler->GetNextBuffer(&db);
db->GetTensor(&tensor, 0, 0); db->GetTensor(&tensor, 0, 0);
EXPECT_TRUE((*tensor) == (*label1)); EXPECT_TRUE((*tensor) == (*label1));

View File

@ -52,8 +52,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) {
std::unordered_set<int64_t> in_set(in.begin(), in.end()); std::unordered_set<int64_t> in_set(in.begin(), in.end());
SubsetRandomSampler sampler(in); SubsetRandomSampler sampler(in);
DummyRandomAccessOp dummy_random_access_op(5); DummyRandomAccessOp dummyRandomAccessOp(5);
sampler.Init(&dummy_random_access_op); sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
std::unique_ptr<DataBuffer> db; std::unique_ptr<DataBuffer> db;
TensorRow row; TensorRow row;
@ -80,8 +80,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) {
std::vector<int64_t> input(total_samples, 1); std::vector<int64_t> input(total_samples, 1);
SubsetRandomSampler sampler(input, samples_per_buffer); SubsetRandomSampler sampler(input, samples_per_buffer);
DummyRandomAccessOp dummy_random_access_op(total_samples); DummyRandomAccessOp dummyRandomAccessOp(total_samples);
sampler.Init(&dummy_random_access_op); sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
std::unique_ptr<DataBuffer> db; std::unique_ptr<DataBuffer> db;
TensorRow row; TensorRow row;
@ -111,8 +111,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
std::unordered_set<int64_t> in_set(in.begin(), in.end()); std::unordered_set<int64_t> in_set(in.begin(), in.end());
SubsetRandomSampler sampler(in); SubsetRandomSampler sampler(in);
DummyRandomAccessOp dummy_random_access_op(5); DummyRandomAccessOp dummyRandomAccessOp(5);
sampler.Init(&dummy_random_access_op); sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
std::unique_ptr<DataBuffer> db; std::unique_ptr<DataBuffer> db;
TensorRow row; TensorRow row;

View File

@ -60,8 +60,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
// create sampler with replacement = true // create sampler with replacement = true
WeightedRandomSampler m_sampler(weights, num_samples, true); WeightedRandomSampler m_sampler(weights, num_samples, true);
DummyRandomAccessOp dummy_random_access_op(total_samples); DummyRandomAccessOp dummyRandomAccessOp(total_samples);
m_sampler.Init(&dummy_random_access_op); m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
std::unique_ptr<DataBuffer> db; std::unique_ptr<DataBuffer> db;
TensorRow row; TensorRow row;
@ -90,8 +90,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
// create sampler with replacement = replacement // create sampler with replacement = replacement
WeightedRandomSampler m_sampler(weights, num_samples, false); WeightedRandomSampler m_sampler(weights, num_samples, false);
DummyRandomAccessOp dummy_random_access_op(total_samples); DummyRandomAccessOp dummyRandomAccessOp(total_samples);
m_sampler.Init(&dummy_random_access_op); m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
std::unique_ptr<DataBuffer> db; std::unique_ptr<DataBuffer> db;
TensorRow row; TensorRow row;
@ -126,8 +126,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) {
// create sampler with replacement = replacement // create sampler with replacement = replacement
WeightedRandomSampler m_sampler(weights, num_samples, true, samples_per_buffer); WeightedRandomSampler m_sampler(weights, num_samples, true, samples_per_buffer);
DummyRandomAccessOp dummy_random_access_op(total_samples); DummyRandomAccessOp dummyRandomAccessOp(total_samples);
m_sampler.Init(&dummy_random_access_op); m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
std::unique_ptr<DataBuffer> db; std::unique_ptr<DataBuffer> db;
TensorRow row; TensorRow row;
@ -162,8 +162,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) {
// create sampler with replacement = replacement // create sampler with replacement = replacement
WeightedRandomSampler m_sampler(weights, num_samples, false, samples_per_buffer); WeightedRandomSampler m_sampler(weights, num_samples, false, samples_per_buffer);
DummyRandomAccessOp dummy_random_access_op(total_samples); DummyRandomAccessOp dummyRandomAccessOp(total_samples);
m_sampler.Init(&dummy_random_access_op); m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
std::unique_ptr<DataBuffer> db; std::unique_ptr<DataBuffer> db;
TensorRow row; TensorRow row;
@ -203,8 +203,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
// create sampler with replacement = true // create sampler with replacement = true
WeightedRandomSampler m_sampler(weights, num_samples, true); WeightedRandomSampler m_sampler(weights, num_samples, true);
DummyRandomAccessOp dummy_random_access_op(total_samples); DummyRandomAccessOp dummyRandomAccessOp(total_samples);
m_sampler.Init(&dummy_random_access_op); m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
std::unique_ptr<DataBuffer> db; std::unique_ptr<DataBuffer> db;
TensorRow row; TensorRow row;
@ -248,8 +248,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
// create sampler with replacement = true // create sampler with replacement = true
WeightedRandomSampler m_sampler(weights, num_samples, false); WeightedRandomSampler m_sampler(weights, num_samples, false);
DummyRandomAccessOp dummy_random_access_op(total_samples); DummyRandomAccessOp dummyRandomAccessOp(total_samples);
m_sampler.Init(&dummy_random_access_op); m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
std::unique_ptr<DataBuffer> db; std::unique_ptr<DataBuffer> db;
TensorRow row; TensorRow row;

View File

@ -439,6 +439,74 @@ def test_case_error_4():
assert "Unexpected error. Result of a tensorOp doesn't match output column names" in str(info.value) assert "Unexpected error. Result of a tensorOp doesn't match output column names" in str(info.value)
def test_sequential_sampler():
source = [(np.array([x]),) for x in range(64)]
ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler())
i = 0
for data in ds1.create_dict_iterator(): # each data is a dictionary
golden = np.array([i])
assert np.array_equal(data["data"], golden)
i = i + 1
def test_random_sampler():
source = [(np.array([x]),) for x in range(64)]
ds1 = ds.GeneratorDataset(source, ["data"], shuffle = True)
for data in ds1.create_dict_iterator(): # each data is a dictionary
pass
def test_distributed_sampler():
source = [(np.array([x]),) for x in range(64)]
for sid in range(8):
ds1 = ds.GeneratorDataset(source, ["data"], shuffle = False, num_shards=8, shard_id=sid)
i = sid
for data in ds1.create_dict_iterator(): # each data is a dictionary
golden = np.array([i])
assert np.array_equal(data["data"], golden)
i = i + 8
def test_num_samples():
source = [(np.array([x]),) for x in range(64)]
num_samples = 32
ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(), num_samples = num_samples)
ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(32)], num_samples = num_samples)
ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples = num_samples)
count = 0
for _ in ds1.create_dict_iterator():
count = count + 1
assert count == num_samples
count = 0
for _ in ds2.create_dict_iterator():
count = count + 1
assert count == num_samples
count = 0
for _ in ds3.create_dict_iterator():
count = count + 1
assert count == num_samples
def test_num_samples_underflow():
source = [(np.array([x]),) for x in range(64)]
num_samples = 256
ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(64)], num_samples = num_samples)
ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples = num_samples)
count = 0
for _ in ds2.create_dict_iterator():
count = count + 1
assert count == 64
count = 0
for _ in ds3.create_dict_iterator():
count = count + 1
assert count == 64
if __name__ == "__main__": if __name__ == "__main__":
test_case_0() test_case_0()
test_case_1() test_case_1()
@ -458,3 +526,6 @@ if __name__ == "__main__":
test_case_error_2() test_case_error_2()
test_case_error_3() test_case_error_3()
test_case_error_4() test_case_error_4()
test_sequential_sampler()
test_distributed_sampler()
test_random_sampler()

View File

@ -87,7 +87,28 @@ def test_random_sampler_multi_iter(print_res=False):
test_config(replacement=True, num_samples=5, num_repeats=5, validate=[0, 1, 2, 3, 4, 5]) test_config(replacement=True, num_samples=5, num_repeats=5, validate=[0, 1, 2, 3, 4, 5])
def test_sampler_py_api():
sampler = ds.SequentialSampler().create()
sampler.set_num_rows(128)
sampler.set_num_samples(64)
sampler.initialize()
sampler.get_indices()
sampler = ds.RandomSampler().create()
sampler.set_num_rows(128)
sampler.set_num_samples(64)
sampler.initialize()
sampler.get_indices()
sampler = ds.DistributedSampler(8, 4).create()
sampler.set_num_rows(128)
sampler.set_num_samples(64)
sampler.initialize()
sampler.get_indices()
if __name__ == '__main__': if __name__ == '__main__':
test_sequential_sampler(True) test_sequential_sampler(True)
test_random_sampler(True) test_random_sampler(True)
test_random_sampler_multi_iter(True) test_random_sampler_multi_iter(True)
test_sampler_py_api()