!1281 Implementation of SplitOp

Merge pull request !1281 from Peilin/splitOp
This commit is contained in:
mindspore-ci-bot 2020-05-22 09:29:03 +08:00 committed by Gitee
commit 2e3d55ed87
24 changed files with 1507 additions and 46 deletions

View File

@ -364,6 +364,18 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetO
std::string err_msg = "Error: Shuffle buffer size is missing";
RETURN_STATUS_UNEXPECTED(err_msg);
}
// Optional arguments
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "reshuffle_each_epoch") {
(void)builder->SetReshuffleEachEpoch(ToBool(args["reshuffle_each_epoch"]));
}
}
}
std::shared_ptr<ShuffleOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*ptr = op;

View File

@ -51,6 +51,7 @@
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
@ -425,11 +426,14 @@ void bindSamplerOps(py::module *m) {
.def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
.def("set_num_samples", [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
.def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); })
.def("get_indices", [](Sampler &self) {
py::array ret;
THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
return ret;
});
.def("get_indices",
[](Sampler &self) {
py::array ret;
THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
return ret;
})
.def("add_child",
[](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) { THROW_IF_ERROR(self->AddChild(child)); });
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator");
@ -441,12 +445,16 @@ void bindSamplerOps(py::module *m) {
.def(py::init<int64_t, bool>(), py::arg("kVal"), py::arg("shuffle"));
(void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler")
.def(py::init<bool, int64_t>(), py::arg("replacement"), py::arg("numSamples"))
.def(py::init<bool>(), py::arg("replacement"));
.def(py::init<bool, bool, int64_t>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"),
py::arg("num_samples"))
.def(py::init<bool, bool>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"));
(void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler")
.def(py::init<>());
(void)py::class_<SubsetSampler, Sampler, std::shared_ptr<SubsetSampler>>(*m, "SubsetSampler")
.def(py::init<int64_t, int64_t>(), py::arg("start_index"), py::arg("subset_size"));
(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
.def(py::init<std::vector<int64_t>>(), py::arg("indices"));

View File

@ -8,5 +8,6 @@ add_library(engine-datasetops-source-sampler OBJECT
sampler.cc
sequential_sampler.cc
subset_random_sampler.cc
subset_sampler.cc
weighted_random_sampler.cc
)

View File

@ -55,13 +55,27 @@ Status DistributedSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer
} else if (cnt_ == samples_per_buffer_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(cnt_, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> sample_ids;
RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, samples_per_buffer_));
int64_t *id_ptr = reinterpret_cast<int64_t *>(sample_ids->GetMutableBuffer());
while (cnt_ < samples_per_buffer_) {
int64_t next_id = (num_devices_ * (cnt_++) + device_id_) % num_rows_;
*(id_ptr++) = shuffle_ ? shuffle_vec_[static_cast<size_t>(next_id)] : next_id;
int64_t sampled_id = (num_devices_ * cnt_ + device_id_) % num_rows_;
if (shuffle_) {
sampled_id = shuffle_vec_[static_cast<size_t>(sampled_id)];
}
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}
*id_ptr = sampled_id;
id_ptr++;
cnt_++;
}
TensorRow row(1, sample_ids);
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
@ -72,11 +86,29 @@ Status DistributedSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer
Status DistributedSampler::Reset() {
CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late");
cnt_ = 0;
rnd_.seed(seed_++);
if (shuffle_ == true) {
rnd_.seed(seed_);
seed_++;
std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_);
}
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK();
}
void DistributedSampler::Print(std::ostream &out, bool show_all) const {
out << "(sampler): DistributedSampler\n";
if (show_all) {
out << "seed_: " << seed_ << '\n';
out << "device_id_: " << device_id_ << '\n';
out << "num_devices_: " << num_devices_ << '\n';
out << "shuffle_: " << shuffle_ << '\n';
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -48,6 +48,8 @@ class DistributedSampler : public Sampler {
// @return - The error code return
Status Reset() override;
void Print(std::ostream &out, bool show_all) const override;
private:
int64_t cnt_; // number of samples that have already been filled in to buffer
uint32_t seed_;

View File

@ -38,6 +38,7 @@ Status PKSampler::InitSampler() {
rnd_.seed(seed_++);
num_pk_samples_ = samples_per_class_ * static_cast<int64_t>(labels_.size());
samples_per_buffer_ = (samples_per_buffer_ > num_pk_samples_) ? num_pk_samples_ : samples_per_buffer_;
num_samples_ = num_pk_samples_;
if (shuffle_ == true) {
std::shuffle(labels_.begin(), labels_.end(), rnd_);
} else {
@ -53,6 +54,10 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
} else if (next_id_ == num_pk_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> sample_ids;
int64_t last_id =
@ -63,8 +68,16 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
int64_t cls_id = next_id_++ / samples_per_class_;
const std::vector<int64_t> &samples = label_to_ids_[labels_[cls_id]];
int64_t rnd_ind = std::uniform_int_distribution<int64_t>(0, samples.size() - 1)(rnd_);
*(id_ptr++) = samples[rnd_ind];
int64_t sampled_id = samples[rnd_ind];
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}
*id_ptr = sampled_id;
id_ptr++;
}
TensorRow row(1, sample_ids);
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
}
@ -75,6 +88,11 @@ Status PKSampler::Reset() {
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_pk_samples_, "ERROR Reset() called early/late");
next_id_ = 0;
rnd_.seed(seed_++);
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK();
}

View File

@ -27,6 +27,10 @@ Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
if (need_to_reset_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
std::shared_ptr<Tensor> sample_ids;
{
py::gil_scoped_acquire gil_acquire;
@ -38,6 +42,14 @@ Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
py::object py_ret = py_sampler_instance.attr("_get_indices")();
py::array np_sample_ids = py_ret.cast<py::array>();
Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor
if (HasChildSampler()) {
for (auto it = sample_ids->begin<int64_t>(); it != sample_ids->end<int64_t>(); ++it) {
int64_t associated_child_id = 0;
RETURN_IF_NOT_OK(GetAssociatedChildId(&associated_child_id, associated_child_id));
*it = associated_child_id;
}
}
} catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what());
} catch (const py::cast_error &e) {
@ -79,6 +91,11 @@ Status PythonSampler::Reset() {
} catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what());
}
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK();
}
} // namespace dataset

View File

