diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index b8302554c00..372863d765c 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -364,6 +364,18 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptrSetReshuffleEachEpoch(ToBool(args["reshuffle_each_epoch"])); + } + } + } + std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); *ptr = op; diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index ffedc8570e3..72030d498e7 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -49,6 +49,7 @@ #include "dataset/engine/datasetops/source/sampler/pk_sampler.h" #include "dataset/engine/datasetops/source/sampler/random_sampler.h" #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "dataset/engine/datasetops/source/sampler/subset_sampler.h" #include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" #include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" #include "dataset/engine/datasetops/source/sampler/python_sampler.h" @@ -411,11 +412,14 @@ void bindSamplerOps(py::module *m) { .def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); }) .def("set_num_samples", [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); }) .def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); }) - .def("get_indices", [](Sampler &self) { - py::array ret; - THROW_IF_ERROR(self.GetAllIdsThenReset(&ret)); - return ret; - }); + .def("get_indices", + [](Sampler &self) { + py::array ret; + THROW_IF_ERROR(self.GetAllIdsThenReset(&ret)); + return ret; + }) + .def("add_child", + [](std::shared_ptr self, std::shared_ptr child) { THROW_IF_ERROR(self->AddChild(child)); }); (void)py::class_>(*m, "ShardOperator"); @@ -427,12 +431,16 @@ void bindSamplerOps(py::module *m) { .def(py::init(), py::arg("kVal"), py::arg("shuffle")); (void)py::class_>(*m, "RandomSampler") - .def(py::init(), py::arg("replacement"), py::arg("numSamples")) - .def(py::init(), py::arg("replacement")); + .def(py::init(), py::arg("replacement"), py::arg("reshuffle_each_epoch"), + py::arg("num_samples")) + .def(py::init(), py::arg("replacement"), py::arg("reshuffle_each_epoch")); (void)py::class_>(*m, "SequentialSampler") .def(py::init<>()); + (void)py::class_>(*m, "SubsetSampler") + .def(py::init(), py::arg("start_index"), py::arg("subset_size")); + (void)py::class_>(*m, "SubsetRandomSampler") .def(py::init>(), py::arg("indices")); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt index 5209d9ba4ad..152b887ef44 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt @@ -8,5 +8,6 @@ add_library(engine-datasetops-source-sampler OBJECT sampler.cc sequential_sampler.cc subset_random_sampler.cc + subset_sampler.cc weighted_random_sampler.cc ) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc index 92dfbf594d8..d4e5a732db7 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc @@ -55,13 +55,27 @@ Status DistributedSampler::GetNextBuffer(std::unique_ptr *out_buffer } else if (cnt_ == samples_per_buffer_) { (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); + } + (*out_buffer) = std::make_unique(cnt_, DataBuffer::kDeBFlagNone); std::shared_ptr sample_ids; RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, samples_per_buffer_)); int64_t *id_ptr = reinterpret_cast(sample_ids->GetMutableBuffer()); while (cnt_ < samples_per_buffer_) { - int64_t next_id = (num_devices_ * (cnt_++) + device_id_) % num_rows_; - *(id_ptr++) = shuffle_ ? shuffle_vec_[static_cast(next_id)] : next_id; + int64_t sampled_id = (num_devices_ * cnt_ + device_id_) % num_rows_; + if (shuffle_) { + sampled_id = shuffle_vec_[static_cast(sampled_id)]; + } + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); + } + + *id_ptr = sampled_id; + id_ptr++; + cnt_++; } TensorRow row(1, sample_ids); (*out_buffer)->set_tensor_table(std::make_unique(1, row)); @@ -72,11 +86,29 @@ Status DistributedSampler::GetNextBuffer(std::unique_ptr *out_buffer Status DistributedSampler::Reset() { CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late"); cnt_ = 0; - rnd_.seed(seed_++); + if (shuffle_ == true) { + rnd_.seed(seed_); + seed_++; std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_); } + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->Reset()); + } + return Status::OK(); } + +void DistributedSampler::Print(std::ostream &out, bool show_all) const { + out << "(sampler): DistributedSampler\n"; + if (show_all) { + out << "seed_: " << seed_ << '\n'; + out << "device_id_: " << device_id_ << '\n'; + out << "num_devices_: " << num_devices_ << '\n'; + out << "shuffle_: " << shuffle_ << '\n'; + } +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h index 58b469dcc8f..29b5cda0da6 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h @@ -48,6 +48,8 @@ class DistributedSampler : public Sampler { // @return - The error code return Status Reset() override; + void Print(std::ostream &out, bool show_all) const override; + private: int64_t cnt_; // number of samples that have already been filled in to buffer uint32_t seed_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc index f4c1189b8c9..72c2cc18746 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc @@ -38,6 +38,7 @@ Status PKSampler::InitSampler() { rnd_.seed(seed_++); num_pk_samples_ = samples_per_class_ * static_cast(labels_.size()); samples_per_buffer_ = (samples_per_buffer_ > num_pk_samples_) ? num_pk_samples_ : samples_per_buffer_; + num_samples_ = num_pk_samples_; if (shuffle_ == true) { std::shuffle(labels_.begin(), labels_.end(), rnd_); } else { @@ -53,6 +54,10 @@ Status PKSampler::GetNextBuffer(std::unique_ptr *out_buffer) { } else if (next_id_ == num_pk_samples_) { (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); + } + (*out_buffer) = std::make_unique(next_id_, DataBuffer::kDeBFlagNone); std::shared_ptr sample_ids; int64_t last_id = @@ -63,8 +68,16 @@ Status PKSampler::GetNextBuffer(std::unique_ptr *out_buffer) { int64_t cls_id = next_id_++ / samples_per_class_; const std::vector &samples = label_to_ids_[labels_[cls_id]]; int64_t rnd_ind = std::uniform_int_distribution(0, samples.size() - 1)(rnd_); - *(id_ptr++) = samples[rnd_ind]; + int64_t sampled_id = samples[rnd_ind]; + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); + } + + *id_ptr = sampled_id; + id_ptr++; } + TensorRow row(1, sample_ids); (*out_buffer)->set_tensor_table(std::make_unique(1, row)); } @@ -75,6 +88,11 @@ Status PKSampler::Reset() { CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_pk_samples_, "ERROR Reset() called early/late"); next_id_ = 0; rnd_.seed(seed_++); + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->Reset()); + } + return Status::OK(); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc index 1747040141b..ca999e31a53 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc @@ -27,6 +27,10 @@ Status PythonSampler::GetNextBuffer(std::unique_ptr *out_buffer) { if (need_to_reset_) { (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); + } + std::shared_ptr sample_ids; { py::gil_scoped_acquire gil_acquire; @@ -38,6 +42,14 @@ Status PythonSampler::GetNextBuffer(std::unique_ptr *out_buffer) { py::object py_ret = py_sampler_instance.attr("_get_indices")(); py::array np_sample_ids = py_ret.cast(); Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor + + if (HasChildSampler()) { + for (auto it = sample_ids->begin(); it != sample_ids->end(); ++it) { + int64_t associated_child_id = 0; + RETURN_IF_NOT_OK(GetAssociatedChildId(&associated_child_id, associated_child_id)); + *it = associated_child_id; + } + } } catch (const py::error_already_set &e) { return Status(StatusCode::kPyFuncException, e.what()); } catch (const py::cast_error &e) { @@ -79,6 +91,11 @@ Status PythonSampler::Reset() { } catch (const py::error_already_set &e) { return Status(StatusCode::kPyFuncException, e.what()); } + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->Reset()); + } + return Status::OK(); } } // namespace dataset diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc index 967632a5d9f..96c83c3114d 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc @@ -14,18 +14,22 @@ * limitations under the License. */ #include "dataset/engine/datasetops/source/sampler/random_sampler.h" + +#include #include #include #include "dataset/util/random.h" namespace mindspore { namespace dataset { -RandomSampler::RandomSampler(bool replacement, int64_t num_samples, int64_t samples_per_buffer) +RandomSampler::RandomSampler(bool replacement, bool reshuffle_each_epoch, int64_t num_samples, + int64_t samples_per_buffer) : Sampler(samples_per_buffer), seed_(GetSeed()), replacement_(replacement), user_num_samples_(num_samples), next_id_(0), + reshuffle_each_epoch_(reshuffle_each_epoch), dist(nullptr) {} Status RandomSampler::GetNextBuffer(std::unique_ptr *out_buffer) { @@ -34,13 +38,29 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr *out_buffer) { } else if (next_id_ == num_samples_) { (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); + } (*out_buffer) = std::make_unique(next_id_, DataBuffer::kDeBFlagNone); + std::shared_ptr sampleIds; - int64_t last_id = samples_per_buffer_ + next_id_ > num_samples_ ? num_samples_ : samples_per_buffer_ + next_id_; + int64_t last_id = std::min(samples_per_buffer_ + next_id_, num_samples_); RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, last_id - next_id_)); int64_t *id_ptr = reinterpret_cast(sampleIds->GetMutableBuffer()); + for (int64_t i = 0; i < (last_id - next_id_); i++) { - *(id_ptr + i) = replacement_ ? (*dist)(rnd_) : shuffled_ids_[static_cast(i + next_id_)]; + int64_t sampled_id = 0; + if (replacement_) { + sampled_id = (*dist)(rnd_); + } else { + sampled_id = shuffled_ids_[static_cast(i + next_id_)]; + } + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); + } + + *(id_ptr + i) = sampled_id; } next_id_ = last_id; TensorRow row(1, sampleIds); @@ -53,7 +73,9 @@ Status RandomSampler::InitSampler() { num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_; CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive"); samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; - rnd_.seed(seed_++); + + rnd_.seed(seed_); + if (replacement_ == false) { shuffled_ids_.reserve(num_rows_); for (int64_t i = 0; i < num_rows_; i++) { @@ -69,11 +91,33 @@ Status RandomSampler::InitSampler() { Status RandomSampler::Reset() { CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); next_id_ = 0; - rnd_.seed(seed_++); - if (replacement_ == false) { + + if (reshuffle_each_epoch_) { + seed_++; + } + + rnd_.seed(seed_); + + if (replacement_ == false && reshuffle_each_epoch_) { std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_); } + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->Reset()); + } + return Status::OK(); } + +void RandomSampler::Print(std::ostream &out, bool show_all) const { + out << "(sampler): RandomSampler\n"; + + if (show_all) { + out << "user_num_samples_: " << user_num_samples_ << '\n'; + out << "num_samples_: " << num_samples_ << '\n'; + out << "next_id_: " << next_id_ << '\n'; + } +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h index 84a07e9fc6b..352751dbb8b 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h @@ -30,7 +30,8 @@ class RandomSampler : public Sampler { // @param bool replacement - put he id back / or not after a sample // @param int64_t numSamples - number samples to draw // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call - explicit RandomSampler(bool replacement = false, int64_t num_samples = std::numeric_limits::max(), + explicit RandomSampler(bool replacement = false, bool reshuffle_each_epoch = true, + int64_t num_samples = std::numeric_limits::max(), int64_t samples_per_buffer = std::numeric_limits::max()); // Destructor. @@ -49,6 +50,8 @@ class RandomSampler : public Sampler { // @return - The error code return Status Reset() override; + virtual void Print(std::ostream &out, bool show_all) const; + private: uint32_t seed_; bool replacement_; @@ -57,6 +60,7 @@ class RandomSampler : public Sampler { int64_t next_id_; std::mt19937 rnd_; std::unique_ptr> dist; + bool reshuffle_each_epoch_; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc index 93c8c305bc3..90e950fceef 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc @@ -15,18 +15,41 @@ */ #include "dataset/engine/datasetops/source/sampler/sampler.h" +#include + namespace mindspore { namespace dataset { Sampler::Sampler(int64_t samples_per_buffer) : DatasetOp(0), num_rows_(0), num_samples_(0), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { + std::shared_ptr child_sampler; + if (HasChildSampler()) { + child_sampler = std::dynamic_pointer_cast(child_[0]); + if (!child_sampler) { + std::string err_msg("Cannot handshake, child is not a sampler object."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // Handshake and init child first. + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op)); + } + } + CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n"); RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_)); - RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_)); + if (HasChildSampler()) { + int64_t child_num_samples = child_sampler->num_samples(); + num_rows_ = child_num_samples; + } else { + RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_)); + } + // It's up to the derived class to check the validity of the two args // Because some sampler only needs one of the arg (weighted_random_sampler) RETURN_IF_NOT_OK(InitSampler()); // init sampler after callback + return Status::OK(); } @@ -44,6 +67,15 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr *sample_ids, int64_t return Status::OK(); } +void Sampler::Print(std::ostream &out, bool show_all) const { + out << "(sampler): base\n"; + + if (show_all) { + out << "num_rows_: " << num_rows_ << '\n'; + out << "num_samples_: " << num_samples_ << '\n'; + } +} + Status Sampler::GetAllIdsThenReset(py::array *data) { std::unique_ptr db; std::shared_ptr sample_ids; @@ -84,5 +116,45 @@ Status Sampler::SetNumRowsInDataset(int64_t num_rows) { num_rows_ = num_rows; return Status::OK(); } + +Status Sampler::AddChild(std::shared_ptr child) { + if (child == nullptr) { + return Status::OK(); + } + + // Only samplers can be added, not any other DatasetOp. + std::shared_ptr sampler = std::dynamic_pointer_cast(child); + if (!sampler) { + std::string err_msg("Cannot add child, child is not a sampler object."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // Samplers can have at most 1 child. + if (!child_.empty()) { + std::string err_msg("Cannot add child sampler, this sampler already has a child."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + child_.push_back(child); + + // doesn't work, protected? + // child->AddParent(this); + return Status::OK(); +} + +bool Sampler::HasChildSampler() { return !child_.empty(); } + +Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) { + if (child_ids_ == nullptr) { + RETURN_STATUS_UNEXPECTED("Trying to get associated child id, but there are no child ids!"); + } + + TensorRow sample_row; + RETURN_IF_NOT_OK(child_ids_->GetRow(0, &sample_row)); + std::shared_ptr sample_ids = sample_row[0]; + RETURN_IF_NOT_OK(sample_ids->GetItemAt(out_associated_id, {id})); + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h index 13570323f13..936a80bb381 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h @@ -90,6 +90,8 @@ class Sampler : public DatasetOp { // setter function for num_samples_ Status SetNumSamples(int64_t num_samples); + int64_t num_samples() { return num_samples_; } + // first handshake between StorageOp and Sampler. This func will call getNumRows and getNumSamples // @param op - StorageOp pointer, pass in so Sampler can call getNumSamples() and get ClassIds() // @return @@ -114,17 +116,48 @@ class Sampler : public DatasetOp { // @return - The error code return Status operator()() final { RETURN_STATUS_UNEXPECTED("Functor not supported in Sampler"); } + // Adds a sampler to become our child. + // @param std::shared_ptr - The sampler to add as a child. + // @return - The error code returned. + Status AddChild(std::shared_ptr child); + // A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler // @param std::shared_ptr* sampleIds // @param int64_t numElements - must be a non 0 number - // @return + // @return - The error code returned. Status CreateSamplerTensor(std::shared_ptr *sample_ids, int64_t num_elements); + void Print(std::ostream &out, bool show_all) const override; + + friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) { + sampler.Print(out, false); + return out; + } + + // Checks if this sampler has a child sampler. + // @return - tre if there is a child sampler, false otherwise. + bool HasChildSampler(); + + // Uses id as an index for the list of ids generated by the child sampler, and gets the + // associated id. + // @param int64_t* out_associated_id - Out parameter, contains the associated id. + // @param int64_t id - The id used as an index to get the associated child id. + // @return - The error code returned. + Status GetAssociatedChildId(int64_t *out_associated_id, int64_t id); + protected: + // Number of rows of data from the place this sampler is sampling from. If this sampler + // has a child sampler, num_rows_ is the number of ids the child sampler will + // output. Otherwise, num_rows_ is the number of rows in the dataset. int64_t num_rows_; + + // Number of ids this sampler will return. int64_t num_samples_; + + // The max number of ids a DataBuffer returned by this sampler will contain. int64_t samples_per_buffer_; std::unique_ptr col_desc_; + std::unique_ptr child_ids_; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc index e405479360f..789f232e1e0 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc @@ -15,6 +15,7 @@ */ #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include #include namespace mindspore { @@ -27,14 +28,26 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr *out_buffer) } else if (next_id_ == num_samples_) { (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); + } + (*out_buffer) = std::make_unique(next_id_, DataBuffer::kDeBFlagNone); std::shared_ptr sampleIds; int64_t lastId = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_; RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, lastId - next_id_)); int64_t *idPtr = reinterpret_cast(sampleIds->GetMutableBuffer()); while (next_id_ < lastId) { - *(idPtr++) = next_id_++; + int64_t sampled_id = next_id_; + if (HasChildSampler()) { + RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); + } + + *idPtr = sampled_id; + next_id_++; + idPtr++; } + TensorRow row(1, sampleIds); (*out_buffer)->set_tensor_table(std::make_unique(1, row)); } @@ -43,6 +56,10 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr *out_buffer) Status SequentialSampler::InitSampler() { num_samples_ = (num_samples_ <= 0) ? num_rows_ : num_samples_; // if num_samples < 0, try if num_rows is set + if (HasChildSampler()) { + num_samples_ = std::min(num_samples_, num_rows_); + } + CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler"); samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; return Status::OK(); @@ -51,7 +68,15 @@ Status SequentialSampler::InitSampler() { Status SequentialSampler::Reset() { CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); next_id_ = 0; + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->Reset()); + } + return Status::OK(); } + +void SequentialSampler::Print(std::ostream &out, bool show_all) const { out << "(sampler): SequentialSampler\n"; } + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h index c38a9ed2f9b..4e195d75dbb 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h @@ -45,6 +45,8 @@ class SequentialSampler : public Sampler { // @return - The error code return Status GetNextBuffer(std::unique_ptr *out_buffer) override; + void Print(std::ostream &out, bool show_all) const override; + private: int64_t next_id_; }; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc index 698edf5e681..ca1160299a2 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc @@ -34,6 +34,8 @@ SubsetRandomSampler::SubsetRandomSampler(const std::vector &indices, in Status SubsetRandomSampler::InitSampler() { CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n"); + num_samples_ = indices_.size(); + // Initialize random generator with seed from config manager rand_gen_.seed(GetSeed()); @@ -56,6 +58,10 @@ Status SubsetRandomSampler::Reset() { rand_gen_.seed(GetSeed()); std::shuffle(indices_.begin(), indices_.end(), rand_gen_); + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->Reset()); + } + return Status::OK(); } @@ -65,6 +71,10 @@ Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr *out_buffe if (sample_id_ == indices_.size()) { (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagEOE); } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); + } + (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagNone); std::shared_ptr outputIds; @@ -87,7 +97,14 @@ Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr *out_buffe RETURN_STATUS_UNEXPECTED(err_msg); } - *(id_ptr++) = indices_[sample_id_++]; + int64_t sampled_id = indices_[sample_id_]; + if (HasChildSampler()) { + RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); + } + + *id_ptr = sampled_id; + id_ptr++; + sample_id_++; } // Create a TensorTable from that single tensor and push into DataBuffer diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.cc new file mode 100644 index 00000000000..320bc601b91 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.cc @@ -0,0 +1,85 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/engine/datasetops/source/sampler/subset_sampler.h" + +#include +#include + +#include "dataset/core/config_manager.h" +#include "dataset/core/global_context.h" + +namespace mindspore { +namespace dataset { +// Constructor. +SubsetSampler::SubsetSampler(int64_t start_index, int64_t subset_size) + : Sampler(subset_size), start_index_(start_index), subset_size_(subset_size), current_id_(0) {} + +Status SubsetSampler::InitSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size_ <= 0\n"); + CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n"); + CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows_\n"); + CHECK_FAIL_RETURN_UNEXPECTED(start_index_ + subset_size_ - 1 < num_rows_, "Final index out of bounds.\n"); + + num_samples_ = subset_size_; + + return Status::OK(); +} + +Status SubsetSampler::Reset() { + current_id_ = 0; + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->Reset()); + } + + return Status::OK(); +} + +Status SubsetSampler::GetNextBuffer(std::unique_ptr *out_buffer) { + if (current_id_ > subset_size_) { + RETURN_STATUS_UNEXPECTED("SubsetSampler Internal Error"); + } else if (current_id_ == subset_size_) { + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); + } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); + } + + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagNone); + std::shared_ptr sampled_ids; + RETURN_IF_NOT_OK(CreateSamplerTensor(&sampled_ids, subset_size_)); + + int64_t *sampled_ids_start_addr = reinterpret_cast(sampled_ids->GetMutableBuffer()); + + while (current_id_ < subset_size_) { + int64_t sampled_id = start_index_ + current_id_; + if (HasChildSampler()) { + RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); + } + + *(sampled_ids_start_addr + current_id_) = sampled_id; + current_id_++; + } + + TensorRow sampled_ids_row(1, sampled_ids); + (*out_buffer)->set_tensor_table(std::make_unique(1, sampled_ids_row)); + } + + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.h new file mode 100644 index 00000000000..70ee80b0a47 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_ + +#include +#include + +#include "dataset/engine/datasetops/source/sampler/sampler.h" + +namespace mindspore { +namespace dataset { + +class SubsetSampler : public Sampler { + public: + // Constructor. + // @param start_index The index we start sampling from. + explicit SubsetSampler(int64_t start_index, int64_t subset_size); + + // Destructor. + ~SubsetSampler() = default; + + // Initialize the sampler. + // @return Status + Status InitSampler() override; + + // Reset the internal variable to the initial state and reshuffle the indices. + // @return Status + Status Reset() override; + + // Get the sample ids. + // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. + // @note the sample ids (int64_t) will be placed in one Tensor. + Status GetNextBuffer(std::unique_ptr *out_buffer) override; + + private: + int64_t start_index_; + int64_t subset_size_; + int64_t current_id_; +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc index 91fc7f7d816..5027dcdd67b 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc @@ -40,6 +40,8 @@ WeightedRandomSampler::WeightedRandomSampler(const std::vector &weights, Status WeightedRandomSampler::InitSampler() { CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && user_num_samples_, "num_samples & num_rows need to be positive"); CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n"); + num_samples_ = user_num_samples_; + // Initialize random generator with seed from config manager rand_gen_.seed(GetSeed()); @@ -81,6 +83,11 @@ Status WeightedRandomSampler::Reset() { } else { discrete_dist_->reset(); } + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->Reset()); + } + return Status::OK(); } @@ -98,6 +105,10 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr *out_buf if (sample_id_ == user_num_samples_) { (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagEOE); } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); + } + (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagNone); std::shared_ptr outputIds; @@ -127,7 +138,12 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr *out_buf RETURN_STATUS_UNEXPECTED("generated id is bigger than numRows (out of bound)."); } - *(id_ptr++) = genId; + if (HasChildSampler()) { + RETURN_IF_NOT_OK(GetAssociatedChildId(&genId, genId)); + } + + *id_ptr = genId; + id_ptr++; sample_id_++; } diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 20a40d5fb0d..9379ba1fa46 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -44,7 +44,7 @@ from .validators import check, check_batch, check_shuffle, check_map, check_filt check_rename, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ - check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat + check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, check_split from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist try: @@ -581,6 +581,117 @@ class Dataset: return self return TakeDataset(self, count) + def _get_absolute_split_sizes(self, sizes): + """ + Internal method called by split to calculate absolute split sizes and to + do some error checking after calculating absolute split sizes. + """ + # call get_dataset_size here and check input here because + # dont want to call this once in check_split and another time in + # here again + dataset_size = self.get_dataset_size() + + if(dataset_size is None or dataset_size <= 0): + raise RuntimeError("dataset size unknown, unable to split.") + + all_int = all(isinstance(item, int) for item in sizes) + if all_int: + sizes_sum = sum(sizes) + if sizes_sum != dataset_size: + raise RuntimeError("sum of split sizes {} is not equal to dataset size {}." + .format(sizes_sum, dataset_size)) + return sizes + + absolute_sizes = [] + for item in sizes: + absolute_size = int(round(item * dataset_size)) + if absolute_size == 0: + raise RuntimeError("split percentage {} is too small.".format(item)) + absolute_sizes.append(absolute_size) + + absolute_sizes_sum = sum(absolute_sizes) + if absolute_sizes_sum != dataset_size: + raise RuntimeError("sum of calculated split sizes {} is not equal to dataset size {}." + .format(absolute_sizes_sum, dataset_size)) + + return absolute_sizes + + @check_split + def split(self, sizes, randomize=True): + """ + Splits the dataset into smaller, non-overlapping datasets. + + This is a general purpose split function which can be called from any operator in the pipeline. + There is another, optimized split function, which will be called automatically if ds.split is + called where ds is a MappableDataset. + + Args: + sizes (list of int or list of float): If a list of integers [s1, s2, …, sn] is + provided, the dataset will be split into n datasets of size s1, size s2, …, size sn + respectively. If the sum of all sizes does not equal the original dataset size, an + an error will occur. + If a list of floats [f1, f2, …, fn] is provided, the dataset will be split into n + Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size + of the original dataset. If after rounding, any size equals 0, an error will occur. + All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur. + randomize (bool): determines whether or not to split the data randomly. If true, the data + will be randomly split. Otherwise, each split will be created with consecutive rows + from the dataset. + + Note: + 1. Dataset cannot be sharded if split is going to be called. + 2. It is strongly recommended to not shuffle the dataset, but use randomize=True instead. + Shuffling the dataset may not be deterministic, which means the data in each split + will be different in each epoch. + + Raises: + RuntimeError: If get_dataset_size returns None or is not supported for this dataset. + RuntimeError: If sizes is list of integers and sum of all elements in sizes does not + equal the dataset size. + RuntimeError: If sizes is list of float and there is a split with size 0 after calculations. + RuntimeError: If the dataset is sharded prior to calling split. + ValueError: If sizes is list of float and not all floats are between 0 and 1, or if the + floats don’t sum to 1. + + Returns + tuple(Dataset), a tuple of datasets that have been split. + + Examples: + >>> import mindspore.dataset as ds + >>> + >>> dataset_dir = "/path/to/text_file.txt" + >>> + >>> # TextFileDataset is not a mappable dataset, so this non optimized split will be called. + >>> # many datasets have shuffle on by default, set shuffle to False if split will be called! + >>> data = ds.TextFileDataset(dataset_dir, shuffle=False) + >>> train, test = data.split([0.9, 0.1]) + """ + if self.is_shuffled(): + logger.warning("dataset is shuffled before split.") + + if self.is_sharded(): + raise RuntimeError("dataset should not be sharded before split.") + + absolute_sizes = self._get_absolute_split_sizes(sizes) + splits = [] + rows_to_skip = 0 + for size in absolute_sizes: + ds = copy.deepcopy(self) + if randomize: + # want to shuffle the same way every epoch before split + ds = ds.shuffle() + ds.reshuffle_each_epoch = False + + if rows_to_skip > 0: + ds = ds.skip(rows_to_skip) + + ds = ds.take(size) + splits.append(ds) + + rows_to_skip += size + + return tuple(splits) + @check_zip_dataset def zip(self, datasets): """ @@ -1053,10 +1164,24 @@ class Dataset: def reset(self): """Reset the dataset for next epoch.""" + def is_shuffled(self): + for input_dataset in self.input: + if input_dataset.is_shuffled(): + return True + + return False + + def is_sharded(self): + for input_dataset in self.input: + if input_dataset.is_sharded(): + return True + + return False + class SourceDataset(Dataset): """ - Abstract class to represent a source dataset which produces content to the data pipeline. + Abstract class to represent a source dataset which produces content to the data pipeline. """ # No need for __init__ since it is the same as the super's init @@ -1093,6 +1218,150 @@ class SourceDataset(Dataset): return file_list raise ValueError("The list of path names matching the patterns is empty.") + def is_shuffled(self): + raise NotImplementedError("SourceDataset must implement is_shuffled.") + + def is_sharded(self): + raise NotImplementedError("SourceDataset must implement is_sharded.") + +class MappableDataset(SourceDataset): + """ + Abstract class to represent a source dataset which supports use of samplers. + """ + + def __init__(self, num_parallel_workers=None): + # check if all subclasses use this name + super().__init__(num_parallel_workers) + self.sampler = None + + def add_sampler(self, new_sampler): + # note: by adding a sampler, we mean that the sampled ids will flow to new_sampler + # after first passing through the current samplers attached to this dataset. + new_sampler.add_child(self.sampler) + self.sampler = new_sampler + + def use_sampler(self, new_sampler): + """ + Will make the current dataset use the new_sampler provided. + + Args: + new_sampler (Sampler): the sampler to use for the current dataset. + + Returns: + Dataset, that uses new_sampler. + + Examples: + >>> import mindspore.dataset as ds + >>> + >>> dataset_dir = "/path/to/imagefolder_directory" + >>> # a SequentialSampler is created by default + >>> data = ds.ImageFolderDatasetV2(dataset_dir) + >>> + >>> # use a DistributedSampler instead of the SequentialSampler + >>> new_sampler = ds.DistributedSampler(10, 2) + >>> data.use_sampler(new_sampler) + """ + self.sampler = self.sampler.child_sampler + self.add_sampler(new_sampler) + + def is_shuffled(self): + raise NotImplementedError("MappableDataset must implement is_shuffled.") + + def is_sharded(self): + raise NotImplementedError("MappableDataset must implement is_sharded.") + + + @check_split + def split(self, sizes, randomize=True): + """ + Splits the dataset into smaller, non-overlapping datasets. + + There is the optimized split function, which will be called automatically when the dataset + that calls this function is a MappableDataset. + + Args: + sizes (list of int or list of float): If a list of integers [s1, s2, …, sn] is + provided, the dataset will be split into n datasets of size s1, size s2, …, size sn + respectively. If the sum of all sizes does not equal the original dataset size, an + an error will occur. + If a list of floats [f1, f2, …, fn] is provided, the dataset will be split into n + Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size + of the original dataset. If after rounding, any size equals 0, an error will occur. + All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur. + randomize (bool): determines whether or not to split the data randomly. If true, the data + will be randomly split. Otherwise, each split will be created with consecutive rows + from the dataset. + + Note: + 1. Dataset should not be sharded if split is going to be called. Instead, create a + DistributedSampler and specify a split to shard after splitting. If dataset is + sharded after a split, it is strongly recommended to set the same seed in each instance + of execution, otherwise each shard may not be part of the same split (see Examples) + 2. It is strongly recommended to not shuffle the dataset, but use randomize=True instead. + Shuffling the dataset may not be deterministic, which means the data in each split + will be different in each epoch. Furthermore, if sharding occurs after split, each + shard may not be part of the same split. + + Raises: + RuntimeError: If get_dataset_size returns None or is not supported for this dataset. + RuntimeError: If sizes is list of integers and sum of all elements in sizes does not + equal the dataset size. + RuntimeError: If sizes is list of float and there is a split with size 0 after calculations. + RuntimeError: If the dataset is sharded prior to calling split. + ValueError: If sizes is list of float and not all floats are between 0 and 1, or if the + floats don’t sum to 1. + + Returns + tuple(Dataset), a tuple of datasets that have been split. + + Examples: + >>> import mindspore.dataset as ds + >>> + >>> dataset_dir = "/path/to/imagefolder_directory" + >>> + >>> # many datasets have shuffle on by default, set shuffle to False if split will be called! + >>> data = ds.ImageFolderDatasetV2(dataset_dir, shuffle=False) + >>> + >>> # sets the seed, and tells split to use this seed when randomizing. This + >>> # is needed because we are sharding later + >>> ds.config.set_seed(58) + >>> train, test = data.split([0.9, 0.1]) + >>> + >>> # if we want to shard the train dataset, we can use a DistributedSampler + >>> train_sampler = ds.DistributedSampler(10, 2) + >>> train.use_sampler(train_sampler) + """ + if self.is_shuffled(): + logger.warning("dataset is shuffled before split.") + + if self.is_sharded(): + raise RuntimeError("dataset should not be sharded before split.") + + absolute_sizes = self._get_absolute_split_sizes(sizes) + splits = [] + current_split_start_index = 0 + for size in absolute_sizes: + ds = copy.deepcopy(self) + if randomize: + # want to shuffle the same way every epoch before split, we are assuming + # that the user will call set_seed + random_sampler = samplers.RandomSampler() + random_sampler.reshuffle_each_epoch = False + ds.add_sampler(random_sampler) + + subset_sampler = samplers.SubsetSampler(current_split_start_index, size) + ds.add_sampler(subset_sampler) + + # add sequential sampler, so that if user calls use_sampler, we will + # get rid of the sequential sampler instead of something we need + ds.add_sampler(samplers.SequentialSampler()) + + splits.append(ds) + + current_split_start_index += size + + return tuple(splits) + class DatasetOp(Dataset): """ @@ -1334,6 +1603,7 @@ class SyncWaitDataset(DatasetOp): flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset) return flag + class ShuffleDataset(DatasetOp): """ The result of applying Shuffle operator to the input Dataset. @@ -1350,6 +1620,7 @@ class ShuffleDataset(DatasetOp): super().__init__() self.buffer_size = buffer_size self.input.append(input_dataset) + self.reshuffle_each_epoch = None input_dataset.output.append(self) self._input_indexs = input_dataset.input_indexs if self.is_sync(): @@ -1358,8 +1629,14 @@ class ShuffleDataset(DatasetOp): def get_args(self): args = super().get_args() args["buffer_size"] = self.buffer_size + if self.reshuffle_each_epoch is not None: + args["reshuffle_each_epoch"] = self.reshuffle_each_epoch + return args + def is_shuffled(self): + return True + # Pyfunc collection for multiprocess pyfunc # This global variable will only be used within subprocesses @@ -1989,8 +2266,14 @@ class StorageDataset(SourceDataset): self._get_pipeline_info() return self._num_classes + def is_shuffled(self): + return False -class RangeDataset(SourceDataset): + def is_sharded(self): + return False + + +class RangeDataset(MappableDataset): """ A source dataset that reads and parses datasets stored on disk in a range. @@ -2013,6 +2296,12 @@ class RangeDataset(SourceDataset): args["step"] = self.step return args + def is_shuffled(self): + return False + + def is_sharded(self): + return False + def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): """ @@ -2052,7 +2341,7 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): return samplers.SequentialSampler() -class ImageFolderDatasetV2(SourceDataset): +class ImageFolderDatasetV2(MappableDataset): """ A source dataset that reads images from a tree of directories. @@ -2190,8 +2479,20 @@ class ImageFolderDatasetV2(SourceDataset): num_samples = self.num_samples return ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[1] + def is_shuffled(self): + if self.shuffle_level is None: + return True -class MnistDataset(SourceDataset): + return self.shuffle_level or self.sampler.is_shuffled() + + def is_sharded(self): + if self.num_shards is not None: + return self.num_shards > 1 + + return self.sampler.is_sharded() + + +class MnistDataset(MappableDataset): """ A source dataset for reading and parsing the Mnist dataset. @@ -2294,6 +2595,18 @@ class MnistDataset(SourceDataset): return get_num_rows(num_rows, self.num_shards) + def is_shuffled(self): + if self.shuffle_level is None: + return True + + return self.shuffle_level or self.sampler.is_shuffled() + + def is_sharded(self): + if self.num_shards is not None: + return self.num_shards > 1 + + return self.sampler.is_sharded() + class MindDataset(SourceDataset): """ @@ -2400,6 +2713,18 @@ class MindDataset(SourceDataset): num_rows = num_rows // self.partitions[0] + 1 return num_rows + def is_shuffled(self): + if self.global_shuffle is None: + return True + + return self.global_shuffle or self.sampler.is_shuffled() + + def is_sharded(self): + if self.num_shards is not None: + return self.num_shards > 1 + + return self.sampler.is_sharded() + def _iter_fn(dataset, num_samples): """ @@ -2609,7 +2934,7 @@ class _GeneratorWorker(multiprocessing.Process): self.terminate() -class GeneratorDataset(SourceDataset): +class GeneratorDataset(MappableDataset): """ A source dataset that generate data from python by invoking python data source each epoch. @@ -2794,6 +3119,12 @@ class GeneratorDataset(SourceDataset): return new_op + def is_shuffled(self): + return self.sampler.is_shuffled() + + def is_sharded(self): + return self.sampler.is_sharded() + class TFRecordDataset(SourceDataset): """ @@ -2920,8 +3251,17 @@ class TFRecordDataset(SourceDataset): else: raise ValueError('set dataset_size with negative value {}'.format(value)) + def is_shuffled(self): + return self.shuffle_files -class ManifestDataset(SourceDataset): + def is_sharded(self): + if self.num_shards is not None: + return self.num_shards > 1 + + return False + + +class ManifestDataset(MappableDataset): """ A source dataset that reads images from a manifest file. @@ -3088,8 +3428,20 @@ class ManifestDataset(SourceDataset): return ManifestOp.get_class_indexing(self.dataset_file, num_samples, class_indexing, self.usage) + def is_shuffled(self): + if self.shuffle_level is None: + return True -class Cifar10Dataset(SourceDataset): + return self.shuffle_level or self.sampler.is_shuffled() + + def is_sharded(self): + if self.num_shards is not None: + return self.num_shards > 1 + + return self.sampler.is_sharded() + + +class Cifar10Dataset(MappableDataset): """ A source dataset that reads cifar10 data. @@ -3197,8 +3549,20 @@ class Cifar10Dataset(SourceDataset): return get_num_rows(num_rows, self.num_shards) + def is_shuffled(self): + if self.shuffle_level is None: + return True -class Cifar100Dataset(SourceDataset): + return self.shuffle_level or self.sampler.is_shuffled() + + def is_sharded(self): + if self.num_shards is not None: + return self.num_shards > 1 + + return self.sampler.is_sharded() + + +class Cifar100Dataset(MappableDataset): """ A source dataset that reads cifar100 data. @@ -3304,6 +3668,18 @@ class Cifar100Dataset(SourceDataset): return get_num_rows(num_rows, self.num_shards) + def is_shuffled(self): + if self.shuffle_level is None: + return True + + return self.shuffle_level or self.sampler.is_shuffled() + + def is_sharded(self): + if self.num_shards is not None: + return self.num_shards > 1 + + return self.sampler.is_sharded() + class RandomDataset(SourceDataset): """ @@ -3355,6 +3731,11 @@ class RandomDataset(SourceDataset): """ return num_samples + def is_shuffled(self): + return True + + def is_sharded(self): + return False class Schema: """ @@ -3534,7 +3915,7 @@ class Schema: return self.to_json() -class VOCDataset(SourceDataset): +class VOCDataset(MappableDataset): """ A source dataset for reading and parsing VOC dataset. @@ -3681,8 +4062,20 @@ class VOCDataset(SourceDataset): return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.mode, class_indexing, num_samples) + def is_shuffled(self): + if self.shuffle_level is None: + return True -class CelebADataset(SourceDataset): + return self.shuffle_level or self.sampler.is_shuffled() + + def is_sharded(self): + if self.num_shards is not None: + return self.num_shards > 1 + + return self.sampler.is_sharded() + + +class CelebADataset(MappableDataset): """ A source dataset for reading and parsing CelebA dataset.Only support list_attr_celeba.txt currently. @@ -3735,6 +4128,18 @@ class CelebADataset(SourceDataset): args["shard_id"] = self.shard_id return args + def is_shuffled(self): + if self.shuffle_level is None: + return True + + return self.shuffle_level or self.sampler.is_shuffled() + + def is_sharded(self): + if self.num_shards is not None: + return self.num_shards > 1 + + return self.sampler.is_sharded() + class TextFileDataset(SourceDataset): """ @@ -3814,3 +4219,12 @@ class TextFileDataset(SourceDataset): return num_rows return min(self.num_samples, num_rows) return self._dataset_size + + def is_shuffled(self): + return self.shuffle_files + + def is_sharded(self): + if self.num_shards is not None: + return self.num_shards > 1 + + return False diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 972f0af1914..8bf223251a2 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -47,6 +47,7 @@ class Sampler: def __init__(self): self.dataset_size = 0 self.num_samples = 0 + self.child_sampler = None def __iter__(self): """ @@ -83,7 +84,35 @@ class Sampler: # Instance fetcher # Do not override this method! def create(self): - return cde.PythonSampler(self) + c_sampler = cde.PythonSampler(self) + c_child_sampler = self.create_child() + c_sampler.add_child(c_child_sampler) + return c_sampler + + def add_child(self, sampler): + self.child_sampler = sampler + + def get_child(self): + return self.child_sampler + + def create_child(self): + c_child_sampler = None + if self.child_sampler is not None: + c_child_sampler = self.child_sampler.create() + + return c_child_sampler + + def is_shuffled(self): + if self.child_sampler is None: + return False + + return self.child_sampler.is_shuffled() + + def is_sharded(self): + if self.child_sampler is None: + return False + + return self.child_sampler.is_sharded() class BuiltinSampler: @@ -93,11 +122,30 @@ class BuiltinSampler: User should not extend this class. """ def __init__(self): - pass + self.child_sampler = None def create(self): pass + def add_child(self, sampler): + self.child_sampler = sampler + + def get_child(self): + return self.child_sampler + + def create_child(self): + c_child_sampler = None + if self.child_sampler is not None: + c_child_sampler = self.child_sampler.create() + + return c_child_sampler + + def is_shuffled(self): + raise NotImplementedError("Sampler must implement is_shuffled.") + + def is_sharded(self): + raise NotImplementedError("Sampler must implement is_sharded.") + class DistributedSampler(BuiltinSampler): """ @@ -142,7 +190,22 @@ class DistributedSampler(BuiltinSampler): def create(self): # each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle self.seed += 1 - return cde.DistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed) + c_sampler = cde.DistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed) + c_child_sampler = self.create_child() + c_sampler.add_child(c_child_sampler) + return c_sampler + + def is_shuffled(self): + if self.child_sampler is None: + return self.shuffle + + return self.child_sampler.is_shuffled() + + def is_sharded(self): + if self.child_sampler is None: + return self.num_shards > 1 + + return self.child_sampler.is_sharded() class PKSampler(BuiltinSampler): @@ -186,7 +249,22 @@ class PKSampler(BuiltinSampler): super().__init__() def create(self): - return cde.PKSampler(self.num_val, self.shuffle) + c_sampler = cde.PKSampler(self.num_val, self.shuffle) + c_child_sampler = self.create_child() + c_sampler.add_child(c_child_sampler) + return c_sampler + + def is_shuffled(self): + if self.child_sampler is None: + return self.shuffle + + return self.child_sampler.is_shuffled() + + def is_sharded(self): + if self.child_sampler is None: + return False + + return self.child_sampler.is_sharded() def _create_for_minddataset(self): if not self.class_column or not isinstance(self.class_column, str): @@ -226,15 +304,31 @@ class RandomSampler(BuiltinSampler): raise ValueError("num_samples should be a positive integer " "value, but got num_samples={}".format(num_samples)) + self.deterministic = False self.replacement = replacement self.num_samples = num_samples + self.reshuffle_each_epoch = True super().__init__() def create(self): - # If num_samples is not specified, then call constructor #2 + c_sampler = None if self.num_samples is None: - return cde.RandomSampler(self.replacement) - return cde.RandomSampler(self.replacement, self.num_samples) + c_sampler = cde.RandomSampler(self.replacement, self.reshuffle_each_epoch) + else: + c_sampler = cde.RandomSampler(self.replacement, self.reshuffle_each_epoch, self.num_samples) + + c_child_sampler = self.create_child() + c_sampler.add_child(c_child_sampler) + return c_sampler + + def is_shuffled(self): + return True + + def is_sharded(self): + if self.child_sampler is None: + return False + + return self.child_sampler.is_sharded() class SequentialSampler(BuiltinSampler): @@ -252,7 +346,80 @@ class SequentialSampler(BuiltinSampler): """ def create(self): - return cde.SequentialSampler() + c_sampler = cde.SequentialSampler() + c_child_sampler = self.create_child() + c_sampler.add_child(c_child_sampler) + return c_sampler + + def is_shuffled(self): + if self.child_sampler is None: + return False + + return self.child_sampler.is_shuffled() + + def is_sharded(self): + if self.child_sampler is None: + return False + + return self.child_sampler.is_sharded() + + +class SubsetSampler(BuiltinSampler): + """ + Samples a subset of elements consecutively from a given index. + + Args: + start_index (int): Index to start sampling at. + subset_size (int): How many samples to include in this subset. + + Examples: + >>> import mindspore.dataset as ds + >>> + >>> dataset_dir = "path/to/imagefolder_directory" + >>> + >>> # creates a SubsetSampler, will sample the next 5 images from the 100th image. + >>> sampler = ds.SubsetSampler(100, 5) + >>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler) + + Raises: + ValueError: If start_index is not a positive int. + ValueError: If subset_size is not a positive int. + """ + + def __init__(self, start_index, subset_size): + if not isinstance(start_index, int): + raise ValueError("start_index should be an int.") + + if start_index < 0: + raise ValueError("start_index should not be negative.") + + if not isinstance(subset_size, int): + raise ValueError("start_index should be an int") + + if subset_size < 0: + raise ValueError("subset_size should not be negative.") + + self.start_index = start_index + self.subset_size = subset_size + super().__init__() + + def create(self): + c_sampler = cde.SubsetSampler(self.start_index, self.subset_size) + c_child_sampler = self.create_child() + c_sampler.add_child(c_child_sampler) + return c_sampler + + def is_shuffled(self): + if self.child_sampler is None: + return False + + return self.child_sampler.is_shuffled() + + def is_sharded(self): + if self.child_sampler is None: + return False + + return self.child_sampler.is_sharded() class SubsetRandomSampler(BuiltinSampler): @@ -282,7 +449,19 @@ class SubsetRandomSampler(BuiltinSampler): super().__init__() def create(self): - return cde.SubsetRandomSampler(self.indices) + c_sampler = cde.SubsetRandomSampler(self.indices) + c_child_sampler = self.create_child() + c_sampler.add_child(c_child_sampler) + return c_sampler + + def is_shuffled(self): + return True + + def is_sharded(self): + if self.child_sampler is None: + return False + + return self.child_sampler.is_sharded() def _create_for_minddataset(self): return cde.MindrecordSubsetRandomSampler(self.indices) @@ -330,4 +509,16 @@ class WeightedRandomSampler(BuiltinSampler): super().__init__() def create(self): - return cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement) + c_sampler = cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement) + c_child_sampler = self.create_child() + c_sampler.add_child(c_child_sampler) + return c_sampler + + def is_shuffled(self): + return True + + def is_sharded(self): + if self.child_sampler is None: + return False + + return self.child_sampler.is_sharded() diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index f868e3e1efe..de3987c8af3 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -1031,3 +1031,44 @@ def check_textfiledataset(method): return method(*args, **kwargs) return new_method + +def check_split(method): + """check the input arguments of split.""" + + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + nreq_param_list = ['sizes'] + nreq_param_bool = ['randomize'] + check_param_type(nreq_param_list, param_dict, list) + check_param_type(nreq_param_bool, param_dict, bool) + + # check sizes: must be list of float or list of int + sizes = param_dict.get('sizes') + + if not sizes: + raise ValueError("sizes cannot be empty.") + all_int = all(isinstance(item, int) for item in sizes) + all_float = all(isinstance(item, float) for item in sizes) + + if not (all_int or all_float): + raise ValueError("sizes should be list of int or list of float.") + + if all_int: + all_positive = all(item > 0 for item in sizes) + if not all_positive: + raise ValueError("sizes is a list of int, but there should be no negative numbers.") + + if all_float: + all_valid_percentages = all(0 < item <= 1 for item in sizes) + if not all_valid_percentages: + raise ValueError("sizes is a list of float, but there should be no numbers outside the range [0, 1].") + + epsilon = 0.00001 + if not abs(sum(sizes) - 1) < epsilon: + raise ValueError("sizes is a list of float, but the percentages do not sum up to 1.") + + return method(*args, **kwargs) + + return new_method diff --git a/tests/ut/cpp/dataset/cifar_op_test.cc b/tests/ut/cpp/dataset/cifar_op_test.cc index dcbea83df4c..8eeeba76afe 100644 --- a/tests/ut/cpp/dataset/cifar_op_test.cc +++ b/tests/ut/cpp/dataset/cifar_op_test.cc @@ -92,7 +92,7 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) { TEST_F(MindDataTestCifarOp, TestRandomSamplerCifar10) { uint32_t original_seed = GlobalContext::config_manager()->seed(); GlobalContext::config_manager()->set_seed(0); - std::unique_ptr sampler = std::make_unique(true, 12); + std::unique_ptr sampler = std::make_unique(true, true, 12); std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; auto tree = Build({Cifarop(16, 2, 32, folder_path, std::move(sampler), 100)}); tree->Prepare(); diff --git a/tests/ut/cpp/dataset/image_folder_op_test.cc b/tests/ut/cpp/dataset/image_folder_op_test.cc index dbe43ab3557..380b7cd02b5 100644 --- a/tests/ut/cpp/dataset/image_folder_op_test.cc +++ b/tests/ut/cpp/dataset/image_folder_op_test.cc @@ -138,7 +138,7 @@ TEST_F(MindDataTestImageFolderSampler, TestRandomImageFolder) { TEST_F(MindDataTestImageFolderSampler, TestRandomSamplerImageFolder) { int32_t original_seed = GlobalContext::config_manager()->seed(); GlobalContext::config_manager()->set_seed(0); - std::unique_ptr sampler = std::make_unique(true, 12); + std::unique_ptr sampler = std::make_unique(true, true, 12); int32_t res[] = {2, 2, 2, 3, 2, 3, 2, 3, 1, 2, 2, 1}; // ground truth label std::string folder_path = datasets_root_path_ + "/testPK/data"; auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))}); diff --git a/tests/ut/python/dataset/test_sampler.py b/tests/ut/python/dataset/test_sampler.py index 93518b3f9cb..efd5ed8803e 100644 --- a/tests/ut/python/dataset/test_sampler.py +++ b/tests/ut/python/dataset/test_sampler.py @@ -163,9 +163,36 @@ def test_python_sampler(): assert list(sp1.get_indices()) == [0, 1, 2, 3, 4] +def test_sampler_chain(): + manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" + map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} + + def test_config(num_shards, shard_id): + sampler = ds.DistributedSampler(num_shards, shard_id, False) + child_sampler = ds.SequentialSampler() + sampler.add_child(child_sampler) + + data1 = ds.ManifestDataset(manifest_file, num_samples=5, sampler=sampler) + + res = [] + for item in data1.create_dict_iterator(): + logger.info("item[image].shape[0]: {}, item[label].item(): {}" + .format(item["image"].shape[0], item["label"].item())) + res.append(map[(item["image"].shape[0], item["label"].item())]) + return res + + assert test_config(2, 0) == [0, 2, 4] + assert test_config(2, 1) == [1, 3, 0] + assert test_config(5, 0) == [0] + assert test_config(5, 1) == [1] + assert test_config(5, 2) == [2] + assert test_config(5, 3) == [3] + assert test_config(5, 4) == [4] + if __name__ == '__main__': test_sequential_sampler(True) test_random_sampler(True) test_random_sampler_multi_iter(True) test_sampler_py_api() test_python_sampler() + test_sampler_chain() diff --git a/tests/ut/python/dataset/test_split.py b/tests/ut/python/dataset/test_split.py new file mode 100644 index 00000000000..f2ff8b64971 --- /dev/null +++ b/tests/ut/python/dataset/test_split.py @@ -0,0 +1,342 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import pytest +import mindspore.dataset as ds + +# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631] +# the label of each image is [0,0,0,1,1] each image can be uniquely identified +# via the following lookup table (dict){(83554, 0): 0, (54214, 0): 1, (54214, 1): 2, (65512, 0): 3, (64631, 1): 4} +manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" +manifest_map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} + +def split_with_invalid_inputs(d): + with pytest.raises(ValueError) as info: + s1, s2 = d.split([]) + assert "sizes cannot be empty" in str(info.value) + + with pytest.raises(ValueError) as info: + s1, s2 = d.split([5, 0.6]) + assert "sizes should be list of int or list of float" in str(info.value) + + with pytest.raises(ValueError) as info: + s1, s2 = d.split([-1, 6]) + assert "there should be no negative numbers" in str(info.value) + + with pytest.raises(RuntimeError) as info: + s1, s2 = d.split([3, 1]) + assert "sum of split sizes 4 is not equal to dataset size 5" in str(info.value) + + with pytest.raises(RuntimeError) as info: + s1, s2 = d.split([5, 1]) + assert "sum of split sizes 6 is not equal to dataset size 5" in str(info.value) + + with pytest.raises(RuntimeError) as info: + s1, s2 = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25]) + assert "sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value) + + with pytest.raises(ValueError) as info: + s1, s2 = d.split([-0.5, 0.5]) + assert "there should be no numbers outside the range [0, 1]" in str(info.value) + + with pytest.raises(ValueError) as info: + s1, s2 = d.split([1.5, 0.5]) + assert "there should be no numbers outside the range [0, 1]" in str(info.value) + + with pytest.raises(ValueError) as info: + s1, s2 = d.split([0.5, 0.6]) + assert "percentages do not sum up to 1" in str(info.value) + + with pytest.raises(ValueError) as info: + s1, s2 = d.split([0.3, 0.6]) + assert "percentages do not sum up to 1" in str(info.value) + + with pytest.raises(RuntimeError) as info: + s1, s2 = d.split([0.05, 0.95]) + assert "percentage 0.05 is too small" in str(info.value) + +def test_unmappable_invalid_input(): + text_file_dataset_path = "../data/dataset/testTextFileDataset/*" + d = ds.TextFileDataset(text_file_dataset_path) + split_with_invalid_inputs(d) + + d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0) + with pytest.raises(RuntimeError) as info: + s1, s2 = d.split([4, 1]) + assert "dataset should not be sharded before split" in str(info.value) + +def test_unmappable_split(): + text_file_dataset_path = "../data/dataset/testTextFileDataset/*" + text_file_data = ["This is a text file.", "Another file.", "Be happy every day.", + "End of file.", "Good luck to everyone."] + ds.config.set_num_parallel_workers(4) + d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) + s1, s2 = d.split([4, 1], randomize=False) + + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(item["text"].item().decode("utf8")) + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(item["text"].item().decode("utf8")) + + assert s1_output == text_file_data[0:4] + assert s2_output == text_file_data[4:] + + # exact percentages + s1, s2 = d.split([0.8, 0.2], randomize=False) + + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(item["text"].item().decode("utf8")) + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(item["text"].item().decode("utf8")) + + assert s1_output == text_file_data[0:4] + assert s2_output == text_file_data[4:] + + # fuzzy percentages + s1, s2 = d.split([0.33, 0.67], randomize=False) + + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(item["text"].item().decode("utf8")) + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(item["text"].item().decode("utf8")) + + assert s1_output == text_file_data[0:2] + assert s2_output == text_file_data[2:] + +def test_mappable_invalid_input(): + d = ds.ManifestDataset(manifest_file) + split_with_invalid_inputs(d) + + d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0) + with pytest.raises(RuntimeError) as info: + s1, s2 = d.split([4, 1]) + assert "dataset should not be sharded before split" in str(info.value) + +def test_mappable_split_general(): + d = ds.ManifestDataset(manifest_file, shuffle=False) + d = d.take(5) + + # absolute rows + s1, s2 = d.split([4, 1], randomize=False) + + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + assert s1_output == [0, 1, 2, 3] + assert s2_output == [4] + + # exact percentages + s1, s2 = d.split([0.8, 0.2], randomize=False) + + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + assert s1_output == [0, 1, 2, 3] + assert s2_output == [4] + + # fuzzy percentages + s1, s2 = d.split([0.33, 0.67], randomize=False) + + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + assert s1_output == [0, 1] + assert s2_output == [2, 3, 4] + +def test_mappable_split_optimized(): + d = ds.ManifestDataset(manifest_file, shuffle=False) + + # absolute rows + s1, s2 = d.split([4, 1], randomize=False) + + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + assert s1_output == [0, 1, 2, 3] + assert s2_output == [4] + + # exact percentages + s1, s2 = d.split([0.8, 0.2], randomize=False) + + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + assert s1_output == [0, 1, 2, 3] + assert s2_output == [4] + + # fuzzy percentages + s1, s2 = d.split([0.33, 0.67], randomize=False) + + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + assert s1_output == [0, 1] + assert s2_output == [2, 3, 4] + +def test_mappable_randomize_deterministic(): + # set arbitrary seed for shard after split + # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4] + ds.config.set_seed(53) + + d = ds.ManifestDataset(manifest_file, shuffle=False) + s1, s2 = d.split([0.8, 0.2]) + + for _ in range(10): + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + # note no overlap + assert s1_output == [0, 1, 3, 4] + assert s2_output == [2] + +def test_mappable_randomize_repeatable(): + # set arbitrary seed for shard after split + # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4] + ds.config.set_seed(53) + + d = ds.ManifestDataset(manifest_file, shuffle=False) + s1, s2 = d.split([0.8, 0.2]) + + num_epochs = 5 + s1 = s1.repeat(num_epochs) + s2 = s2.repeat(num_epochs) + + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + # note no overlap + assert s1_output == [0, 1, 3, 4] * num_epochs + assert s2_output == [2] * num_epochs + +def test_mappable_sharding(): + # set arbitrary seed for repeatability for shard after split + # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4] + ds.config.set_seed(53) + + num_epochs = 5 + first_split_num_rows = 4 + + d = ds.ManifestDataset(manifest_file, shuffle=False) + s1, s2 = d.split([first_split_num_rows, 1]) + + distributed_sampler = ds.DistributedSampler(2, 0) + s1.use_sampler(distributed_sampler) + + s1 = s1.repeat(num_epochs) + + # testing sharding, second dataset to simulate another instance + d2 = ds.ManifestDataset(manifest_file, shuffle=False) + d2s1, d2s2 = d2.split([first_split_num_rows, 1]) + + distributed_sampler = ds.DistributedSampler(2, 1) + d2s1.use_sampler(distributed_sampler) + + d2s1 = d2s1.repeat(num_epochs) + + # shard 0 + s1_output = [] + for item in s1.create_dict_iterator(): + s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + # shard 1 + d2s1_output = [] + for item in d2s1.create_dict_iterator(): + d2s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + rows_per_shard_per_epoch = 2 + assert len(s1_output) == rows_per_shard_per_epoch * num_epochs + assert len(d2s1_output) == rows_per_shard_per_epoch * num_epochs + + # verify each epoch that + # 1. shards contain no common elements + # 2. the data was split the same way, and that the union of shards equal the split + correct_sorted_split_result = [0, 1, 3, 4] + for i in range(num_epochs): + combined_data = [] + for j in range(rows_per_shard_per_epoch): + combined_data.append(s1_output[i * rows_per_shard_per_epoch + j]) + combined_data.append(d2s1_output[i * rows_per_shard_per_epoch + j]) + + assert sorted(combined_data) == correct_sorted_split_result + + # test other split + s2_output = [] + for item in s2.create_dict_iterator(): + s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + d2s2_output = [] + for item in d2s2.create_dict_iterator(): + d2s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) + + assert s2_output == [2] + assert d2s2_output == [2] + +if __name__ == '__main__': + test_unmappable_invalid_input() + test_unmappable_split() + test_mappable_invalid_input() + test_mappable_split_general() + test_mappable_split_optimized() + test_mappable_randomize_deterministic() + test_mappable_randomize_repeatable() + test_mappable_sharding()