forked from mindspore-Ecosystem/mindspore
general split case done, chaining sampler (basic case) is working
implementation 99% complete everything and tested except for repeatable shuffling tested most basic/typical split usecases cleanup some more cleanup fix CI more ci fix more ci fixes more ci fix more ci fix more ci fix added more tests, fixed some bugs some more clean up and test cases added shard/shuffle before split warning/error addressed code review comments and ci fixed ci
This commit is contained in:
parent
bcfaff97f9
commit
71e8bb1960
|
@ -364,6 +364,18 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetO
|
|||
std::string err_msg = "Error: Shuffle buffer size is missing";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
// Optional arguments
|
||||
for (auto arg : args) {
|
||||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "reshuffle_each_epoch") {
|
||||
(void)builder->SetReshuffleEachEpoch(ToBool(args["reshuffle_each_epoch"]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ShuffleOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
*ptr = op;
|
||||
|
|
|
@ -49,6 +49,7 @@
|
|||
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
|
||||
|
@ -411,11 +412,14 @@ void bindSamplerOps(py::module *m) {
|
|||
.def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
|
||||
.def("set_num_samples", [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
|
||||
.def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); })
|
||||
.def("get_indices", [](Sampler &self) {
|
||||
py::array ret;
|
||||
THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
|
||||
return ret;
|
||||
});
|
||||
.def("get_indices",
|
||||
[](Sampler &self) {
|
||||
py::array ret;
|
||||
THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
|
||||
return ret;
|
||||
})
|
||||
.def("add_child",
|
||||
[](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) { THROW_IF_ERROR(self->AddChild(child)); });
|
||||
|
||||
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator");
|
||||
|
||||
|
@ -427,12 +431,16 @@ void bindSamplerOps(py::module *m) {
|
|||
.def(py::init<int64_t, bool>(), py::arg("kVal"), py::arg("shuffle"));
|
||||
|
||||
(void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler")
|
||||
.def(py::init<bool, int64_t>(), py::arg("replacement"), py::arg("numSamples"))
|
||||
.def(py::init<bool>(), py::arg("replacement"));
|
||||
.def(py::init<bool, bool, int64_t>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"),
|
||||
py::arg("num_samples"))
|
||||
.def(py::init<bool, bool>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"));
|
||||
|
||||
(void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler")
|
||||
.def(py::init<>());
|
||||
|
||||
(void)py::class_<SubsetSampler, Sampler, std::shared_ptr<SubsetSampler>>(*m, "SubsetSampler")
|
||||
.def(py::init<int64_t, int64_t>(), py::arg("start_index"), py::arg("subset_size"));
|
||||
|
||||
(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
|
||||
.def(py::init<std::vector<int64_t>>(), py::arg("indices"));
|
||||
|
||||
|
|
|
@ -8,5 +8,6 @@ add_library(engine-datasetops-source-sampler OBJECT
|
|||
sampler.cc
|
||||
sequential_sampler.cc
|
||||
subset_random_sampler.cc
|
||||
subset_sampler.cc
|
||||
weighted_random_sampler.cc
|
||||
)
|
||||
|
|
|
@ -55,13 +55,27 @@ Status DistributedSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer
|
|||
} else if (cnt_ == samples_per_buffer_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
}
|
||||
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(cnt_, DataBuffer::kDeBFlagNone);
|
||||
std::shared_ptr<Tensor> sample_ids;
|
||||
RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, samples_per_buffer_));
|
||||
int64_t *id_ptr = reinterpret_cast<int64_t *>(sample_ids->GetMutableBuffer());
|
||||
while (cnt_ < samples_per_buffer_) {
|
||||
int64_t next_id = (num_devices_ * (cnt_++) + device_id_) % num_rows_;
|
||||
*(id_ptr++) = shuffle_ ? shuffle_vec_[static_cast<size_t>(next_id)] : next_id;
|
||||
int64_t sampled_id = (num_devices_ * cnt_ + device_id_) % num_rows_;
|
||||
if (shuffle_) {
|
||||
sampled_id = shuffle_vec_[static_cast<size_t>(sampled_id)];
|
||||
}
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
|
||||
}
|
||||
|
||||
*id_ptr = sampled_id;
|
||||
id_ptr++;
|
||||
cnt_++;
|
||||
}
|
||||
TensorRow row(1, sample_ids);
|
||||
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
|
||||
|
@ -72,11 +86,29 @@ Status DistributedSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer
|
|||
Status DistributedSampler::Reset() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late");
|
||||
cnt_ = 0;
|
||||
rnd_.seed(seed_++);
|
||||
|
||||
if (shuffle_ == true) {
|
||||
rnd_.seed(seed_);
|
||||
seed_++;
|
||||
std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_);
|
||||
}
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void DistributedSampler::Print(std::ostream &out, bool show_all) const {
|
||||
out << "(sampler): DistributedSampler\n";
|
||||
if (show_all) {
|
||||
out << "seed_: " << seed_ << '\n';
|
||||
out << "device_id_: " << device_id_ << '\n';
|
||||
out << "num_devices_: " << num_devices_ << '\n';
|
||||
out << "shuffle_: " << shuffle_ << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -48,6 +48,8 @@ class DistributedSampler : public Sampler {
|
|||
// @return - The error code return
|
||||
Status Reset() override;
|
||||
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
private:
|
||||
int64_t cnt_; // number of samples that have already been filled in to buffer
|
||||
uint32_t seed_;
|
||||
|
|
|
@ -38,6 +38,7 @@ Status PKSampler::InitSampler() {
|
|||
rnd_.seed(seed_++);
|
||||
num_pk_samples_ = samples_per_class_ * static_cast<int64_t>(labels_.size());
|
||||
samples_per_buffer_ = (samples_per_buffer_ > num_pk_samples_) ? num_pk_samples_ : samples_per_buffer_;
|
||||
num_samples_ = num_pk_samples_;
|
||||
if (shuffle_ == true) {
|
||||
std::shuffle(labels_.begin(), labels_.end(), rnd_);
|
||||
} else {
|
||||
|
@ -53,6 +54,10 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
|||
} else if (next_id_ == num_pk_samples_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
}
|
||||
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
|
||||
std::shared_ptr<Tensor> sample_ids;
|
||||
int64_t last_id =
|
||||
|
@ -63,8 +68,16 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
|||
int64_t cls_id = next_id_++ / samples_per_class_;
|
||||
const std::vector<int64_t> &samples = label_to_ids_[labels_[cls_id]];
|
||||
int64_t rnd_ind = std::uniform_int_distribution<int64_t>(0, samples.size() - 1)(rnd_);
|
||||
*(id_ptr++) = samples[rnd_ind];
|
||||
int64_t sampled_id = samples[rnd_ind];
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
|
||||
}
|
||||
|
||||
*id_ptr = sampled_id;
|
||||
id_ptr++;
|
||||
}
|
||||
|
||||
TensorRow row(1, sample_ids);
|
||||
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
|
||||
}
|
||||
|
@ -75,6 +88,11 @@ Status PKSampler::Reset() {
|
|||
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_pk_samples_, "ERROR Reset() called early/late");
|
||||
next_id_ = 0;
|
||||
rnd_.seed(seed_++);
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,10 @@ Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
|||
if (need_to_reset_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
}
|
||||
|
||||
std::shared_ptr<Tensor> sample_ids;
|
||||
{
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
|
@ -38,6 +42,14 @@ Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
|||
py::object py_ret = py_sampler_instance.attr("_get_indices")();
|
||||
py::array np_sample_ids = py_ret.cast<py::array>();
|
||||
Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor
|
||||
|
||||
if (HasChildSampler()) {
|
||||
for (auto it = sample_ids->begin<int64_t>(); it != sample_ids->end<int64_t>(); ++it) {
|
||||
int64_t associated_child_id = 0;
|
||||
RETURN_IF_NOT_OK(GetAssociatedChildId(&associated_child_id, associated_child_id));
|
||||
*it = associated_child_id;
|
||||
}
|
||||
}
|
||||
} catch (const py::error_already_set &e) {
|
||||
return Status(StatusCode::kPyFuncException, e.what());
|
||||
} catch (const py::cast_error &e) {
|
||||
|
@ -79,6 +91,11 @@ Status PythonSampler::Reset() {
|
|||
} catch (const py::error_already_set &e) {
|
||||
return Status(StatusCode::kPyFuncException, e.what());
|
||||
}
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
|
|
|
@ -14,18 +14,22 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include "dataset/util/random.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
RandomSampler::RandomSampler(bool replacement, int64_t num_samples, int64_t samples_per_buffer)
|
||||
RandomSampler::RandomSampler(bool replacement, bool reshuffle_each_epoch, int64_t num_samples,
|
||||
int64_t samples_per_buffer)
|
||||
: Sampler(samples_per_buffer),
|
||||
seed_(GetSeed()),
|
||||
replacement_(replacement),
|
||||
user_num_samples_(num_samples),
|
||||
next_id_(0),
|
||||
reshuffle_each_epoch_(reshuffle_each_epoch),
|
||||
dist(nullptr) {}
|
||||
|
||||
Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
|
@ -34,13 +38,29 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
|||
} else if (next_id_ == num_samples_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
}
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
|
||||
|
||||
std::shared_ptr<Tensor> sampleIds;
|
||||
int64_t last_id = samples_per_buffer_ + next_id_ > num_samples_ ? num_samples_ : samples_per_buffer_ + next_id_;
|
||||
int64_t last_id = std::min(samples_per_buffer_ + next_id_, num_samples_);
|
||||
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, last_id - next_id_));
|
||||
int64_t *id_ptr = reinterpret_cast<int64_t *>(sampleIds->GetMutableBuffer());
|
||||
|
||||
for (int64_t i = 0; i < (last_id - next_id_); i++) {
|
||||
*(id_ptr + i) = replacement_ ? (*dist)(rnd_) : shuffled_ids_[static_cast<size_t>(i + next_id_)];
|
||||
int64_t sampled_id = 0;
|
||||
if (replacement_) {
|
||||
sampled_id = (*dist)(rnd_);
|
||||
} else {
|
||||
sampled_id = shuffled_ids_[static_cast<size_t>(i + next_id_)];
|
||||
}
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
|
||||
}
|
||||
|
||||
*(id_ptr + i) = sampled_id;
|
||||
}
|
||||
next_id_ = last_id;
|
||||
TensorRow row(1, sampleIds);
|
||||
|
@ -53,7 +73,9 @@ Status RandomSampler::InitSampler() {
|
|||
num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive");
|
||||
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
|
||||
rnd_.seed(seed_++);
|
||||
|
||||
rnd_.seed(seed_);
|
||||
|
||||
if (replacement_ == false) {
|
||||
shuffled_ids_.reserve(num_rows_);
|
||||
for (int64_t i = 0; i < num_rows_; i++) {
|
||||
|
@ -69,11 +91,33 @@ Status RandomSampler::InitSampler() {
|
|||
Status RandomSampler::Reset() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
|
||||
next_id_ = 0;
|
||||
rnd_.seed(seed_++);
|
||||
if (replacement_ == false) {
|
||||
|
||||
if (reshuffle_each_epoch_) {
|
||||
seed_++;
|
||||
}
|
||||
|
||||
rnd_.seed(seed_);
|
||||
|
||||
if (replacement_ == false && reshuffle_each_epoch_) {
|
||||
std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_);
|
||||
}
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RandomSampler::Print(std::ostream &out, bool show_all) const {
|
||||
out << "(sampler): RandomSampler\n";
|
||||
|
||||
if (show_all) {
|
||||
out << "user_num_samples_: " << user_num_samples_ << '\n';
|
||||
out << "num_samples_: " << num_samples_ << '\n';
|
||||
out << "next_id_: " << next_id_ << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,7 +30,8 @@ class RandomSampler : public Sampler {
|
|||
// @param bool replacement - put he id back / or not after a sample
|
||||
// @param int64_t numSamples - number samples to draw
|
||||
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
|
||||
explicit RandomSampler(bool replacement = false, int64_t num_samples = std::numeric_limits<int64_t>::max(),
|
||||
explicit RandomSampler(bool replacement = false, bool reshuffle_each_epoch = true,
|
||||
int64_t num_samples = std::numeric_limits<int64_t>::max(),
|
||||
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
|
||||
|
||||
// Destructor.
|
||||
|
@ -49,6 +50,8 @@ class RandomSampler : public Sampler {
|
|||
// @return - The error code return
|
||||
Status Reset() override;
|
||||
|
||||
virtual void Print(std::ostream &out, bool show_all) const;
|
||||
|
||||
private:
|
||||
uint32_t seed_;
|
||||
bool replacement_;
|
||||
|
@ -57,6 +60,7 @@ class RandomSampler : public Sampler {
|
|||
int64_t next_id_;
|
||||
std::mt19937 rnd_;
|
||||
std::unique_ptr<std::uniform_int_distribution<int64_t>> dist;
|
||||
bool reshuffle_each_epoch_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,18 +15,41 @@
|
|||
*/
|
||||
#include "dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
Sampler::Sampler(int64_t samples_per_buffer)
|
||||
: DatasetOp(0), num_rows_(0), num_samples_(0), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {}
|
||||
|
||||
Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
|
||||
std::shared_ptr<Sampler> child_sampler;
|
||||
if (HasChildSampler()) {
|
||||
child_sampler = std::dynamic_pointer_cast<Sampler>(child_[0]);
|
||||
if (!child_sampler) {
|
||||
std::string err_msg("Cannot handshake, child is not a sampler object.");
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
// Handshake and init child first.
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op));
|
||||
}
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n");
|
||||
RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_));
|
||||
RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_));
|
||||
if (HasChildSampler()) {
|
||||
int64_t child_num_samples = child_sampler->num_samples();
|
||||
num_rows_ = child_num_samples;
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_));
|
||||
}
|
||||
|
||||
// It's up to the derived class to check the validity of the two args
|
||||
// Because some sampler only needs one of the arg (weighted_random_sampler)
|
||||
RETURN_IF_NOT_OK(InitSampler()); // init sampler after callback
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -44,6 +67,15 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
void Sampler::Print(std::ostream &out, bool show_all) const {
|
||||
out << "(sampler): base\n";
|
||||
|
||||
if (show_all) {
|
||||
out << "num_rows_: " << num_rows_ << '\n';
|
||||
out << "num_samples_: " << num_samples_ << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
Status Sampler::GetAllIdsThenReset(py::array *data) {
|
||||
std::unique_ptr<DataBuffer> db;
|
||||
std::shared_ptr<Tensor> sample_ids;
|
||||
|
@ -84,5 +116,45 @@ Status Sampler::SetNumRowsInDataset(int64_t num_rows) {
|
|||
num_rows_ = num_rows;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Sampler::AddChild(std::shared_ptr<DatasetOp> child) {
|
||||
if (child == nullptr) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Only samplers can be added, not any other DatasetOp.
|
||||
std::shared_ptr<Sampler> sampler = std::dynamic_pointer_cast<Sampler>(child);
|
||||
if (!sampler) {
|
||||
std::string err_msg("Cannot add child, child is not a sampler object.");
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
// Samplers can have at most 1 child.
|
||||
if (!child_.empty()) {
|
||||
std::string err_msg("Cannot add child sampler, this sampler already has a child.");
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
child_.push_back(child);
|
||||
|
||||
// doesn't work, protected?
|
||||
// child->AddParent(this);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool Sampler::HasChildSampler() { return !child_.empty(); }
|
||||
|
||||
Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
|
||||
if (child_ids_ == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Trying to get associated child id, but there are no child ids!");
|
||||
}
|
||||
|
||||
TensorRow sample_row;
|
||||
RETURN_IF_NOT_OK(child_ids_->GetRow(0, &sample_row));
|
||||
std::shared_ptr<Tensor> sample_ids = sample_row[0];
|
||||
RETURN_IF_NOT_OK(sample_ids->GetItemAt<int64_t>(out_associated_id, {id}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -90,6 +90,8 @@ class Sampler : public DatasetOp {
|
|||
// setter function for num_samples_
|
||||
Status SetNumSamples(int64_t num_samples);
|
||||
|
||||
int64_t num_samples() { return num_samples_; }
|
||||
|
||||
// first handshake between StorageOp and Sampler. This func will call getNumRows and getNumSamples
|
||||
// @param op - StorageOp pointer, pass in so Sampler can call getNumSamples() and get ClassIds()
|
||||
// @return
|
||||
|
@ -114,17 +116,48 @@ class Sampler : public DatasetOp {
|
|||
// @return - The error code return
|
||||
Status operator()() final { RETURN_STATUS_UNEXPECTED("Functor not supported in Sampler"); }
|
||||
|
||||
// Adds a sampler to become our child.
|
||||
// @param std::shared_ptr<DatasetOp> - The sampler to add as a child.
|
||||
// @return - The error code returned.
|
||||
Status AddChild(std::shared_ptr<DatasetOp> child);
|
||||
|
||||
// A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler
|
||||
// @param std::shared_ptr<Tensor>* sampleIds
|
||||
// @param int64_t numElements - must be a non 0 number
|
||||
// @return
|
||||
// @return - The error code returned.
|
||||
Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements);
|
||||
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) {
|
||||
sampler.Print(out, false);
|
||||
return out;
|
||||
}
|
||||
|
||||
// Checks if this sampler has a child sampler.
|
||||
// @return - tre if there is a child sampler, false otherwise.
|
||||
bool HasChildSampler();
|
||||
|
||||
// Uses id as an index for the list of ids generated by the child sampler, and gets the
|
||||
// associated id.
|
||||
// @param int64_t* out_associated_id - Out parameter, contains the associated id.
|
||||
// @param int64_t id - The id used as an index to get the associated child id.
|
||||
// @return - The error code returned.
|
||||
Status GetAssociatedChildId(int64_t *out_associated_id, int64_t id);
|
||||
|
||||
protected:
|
||||
// Number of rows of data from the place this sampler is sampling from. If this sampler
|
||||
// has a child sampler, num_rows_ is the number of ids the child sampler will
|
||||
// output. Otherwise, num_rows_ is the number of rows in the dataset.
|
||||
int64_t num_rows_;
|
||||
|
||||
// Number of ids this sampler will return.
|
||||
int64_t num_samples_;
|
||||
|
||||
// The max number of ids a DataBuffer returned by this sampler will contain.
|
||||
int64_t samples_per_buffer_;
|
||||
std::unique_ptr<ColDescriptor> col_desc_;
|
||||
std::unique_ptr<DataBuffer> child_ids_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -27,14 +28,26 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
|
|||
} else if (next_id_ == num_samples_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
}
|
||||
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
|
||||
std::shared_ptr<Tensor> sampleIds;
|
||||
int64_t lastId = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_;
|
||||
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, lastId - next_id_));
|
||||
int64_t *idPtr = reinterpret_cast<int64_t *>(sampleIds->GetMutableBuffer());
|
||||
while (next_id_ < lastId) {
|
||||
*(idPtr++) = next_id_++;
|
||||
int64_t sampled_id = next_id_;
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
|
||||
}
|
||||
|
||||
*idPtr = sampled_id;
|
||||
next_id_++;
|
||||
idPtr++;
|
||||
}
|
||||
|
||||
TensorRow row(1, sampleIds);
|
||||
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
|
||||
}
|
||||
|
@ -43,6 +56,10 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
|
|||
|
||||
Status SequentialSampler::InitSampler() {
|
||||
num_samples_ = (num_samples_ <= 0) ? num_rows_ : num_samples_; // if num_samples < 0, try if num_rows is set
|
||||
if (HasChildSampler()) {
|
||||
num_samples_ = std::min(num_samples_, num_rows_);
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler");
|
||||
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
|
||||
return Status::OK();
|
||||
|
@ -51,7 +68,15 @@ Status SequentialSampler::InitSampler() {
|
|||
Status SequentialSampler::Reset() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
|
||||
next_id_ = 0;
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void SequentialSampler::Print(std::ostream &out, bool show_all) const { out << "(sampler): SequentialSampler\n"; }
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -45,6 +45,8 @@ class SequentialSampler : public Sampler {
|
|||
// @return - The error code return
|
||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
private:
|
||||
int64_t next_id_;
|
||||
};
|
||||
|
|
|
@ -34,6 +34,8 @@ SubsetRandomSampler::SubsetRandomSampler(const std::vector<int64_t> &indices, in
|
|||
Status SubsetRandomSampler::InitSampler() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n");
|
||||
|
||||
num_samples_ = indices_.size();
|
||||
|
||||
// Initialize random generator with seed from config manager
|
||||
rand_gen_.seed(GetSeed());
|
||||
|
||||
|
@ -56,6 +58,10 @@ Status SubsetRandomSampler::Reset() {
|
|||
rand_gen_.seed(GetSeed());
|
||||
std::shuffle(indices_.begin(), indices_.end(), rand_gen_);
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -65,6 +71,10 @@ Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffe
|
|||
if (sample_id_ == indices_.size()) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
}
|
||||
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone);
|
||||
std::shared_ptr<Tensor> outputIds;
|
||||
|
||||
|
@ -87,7 +97,14 @@ Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffe
|
|||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
*(id_ptr++) = indices_[sample_id_++];
|
||||
int64_t sampled_id = indices_[sample_id_];
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
|
||||
}
|
||||
|
||||
*id_ptr = sampled_id;
|
||||
id_ptr++;
|
||||
sample_id_++;
|
||||
}
|
||||
|
||||
// Create a TensorTable from that single tensor and push into DataBuffer
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/config_manager.h"
|
||||
#include "dataset/core/global_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Constructor.
|
||||
SubsetSampler::SubsetSampler(int64_t start_index, int64_t subset_size)
|
||||
: Sampler(subset_size), start_index_(start_index), subset_size_(subset_size), current_id_(0) {}
|
||||
|
||||
Status SubsetSampler::InitSampler() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size_ <= 0\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows_\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ + subset_size_ - 1 < num_rows_, "Final index out of bounds.\n");
|
||||
|
||||
num_samples_ = subset_size_;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SubsetSampler::Reset() {
|
||||
current_id_ = 0;
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SubsetSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
if (current_id_ > subset_size_) {
|
||||
RETURN_STATUS_UNEXPECTED("SubsetSampler Internal Error");
|
||||
} else if (current_id_ == subset_size_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
}
|
||||
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagNone);
|
||||
std::shared_ptr<Tensor> sampled_ids;
|
||||
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampled_ids, subset_size_));
|
||||
|
||||
int64_t *sampled_ids_start_addr = reinterpret_cast<int64_t *>(sampled_ids->GetMutableBuffer());
|
||||
|
||||
while (current_id_ < subset_size_) {
|
||||
int64_t sampled_id = start_index_ + current_id_;
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
|
||||
}
|
||||
|
||||
*(sampled_ids_start_addr + current_id_) = sampled_id;
|
||||
current_id_++;
|
||||
}
|
||||
|
||||
TensorRow sampled_ids_row(1, sampled_ids);
|
||||
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, sampled_ids_row));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
|
||||
#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class SubsetSampler : public Sampler {
|
||||
public:
|
||||
// Constructor.
|
||||
// @param start_index The index we start sampling from.
|
||||
explicit SubsetSampler(int64_t start_index, int64_t subset_size);
|
||||
|
||||
// Destructor.
|
||||
~SubsetSampler() = default;
|
||||
|
||||
// Initialize the sampler.
|
||||
// @return Status
|
||||
Status InitSampler() override;
|
||||
|
||||
// Reset the internal variable to the initial state and reshuffle the indices.
|
||||
// @return Status
|
||||
Status Reset() override;
|
||||
|
||||
// Get the sample ids.
|
||||
// @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.
|
||||
// @note the sample ids (int64_t) will be placed in one Tensor.
|
||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
private:
|
||||
int64_t start_index_;
|
||||
int64_t subset_size_;
|
||||
int64_t current_id_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
|
|
@ -40,6 +40,8 @@ WeightedRandomSampler::WeightedRandomSampler(const std::vector<double> &weights,
|
|||
Status WeightedRandomSampler::InitSampler() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && user_num_samples_, "num_samples & num_rows need to be positive");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n");
|
||||
num_samples_ = user_num_samples_;
|
||||
|
||||
// Initialize random generator with seed from config manager
|
||||
rand_gen_.seed(GetSeed());
|
||||
|
||||
|
@ -81,6 +83,11 @@ Status WeightedRandomSampler::Reset() {
|
|||
} else {
|
||||
discrete_dist_->reset();
|
||||
}
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -98,6 +105,10 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
|
|||
if (sample_id_ == user_num_samples_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
}
|
||||
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone);
|
||||
std::shared_ptr<Tensor> outputIds;
|
||||
|
||||
|
@ -127,7 +138,12 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
|
|||
RETURN_STATUS_UNEXPECTED("generated id is bigger than numRows (out of bound).");
|
||||
}
|
||||
|
||||
*(id_ptr++) = genId;
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(GetAssociatedChildId(&genId, genId));
|
||||
}
|
||||
|
||||
*id_ptr = genId;
|
||||
id_ptr++;
|
||||
sample_id_++;
|
||||
}
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ from .validators import check, check_batch, check_shuffle, check_map, check_filt
|
|||
check_rename, \
|
||||
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
||||
check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \
|
||||
check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat
|
||||
check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, check_split
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
|
||||
try:
|
||||
|
@ -581,6 +581,117 @@ class Dataset:
|
|||
return self
|
||||
return TakeDataset(self, count)
|
||||
|
||||
def _get_absolute_split_sizes(self, sizes):
|
||||
"""
|
||||
Internal method called by split to calculate absolute split sizes and to
|
||||
do some error checking after calculating absolute split sizes.
|
||||
"""
|
||||
# call get_dataset_size here and check input here because
|
||||
# dont want to call this once in check_split and another time in
|
||||
# here again
|
||||
dataset_size = self.get_dataset_size()
|
||||
|
||||
if(dataset_size is None or dataset_size <= 0):
|
||||
raise RuntimeError("dataset size unknown, unable to split.")
|
||||
|
||||
all_int = all(isinstance(item, int) for item in sizes)
|
||||
if all_int:
|
||||
sizes_sum = sum(sizes)
|
||||
if sizes_sum != dataset_size:
|
||||
raise RuntimeError("sum of split sizes {} is not equal to dataset size {}."
|
||||
.format(sizes_sum, dataset_size))
|
||||
return sizes
|
||||
|
||||
absolute_sizes = []
|
||||
for item in sizes:
|
||||
absolute_size = int(round(item * dataset_size))
|
||||
if absolute_size == 0:
|
||||
raise RuntimeError("split percentage {} is too small.".format(item))
|
||||
absolute_sizes.append(absolute_size)
|
||||
|
||||
absolute_sizes_sum = sum(absolute_sizes)
|
||||
if absolute_sizes_sum != dataset_size:
|
||||
raise RuntimeError("sum of calculated split sizes {} is not equal to dataset size {}."
|
||||
.format(absolute_sizes_sum, dataset_size))
|
||||
|
||||
return absolute_sizes
|
||||
|
||||
@check_split
|
||||
def split(self, sizes, randomize=True):
|
||||
"""
|
||||
Splits the dataset into smaller, non-overlapping datasets.
|
||||
|
||||
This is a general purpose split function which can be called from any operator in the pipeline.
|
||||
There is another, optimized split function, which will be called automatically if ds.split is
|
||||
called where ds is a MappableDataset.
|
||||
|
||||
Args:
|
||||
sizes (list of int or list of float): If a list of integers [s1, s2, …, sn] is
|
||||
provided, the dataset will be split into n datasets of size s1, size s2, …, size sn
|
||||
respectively. If the sum of all sizes does not equal the original dataset size, an
|
||||
an error will occur.
|
||||
If a list of floats [f1, f2, …, fn] is provided, the dataset will be split into n
|
||||
Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size
|
||||
of the original dataset. If after rounding, any size equals 0, an error will occur.
|
||||
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
|
||||
randomize (bool): determines whether or not to split the data randomly. If true, the data
|
||||
will be randomly split. Otherwise, each split will be created with consecutive rows
|
||||
from the dataset.
|
||||
|
||||
Note:
|
||||
1. Dataset cannot be sharded if split is going to be called.
|
||||
2. It is strongly recommended to not shuffle the dataset, but use randomize=True instead.
|
||||
Shuffling the dataset may not be deterministic, which means the data in each split
|
||||
will be different in each epoch.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If get_dataset_size returns None or is not supported for this dataset.
|
||||
RuntimeError: If sizes is list of integers and sum of all elements in sizes does not
|
||||
equal the dataset size.
|
||||
RuntimeError: If sizes is list of float and there is a split with size 0 after calculations.
|
||||
RuntimeError: If the dataset is sharded prior to calling split.
|
||||
ValueError: If sizes is list of float and not all floats are between 0 and 1, or if the
|
||||
floats don’t sum to 1.
|
||||
|
||||
Returns
|
||||
tuple(Dataset), a tuple of datasets that have been split.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>>
|
||||
>>> dataset_dir = "/path/to/text_file.txt"
|
||||
>>>
|
||||
>>> # TextFileDataset is not a mappable dataset, so this non optimized split will be called.
|
||||
>>> # many datasets have shuffle on by default, set shuffle to False if split will be called!
|
||||
>>> data = ds.TextFileDataset(dataset_dir, shuffle=False)
|
||||
>>> train, test = data.split([0.9, 0.1])
|
||||
"""
|
||||
if self.is_shuffled():
|
||||
logger.warning("dataset is shuffled before split.")
|
||||
|
||||
if self.is_sharded():
|
||||
raise RuntimeError("dataset should not be sharded before split.")
|
||||
|
||||
absolute_sizes = self._get_absolute_split_sizes(sizes)
|
||||
splits = []
|
||||
rows_to_skip = 0
|
||||
for size in absolute_sizes:
|
||||
ds = copy.deepcopy(self)
|
||||
if randomize:
|
||||
# want to shuffle the same way every epoch before split
|
||||
ds = ds.shuffle()
|
||||
ds.reshuffle_each_epoch = False
|
||||
|
||||
if rows_to_skip > 0:
|
||||
ds = ds.skip(rows_to_skip)
|
||||
|
||||
ds = ds.take(size)
|
||||
splits.append(ds)
|
||||
|
||||
rows_to_skip += size
|
||||
|
||||
return tuple(splits)
|
||||
|
||||
@check_zip_dataset
|
||||
def zip(self, datasets):
|
||||
"""
|
||||
|
@ -1053,10 +1164,24 @@ class Dataset:
|
|||
def reset(self):
|
||||
"""Reset the dataset for next epoch."""
|
||||
|
||||
def is_shuffled(self):
|
||||
for input_dataset in self.input:
|
||||
if input_dataset.is_shuffled():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def is_sharded(self):
|
||||
for input_dataset in self.input:
|
||||
if input_dataset.is_sharded():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class SourceDataset(Dataset):
|
||||
"""
|
||||
Abstract class to represent a source dataset which produces content to the data pipeline.
|
||||
Abstract class to represent a source dataset which produces content to the data pipeline.
|
||||
"""
|
||||
|
||||
# No need for __init__ since it is the same as the super's init
|
||||
|
@ -1093,6 +1218,150 @@ class SourceDataset(Dataset):
|
|||
return file_list
|
||||
raise ValueError("The list of path names matching the patterns is empty.")
|
||||
|
||||
def is_shuffled(self):
|
||||
raise NotImplementedError("SourceDataset must implement is_shuffled.")
|
||||
|
||||
def is_sharded(self):
|
||||
raise NotImplementedError("SourceDataset must implement is_sharded.")
|
||||
|
||||
class MappableDataset(SourceDataset):
|
||||
"""
|
||||
Abstract class to represent a source dataset which supports use of samplers.
|
||||
"""
|
||||
|
||||
def __init__(self, num_parallel_workers=None):
|
||||
# check if all subclasses use this name
|
||||
super().__init__(num_parallel_workers)
|
||||
self.sampler = None
|
||||
|
||||
def add_sampler(self, new_sampler):
|
||||
# note: by adding a sampler, we mean that the sampled ids will flow to new_sampler
|
||||
# after first passing through the current samplers attached to this dataset.
|
||||
new_sampler.add_child(self.sampler)
|
||||
self.sampler = new_sampler
|
||||
|
||||
def use_sampler(self, new_sampler):
|
||||
"""
|
||||
Will make the current dataset use the new_sampler provided.
|
||||
|
||||
Args:
|
||||
new_sampler (Sampler): the sampler to use for the current dataset.
|
||||
|
||||
Returns:
|
||||
Dataset, that uses new_sampler.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>>
|
||||
>>> dataset_dir = "/path/to/imagefolder_directory"
|
||||
>>> # a SequentialSampler is created by default
|
||||
>>> data = ds.ImageFolderDatasetV2(dataset_dir)
|
||||
>>>
|
||||
>>> # use a DistributedSampler instead of the SequentialSampler
|
||||
>>> new_sampler = ds.DistributedSampler(10, 2)
|
||||
>>> data.use_sampler(new_sampler)
|
||||
"""
|
||||
self.sampler = self.sampler.child_sampler
|
||||
self.add_sampler(new_sampler)
|
||||
|
||||
def is_shuffled(self):
|
||||
raise NotImplementedError("MappableDataset must implement is_shuffled.")
|
||||
|
||||
def is_sharded(self):
|
||||
raise NotImplementedError("MappableDataset must implement is_sharded.")
|
||||
|
||||
|
||||
@check_split
|
||||
def split(self, sizes, randomize=True):
|
||||
"""
|
||||
Splits the dataset into smaller, non-overlapping datasets.
|
||||
|
||||
There is the optimized split function, which will be called automatically when the dataset
|
||||
that calls this function is a MappableDataset.
|
||||
|
||||
Args:
|
||||
sizes (list of int or list of float): If a list of integers [s1, s2, …, sn] is
|
||||
provided, the dataset will be split into n datasets of size s1, size s2, …, size sn
|
||||
respectively. If the sum of all sizes does not equal the original dataset size, an
|
||||
an error will occur.
|
||||
If a list of floats [f1, f2, …, fn] is provided, the dataset will be split into n
|
||||
Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size
|
||||
of the original dataset. If after rounding, any size equals 0, an error will occur.
|
||||
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
|
||||
randomize (bool): determines whether or not to split the data randomly. If true, the data
|
||||
will be randomly split. Otherwise, each split will be created with consecutive rows
|
||||
from the dataset.
|
||||
|
||||
Note:
|
||||
1. Dataset should not be sharded if split is going to be called. Instead, create a
|
||||
DistributedSampler and specify a split to shard after splitting. If dataset is
|
||||
sharded after a split, it is strongly recommended to set the same seed in each instance
|
||||
of execution, otherwise each shard may not be part of the same split (see Examples)
|
||||
2. It is strongly recommended to not shuffle the dataset, but use randomize=True instead.
|
||||
Shuffling the dataset may not be deterministic, which means the data in each split
|
||||
will be different in each epoch. Furthermore, if sharding occurs after split, each
|
||||
shard may not be part of the same split.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If get_dataset_size returns None or is not supported for this dataset.
|
||||
RuntimeError: If sizes is list of integers and sum of all elements in sizes does not
|
||||
equal the dataset size.
|
||||
RuntimeError: If sizes is list of float and there is a split with size 0 after calculations.
|
||||
RuntimeError: If the dataset is sharded prior to calling split.
|
||||
ValueError: If sizes is list of float and not all floats are between 0 and 1, or if the
|
||||
floats don’t sum to 1.
|
||||
|
||||
Returns
|
||||
tuple(Dataset), a tuple of datasets that have been split.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>>
|
||||
>>> dataset_dir = "/path/to/imagefolder_directory"
|
||||
>>>
|
||||
>>> # many datasets have shuffle on by default, set shuffle to False if split will be called!
|
||||
>>> data = ds.ImageFolderDatasetV2(dataset_dir, shuffle=False)
|
||||
>>>
|
||||
>>> # sets the seed, and tells split to use this seed when randomizing. This
|
||||
>>> # is needed because we are sharding later
|
||||
>>> ds.config.set_seed(58)
|
||||
>>> train, test = data.split([0.9, 0.1])
|
||||
>>>
|
||||
>>> # if we want to shard the train dataset, we can use a DistributedSampler
|
||||
>>> train_sampler = ds.DistributedSampler(10, 2)
|
||||
>>> train.use_sampler(train_sampler)
|
||||
"""
|
||||
if self.is_shuffled():
|
||||
logger.warning("dataset is shuffled before split.")
|
||||
|
||||
if self.is_sharded():
|
||||
raise RuntimeError("dataset should not be sharded before split.")
|
||||
|
||||
absolute_sizes = self._get_absolute_split_sizes(sizes)
|
||||
splits = []
|
||||
current_split_start_index = 0
|
||||
for size in absolute_sizes:
|
||||
ds = copy.deepcopy(self)
|
||||
if randomize:
|
||||
# want to shuffle the same way every epoch before split, we are assuming
|
||||
# that the user will call set_seed
|
||||
random_sampler = samplers.RandomSampler()
|
||||
random_sampler.reshuffle_each_epoch = False
|
||||
ds.add_sampler(random_sampler)
|
||||
|
||||
subset_sampler = samplers.SubsetSampler(current_split_start_index, size)
|
||||
ds.add_sampler(subset_sampler)
|
||||
|
||||
# add sequential sampler, so that if user calls use_sampler, we will
|
||||
# get rid of the sequential sampler instead of something we need
|
||||
ds.add_sampler(samplers.SequentialSampler())
|
||||
|
||||
splits.append(ds)
|
||||
|
||||
current_split_start_index += size
|
||||
|
||||
return tuple(splits)
|
||||
|
||||
|
||||
class DatasetOp(Dataset):
|
||||
"""
|
||||
|
@ -1334,6 +1603,7 @@ class SyncWaitDataset(DatasetOp):
|
|||
flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset)
|
||||
return flag
|
||||
|
||||
|
||||
class ShuffleDataset(DatasetOp):
|
||||
"""
|
||||
The result of applying Shuffle operator to the input Dataset.
|
||||
|
@ -1350,6 +1620,7 @@ class ShuffleDataset(DatasetOp):
|
|||
super().__init__()
|
||||
self.buffer_size = buffer_size
|
||||
self.input.append(input_dataset)
|
||||
self.reshuffle_each_epoch = None
|
||||
input_dataset.output.append(self)
|
||||
self._input_indexs = input_dataset.input_indexs
|
||||
if self.is_sync():
|
||||
|
@ -1358,8 +1629,14 @@ class ShuffleDataset(DatasetOp):
|
|||
def get_args(self):
|
||||
args = super().get_args()
|
||||
args["buffer_size"] = self.buffer_size
|
||||
if self.reshuffle_each_epoch is not None:
|
||||
args["reshuffle_each_epoch"] = self.reshuffle_each_epoch
|
||||
|
||||
return args
|
||||
|
||||
def is_shuffled(self):
|
||||
return True
|
||||
|
||||
|
||||
# Pyfunc collection for multiprocess pyfunc
|
||||
# This global variable will only be used within subprocesses
|
||||
|
@ -1989,8 +2266,14 @@ class StorageDataset(SourceDataset):
|
|||
self._get_pipeline_info()
|
||||
return self._num_classes
|
||||
|
||||
def is_shuffled(self):
|
||||
return False
|
||||
|
||||
class RangeDataset(SourceDataset):
|
||||
def is_sharded(self):
|
||||
return False
|
||||
|
||||
|
||||
class RangeDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset that reads and parses datasets stored on disk in a range.
|
||||
|
||||
|
@ -2013,6 +2296,12 @@ class RangeDataset(SourceDataset):
|
|||
args["step"] = self.step
|
||||
return args
|
||||
|
||||
def is_shuffled(self):
|
||||
return False
|
||||
|
||||
def is_sharded(self):
|
||||
return False
|
||||
|
||||
|
||||
def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
||||
"""
|
||||
|
@ -2052,7 +2341,7 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
|||
return samplers.SequentialSampler()
|
||||
|
||||
|
||||
class ImageFolderDatasetV2(SourceDataset):
|
||||
class ImageFolderDatasetV2(MappableDataset):
|
||||
"""
|
||||
A source dataset that reads images from a tree of directories.
|
||||
|
||||
|
@ -2190,8 +2479,20 @@ class ImageFolderDatasetV2(SourceDataset):
|
|||
num_samples = self.num_samples
|
||||
return ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[1]
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
return True
|
||||
|
||||
class MnistDataset(SourceDataset):
|
||||
return self.shuffle_level or self.sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.num_shards is not None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return self.sampler.is_sharded()
|
||||
|
||||
|
||||
class MnistDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing the Mnist dataset.
|
||||
|
||||
|
@ -2294,6 +2595,18 @@ class MnistDataset(SourceDataset):
|
|||
|
||||
return get_num_rows(num_rows, self.num_shards)
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
return True
|
||||
|
||||
return self.shuffle_level or self.sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.num_shards is not None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return self.sampler.is_sharded()
|
||||
|
||||
|
||||
class MindDataset(SourceDataset):
|
||||
"""
|
||||
|
@ -2400,6 +2713,18 @@ class MindDataset(SourceDataset):
|
|||
num_rows = num_rows // self.partitions[0] + 1
|
||||
return num_rows
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.global_shuffle is None:
|
||||
return True
|
||||
|
||||
return self.global_shuffle or self.sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.num_shards is not None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return self.sampler.is_sharded()
|
||||
|
||||
|
||||
def _iter_fn(dataset, num_samples):
|
||||
"""
|
||||
|
@ -2609,7 +2934,7 @@ class _GeneratorWorker(multiprocessing.Process):
|
|||
self.terminate()
|
||||
|
||||
|
||||
class GeneratorDataset(SourceDataset):
|
||||
class GeneratorDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset that generate data from python by invoking python data source each epoch.
|
||||
|
||||
|
@ -2794,6 +3119,12 @@ class GeneratorDataset(SourceDataset):
|
|||
|
||||
return new_op
|
||||
|
||||
def is_shuffled(self):
|
||||
return self.sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
return self.sampler.is_sharded()
|
||||
|
||||
|
||||
class TFRecordDataset(SourceDataset):
|
||||
"""
|
||||
|
@ -2920,8 +3251,17 @@ class TFRecordDataset(SourceDataset):
|
|||
else:
|
||||
raise ValueError('set dataset_size with negative value {}'.format(value))
|
||||
|
||||
def is_shuffled(self):
|
||||
return self.shuffle_files
|
||||
|
||||
class ManifestDataset(SourceDataset):
|
||||
def is_sharded(self):
|
||||
if self.num_shards is not None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class ManifestDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset that reads images from a manifest file.
|
||||
|
||||
|
@ -3088,8 +3428,20 @@ class ManifestDataset(SourceDataset):
|
|||
|
||||
return ManifestOp.get_class_indexing(self.dataset_file, num_samples, class_indexing, self.usage)
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
return True
|
||||
|
||||
class Cifar10Dataset(SourceDataset):
|
||||
return self.shuffle_level or self.sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.num_shards is not None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return self.sampler.is_sharded()
|
||||
|
||||
|
||||
class Cifar10Dataset(MappableDataset):
|
||||
"""
|
||||
A source dataset that reads cifar10 data.
|
||||
|
||||
|
@ -3197,8 +3549,20 @@ class Cifar10Dataset(SourceDataset):
|
|||
|
||||
return get_num_rows(num_rows, self.num_shards)
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
return True
|
||||
|
||||
class Cifar100Dataset(SourceDataset):
|
||||
return self.shuffle_level or self.sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.num_shards is not None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return self.sampler.is_sharded()
|
||||
|
||||
|
||||
class Cifar100Dataset(MappableDataset):
|
||||
"""
|
||||
A source dataset that reads cifar100 data.
|
||||
|
||||
|
@ -3304,6 +3668,18 @@ class Cifar100Dataset(SourceDataset):
|
|||
|
||||
return get_num_rows(num_rows, self.num_shards)
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
return True
|
||||
|
||||
return self.shuffle_level or self.sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.num_shards is not None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return self.sampler.is_sharded()
|
||||
|
||||
|
||||
class RandomDataset(SourceDataset):
|
||||
"""
|
||||
|
@ -3355,6 +3731,11 @@ class RandomDataset(SourceDataset):
|
|||
"""
|
||||
return num_samples
|
||||
|
||||
def is_shuffled(self):
|
||||
return True
|
||||
|
||||
def is_sharded(self):
|
||||
return False
|
||||
|
||||
class Schema:
|
||||
"""
|
||||
|
@ -3534,7 +3915,7 @@ class Schema:
|
|||
return self.to_json()
|
||||
|
||||
|
||||
class VOCDataset(SourceDataset):
|
||||
class VOCDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing VOC dataset.
|
||||
|
||||
|
@ -3681,8 +4062,20 @@ class VOCDataset(SourceDataset):
|
|||
|
||||
return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.mode, class_indexing, num_samples)
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
return True
|
||||
|
||||
class CelebADataset(SourceDataset):
|
||||
return self.shuffle_level or self.sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.num_shards is not None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return self.sampler.is_sharded()
|
||||
|
||||
|
||||
class CelebADataset(MappableDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing CelebA dataset.Only support list_attr_celeba.txt currently.
|
||||
|
||||
|
@ -3735,6 +4128,18 @@ class CelebADataset(SourceDataset):
|
|||
args["shard_id"] = self.shard_id
|
||||
return args
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
return True
|
||||
|
||||
return self.shuffle_level or self.sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.num_shards is not None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return self.sampler.is_sharded()
|
||||
|
||||
|
||||
class TextFileDataset(SourceDataset):
|
||||
"""
|
||||
|
@ -3814,3 +4219,12 @@ class TextFileDataset(SourceDataset):
|
|||
return num_rows
|
||||
return min(self.num_samples, num_rows)
|
||||
return self._dataset_size
|
||||
|
||||
def is_shuffled(self):
|
||||
return self.shuffle_files
|
||||
|
||||
def is_sharded(self):
|
||||
if self.num_shards is not None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return False
|
||||
|
|
|
@ -47,6 +47,7 @@ class Sampler:
|
|||
def __init__(self):
|
||||
self.dataset_size = 0
|
||||
self.num_samples = 0
|
||||
self.child_sampler = None
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
|
@ -83,7 +84,35 @@ class Sampler:
|
|||
# Instance fetcher
|
||||
# Do not override this method!
|
||||
def create(self):
|
||||
return cde.PythonSampler(self)
|
||||
c_sampler = cde.PythonSampler(self)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def add_child(self, sampler):
|
||||
self.child_sampler = sampler
|
||||
|
||||
def get_child(self):
|
||||
return self.child_sampler
|
||||
|
||||
def create_child(self):
|
||||
c_child_sampler = None
|
||||
if self.child_sampler is not None:
|
||||
c_child_sampler = self.child_sampler.create()
|
||||
|
||||
return c_child_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
|
||||
class BuiltinSampler:
|
||||
|
@ -93,11 +122,30 @@ class BuiltinSampler:
|
|||
User should not extend this class.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
self.child_sampler = None
|
||||
|
||||
def create(self):
|
||||
pass
|
||||
|
||||
def add_child(self, sampler):
|
||||
self.child_sampler = sampler
|
||||
|
||||
def get_child(self):
|
||||
return self.child_sampler
|
||||
|
||||
def create_child(self):
|
||||
c_child_sampler = None
|
||||
if self.child_sampler is not None:
|
||||
c_child_sampler = self.child_sampler.create()
|
||||
|
||||
return c_child_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
raise NotImplementedError("Sampler must implement is_shuffled.")
|
||||
|
||||
def is_sharded(self):
|
||||
raise NotImplementedError("Sampler must implement is_sharded.")
|
||||
|
||||
|
||||
class DistributedSampler(BuiltinSampler):
|
||||
"""
|
||||
|
@ -142,7 +190,22 @@ class DistributedSampler(BuiltinSampler):
|
|||
def create(self):
|
||||
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
|
||||
self.seed += 1
|
||||
return cde.DistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed)
|
||||
c_sampler = cde.DistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.child_sampler is None:
|
||||
return self.shuffle
|
||||
|
||||
return self.child_sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.child_sampler is None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
|
||||
class PKSampler(BuiltinSampler):
|
||||
|
@ -186,7 +249,22 @@ class PKSampler(BuiltinSampler):
|
|||
super().__init__()
|
||||
|
||||
def create(self):
|
||||
return cde.PKSampler(self.num_val, self.shuffle)
|
||||
c_sampler = cde.PKSampler(self.num_val, self.shuffle)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.child_sampler is None:
|
||||
return self.shuffle
|
||||
|
||||
return self.child_sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
def _create_for_minddataset(self):
|
||||
if not self.class_column or not isinstance(self.class_column, str):
|
||||
|
@ -226,15 +304,31 @@ class RandomSampler(BuiltinSampler):
|
|||
raise ValueError("num_samples should be a positive integer "
|
||||
"value, but got num_samples={}".format(num_samples))
|
||||
|
||||
self.deterministic = False
|
||||
self.replacement = replacement
|
||||
self.num_samples = num_samples
|
||||
self.reshuffle_each_epoch = True
|
||||
super().__init__()
|
||||
|
||||
def create(self):
|
||||
# If num_samples is not specified, then call constructor #2
|
||||
c_sampler = None
|
||||
if self.num_samples is None:
|
||||
return cde.RandomSampler(self.replacement)
|
||||
return cde.RandomSampler(self.replacement, self.num_samples)
|
||||
c_sampler = cde.RandomSampler(self.replacement, self.reshuffle_each_epoch)
|
||||
else:
|
||||
c_sampler = cde.RandomSampler(self.replacement, self.reshuffle_each_epoch, self.num_samples)
|
||||
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
return True
|
||||
|
||||
def is_sharded(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
|
||||
class SequentialSampler(BuiltinSampler):
|
||||
|
@ -252,7 +346,80 @@ class SequentialSampler(BuiltinSampler):
|
|||
"""
|
||||
|
||||
def create(self):
|
||||
return cde.SequentialSampler()
|
||||
c_sampler = cde.SequentialSampler()
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
|
||||
class SubsetSampler(BuiltinSampler):
|
||||
"""
|
||||
Samples a subset of elements consecutively from a given index.
|
||||
|
||||
Args:
|
||||
start_index (int): Index to start sampling at.
|
||||
subset_size (int): How many samples to include in this subset.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>>
|
||||
>>> dataset_dir = "path/to/imagefolder_directory"
|
||||
>>>
|
||||
>>> # creates a SubsetSampler, will sample the next 5 images from the 100th image.
|
||||
>>> sampler = ds.SubsetSampler(100, 5)
|
||||
>>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler)
|
||||
|
||||
Raises:
|
||||
ValueError: If start_index is not a positive int.
|
||||
ValueError: If subset_size is not a positive int.
|
||||
"""
|
||||
|
||||
def __init__(self, start_index, subset_size):
|
||||
if not isinstance(start_index, int):
|
||||
raise ValueError("start_index should be an int.")
|
||||
|
||||
if start_index < 0:
|
||||
raise ValueError("start_index should not be negative.")
|
||||
|
||||
if not isinstance(subset_size, int):
|
||||
raise ValueError("start_index should be an int")
|
||||
|
||||
if subset_size < 0:
|
||||
raise ValueError("subset_size should not be negative.")
|
||||
|
||||
self.start_index = start_index
|
||||
self.subset_size = subset_size
|
||||
super().__init__()
|
||||
|
||||
def create(self):
|
||||
c_sampler = cde.SubsetSampler(self.start_index, self.subset_size)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
|
||||
class SubsetRandomSampler(BuiltinSampler):
|
||||
|
@ -282,7 +449,19 @@ class SubsetRandomSampler(BuiltinSampler):
|
|||
super().__init__()
|
||||
|
||||
def create(self):
|
||||
return cde.SubsetRandomSampler(self.indices)
|
||||
c_sampler = cde.SubsetRandomSampler(self.indices)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
return True
|
||||
|
||||
def is_sharded(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
def _create_for_minddataset(self):
|
||||
return cde.MindrecordSubsetRandomSampler(self.indices)
|
||||
|
@ -330,4 +509,16 @@ class WeightedRandomSampler(BuiltinSampler):
|
|||
super().__init__()
|
||||
|
||||
def create(self):
|
||||
return cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement)
|
||||
c_sampler = cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
return True
|
||||
|
||||
def is_sharded(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
|
|
@ -1031,3 +1031,44 @@ def check_textfiledataset(method):
|
|||
return method(*args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
def check_split(method):
|
||||
"""check the input arguments of split."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
||||
nreq_param_list = ['sizes']
|
||||
nreq_param_bool = ['randomize']
|
||||
check_param_type(nreq_param_list, param_dict, list)
|
||||
check_param_type(nreq_param_bool, param_dict, bool)
|
||||
|
||||
# check sizes: must be list of float or list of int
|
||||
sizes = param_dict.get('sizes')
|
||||
|
||||
if not sizes:
|
||||
raise ValueError("sizes cannot be empty.")
|
||||
all_int = all(isinstance(item, int) for item in sizes)
|
||||
all_float = all(isinstance(item, float) for item in sizes)
|
||||
|
||||
if not (all_int or all_float):
|
||||
raise ValueError("sizes should be list of int or list of float.")
|
||||
|
||||
if all_int:
|
||||
all_positive = all(item > 0 for item in sizes)
|
||||
if not all_positive:
|
||||
raise ValueError("sizes is a list of int, but there should be no negative numbers.")
|
||||
|
||||
if all_float:
|
||||
all_valid_percentages = all(0 < item <= 1 for item in sizes)
|
||||
if not all_valid_percentages:
|
||||
raise ValueError("sizes is a list of float, but there should be no numbers outside the range [0, 1].")
|
||||
|
||||
epsilon = 0.00001
|
||||
if not abs(sum(sizes) - 1) < epsilon:
|
||||
raise ValueError("sizes is a list of float, but the percentages do not sum up to 1.")
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -92,7 +92,7 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) {
|
|||
TEST_F(MindDataTestCifarOp, TestRandomSamplerCifar10) {
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
GlobalContext::config_manager()->set_seed(0);
|
||||
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, 12);
|
||||
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, true, 12);
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
auto tree = Build({Cifarop(16, 2, 32, folder_path, std::move(sampler), 100)});
|
||||
tree->Prepare();
|
||||
|
|
|
@ -138,7 +138,7 @@ TEST_F(MindDataTestImageFolderSampler, TestRandomImageFolder) {
|
|||
TEST_F(MindDataTestImageFolderSampler, TestRandomSamplerImageFolder) {
|
||||
int32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
GlobalContext::config_manager()->set_seed(0);
|
||||
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, 12);
|
||||
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, true, 12);
|
||||
int32_t res[] = {2, 2, 2, 3, 2, 3, 2, 3, 1, 2, 2, 1}; // ground truth label
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data";
|
||||
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))});
|
||||
|
|
|
@ -163,9 +163,36 @@ def test_python_sampler():
|
|||
assert list(sp1.get_indices()) == [0, 1, 2, 3, 4]
|
||||
|
||||
|
||||
def test_sampler_chain():
|
||||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
|
||||
def test_config(num_shards, shard_id):
|
||||
sampler = ds.DistributedSampler(num_shards, shard_id, False)
|
||||
child_sampler = ds.SequentialSampler()
|
||||
sampler.add_child(child_sampler)
|
||||
|
||||
data1 = ds.ManifestDataset(manifest_file, num_samples=5, sampler=sampler)
|
||||
|
||||
res = []
|
||||
for item in data1.create_dict_iterator():
|
||||
logger.info("item[image].shape[0]: {}, item[label].item(): {}"
|
||||
.format(item["image"].shape[0], item["label"].item()))
|
||||
res.append(map[(item["image"].shape[0], item["label"].item())])
|
||||
return res
|
||||
|
||||
assert test_config(2, 0) == [0, 2, 4]
|
||||
assert test_config(2, 1) == [1, 3, 0]
|
||||
assert test_config(5, 0) == [0]
|
||||
assert test_config(5, 1) == [1]
|
||||
assert test_config(5, 2) == [2]
|
||||
assert test_config(5, 3) == [3]
|
||||
assert test_config(5, 4) == [4]
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sequential_sampler(True)
|
||||
test_random_sampler(True)
|
||||
test_random_sampler_multi_iter(True)
|
||||
test_sampler_py_api()
|
||||
test_python_sampler()
|
||||
test_sampler_chain()
|
||||
|
|
|
@ -0,0 +1,342 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
|
||||
# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
|
||||
# the label of each image is [0,0,0,1,1] each image can be uniquely identified
|
||||
# via the following lookup table (dict){(83554, 0): 0, (54214, 0): 1, (54214, 1): 2, (65512, 0): 3, (64631, 1): 4}
|
||||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
manifest_map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
|
||||
def split_with_invalid_inputs(d):
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([])
|
||||
assert "sizes cannot be empty" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([5, 0.6])
|
||||
assert "sizes should be list of int or list of float" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([-1, 6])
|
||||
assert "there should be no negative numbers" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
s1, s2 = d.split([3, 1])
|
||||
assert "sum of split sizes 4 is not equal to dataset size 5" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
s1, s2 = d.split([5, 1])
|
||||
assert "sum of split sizes 6 is not equal to dataset size 5" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
s1, s2 = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25])
|
||||
assert "sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([-0.5, 0.5])
|
||||
assert "there should be no numbers outside the range [0, 1]" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([1.5, 0.5])
|
||||
assert "there should be no numbers outside the range [0, 1]" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([0.5, 0.6])
|
||||
assert "percentages do not sum up to 1" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([0.3, 0.6])
|
||||
assert "percentages do not sum up to 1" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
s1, s2 = d.split([0.05, 0.95])
|
||||
assert "percentage 0.05 is too small" in str(info.value)
|
||||
|
||||
def test_unmappable_invalid_input():
|
||||
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
|
||||
d = ds.TextFileDataset(text_file_dataset_path)
|
||||
split_with_invalid_inputs(d)
|
||||
|
||||
d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0)
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
s1, s2 = d.split([4, 1])
|
||||
assert "dataset should not be sharded before split" in str(info.value)
|
||||
|
||||
def test_unmappable_split():
|
||||
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
|
||||
text_file_data = ["This is a text file.", "Another file.", "Be happy every day.",
|
||||
"End of file.", "Good luck to everyone."]
|
||||
ds.config.set_num_parallel_workers(4)
|
||||
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
|
||||
s1, s2 = d.split([4, 1], randomize=False)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
assert s1_output == text_file_data[0:4]
|
||||
assert s2_output == text_file_data[4:]
|
||||
|
||||
# exact percentages
|
||||
s1, s2 = d.split([0.8, 0.2], randomize=False)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
assert s1_output == text_file_data[0:4]
|
||||
assert s2_output == text_file_data[4:]
|
||||
|
||||
# fuzzy percentages
|
||||
s1, s2 = d.split([0.33, 0.67], randomize=False)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
assert s1_output == text_file_data[0:2]
|
||||
assert s2_output == text_file_data[2:]
|
||||
|
||||
def test_mappable_invalid_input():
|
||||
d = ds.ManifestDataset(manifest_file)
|
||||
split_with_invalid_inputs(d)
|
||||
|
||||
d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0)
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
s1, s2 = d.split([4, 1])
|
||||
assert "dataset should not be sharded before split" in str(info.value)
|
||||
|
||||
def test_mappable_split_general():
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
d = d.take(5)
|
||||
|
||||
# absolute rows
|
||||
s1, s2 = d.split([4, 1], randomize=False)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
assert s1_output == [0, 1, 2, 3]
|
||||
assert s2_output == [4]
|
||||
|
||||
# exact percentages
|
||||
s1, s2 = d.split([0.8, 0.2], randomize=False)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
assert s1_output == [0, 1, 2, 3]
|
||||
assert s2_output == [4]
|
||||
|
||||
# fuzzy percentages
|
||||
s1, s2 = d.split([0.33, 0.67], randomize=False)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
assert s1_output == [0, 1]
|
||||
assert s2_output == [2, 3, 4]
|
||||
|
||||
def test_mappable_split_optimized():
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
|
||||
# absolute rows
|
||||
s1, s2 = d.split([4, 1], randomize=False)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
assert s1_output == [0, 1, 2, 3]
|
||||
assert s2_output == [4]
|
||||
|
||||
# exact percentages
|
||||
s1, s2 = d.split([0.8, 0.2], randomize=False)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
assert s1_output == [0, 1, 2, 3]
|
||||
assert s2_output == [4]
|
||||
|
||||
# fuzzy percentages
|
||||
s1, s2 = d.split([0.33, 0.67], randomize=False)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
assert s1_output == [0, 1]
|
||||
assert s2_output == [2, 3, 4]
|
||||
|
||||
def test_mappable_randomize_deterministic():
|
||||
# set arbitrary seed for shard after split
|
||||
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
|
||||
ds.config.set_seed(53)
|
||||
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
s1, s2 = d.split([0.8, 0.2])
|
||||
|
||||
for _ in range(10):
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
# note no overlap
|
||||
assert s1_output == [0, 1, 3, 4]
|
||||
assert s2_output == [2]
|
||||
|
||||
def test_mappable_randomize_repeatable():
|
||||
# set arbitrary seed for shard after split
|
||||
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
|
||||
ds.config.set_seed(53)
|
||||
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
s1, s2 = d.split([0.8, 0.2])
|
||||
|
||||
num_epochs = 5
|
||||
s1 = s1.repeat(num_epochs)
|
||||
s2 = s2.repeat(num_epochs)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
# note no overlap
|
||||
assert s1_output == [0, 1, 3, 4] * num_epochs
|
||||
assert s2_output == [2] * num_epochs
|
||||
|
||||
def test_mappable_sharding():
|
||||
# set arbitrary seed for repeatability for shard after split
|
||||
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
|
||||
ds.config.set_seed(53)
|
||||
|
||||
num_epochs = 5
|
||||
first_split_num_rows = 4
|
||||
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
s1, s2 = d.split([first_split_num_rows, 1])
|
||||
|
||||
distributed_sampler = ds.DistributedSampler(2, 0)
|
||||
s1.use_sampler(distributed_sampler)
|
||||
|
||||
s1 = s1.repeat(num_epochs)
|
||||
|
||||
# testing sharding, second dataset to simulate another instance
|
||||
d2 = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
d2s1, d2s2 = d2.split([first_split_num_rows, 1])
|
||||
|
||||
distributed_sampler = ds.DistributedSampler(2, 1)
|
||||
d2s1.use_sampler(distributed_sampler)
|
||||
|
||||
d2s1 = d2s1.repeat(num_epochs)
|
||||
|
||||
# shard 0
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
# shard 1
|
||||
d2s1_output = []
|
||||
for item in d2s1.create_dict_iterator():
|
||||
d2s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
rows_per_shard_per_epoch = 2
|
||||
assert len(s1_output) == rows_per_shard_per_epoch * num_epochs
|
||||
assert len(d2s1_output) == rows_per_shard_per_epoch * num_epochs
|
||||
|
||||
# verify each epoch that
|
||||
# 1. shards contain no common elements
|
||||
# 2. the data was split the same way, and that the union of shards equal the split
|
||||
correct_sorted_split_result = [0, 1, 3, 4]
|
||||
for i in range(num_epochs):
|
||||
combined_data = []
|
||||
for j in range(rows_per_shard_per_epoch):
|
||||
combined_data.append(s1_output[i * rows_per_shard_per_epoch + j])
|
||||
combined_data.append(d2s1_output[i * rows_per_shard_per_epoch + j])
|
||||
|
||||
assert sorted(combined_data) == correct_sorted_split_result
|
||||
|
||||
# test other split
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
d2s2_output = []
|
||||
for item in d2s2.create_dict_iterator():
|
||||
d2s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
assert s2_output == [2]
|
||||
assert d2s2_output == [2]
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_unmappable_invalid_input()
|
||||
test_unmappable_split()
|
||||
test_mappable_invalid_input()
|
||||
test_mappable_split_general()
|
||||
test_mappable_split_optimized()
|
||||
test_mappable_randomize_deterministic()
|
||||
test_mappable_randomize_repeatable()
|
||||
test_mappable_sharding()
|
Loading…
Reference in New Issue