@ -14,18 +14,22 @@
* limitations under the License.
*/
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
#include <algorithm>
#include <limits>
#include <memory>
#include "dataset/util/random.h"
namespace mindspore {
namespace dataset {
RandomSampler::RandomSampler(bool replacement, int64_t num_samples, int64_t samples_per_buffer)
RandomSampler::RandomSampler(bool replacement, bool reshuffle_each_epoch, int64_t num_samples,
int64_t samples_per_buffer)
: Sampler(samples_per_buffer),
seed_(GetSeed()),
replacement_(replacement),
user_num_samples_(num_samples),
next_id_(0),
reshuffle_each_epoch_(reshuffle_each_epoch),
dist(nullptr) {}
Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
@ -34,13 +38,29 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
} else if (next_id_ == num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> sampleIds;
int64_t last_id = samples_per_buffer_ + next_id_ > num_samples_ ? num_samples_ : samples_per_buffer_ + next_id_;
int64_t last_id = std::min(samples_per_buffer_ + next_id_, num_samples_);
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, last_id - next_id_));
int64_t *id_ptr = reinterpret_cast<int64_t *>(sampleIds->GetMutableBuffer());
for (int64_t i = 0; i < (last_id - next_id_); i++) {
*(id_ptr + i) = replacement_ ? (*dist)(rnd_) : shuffled_ids_[static_cast<size_t>(i + next_id_)];
int64_t sampled_id = 0;
if (replacement_) {
sampled_id = (*dist)(rnd_);
} else {
sampled_id = shuffled_ids_[static_cast<size_t>(i + next_id_)];
}
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}
*(id_ptr + i) = sampled_id;
}
next_id_ = last_id;
TensorRow row(1, sampleIds);
@ -53,7 +73,9 @@ Status RandomSampler::InitSampler() {
num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_;
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
rnd_.seed(seed_++);
rnd_.seed(seed_);
if (replacement_ == false) {
shuffled_ids_.reserve(num_rows_);
for (int64_t i = 0; i < num_rows_; i++) {
@ -69,11 +91,33 @@ Status RandomSampler::InitSampler() {
Status RandomSampler::Reset() {
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
next_id_ = 0;
rnd_.seed(seed_++);
if (replacement_ == false) {
if (reshuffle_each_epoch_) {
seed_++;
}
rnd_.seed(seed_);
if (replacement_ == false && reshuffle_each_epoch_) {
std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_);
}
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK();
}
void RandomSampler::Print(std::ostream &out, bool show_all) const {
out << "(sampler): RandomSampler\n";
if (show_all) {
out << "user_num_samples_: " << user_num_samples_ << '\n';
out << "num_samples_: " << num_samples_ << '\n';
out << "next_id_: " << next_id_ << '\n';
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -30,7 +30,8 @@ class RandomSampler : public Sampler {
// @param bool replacement - put he id back / or not after a sample
// @param int64_t numSamples - number samples to draw
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit RandomSampler(bool replacement = false, int64_t num_samples = std::numeric_limits<int64_t>::max(),
explicit RandomSampler(bool replacement = false, bool reshuffle_each_epoch = true,
int64_t num_samples = std::numeric_limits<int64_t>::max(),
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
// Destructor.
@ -49,6 +50,8 @@ class RandomSampler : public Sampler {
// @return - The error code return
Status Reset() override;
virtual void Print(std::ostream &out, bool show_all) const;
private:
uint32_t seed_;
bool replacement_;
@ -57,6 +60,7 @@ class RandomSampler : public Sampler {
int64_t next_id_;
std::mt19937 rnd_;
std::unique_ptr<std::uniform_int_distribution<int64_t>> dist;
bool reshuffle_each_epoch_;
};
} // namespace dataset
} // namespace mindspore

View File

@ -15,18 +15,41 @@
*/
#include "dataset/engine/datasetops/source/sampler/sampler.h"
#include <string>
namespace mindspore {
namespace dataset {
Sampler::Sampler(int64_t samples_per_buffer)
: DatasetOp(0), num_rows_(0), num_samples_(0), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {}
Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
std::shared_ptr<Sampler> child_sampler;
if (HasChildSampler()) {
child_sampler = std::dynamic_pointer_cast<Sampler>(child_[0]);
if (!child_sampler) {
std::string err_msg("Cannot handshake, child is not a sampler object.");
RETURN_STATUS_UNEXPECTED(err_msg);
}
// Handshake and init child first.
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op));
}
}
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n");
RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_));
RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_));
if (HasChildSampler()) {
int64_t child_num_samples = child_sampler->num_samples();
num_rows_ = child_num_samples;
} else {
RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_));
}
// It's up to the derived class to check the validity of the two args
// Because some sampler only needs one of the arg (weighted_random_sampler)
RETURN_IF_NOT_OK(InitSampler()); // init sampler after callback
return Status::OK();
}
@ -44,6 +67,15 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t
return Status::OK();
}
void Sampler::Print(std::ostream &out, bool show_all) const {
out << "(sampler): base\n";
if (show_all) {
out << "num_rows_: " << num_rows_ << '\n';
out << "num_samples_: " << num_samples_ << '\n';
}
}
Status Sampler::GetAllIdsThenReset(py::array *data) {
std::unique_ptr<DataBuffer> db;
std::shared_ptr<Tensor> sample_ids;
@ -84,5 +116,45 @@ Status Sampler::SetNumRowsInDataset(int64_t num_rows) {
num_rows_ = num_rows;
return Status::OK();
}
Status Sampler::AddChild(std::shared_ptr<DatasetOp> child) {
if (child == nullptr) {
return Status::OK();
}
// Only samplers can be added, not any other DatasetOp.
std::shared_ptr<Sampler> sampler = std::dynamic_pointer_cast<Sampler>(child);
if (!sampler) {
std::string err_msg("Cannot add child, child is not a sampler object.");
RETURN_STATUS_UNEXPECTED(err_msg);
}
// Samplers can have at most 1 child.
if (!child_.empty()) {
std::string err_msg("Cannot add child sampler, this sampler already has a child.");
RETURN_STATUS_UNEXPECTED(err_msg);
}
child_.push_back(child);
// doesn't work, protected?
// child->AddParent(this);
return Status::OK();
}
bool Sampler::HasChildSampler() { return !child_.empty(); }
Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
if (child_ids_ == nullptr) {
RETURN_STATUS_UNEXPECTED("Trying to get associated child id, but there are no child ids!");
}
TensorRow sample_row;
RETURN_IF_NOT_OK(child_ids_->GetRow(0, &sample_row));
std::shared_ptr<Tensor> sample_ids = sample_row[0];
RETURN_IF_NOT_OK(sample_ids->GetItemAt<int64_t>(out_associated_id, {id}));
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -90,6 +90,8 @@ class Sampler : public DatasetOp {
// setter function for num_samples_
Status SetNumSamples(int64_t num_samples);
int64_t num_samples() { return num_samples_; }
// first handshake between StorageOp and Sampler. This func will call getNumRows and getNumSamples
// @param op - StorageOp pointer, pass in so Sampler can call getNumSamples() and get ClassIds()
// @return
@ -114,17 +116,48 @@ class Sampler : public DatasetOp {
// @return - The error code return
Status operator()() final { RETURN_STATUS_UNEXPECTED("Functor not supported in Sampler"); }
// Adds a sampler to become our child.
// @param std::shared_ptr<DatasetOp> - The sampler to add as a child.
// @return - The error code returned.
Status AddChild(std::shared_ptr<DatasetOp> child);
// A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler
// @param std::shared_ptr<Tensor>* sampleIds
// @param int64_t numElements - must be a non 0 number
// @return
// @return - The error code returned.
Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements);
void Print(std::ostream &out, bool show_all) const override;
friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) {
sampler.Print(out, false);
return out;
}
// Checks if this sampler has a child sampler.
// @return - tre if there is a child sampler, false otherwise.
bool HasChildSampler();
// Uses id as an index for the list of ids generated by the child sampler, and gets the
// associated id.
// @param int64_t* out_associated_id - Out parameter, contains the associated id.
// @param int64_t id - The id used as an index to get the associated child id.
// @return - The error code returned.
Status GetAssociatedChildId(int64_t *out_associated_id, int64_t id);
protected:
// Number of rows of data from the place this sampler is sampling from. If this sampler
// has a child sampler, num_rows_ is the number of ids the child sampler will
// output. Otherwise, num_rows_ is the number of rows in the dataset.
int64_t num_rows_;
// Number of ids this sampler will return.
int64_t num_samples_;
// The max number of ids a DataBuffer returned by this sampler will contain.
int64_t samples_per_buffer_;
std::unique_ptr<ColDescriptor> col_desc_;
std::unique_ptr<DataBuffer> child_ids_;
};
} // namespace dataset
} // namespace mindspore

View File

@ -15,6 +15,7 @@
*/
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include <algorithm>
#include <memory>
namespace mindspore {
@ -27,14 +28,26 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
} else if (next_id_ == num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> sampleIds;
int64_t lastId = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_;
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, lastId - next_id_));
int64_t *idPtr = reinterpret_cast<int64_t *>(sampleIds->GetMutableBuffer());
while (next_id_ < lastId) {
*(idPtr++) = next_id_++;
int64_t sampled_id = next_id_;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}
*idPtr = sampled_id;
next_id_++;
idPtr++;
}
TensorRow row(1, sampleIds);
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
}
@ -43,6 +56,10 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
Status SequentialSampler::InitSampler() {
num_samples_ = (num_samples_ <= 0) ? num_rows_ : num_samples_; // if num_samples < 0, try if num_rows is set
if (HasChildSampler()) {
num_samples_ = std::min(num_samples_, num_rows_);
}
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
return Status::OK();
@ -51,7 +68,15 @@ Status SequentialSampler::InitSampler() {
Status SequentialSampler::Reset() {
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
next_id_ = 0;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK();
}
void SequentialSampler::Print(std::ostream &out, bool show_all) const { out << "(sampler): SequentialSampler\n"; }
} // namespace dataset
} // namespace mindspore

View File

@ -45,6 +45,8 @@ class SequentialSampler : public Sampler {
// @return - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
void Print(std::ostream &out, bool show_all) const override;
private:
int64_t next_id_;
};

View File

@ -34,6 +34,8 @@ SubsetRandomSampler::SubsetRandomSampler(const std::vector<int64_t> &indices, in
Status SubsetRandomSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n");
num_samples_ = indices_.size();
// Initialize random generator with seed from config manager
rand_gen_.seed(GetSeed());
@ -56,6 +58,10 @@ Status SubsetRandomSampler::Reset() {
rand_gen_.seed(GetSeed());
std::shuffle(indices_.begin(), indices_.end(), rand_gen_);
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK();
}
@ -65,6 +71,10 @@ Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffe
if (sample_id_ == indices_.size()) {
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> outputIds;
@ -87,7 +97,14 @@ Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffe
RETURN_STATUS_UNEXPECTED(err_msg);
}
*(id_ptr++) = indices_[sample_id_++];
int64_t sampled_id = indices_[sample_id_];
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}
*id_ptr = sampled_id;
id_ptr++;
sample_id_++;
}
// Create a TensorTable from that single tensor and push into DataBuffer

View File

@ -0,0 +1,85 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
#include <memory>
#include <string>
#include "dataset/core/config_manager.h"
#include "dataset/core/global_context.h"
namespace mindspore {
namespace dataset {
// Constructor.
SubsetSampler::SubsetSampler(int64_t start_index, int64_t subset_size)
: Sampler(subset_size), start_index_(start_index), subset_size_(subset_size), current_id_(0) {}
Status SubsetSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size_ <= 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows_\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ + subset_size_ - 1 < num_rows_, "Final index out of bounds.\n");
num_samples_ = subset_size_;
return Status::OK();
}
Status SubsetSampler::Reset() {
current_id_ = 0;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK();
}
Status SubsetSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
if (current_id_ > subset_size_) {
RETURN_STATUS_UNEXPECTED("SubsetSampler Internal Error");
} else if (current_id_ == subset_size_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> sampled_ids;
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampled_ids, subset_size_));
int64_t *sampled_ids_start_addr = reinterpret_cast<int64_t *>(sampled_ids->GetMutableBuffer());
while (current_id_ < subset_size_) {
int64_t sampled_id = start_index_ + current_id_;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}
*(sampled_ids_start_addr + current_id_) = sampled_id;
current_id_++;
}
TensorRow sampled_ids_row(1, sampled_ids);
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, sampled_ids_row));
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
#include <memory>
#include <vector>
#include "dataset/engine/datasetops/source/sampler/sampler.h"
namespace mindspore {
namespace dataset {
class SubsetSampler : public Sampler {
public:
// Constructor.
// @param start_index The index we start sampling from.
explicit SubsetSampler(int64_t start_index, int64_t subset_size);
// Destructor.
~SubsetSampler() = default;
// Initialize the sampler.
// @return Status
Status InitSampler() override;
// Reset the internal variable to the initial state and reshuffle the indices.
// @return Status
Status Reset() override;
// Get the sample ids.
// @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.
// @note the sample ids (int64_t) will be placed in one Tensor.
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
private:
int64_t start_index_;
int64_t subset_size_;
int64_t current_id_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_

View File

@ -40,6 +40,8 @@ WeightedRandomSampler::WeightedRandomSampler(const std::vector<double> &weights,
Status WeightedRandomSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && user_num_samples_, "num_samples & num_rows need to be positive");
CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n");
num_samples_ = user_num_samples_;
// Initialize random generator with seed from config manager
rand_gen_.seed(GetSeed());
@ -81,6 +83,11 @@ Status WeightedRandomSampler::Reset() {
} else {
discrete_dist_->reset();
}
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK();
}
@ -98,6 +105,10 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
if (sample_id_ == user_num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> outputIds;
@ -127,7 +138,12 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
RETURN_STATUS_UNEXPECTED("generated id is bigger than numRows (out of bound).");
}
*(id_ptr++) = genId;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&genId, genId));
}
*id_ptr = genId;
id_ptr++;
sample_id_++;
}

View File

@ -44,7 +44,7 @@ from .validators import check, check_batch, check_shuffle, check_map, check_filt
check_rename, \
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \
check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat
check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, check_split
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
try:
@ -581,6 +581,117 @@ class Dataset:
return self
return TakeDataset(self, count)
def _get_absolute_split_sizes(self, sizes):
"""
Internal method called by split to calculate absolute split sizes and to
do some error checking after calculating absolute split sizes.
"""
# call get_dataset_size here and check input here because
# dont want to call this once in check_split and another time in
# here again
dataset_size = self.get_dataset_size()
if(dataset_size is None or dataset_size <= 0):
raise RuntimeError("dataset size unknown, unable to split.")
all_int = all(isinstance(item, int) for item in sizes)
if all_int:
sizes_sum = sum(sizes)
if sizes_sum != dataset_size:
raise RuntimeError("sum of split sizes {} is not equal to dataset size {}."
.format(sizes_sum, dataset_size))
return sizes
absolute_sizes = []
for item in sizes:
absolute_size = int(round(item * dataset_size))
if absolute_size == 0:
raise RuntimeError("split percentage {} is too small.".format(item))
absolute_sizes.append(absolute_size)
absolute_sizes_sum = sum(absolute_sizes)
if absolute_sizes_sum != dataset_size:
raise RuntimeError("sum of calculated split sizes {} is not equal to dataset size {}."
.format(absolute_sizes_sum, dataset_size))
return absolute_sizes
@check_split
def split(self, sizes, randomize=True):
"""
Splits the dataset into smaller, non-overlapping datasets.
This is a general purpose split function which can be called from any operator in the pipeline.
There is another, optimized split function, which will be called automatically if ds.split is
called where ds is a MappableDataset.
Args:
sizes (list of int or list of float): If a list of integers [s1, s2, , sn] is
provided, the dataset will be split into n datasets of size s1, size s2, , size sn
respectively. If the sum of all sizes does not equal the original dataset size, an
an error will occur.
If a list of floats [f1, f2, , fn] is provided, the dataset will be split into n
Datasets of size f1*K, f2*K, , fn*K (rounded to nearest integer) where K is the size
of the original dataset. If after rounding, any size equals 0, an error will occur.
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
randomize (bool): determines whether or not to split the data randomly. If true, the data
will be randomly split. Otherwise, each split will be created with consecutive rows
from the dataset.
Note:
1. Dataset cannot be sharded if split is going to be called.
2. It is strongly recommended to not shuffle the dataset, but use randomize=True instead.
Shuffling the dataset may not be deterministic, which means the data in each split
will be different in each epoch.
Raises:
RuntimeError: If get_dataset_size returns None or is not supported for this dataset.
RuntimeError: If sizes is list of integers and sum of all elements in sizes does not
equal the dataset size.
RuntimeError: If sizes is list of float and there is a split with size 0 after calculations.
RuntimeError: If the dataset is sharded prior to calling split.
ValueError: If sizes is list of float and not all floats are between 0 and 1, or if the
floats dont sum to 1.
Returns
tuple(Dataset), a tuple of datasets that have been split.
Examples:
>>> import mindspore.dataset as ds
>>>
>>> dataset_dir = "/path/to/text_file.txt"
>>>
>>> # TextFileDataset is not a mappable dataset, so this non optimized split will be called.
>>> # many datasets have shuffle on by default, set shuffle to False if split will be called!
>>> data = ds.TextFileDataset(dataset_dir, shuffle=False)
>>> train, test = data.split([0.9, 0.1])
"""
if self.is_shuffled():
logger.warning("dataset is shuffled before split.")
if self.is_sharded():
raise RuntimeError("dataset should not be sharded before split.")
absolute_sizes = self._get_absolute_split_sizes(sizes)
splits = []
rows_to_skip = 0
for size in absolute_sizes:
ds = copy.deepcopy(self)
if randomize:
# want to shuffle the same way every epoch before split
ds = ds.shuffle()
ds.reshuffle_each_epoch = False
if rows_to_skip > 0:
ds = ds.skip(rows_to_skip)
ds = ds.take(size)
splits.append(ds)
rows_to_skip += size
return tuple(splits)
@check_zip_dataset
def zip(self, datasets):
"""
@ -1053,10 +1164,24 @@ class Dataset:
def reset(self):
"""Reset the dataset for next epoch."""
def is_shuffled(self):
for input_dataset in self.input:
if input_dataset.is_shuffled():
return True
return False
def is_sharded(self):
for input_dataset in self.input:
if input_dataset.is_sharded():
return True
return False
class SourceDataset(Dataset):
"""
Abstract class to represent a source dataset which produces content to the data pipeline.
Abstract class to represent a source dataset which produces content to the data pipeline.
"""
# No need for __init__ since it is the same as the super's init
@ -1093,6 +1218,150 @@ class SourceDataset(Dataset):
return file_list
raise ValueError("The list of path names matching the patterns is empty.")
def is_shuffled(self):
raise NotImplementedError("SourceDataset must implement is_shuffled.")
def is_sharded(self):
raise NotImplementedError("SourceDataset must implement is_sharded.")
class MappableDataset(SourceDataset):
"""
Abstract class to represent a source dataset which supports use of samplers.
"""
def __init__(self, num_parallel_workers=None):
# check if all subclasses use this name
super().__init__(num_parallel_workers)
self.sampler = None
def add_sampler(self, new_sampler):
# note: by adding a sampler, we mean that the sampled ids will flow to new_sampler
# after first passing through the current samplers attached to this dataset.
new_sampler.add_child(self.sampler)
self.sampler = new_sampler
def use_sampler(self, new_sampler):
"""
Will make the current dataset use the new_sampler provided.
Args:
new_sampler (Sampler): the sampler to use for the current dataset.
Returns:
Dataset, that uses new_sampler.
Examples:
>>> import mindspore.dataset as ds
>>>
>>> dataset_dir = "/path/to/imagefolder_directory"
>>> # a SequentialSampler is created by default
>>> data = ds.ImageFolderDatasetV2(dataset_dir)
>>>
>>> # use a DistributedSampler instead of the SequentialSampler
>>> new_sampler = ds.DistributedSampler(10, 2)
>>> data.use_sampler(new_sampler)
"""
self.sampler = self.sampler.child_sampler
self.add_sampler(new_sampler)
def is_shuffled(self):
raise NotImplementedError("MappableDataset must implement is_shuffled.")
def is_sharded(self):
raise NotImplementedError("MappableDataset must implement is_sharded.")
@check_split
def split(self, sizes, randomize=True):
"""
Splits the dataset into smaller, non-overlapping datasets.
There is the optimized split function, which will be called automatically when the dataset
that calls this function is a MappableDataset.
Args:
sizes (list of int or list of float): If a list of integers [s1, s2, , sn] is
provided, the dataset will be split into n datasets of size s1, size s2, , size sn
respectively. If the sum of all sizes does not equal the original dataset size, an
an error will occur.
If a list of floats [f1, f2, , fn] is provided, the dataset will be split into n
Datasets of size f1*K, f2*K, , fn*K (rounded to nearest integer) where K is the size
of the original dataset. If after rounding, any size equals 0, an error will occur.
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
randomize (bool): determines whether or not to split the data randomly. If true, the data
will be randomly split. Otherwise, each split will be created with consecutive rows
from the dataset.
Note:
1. Dataset should not be sharded if split is going to be called. Instead, create a
DistributedSampler and specify a split to shard after splitting. If dataset is
sharded after a split, it is strongly recommended to set the same seed in each instance
of execution, otherwise each shard may not be part of the same split (see Examples)
2. It is strongly recommended to not shuffle the dataset, but use randomize=True instead.
Shuffling the dataset may not be deterministic, which means the data in each split
will be different in each epoch. Furthermore, if sharding occurs after split, each
shard may not be part of the same split.
Raises:
RuntimeError: If get_dataset_size returns None or is not supported for this dataset.
RuntimeError: If sizes is list of integers and sum of all elements in sizes does not
equal the dataset size.
RuntimeError: If sizes is list of float and there is a split with size 0 after calculations.
RuntimeError: If the dataset is sharded prior to calling split.
ValueError: If sizes is list of float and not all floats are between 0 and 1, or if the
floats dont sum to 1.
Returns
tuple(Dataset), a tuple of datasets that have been split.
Examples:
>>> import mindspore.dataset as ds
>>>
>>> dataset_dir = "/path/to/imagefolder_directory"
>>>
>>> # many datasets have shuffle on by default, set shuffle to False if split will be called!
>>> data = ds.ImageFolderDatasetV2(dataset_dir, shuffle=False)
>>>
>>> # sets the seed, and tells split to use this seed when randomizing. This
>>> # is needed because we are sharding later
>>> ds.config.set_seed(58)
>>> train, test = data.split([0.9, 0.1])
>>>
>>> # if we want to shard the train dataset, we can use a DistributedSampler
>>> train_sampler = ds.DistributedSampler(10, 2)
>>> train.use_sampler(train_sampler)
"""
if self.is_shuffled():
logger.warning("dataset is shuffled before split.")
if self.is_sharded():
raise RuntimeError("dataset should not be sharded before split.")
absolute_sizes = self._get_absolute_split_sizes(sizes)
splits = []
current_split_start_index = 0
for size in absolute_sizes:
ds = copy.deepcopy(self)
if randomize:
# want to shuffle the same way every epoch before split, we are assuming
# that the user will call set_seed
random_sampler = samplers.RandomSampler()
random_sampler.reshuffle_each_epoch = False
ds.add_sampler(random_sampler)
subset_sampler = samplers.SubsetSampler(current_split_start_index, size)
ds.add_sampler(subset_sampler)
# add sequential sampler, so that if user calls use_sampler, we will
# get rid of the sequential sampler instead of something we need
ds.add_sampler(samplers.SequentialSampler())
splits.append(ds)
current_split_start_index += size
return tuple(splits)
class DatasetOp(Dataset):
"""
@ -1336,6 +1605,7 @@ class SyncWaitDataset(DatasetOp):
flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset)
return flag
class ShuffleDataset(DatasetOp):
"""
The result of applying Shuffle operator to the input Dataset.
@ -1352,6 +1622,7 @@ class ShuffleDataset(DatasetOp):
super().__init__()
self.buffer_size = buffer_size
self.input.append(input_dataset)
self.reshuffle_each_epoch = None
input_dataset.output.append(self)
self._input_indexs = input_dataset.input_indexs
if self.is_sync():
@ -1360,8 +1631,14 @@ class ShuffleDataset(DatasetOp):
def get_args(self):
args = super().get_args()
args["buffer_size"] = self.buffer_size
if self.reshuffle_each_epoch is not None:
args["reshuffle_each_epoch"] = self.reshuffle_each_epoch
return args
def is_shuffled(self):
return True
# Pyfunc collection for multiprocess pyfunc
# This global variable will only be used within subprocesses
@ -1991,8 +2268,14 @@ class StorageDataset(SourceDataset):
self._get_pipeline_info()
return self._num_classes
def is_shuffled(self):
return False
class RangeDataset(SourceDataset):
def is_sharded(self):
return False
class RangeDataset(MappableDataset):
"""
A source dataset that reads and parses datasets stored on disk in a range.
@ -2015,6 +2298,12 @@ class RangeDataset(SourceDataset):
args["step"] = self.step
return args
def is_shuffled(self):
return False
def is_sharded(self):
return False
def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
"""
@ -2054,7 +2343,7 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
return samplers.SequentialSampler()
class ImageFolderDatasetV2(SourceDataset):
class ImageFolderDatasetV2(MappableDataset):
"""
A source dataset that reads images from a tree of directories.
@ -2192,8 +2481,20 @@ class ImageFolderDatasetV2(SourceDataset):
num_samples = self.num_samples
return ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[1]
def is_shuffled(self):
if self.shuffle_level is None:
return True
class MnistDataset(SourceDataset):
return self.shuffle_level or self.sampler.is_shuffled()
def is_sharded(self):
if self.num_shards is not None:
return self.num_shards > 1
return self.sampler.is_sharded()
class MnistDataset(MappableDataset):
"""
A source dataset for reading and parsing the Mnist dataset.
@ -2296,6 +2597,18 @@ class MnistDataset(SourceDataset):
return get_num_rows(num_rows, self.num_shards)
def is_shuffled(self):
if self.shuffle_level is None:
return True
return self.shuffle_level or self.sampler.is_shuffled()
def is_sharded(self):
if self.num_shards is not None:
return self.num_shards > 1
return self.sampler.is_sharded()
class MindDataset(SourceDataset):
"""
@ -2402,6 +2715,18 @@ class MindDataset(SourceDataset):
num_rows = num_rows // self.partitions[0] + 1
return num_rows
def is_shuffled(self):
if self.global_shuffle is None:
return True
return self.global_shuffle or self.sampler.is_shuffled()
def is_sharded(self):
if self.num_shards is not None:
return self.num_shards > 1
return self.sampler.is_sharded()
def _iter_fn(dataset, num_samples):
"""
@ -2611,7 +2936,7 @@ class _GeneratorWorker(multiprocessing.Process):
self.terminate()
class GeneratorDataset(SourceDataset):
class GeneratorDataset(MappableDataset):
"""
A source dataset that generate data from python by invoking python data source each epoch.
@ -2796,6 +3121,12 @@ class GeneratorDataset(SourceDataset):
return new_op
def is_shuffled(self):
return self.sampler.is_shuffled()
def is_sharded(self):
return self.sampler.is_sharded()
class TFRecordDataset(SourceDataset):
"""
@ -2922,8 +3253,17 @@ class TFRecordDataset(SourceDataset):
else:
raise ValueError('set dataset_size with negative value {}'.format(value))
def is_shuffled(self):
return self.shuffle_files
class ManifestDataset(SourceDataset):
def is_sharded(self):
if self.num_shards is not None:
return self.num_shards > 1
return False
class ManifestDataset(MappableDataset):
"""
A source dataset that reads images from a manifest file.
@ -3090,8 +3430,20 @@ class ManifestDataset(SourceDataset):
return ManifestOp.get_class_indexing(self.dataset_file, num_samples, class_indexing, self.usage)
def is_shuffled(self):
if self.shuffle_level is None:
return True
class Cifar10Dataset(SourceDataset):
return self.shuffle_level or self.sampler.is_shuffled()
def is_sharded(self):
if self.num_shards is not None:
return self.num_shards > 1
return self.sampler.is_sharded()
class Cifar10Dataset(MappableDataset):
"""
A source dataset that reads cifar10 data.
@ -3199,8 +3551,20 @@ class Cifar10Dataset(SourceDataset):
return get_num_rows(num_rows, self.num_shards)
def is_shuffled(self):
if self.shuffle_level is None:
return True
class Cifar100Dataset(SourceDataset):
return self.shuffle_level or self.sampler.is_shuffled()
def is_sharded(self):
if self.num_shards is not None:
return self.num_shards > 1
return self.sampler.is_sharded()
class Cifar100Dataset(MappableDataset):
"""
A source dataset that reads cifar100 data.
@ -3306,6 +3670,18 @@ class Cifar100Dataset(SourceDataset):
return get_num_rows(num_rows, self.num_shards)
def is_shuffled(self):
if self.shuffle_level is None:
return True
return self.shuffle_level or self.sampler.is_shuffled()
def is_sharded(self):
if self.num_shards is not None:
return self.num_shards > 1
return self.sampler.is_sharded()
class RandomDataset(SourceDataset):
"""
@ -3357,6 +3733,11 @@ class RandomDataset(SourceDataset):
"""
return num_samples
def is_shuffled(self):
return True
def is_sharded(self):
return False
class Schema:
"""
@ -3536,7 +3917,7 @@ class Schema:
return self.to_json()
class VOCDataset(SourceDataset):
class VOCDataset(MappableDataset):
"""
A source dataset for reading and parsing VOC dataset.
@ -3683,8 +4064,20 @@ class VOCDataset(SourceDataset):
return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.mode, class_indexing, num_samples)
def is_shuffled(self):
if self.shuffle_level is None:
return True
class CelebADataset(SourceDataset):
return self.shuffle_level or self.sampler.is_shuffled()
def is_sharded(self):
if self.num_shards is not None:
return self.num_shards > 1
return self.sampler.is_sharded()
class CelebADataset(MappableDataset):
"""
A source dataset for reading and parsing CelebA dataset.Only support list_attr_celeba.txt currently.
@ -3737,6 +4130,18 @@ class CelebADataset(SourceDataset):
args["shard_id"] = self.shard_id
return args
def is_shuffled(self):
if self.shuffle_level is None:
return True
return self.shuffle_level or self.sampler.is_shuffled()
def is_sharded(self):
if self.num_shards is not None:
return self.num_shards > 1
return self.sampler.is_sharded()
class TextFileDataset(SourceDataset):
"""
@ -3816,3 +4221,12 @@ class TextFileDataset(SourceDataset):
return num_rows
return min(self.num_samples, num_rows)
return self._dataset_size
def is_shuffled(self):
return self.shuffle_files
def is_sharded(self):
if self.num_shards is not None:
return self.num_shards > 1
return False

View File

@ -47,6 +47,7 @@ class Sampler:
def __init__(self):
self.dataset_size = 0
self.num_samples = 0
self.child_sampler = None
def __iter__(self):
"""
@ -83,7 +84,35 @@ class Sampler:
# Instance fetcher
# Do not override this method!
def create(self):
return cde.PythonSampler(self)
c_sampler = cde.PythonSampler(self)
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def add_child(self, sampler):
self.child_sampler = sampler
def get_child(self):
return self.child_sampler
def create_child(self):
c_child_sampler = None
if self.child_sampler is not None:
c_child_sampler = self.child_sampler.create()
return c_child_sampler
def is_shuffled(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_shuffled()
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
class BuiltinSampler:
@ -93,11 +122,30 @@ class BuiltinSampler:
User should not extend this class.
"""
def __init__(self):
pass
self.child_sampler = None
def create(self):
pass
def add_child(self, sampler):
self.child_sampler = sampler
def get_child(self):
return self.child_sampler
def create_child(self):
c_child_sampler = None
if self.child_sampler is not None:
c_child_sampler = self.child_sampler.create()
return c_child_sampler
def is_shuffled(self):
raise NotImplementedError("Sampler must implement is_shuffled.")
def is_sharded(self):
raise NotImplementedError("Sampler must implement is_sharded.")
class DistributedSampler(BuiltinSampler):
"""
@ -142,7 +190,22 @@ class DistributedSampler(BuiltinSampler):
def create(self):
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
self.seed += 1
return cde.DistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed)
c_sampler = cde.DistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed)
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
if self.child_sampler is None:
return self.shuffle
return self.child_sampler.is_shuffled()
def is_sharded(self):
if self.child_sampler is None:
return self.num_shards > 1
return self.child_sampler.is_sharded()
class PKSampler(BuiltinSampler):
@ -186,7 +249,22 @@ class PKSampler(BuiltinSampler):
super().__init__()
def create(self):
return cde.PKSampler(self.num_val, self.shuffle)
c_sampler = cde.PKSampler(self.num_val, self.shuffle)
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
if self.child_sampler is None:
return self.shuffle
return self.child_sampler.is_shuffled()
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
def _create_for_minddataset(self):
if not self.class_column or not isinstance(self.class_column, str):
@ -226,15 +304,31 @@ class RandomSampler(BuiltinSampler):
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples))
self.deterministic = False
self.replacement = replacement
self.num_samples = num_samples
self.reshuffle_each_epoch = True
super().__init__()
def create(self):
# If num_samples is not specified, then call constructor #2
c_sampler = None
if self.num_samples is None:
return cde.RandomSampler(self.replacement)
return cde.RandomSampler(self.replacement, self.num_samples)
c_sampler = cde.RandomSampler(self.replacement, self.reshuffle_each_epoch)
else:
c_sampler = cde.RandomSampler(self.replacement, self.reshuffle_each_epoch, self.num_samples)
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
return True
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
class SequentialSampler(BuiltinSampler):
@ -252,7 +346,80 @@ class SequentialSampler(BuiltinSampler):
"""
def create(self):
return cde.SequentialSampler()
c_sampler = cde.SequentialSampler()
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_shuffled()
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
class SubsetSampler(BuiltinSampler):
"""
Samples a subset of elements consecutively from a given index.
Args:
start_index (int): Index to start sampling at.
subset_size (int): How many samples to include in this subset.
Examples:
>>> import mindspore.dataset as ds
>>>
>>> dataset_dir = "path/to/imagefolder_directory"
>>>
>>> # creates a SubsetSampler, will sample the next 5 images from the 100th image.
>>> sampler = ds.SubsetSampler(100, 5)
>>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler)
Raises:
ValueError: If start_index is not a positive int.
ValueError: If subset_size is not a positive int.
"""
def __init__(self, start_index, subset_size):
if not isinstance(start_index, int):
raise ValueError("start_index should be an int.")
if start_index < 0:
raise ValueError("start_index should not be negative.")
if not isinstance(subset_size, int):
raise ValueError("start_index should be an int")
if subset_size < 0:
raise ValueError("subset_size should not be negative.")
self.start_index = start_index
self.subset_size = subset_size
super().__init__()
def create(self):
c_sampler = cde.SubsetSampler(self.start_index, self.subset_size)
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_shuffled()
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
class SubsetRandomSampler(BuiltinSampler):
@ -282,7 +449,19 @@ class SubsetRandomSampler(BuiltinSampler):
super().__init__()
def create(self):
return cde.SubsetRandomSampler(self.indices)
c_sampler = cde.SubsetRandomSampler(self.indices)
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
return True
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
def _create_for_minddataset(self):
return cde.MindrecordSubsetRandomSampler(self.indices)
@ -330,4 +509,16 @@ class WeightedRandomSampler(BuiltinSampler):
super().__init__()
def create(self):
return cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement)
c_sampler = cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement)
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
return True
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()

View File

@ -1031,3 +1031,44 @@ def check_textfiledataset(method):
return method(*args, **kwargs)
return new_method
def check_split(method):
"""check the input arguments of split."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
nreq_param_list = ['sizes']
nreq_param_bool = ['randomize']
check_param_type(nreq_param_list, param_dict, list)
check_param_type(nreq_param_bool, param_dict, bool)
# check sizes: must be list of float or list of int
sizes = param_dict.get('sizes')
if not sizes:
raise ValueError("sizes cannot be empty.")
all_int = all(isinstance(item, int) for item in sizes)
all_float = all(isinstance(item, float) for item in sizes)
if not (all_int or all_float):
raise ValueError("sizes should be list of int or list of float.")
if all_int:
all_positive = all(item > 0 for item in sizes)
if not all_positive:
raise ValueError("sizes is a list of int, but there should be no negative numbers.")
if all_float:
all_valid_percentages = all(0 < item <= 1 for item in sizes)
if not all_valid_percentages:
raise ValueError("sizes is a list of float, but there should be no numbers outside the range [0, 1].")
epsilon = 0.00001
if not abs(sum(sizes) - 1) < epsilon:
raise ValueError("sizes is a list of float, but the percentages do not sum up to 1.")
return method(*args, **kwargs)
return new_method

View File

@ -92,7 +92,7 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) {
TEST_F(MindDataTestCifarOp, TestRandomSamplerCifar10) {
uint32_t original_seed = GlobalContext::config_manager()->seed();
GlobalContext::config_manager()->set_seed(0);
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, 12);
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, true, 12);
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
auto tree = Build({Cifarop(16, 2, 32, folder_path, std::move(sampler), 100)});
tree->Prepare();

View File

@ -138,7 +138,7 @@ TEST_F(MindDataTestImageFolderSampler, TestRandomImageFolder) {
TEST_F(MindDataTestImageFolderSampler, TestRandomSamplerImageFolder) {
int32_t original_seed = GlobalContext::config_manager()->seed();
GlobalContext::config_manager()->set_seed(0);
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, 12);
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, true, 12);
int32_t res[] = {2, 2, 2, 3, 2, 3, 2, 3, 1, 2, 2, 1}; // ground truth label
std::string folder_path = datasets_root_path_ + "/testPK/data";
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))});

View File

@ -164,9 +164,36 @@ def test_python_sampler():
assert list(sp1.get_indices()) == [0, 1, 2, 3, 4]
def test_sampler_chain():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
def test_config(num_shards, shard_id):
sampler = ds.DistributedSampler(num_shards, shard_id, False)
child_sampler = ds.SequentialSampler()
sampler.add_child(child_sampler)
data1 = ds.ManifestDataset(manifest_file, num_samples=5, sampler=sampler)
res = []
for item in data1.create_dict_iterator():
logger.info("item[image].shape[0]: {}, item[label].item(): {}"
.format(item["image"].shape[0], item["label"].item()))
res.append(map[(item["image"].shape[0], item["label"].item())])
return res
assert test_config(2, 0) == [0, 2, 4]
assert test_config(2, 1) == [1, 3, 0]
assert test_config(5, 0) == [0]
assert test_config(5, 1) == [1]
assert test_config(5, 2) == [2]
assert test_config(5, 3) == [3]
assert test_config(5, 4) == [4]
if __name__ == '__main__':
test_sequential_sampler(True)
test_random_sampler(True)
test_random_sampler_multi_iter(True)
test_sampler_py_api()
test_python_sampler()
test_sampler_chain()

View File

@ -0,0 +1,342 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import pytest
import mindspore.dataset as ds
# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
# the label of each image is [0,0,0,1,1] each image can be uniquely identified
# via the following lookup table (dict){(83554, 0): 0, (54214, 0): 1, (54214, 1): 2, (65512, 0): 3, (64631, 1): 4}
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
manifest_map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
def split_with_invalid_inputs(d):
with pytest.raises(ValueError) as info:
s1, s2 = d.split([])
assert "sizes cannot be empty" in str(info.value)
with pytest.raises(ValueError) as info:
s1, s2 = d.split([5, 0.6])
assert "sizes should be list of int or list of float" in str(info.value)
with pytest.raises(ValueError) as info:
s1, s2 = d.split([-1, 6])
assert "there should be no negative numbers" in str(info.value)
with pytest.raises(RuntimeError) as info:
s1, s2 = d.split([3, 1])
assert "sum of split sizes 4 is not equal to dataset size 5" in str(info.value)
with pytest.raises(RuntimeError) as info:
s1, s2 = d.split([5, 1])
assert "sum of split sizes 6 is not equal to dataset size 5" in str(info.value)
with pytest.raises(RuntimeError) as info:
s1, s2 = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25])
assert "sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value)
with pytest.raises(ValueError) as info:
s1, s2 = d.split([-0.5, 0.5])
assert "there should be no numbers outside the range [0, 1]" in str(info.value)
with pytest.raises(ValueError) as info:
s1, s2 = d.split([1.5, 0.5])
assert "there should be no numbers outside the range [0, 1]" in str(info.value)
with pytest.raises(ValueError) as info:
s1, s2 = d.split([0.5, 0.6])
assert "percentages do not sum up to 1" in str(info.value)
with pytest.raises(ValueError) as info:
s1, s2 = d.split([0.3, 0.6])
assert "percentages do not sum up to 1" in str(info.value)
with pytest.raises(RuntimeError) as info:
s1, s2 = d.split([0.05, 0.95])
assert "percentage 0.05 is too small" in str(info.value)
def test_unmappable_invalid_input():
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
d = ds.TextFileDataset(text_file_dataset_path)
split_with_invalid_inputs(d)
d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0)
with pytest.raises(RuntimeError) as info:
s1, s2 = d.split([4, 1])
assert "dataset should not be sharded before split" in str(info.value)
def test_unmappable_split():
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
text_file_data = ["This is a text file.", "Another file.", "Be happy every day.",
"End of file.", "Good luck to everyone."]
ds.config.set_num_parallel_workers(4)
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
s1, s2 = d.split([4, 1], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(item["text"].item().decode("utf8"))
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(item["text"].item().decode("utf8"))
assert s1_output == text_file_data[0:4]
assert s2_output == text_file_data[4:]
# exact percentages
s1, s2 = d.split([0.8, 0.2], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(item["text"].item().decode("utf8"))
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(item["text"].item().decode("utf8"))
assert s1_output == text_file_data[0:4]
assert s2_output == text_file_data[4:]
# fuzzy percentages
s1, s2 = d.split([0.33, 0.67], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(item["text"].item().decode("utf8"))
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(item["text"].item().decode("utf8"))
assert s1_output == text_file_data[0:2]
assert s2_output == text_file_data[2:]
def test_mappable_invalid_input():
d = ds.ManifestDataset(manifest_file)
split_with_invalid_inputs(d)
d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0)
with pytest.raises(RuntimeError) as info:
s1, s2 = d.split([4, 1])
assert "dataset should not be sharded before split" in str(info.value)
def test_mappable_split_general():
d = ds.ManifestDataset(manifest_file, shuffle=False)
d = d.take(5)
# absolute rows
s1, s2 = d.split([4, 1], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1_output == [0, 1, 2, 3]
assert s2_output == [4]
# exact percentages
s1, s2 = d.split([0.8, 0.2], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1_output == [0, 1, 2, 3]
assert s2_output == [4]
# fuzzy percentages
s1, s2 = d.split([0.33, 0.67], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1_output == [0, 1]
assert s2_output == [2, 3, 4]
def test_mappable_split_optimized():
d = ds.ManifestDataset(manifest_file, shuffle=False)
# absolute rows
s1, s2 = d.split([4, 1], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1_output == [0, 1, 2, 3]
assert s2_output == [4]
# exact percentages
s1, s2 = d.split([0.8, 0.2], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1_output == [0, 1, 2, 3]
assert s2_output == [4]
# fuzzy percentages
s1, s2 = d.split([0.33, 0.67], randomize=False)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1_output == [0, 1]
assert s2_output == [2, 3, 4]
def test_mappable_randomize_deterministic():
# set arbitrary seed for shard after split
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
ds.config.set_seed(53)
d = ds.ManifestDataset(manifest_file, shuffle=False)
s1, s2 = d.split([0.8, 0.2])
for _ in range(10):
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
# note no overlap
assert s1_output == [0, 1, 3, 4]
assert s2_output == [2]
def test_mappable_randomize_repeatable():
# set arbitrary seed for shard after split
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
ds.config.set_seed(53)
d = ds.ManifestDataset(manifest_file, shuffle=False)
s1, s2 = d.split([0.8, 0.2])
num_epochs = 5
s1 = s1.repeat(num_epochs)
s2 = s2.repeat(num_epochs)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
# note no overlap
assert s1_output == [0, 1, 3, 4] * num_epochs
assert s2_output == [2] * num_epochs
def test_mappable_sharding():
# set arbitrary seed for repeatability for shard after split
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
ds.config.set_seed(53)
num_epochs = 5
first_split_num_rows = 4
d = ds.ManifestDataset(manifest_file, shuffle=False)
s1, s2 = d.split([first_split_num_rows, 1])
distributed_sampler = ds.DistributedSampler(2, 0)
s1.use_sampler(distributed_sampler)
s1 = s1.repeat(num_epochs)
# testing sharding, second dataset to simulate another instance
d2 = ds.ManifestDataset(manifest_file, shuffle=False)
d2s1, d2s2 = d2.split([first_split_num_rows, 1])
distributed_sampler = ds.DistributedSampler(2, 1)
d2s1.use_sampler(distributed_sampler)
d2s1 = d2s1.repeat(num_epochs)
# shard 0
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
# shard 1
d2s1_output = []
for item in d2s1.create_dict_iterator():
d2s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
rows_per_shard_per_epoch = 2
assert len(s1_output) == rows_per_shard_per_epoch * num_epochs
assert len(d2s1_output) == rows_per_shard_per_epoch * num_epochs
# verify each epoch that
# 1. shards contain no common elements
# 2. the data was split the same way, and that the union of shards equal the split
correct_sorted_split_result = [0, 1, 3, 4]
for i in range(num_epochs):
combined_data = []
for j in range(rows_per_shard_per_epoch):
combined_data.append(s1_output[i * rows_per_shard_per_epoch + j])
combined_data.append(d2s1_output[i * rows_per_shard_per_epoch + j])
assert sorted(combined_data) == correct_sorted_split_result
# test other split
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
d2s2_output = []
for item in d2s2.create_dict_iterator():
d2s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s2_output == [2]
assert d2s2_output == [2]
if __name__ == '__main__':
test_unmappable_invalid_input()
test_unmappable_split()
test_mappable_invalid_input()
test_mappable_split_general()
test_mappable_split_optimized()
test_mappable_randomize_deterministic()
test_mappable_randomize_repeatable()
test_mappable_sharding()