!1281 Implementation of SplitOp
Merge pull request !1281 from Peilin/splitOp
This commit is contained in:
commit
2e3d55ed87
|
@ -364,6 +364,18 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetO
|
|||
std::string err_msg = "Error: Shuffle buffer size is missing";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
// Optional arguments
|
||||
for (auto arg : args) {
|
||||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "reshuffle_each_epoch") {
|
||||
(void)builder->SetReshuffleEachEpoch(ToBool(args["reshuffle_each_epoch"]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ShuffleOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
*ptr = op;
|
||||
|
|
|
@ -51,6 +51,7 @@
|
|||
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
|
||||
|
@ -425,11 +426,14 @@ void bindSamplerOps(py::module *m) {
|
|||
.def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
|
||||
.def("set_num_samples", [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
|
||||
.def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); })
|
||||
.def("get_indices", [](Sampler &self) {
|
||||
.def("get_indices",
|
||||
[](Sampler &self) {
|
||||
py::array ret;
|
||||
THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
|
||||
return ret;
|
||||
});
|
||||
})
|
||||
.def("add_child",
|
||||
[](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) { THROW_IF_ERROR(self->AddChild(child)); });
|
||||
|
||||
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator");
|
||||
|
||||
|
@ -441,12 +445,16 @@ void bindSamplerOps(py::module *m) {
|
|||
.def(py::init<int64_t, bool>(), py::arg("kVal"), py::arg("shuffle"));
|
||||
|
||||
(void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler")
|
||||
.def(py::init<bool, int64_t>(), py::arg("replacement"), py::arg("numSamples"))
|
||||
.def(py::init<bool>(), py::arg("replacement"));
|
||||
.def(py::init<bool, bool, int64_t>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"),
|
||||
py::arg("num_samples"))
|
||||
.def(py::init<bool, bool>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"));
|
||||
|
||||
(void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler")
|
||||
.def(py::init<>());
|
||||
|
||||
(void)py::class_<SubsetSampler, Sampler, std::shared_ptr<SubsetSampler>>(*m, "SubsetSampler")
|
||||
.def(py::init<int64_t, int64_t>(), py::arg("start_index"), py::arg("subset_size"));
|
||||
|
||||
(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
|
||||
.def(py::init<std::vector<int64_t>>(), py::arg("indices"));
|
||||
|
||||
|
|
|
@ -8,5 +8,6 @@ add_library(engine-datasetops-source-sampler OBJECT
|
|||
sampler.cc
|
||||
sequential_sampler.cc
|
||||
subset_random_sampler.cc
|
||||
subset_sampler.cc
|
||||
weighted_random_sampler.cc
|
||||
)
|
||||
|
|
|
@ -55,13 +55,27 @@ Status DistributedSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer
|
|||
} else if (cnt_ == samples_per_buffer_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
}
|
||||
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(cnt_, DataBuffer::kDeBFlagNone);
|
||||
std::shared_ptr<Tensor> sample_ids;
|
||||
RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, samples_per_buffer_));
|
||||
int64_t *id_ptr = reinterpret_cast<int64_t *>(sample_ids->GetMutableBuffer());
|
||||
while (cnt_ < samples_per_buffer_) {
|
||||
int64_t next_id = (num_devices_ * (cnt_++) + device_id_) % num_rows_;
|
||||
*(id_ptr++) = shuffle_ ? shuffle_vec_[static_cast<size_t>(next_id)] : next_id;
|
||||
int64_t sampled_id = (num_devices_ * cnt_ + device_id_) % num_rows_;
|
||||
if (shuffle_) {
|
||||
sampled_id = shuffle_vec_[static_cast<size_t>(sampled_id)];
|
||||
}
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
|
||||
}
|
||||
|
||||
*id_ptr = sampled_id;
|
||||
id_ptr++;
|
||||
cnt_++;
|
||||
}
|
||||
TensorRow row(1, sample_ids);
|
||||
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
|
||||
|
@ -72,11 +86,29 @@ Status DistributedSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer
|
|||
Status DistributedSampler::Reset() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late");
|
||||
cnt_ = 0;
|
||||
rnd_.seed(seed_++);
|
||||
|
||||
if (shuffle_ == true) {
|
||||
rnd_.seed(seed_);
|
||||
seed_++;
|
||||
std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_);
|
||||
}
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void DistributedSampler::Print(std::ostream &out, bool show_all) const {
|
||||
out << "(sampler): DistributedSampler\n";
|
||||
if (show_all) {
|
||||
out << "seed_: " << seed_ << '\n';
|
||||
out << "device_id_: " << device_id_ << '\n';
|
||||
out << "num_devices_: " << num_devices_ << '\n';
|
||||
out << "shuffle_: " << shuffle_ << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -48,6 +48,8 @@ class DistributedSampler : public Sampler {
|
|||
// @return - The error code return
|
||||
Status Reset() override;
|
||||
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
private:
|
||||
int64_t cnt_; // number of samples that have already been filled in to buffer
|
||||
uint32_t seed_;
|
||||
|
|
|
@ -38,6 +38,7 @@ Status PKSampler::InitSampler() {
|
|||
rnd_.seed(seed_++);
|
||||
num_pk_samples_ = samples_per_class_ * static_cast<int64_t>(labels_.size());
|
||||
samples_per_buffer_ = (samples_per_buffer_ > num_pk_samples_) ? num_pk_samples_ : samples_per_buffer_;
|
||||
num_samples_ = num_pk_samples_;
|
||||
if (shuffle_ == true) {
|
||||
std::shuffle(labels_.begin(), labels_.end(), rnd_);
|
||||
} else {
|
||||
|
@ -53,6 +54,10 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
|||
} else if (next_id_ == num_pk_samples_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
}
|
||||
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
|
||||
std::shared_ptr<Tensor> sample_ids;
|
||||
int64_t last_id =
|
||||
|
@ -63,8 +68,16 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
|||
int64_t cls_id = next_id_++ / samples_per_class_;
|
||||
const std::vector<int64_t> &samples = label_to_ids_[labels_[cls_id]];
|
||||
int64_t rnd_ind = std::uniform_int_distribution<int64_t>(0, samples.size() - 1)(rnd_);
|
||||
*(id_ptr++) = samples[rnd_ind];
|
||||
int64_t sampled_id = samples[rnd_ind];
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
|
||||
}
|
||||
|
||||
*id_ptr = sampled_id;
|
||||
id_ptr++;
|
||||
}
|
||||
|
||||
TensorRow row(1, sample_ids);
|
||||
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
|
||||
}
|
||||
|
@ -75,6 +88,11 @@ Status PKSampler::Reset() {
|
|||
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_pk_samples_, "ERROR Reset() called early/late");
|
||||
next_id_ = 0;
|
||||
rnd_.seed(seed_++);
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,10 @@ Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
|||
if (need_to_reset_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
}
|
||||
|
||||
std::shared_ptr<Tensor> sample_ids;
|
||||
{
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
|
@ -38,6 +42,14 @@ Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
|||
py::object py_ret = py_sampler_instance.attr("_get_indices")();
|
||||
py::array np_sample_ids = py_ret.cast<py::array>();
|
||||
Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor
|
||||
|
||||
if (HasChildSampler()) {
|
||||
for (auto it = sample_ids->begin<int64_t>(); it != sample_ids->end<int64_t>(); ++it) {
|
||||
int64_t associated_child_id = 0;
|
||||
RETURN_IF_NOT_OK(GetAssociatedChildId(&associated_child_id, associated_child_id));
|
||||
*it = associated_child_id;
|
||||
}
|
||||
}
|
||||
} catch (const py::error_already_set &e) {
|
||||
return Status(StatusCode::kPyFuncException, e.what());
|
||||
} catch (const py::cast_error &e) {
|
||||
|
@ -79,6 +91,11 @@ Status PythonSampler::Reset() {
|
|||
} catch (const py::error_already_set &e) {
|
||||
return Status(StatusCode::kPyFuncException, e.what());
|
||||
}
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
|
|
|
@ -14,18 +14,22 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include "dataset/util/random.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
RandomSampler::RandomSampler(bool replacement, int64_t num_samples, int64_t samples_per_buffer)
|
||||
RandomSampler::RandomSampler(bool replacement, bool reshuffle_each_epoch, int64_t num_samples,
|
||||
int64_t samples_per_buffer)
|
||||
: Sampler(samples_per_buffer),
|
||||
seed_(GetSeed()),
|
||||
replacement_(replacement),
|
||||
user_num_samples_(num_samples),
|
||||
next_id_(0),
|
||||
reshuffle_each_epoch_(reshuffle_each_epoch),
|
||||
dist(nullptr) {}
|
||||
|
||||
Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
|
@ -34,13 +38,29 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
|||
} else if (next_id_ == num_samples_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
}
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
|
||||
|
||||
std::shared_ptr<Tensor> sampleIds;
|
||||
int64_t last_id = samples_per_buffer_ + next_id_ > num_samples_ ? num_samples_ : samples_per_buffer_ + next_id_;
|
||||
int64_t last_id = std::min(samples_per_buffer_ + next_id_, num_samples_);
|
||||
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, last_id - next_id_));
|
||||
int64_t *id_ptr = reinterpret_cast<int64_t *>(sampleIds->GetMutableBuffer());
|
||||
|
||||
for (int64_t i = 0; i < (last_id - next_id_); i++) {
|
||||
*(id_ptr + i) = replacement_ ? (*dist)(rnd_) : shuffled_ids_[static_cast<size_t>(i + next_id_)];
|
||||
int64_t sampled_id = 0;
|
||||
if (replacement_) {
|
||||
sampled_id = (*dist)(rnd_);
|
||||
} else {
|
||||
sampled_id = shuffled_ids_[static_cast<size_t>(i + next_id_)];
|
||||
}
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
|
||||
}
|
||||
|
||||
*(id_ptr + i) = sampled_id;
|
||||
}
|
||||
next_id_ = last_id;
|
||||
TensorRow row(1, sampleIds);
|
||||
|
@ -53,7 +73,9 @@ Status RandomSampler::InitSampler() {
|
|||
num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive");
|
||||
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
|
||||
rnd_.seed(seed_++);
|
||||
|
||||
rnd_.seed(seed_);
|
||||
|
||||
if (replacement_ == false) {
|
||||
shuffled_ids_.reserve(num_rows_);
|
||||
for (int64_t i = 0; i < num_rows_; i++) {
|
||||
|
@ -69,11 +91,33 @@ Status RandomSampler::InitSampler() {
|
|||
Status RandomSampler::Reset() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
|
||||
next_id_ = 0;
|
||||
rnd_.seed(seed_++);
|
||||
if (replacement_ == false) {
|
||||
|
||||
if (reshuffle_each_epoch_) {
|
||||
seed_++;
|
||||
}
|
||||
|
||||
rnd_.seed(seed_);
|
||||
|
||||
if (replacement_ == false && reshuffle_each_epoch_) {
|
||||
std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_);
|
||||
}
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RandomSampler::Print(std::ostream &out, bool show_all) const {
|
||||
out << "(sampler): RandomSampler\n";
|
||||
|
||||
if (show_all) {
|
||||
out << "user_num_samples_: " << user_num_samples_ << '\n';
|
||||
out << "num_samples_: " << num_samples_ << '\n';
|
||||
out << "next_id_: " << next_id_ << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,7 +30,8 @@ class RandomSampler : public Sampler {
|
|||
// @param bool replacement - put he id back / or not after a sample
|
||||
// @param int64_t numSamples - number samples to draw
|
||||
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
|
||||
explicit RandomSampler(bool replacement = false, int64_t num_samples = std::numeric_limits<int64_t>::max(),
|
||||
explicit RandomSampler(bool replacement = false, bool reshuffle_each_epoch = true,
|
||||
int64_t num_samples = std::numeric_limits<int64_t>::max(),
|
||||
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
|
||||
|
||||
// Destructor.
|
||||
|
@ -49,6 +50,8 @@ class RandomSampler : public Sampler {
|
|||
// @return - The error code return
|
||||
Status Reset() override;
|
||||
|
||||
virtual void Print(std::ostream &out, bool show_all) const;
|
||||
|
||||
private:
|
||||
uint32_t seed_;
|
||||
bool replacement_;
|
||||
|
@ -57,6 +60,7 @@ class RandomSampler : public Sampler {
|
|||
int64_t next_id_;
|
||||
std::mt19937 rnd_;
|
||||
std::unique_ptr<std::uniform_int_distribution<int64_t>> dist;
|
||||
bool reshuffle_each_epoch_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,18 +15,41 @@
|
|||
*/
|
||||
#include "dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
Sampler::Sampler(int64_t samples_per_buffer)
|
||||
: DatasetOp(0), num_rows_(0), num_samples_(0), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {}
|
||||
|
||||
Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
|
||||
std::shared_ptr<Sampler> child_sampler;
|
||||
if (HasChildSampler()) {
|
||||
child_sampler = std::dynamic_pointer_cast<Sampler>(child_[0]);
|
||||
if (!child_sampler) {
|
||||
std::string err_msg("Cannot handshake, child is not a sampler object.");
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
// Handshake and init child first.
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op));
|
||||
}
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n");
|
||||
RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_));
|
||||
if (HasChildSampler()) {
|
||||
int64_t child_num_samples = child_sampler->num_samples();
|
||||
num_rows_ = child_num_samples;
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_));
|
||||
}
|
||||
|
||||
// It's up to the derived class to check the validity of the two args
|
||||
// Because some sampler only needs one of the arg (weighted_random_sampler)
|
||||
RETURN_IF_NOT_OK(InitSampler()); // init sampler after callback
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -44,6 +67,15 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
void Sampler::Print(std::ostream &out, bool show_all) const {
|
||||
out << "(sampler): base\n";
|
||||
|
||||
if (show_all) {
|
||||
out << "num_rows_: " << num_rows_ << '\n';
|
||||
out << "num_samples_: " << num_samples_ << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
Status Sampler::GetAllIdsThenReset(py::array *data) {
|
||||
std::unique_ptr<DataBuffer> db;
|
||||
std::shared_ptr<Tensor> sample_ids;
|
||||
|
@ -84,5 +116,45 @@ Status Sampler::SetNumRowsInDataset(int64_t num_rows) {
|
|||
num_rows_ = num_rows;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Sampler::AddChild(std::shared_ptr<DatasetOp> child) {
|
||||
if (child == nullptr) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Only samplers can be added, not any other DatasetOp.
|
||||
std::shared_ptr<Sampler> sampler = std::dynamic_pointer_cast<Sampler>(child);
|
||||
if (!sampler) {
|
||||
std::string err_msg("Cannot add child, child is not a sampler object.");
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
// Samplers can have at most 1 child.
|
||||
if (!child_.empty()) {
|
||||
std::string err_msg("Cannot add child sampler, this sampler already has a child.");
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
child_.push_back(child);
|
||||
|
||||
// doesn't work, protected?
|
||||
// child->AddParent(this);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool Sampler::HasChildSampler() { return !child_.empty(); }
|
||||
|
||||
Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
|
||||
if (child_ids_ == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Trying to get associated child id, but there are no child ids!");
|
||||
}
|
||||
|
||||
TensorRow sample_row;
|
||||
RETURN_IF_NOT_OK(child_ids_->GetRow(0, &sample_row));
|
||||
std::shared_ptr<Tensor> sample_ids = sample_row[0];
|
||||
RETURN_IF_NOT_OK(sample_ids->GetItemAt<int64_t>(out_associated_id, {id}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -90,6 +90,8 @@ class Sampler : public DatasetOp {
|
|||
// setter function for num_samples_
|
||||
Status SetNumSamples(int64_t num_samples);
|
||||
|
||||
int64_t num_samples() { return num_samples_; }
|
||||
|
||||
// first handshake between StorageOp and Sampler. This func will call getNumRows and getNumSamples
|
||||
// @param op - StorageOp pointer, pass in so Sampler can call getNumSamples() and get ClassIds()
|
||||
// @return
|
||||
|
@ -114,17 +116,48 @@ class Sampler : public DatasetOp {
|
|||
// @return - The error code return
|
||||
Status operator()() final { RETURN_STATUS_UNEXPECTED("Functor not supported in Sampler"); }
|
||||
|
||||
// Adds a sampler to become our child.
|
||||
// @param std::shared_ptr<DatasetOp> - The sampler to add as a child.
|
||||
// @return - The error code returned.
|
||||
Status AddChild(std::shared_ptr<DatasetOp> child);
|
||||
|
||||
// A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler
|
||||
// @param std::shared_ptr<Tensor>* sampleIds
|
||||
// @param int64_t numElements - must be a non 0 number
|
||||
// @return
|
||||
// @return - The error code returned.
|
||||
Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements);
|
||||
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) {
|
||||
sampler.Print(out, false);
|
||||
return out;
|
||||
}
|
||||
|
||||
// Checks if this sampler has a child sampler.
|
||||
// @return - tre if there is a child sampler, false otherwise.
|
||||
bool HasChildSampler();
|
||||
|
||||
// Uses id as an index for the list of ids generated by the child sampler, and gets the
|
||||
// associated id.
|
||||
// @param int64_t* out_associated_id - Out parameter, contains the associated id.
|
||||
// @param int64_t id - The id used as an index to get the associated child id.
|
||||
// @return - The error code returned.
|
||||
Status GetAssociatedChildId(int64_t *out_associated_id, int64_t id);
|
||||
|
||||
protected:
|
||||
// Number of rows of data from the place this sampler is sampling from. If this sampler
|
||||
// has a child sampler, num_rows_ is the number of ids the child sampler will
|
||||
// output. Otherwise, num_rows_ is the number of rows in the dataset.
|
||||
int64_t num_rows_;
|
||||
|
||||
// Number of ids this sampler will return.
|
||||
int64_t num_samples_;
|
||||
|
||||
// The max number of ids a DataBuffer returned by this sampler will contain.
|
||||
int64_t samples_per_buffer_;
|
||||
std::unique_ptr<ColDescriptor> col_desc_;
|
||||
std::unique_ptr<DataBuffer> child_ids_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -27,14 +28,26 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
|
|||
} else if (next_id_ == num_samples_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
}
|
||||
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
|
||||
std::shared_ptr<Tensor> sampleIds;
|
||||
int64_t lastId = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_;
|
||||
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, lastId - next_id_));
|
||||
int64_t *idPtr = reinterpret_cast<int64_t *>(sampleIds->GetMutableBuffer());
|
||||
while (next_id_ < lastId) {
|
||||
*(idPtr++) = next_id_++;
|
||||
int64_t sampled_id = next_id_;
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
|
||||
}
|
||||
|
||||
*idPtr = sampled_id;
|
||||
next_id_++;
|
||||
idPtr++;
|
||||
}
|
||||
|
||||
TensorRow row(1, sampleIds);
|
||||
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
|
||||
}
|
||||
|
@ -43,6 +56,10 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
|
|||
|
||||
Status SequentialSampler::InitSampler() {
|
||||
num_samples_ = (num_samples_ <= 0) ? num_rows_ : num_samples_; // if num_samples < 0, try if num_rows is set
|
||||
if (HasChildSampler()) {
|
||||
num_samples_ = std::min(num_samples_, num_rows_);
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler");
|
||||
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
|
||||
return Status::OK();
|
||||
|
@ -51,7 +68,15 @@ Status SequentialSampler::InitSampler() {
|
|||
Status SequentialSampler::Reset() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
|
||||
next_id_ = 0;
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void SequentialSampler::Print(std::ostream &out, bool show_all) const { out << "(sampler): SequentialSampler\n"; }
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -45,6 +45,8 @@ class SequentialSampler : public Sampler {
|
|||
// @return - The error code return
|
||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
private:
|
||||
int64_t next_id_;
|
||||
};
|
||||
|
|
|
@ -34,6 +34,8 @@ SubsetRandomSampler::SubsetRandomSampler(const std::vector<int64_t> &indices, in
|
|||
Status SubsetRandomSampler::InitSampler() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n");
|
||||
|
||||
num_samples_ = indices_.size();
|
||||
|
||||
// Initialize random generator with seed from config manager
|
||||
rand_gen_.seed(GetSeed());
|
||||
|
||||
|
@ -56,6 +58,10 @@ Status SubsetRandomSampler::Reset() {
|
|||
rand_gen_.seed(GetSeed());
|
||||
std::shuffle(indices_.begin(), indices_.end(), rand_gen_);
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -65,6 +71,10 @@ Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffe
|
|||
if (sample_id_ == indices_.size()) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
}
|
||||
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone);
|
||||
std::shared_ptr<Tensor> outputIds;
|
||||
|
||||
|
@ -87,7 +97,14 @@ Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffe
|
|||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
*(id_ptr++) = indices_[sample_id_++];
|
||||
int64_t sampled_id = indices_[sample_id_];
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
|
||||
}
|
||||
|
||||
*id_ptr = sampled_id;
|
||||
id_ptr++;
|
||||
sample_id_++;
|
||||
}
|
||||
|
||||
// Create a TensorTable from that single tensor and push into DataBuffer
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/config_manager.h"
|
||||
#include "dataset/core/global_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Constructor.
|
||||
SubsetSampler::SubsetSampler(int64_t start_index, int64_t subset_size)
|
||||
: Sampler(subset_size), start_index_(start_index), subset_size_(subset_size), current_id_(0) {}
|
||||
|
||||
Status SubsetSampler::InitSampler() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size_ <= 0\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows_\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ + subset_size_ - 1 < num_rows_, "Final index out of bounds.\n");
|
||||
|
||||
num_samples_ = subset_size_;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SubsetSampler::Reset() {
|
||||
current_id_ = 0;
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SubsetSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
if (current_id_ > subset_size_) {
|
||||
RETURN_STATUS_UNEXPECTED("SubsetSampler Internal Error");
|
||||
} else if (current_id_ == subset_size_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
}
|
||||
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagNone);
|
||||
std::shared_ptr<Tensor> sampled_ids;
|
||||
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampled_ids, subset_size_));
|
||||
|
||||
int64_t *sampled_ids_start_addr = reinterpret_cast<int64_t *>(sampled_ids->GetMutableBuffer());
|
||||
|
||||
while (current_id_ < subset_size_) {
|
||||
int64_t sampled_id = start_index_ + current_id_;
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
|
||||
}
|
||||
|
||||
*(sampled_ids_start_addr + current_id_) = sampled_id;
|
||||
current_id_++;
|
||||
}
|
||||
|
||||
TensorRow sampled_ids_row(1, sampled_ids);
|
||||
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, sampled_ids_row));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
|
||||
#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class SubsetSampler : public Sampler {
|
||||
public:
|
||||
// Constructor.
|
||||
// @param start_index The index we start sampling from.
|
||||
explicit SubsetSampler(int64_t start_index, int64_t subset_size);
|
||||
|
||||
// Destructor.
|
||||
~SubsetSampler() = default;
|
||||
|
||||
// Initialize the sampler.
|
||||
// @return Status
|
||||
Status InitSampler() override;
|
||||
|
||||
// Reset the internal variable to the initial state and reshuffle the indices.
|
||||
// @return Status
|
||||
Status Reset() override;
|
||||
|
||||
// Get the sample ids.
|
||||
// @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.
|
||||
// @note the sample ids (int64_t) will be placed in one Tensor.
|
||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
private:
|
||||
int64_t start_index_;
|
||||
int64_t subset_size_;
|
||||
int64_t current_id_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
|
|
@ -40,6 +40,8 @@ WeightedRandomSampler::WeightedRandomSampler(const std::vector<double> &weights,
|
|||
Status WeightedRandomSampler::InitSampler() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && user_num_samples_, "num_samples & num_rows need to be positive");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n");
|
||||
num_samples_ = user_num_samples_;
|
||||
|
||||
// Initialize random generator with seed from config manager
|
||||
rand_gen_.seed(GetSeed());
|
||||
|
||||
|
@ -81,6 +83,11 @@ Status WeightedRandomSampler::Reset() {
|
|||
} else {
|
||||
discrete_dist_->reset();
|
||||
}
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -98,6 +105,10 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
|
|||
if (sample_id_ == user_num_samples_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
}
|
||||
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone);
|
||||
std::shared_ptr<Tensor> outputIds;
|
||||
|
||||
|
@ -127,7 +138,12 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
|
|||
RETURN_STATUS_UNEXPECTED("generated id is bigger than numRows (out of bound).");
|
||||
}
|
||||
|
||||
*(id_ptr++) = genId;
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(GetAssociatedChildId(&genId, genId));
|
||||
}
|
||||
|
||||
*id_ptr = genId;
|
||||
id_ptr++;
|
||||
sample_id_++;
|
||||
}
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ from .validators import check, check_batch, check_shuffle, check_map, check_filt
|
|||
check_rename, \
|
||||
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
||||
check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \
|
||||
check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat
|
||||
check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, check_split
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
|
||||
try:
|
||||
|
@ -581,6 +581,117 @@ class Dataset:
|
|||
return self
|
||||
return TakeDataset(self, count)
|
||||
|
||||
def _get_absolute_split_sizes(self, sizes):
|
||||
"""
|
||||
Internal method called by split to calculate absolute split sizes and to
|
||||
do some error checking after calculating absolute split sizes.
|
||||
"""
|
||||
# call get_dataset_size here and check input here because
|
||||
# dont want to call this once in check_split and another time in
|
||||
# here again
|
||||
dataset_size = self.get_dataset_size()
|
||||
|
||||
if(dataset_size is None or dataset_size <= 0):
|
||||
raise RuntimeError("dataset size unknown, unable to split.")
|
||||
|
||||
all_int = all(isinstance(item, int) for item in sizes)
|
||||
if all_int:
|
||||
sizes_sum = sum(sizes)
|
||||
if sizes_sum != dataset_size:
|
||||
raise RuntimeError("sum of split sizes {} is not equal to dataset size {}."
|
||||
.format(sizes_sum, dataset_size))
|
||||
return sizes
|
||||
|
||||
absolute_sizes = []
|
||||
for item in sizes:
|
||||
absolute_size = int(round(item * dataset_size))
|
||||
if absolute_size == 0:
|
||||
raise RuntimeError("split percentage {} is too small.".format(item))
|
||||
absolute_sizes.append(absolute_size)
|
||||
|
||||
absolute_sizes_sum = sum(absolute_sizes)
|
||||
if absolute_sizes_sum != dataset_size:
|
||||
raise RuntimeError("sum of calculated split sizes {} is not equal to dataset size {}."
|
||||
.format(absolute_sizes_sum, dataset_size))
|
||||
|
||||
return absolute_sizes
|
||||
|
||||
@check_split
|
||||
def split(self, sizes, randomize=True):
|
||||
"""
|
||||
Splits the dataset into smaller, non-overlapping datasets.
|
||||
|
||||
This is a general purpose split function which can be called from any operator in the pipeline.
|
||||
There is another, optimized split function, which will be called automatically if ds.split is
|
||||
called where ds is a MappableDataset.
|
||||
|
||||
Args:
|
||||
sizes (list of int or list of float): If a list of integers [s1, s2, …, sn] is
|
||||
provided, the dataset will be split into n datasets of size s1, size s2, …, size sn
|
||||
respectively. If the sum of all sizes does not equal the original dataset size, an
|
||||
an error will occur.
|
||||
If a list of floats [f1, f2, …, fn] is provided, the dataset will be split into n
|
||||
Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size
|
||||
of the original dataset. If after rounding, any size equals 0, an error will occur.
|
||||
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
|
||||
randomize (bool): determines whether or not to split the data randomly. If true, the data
|
||||
will be randomly split. Otherwise, each split will be created with consecutive rows
|
||||
from the dataset.
|
||||
|
||||
Note:
|
||||
1. Dataset cannot be sharded if split is going to be called.
|
||||
2. It is strongly recommended to not shuffle the dataset, but use randomize=True instead.
|
||||
Shuffling the dataset may not be deterministic, which means the data in each split
|
||||
will be different in each epoch.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If get_dataset_size returns None or is not supported for this dataset.
|
||||
RuntimeError: If sizes is list of integers and sum of all elements in sizes does not
|
||||
equal the dataset size.
|
||||
RuntimeError: If sizes is list of float and there is a split with size 0 after calculations.
|
||||
RuntimeError: If the dataset is sharded prior to calling split.
|
||||
ValueError: If sizes is list of float and not all floats are between 0 and 1, or if the
|
||||
floats don’t sum to 1.
|
||||
|
||||
Returns
|
||||
tuple(Dataset), a tuple of datasets that have been split.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>>
|
||||
>>> dataset_dir = "/path/to/text_file.txt"
|
||||
>>>
|
||||
>>> # TextFileDataset is not a mappable dataset, so this non optimized split will be called.
|
||||
>>> # many datasets have shuffle on by default, set shuffle to False if split will be called!
|
||||
>>> data = ds.TextFileDataset(dataset_dir, shuffle=False)
|
||||
>>> train, test = data.split([0.9, 0.1])
|
||||
"""
|
||||
if self.is_shuffled():
|
||||
logger.warning("dataset is shuffled before split.")
|
||||
|
||||
if self.is_sharded():
|
||||
raise RuntimeError("dataset should not be sharded before split.")
|
||||
|
||||
absolute_sizes = self._get_absolute_split_sizes(sizes)
|
||||
splits = []
|
||||
rows_to_skip = 0
|
||||
for size in absolute_sizes:
|
||||
ds = copy.deepcopy(self)
|
||||
if randomize:
|
||||
# want to shuffle the same way every epoch before split
|
||||
ds = ds.shuffle()
|
||||
ds.reshuffle_each_epoch = False
|
||||
|
||||
if rows_to_skip > 0:
|
||||
ds = ds.skip(rows_to_skip)
|
||||
|
||||
ds = ds.take(size)
|
||||
splits.append(ds)
|
||||
|
||||
rows_to_skip += size
|
||||
|
||||
return tuple(splits)
|
||||
|
||||
@check_zip_dataset
|
||||
def zip(self, datasets):
|
||||
"""
|
||||
|
@ -1053,6 +1164,20 @@ class Dataset:
|
|||
def reset(self):
|
||||
"""Reset the dataset for next epoch."""
|
||||
|
||||
def is_shuffled(self):
|
||||
for input_dataset in self.input:
|
||||
if input_dataset.is_shuffled():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def is_sharded(self):
|
||||
for input_dataset in self.input:
|
||||
if input_dataset.is_sharded():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class SourceDataset(Dataset):
|
||||
"""
|
||||
|
@ -1093,6 +1218,150 @@ class SourceDataset(Dataset):
|
|||
return file_list
|
||||
raise ValueError("The list of path names matching the patterns is empty.")
|
||||
|
||||
def is_shuffled(self):
|
||||
raise NotImplementedError("SourceDataset must implement is_shuffled.")
|
||||
|
||||
def is_sharded(self):
|
||||
raise NotImplementedError("SourceDataset must implement is_sharded.")
|
||||
|
||||
class MappableDataset(SourceDataset):
|
||||
"""
|
||||
Abstract class to represent a source dataset which supports use of samplers.
|
||||
"""
|
||||
|
||||
def __init__(self, num_parallel_workers=None):
|
||||
# check if all subclasses use this name
|
||||
super().__init__(num_parallel_workers)
|
||||
self.sampler = None
|
||||
|
||||
def add_sampler(self, new_sampler):
|
||||
# note: by adding a sampler, we mean that the sampled ids will flow to new_sampler
|
||||
# after first passing through the current samplers attached to this dataset.
|
||||
new_sampler.add_child(self.sampler)
|
||||
self.sampler = new_sampler
|
||||
|
||||
def use_sampler(self, new_sampler):
|
||||
"""
|
||||
Will make the current dataset use the new_sampler provided.
|
||||
|
||||
Args:
|
||||
new_sampler (Sampler): the sampler to use for the current dataset.
|
||||
|
||||
Returns:
|
||||
Dataset, that uses new_sampler.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>>
|
||||
>>> dataset_dir = "/path/to/imagefolder_directory"
|
||||
>>> # a SequentialSampler is created by default
|
||||
>>> data = ds.ImageFolderDatasetV2(dataset_dir)
|
||||
>>>
|
||||
>>> # use a DistributedSampler instead of the SequentialSampler
|
||||
>>> new_sampler = ds.DistributedSampler(10, 2)
|
||||
>>> data.use_sampler(new_sampler)
|
||||
"""
|
||||
self.sampler = self.sampler.child_sampler
|
||||
self.add_sampler(new_sampler)
|
||||
|
||||
def is_shuffled(self):
|
||||
raise NotImplementedError("MappableDataset must implement is_shuffled.")
|
||||
|
||||
def is_sharded(self):
|
||||
raise NotImplementedError("MappableDataset must implement is_sharded.")
|
||||
|
||||
|
||||
@check_split
|
||||
def split(self, sizes, randomize=True):
|
||||
"""
|
||||
Splits the dataset into smaller, non-overlapping datasets.
|
||||
|
||||
There is the optimized split function, which will be called automatically when the dataset
|
||||
that calls this function is a MappableDataset.
|
||||
|
||||
Args:
|
||||
sizes (list of int or list of float): If a list of integers [s1, s2, …, sn] is
|
||||
provided, the dataset will be split into n datasets of size s1, size s2, …, size sn
|
||||
respectively. If the sum of all sizes does not equal the original dataset size, an
|
||||
an error will occur.
|
||||
If a list of floats [f1, f2, …, fn] is provided, the dataset will be split into n
|
||||
Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size
|
||||
of the original dataset. If after rounding, any size equals 0, an error will occur.
|
||||
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
|
||||
randomize (bool): determines whether or not to split the data randomly. If true, the data
|
||||
will be randomly split. Otherwise, each split will be created with consecutive rows
|
||||
from the dataset.
|
||||
|
||||
Note:
|
||||
1. Dataset should not be sharded if split is going to be called. Instead, create a
|
||||
DistributedSampler and specify a split to shard after splitting. If dataset is
|
||||
sharded after a split, it is strongly recommended to set the same seed in each instance
|
||||
of execution, otherwise each shard may not be part of the same split (see Examples)
|
||||
2. It is strongly recommended to not shuffle the dataset, but use randomize=True instead.
|
||||
Shuffling the dataset may not be deterministic, which means the data in each split
|
||||
will be different in each epoch. Furthermore, if sharding occurs after split, each
|
||||
shard may not be part of the same split.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If get_dataset_size returns None or is not supported for this dataset.
|
||||
RuntimeError: If sizes is list of integers and sum of all elements in sizes does not
|
||||
equal the dataset size.
|
||||
RuntimeError: If sizes is list of float and there is a split with size 0 after calculations.
|
||||
RuntimeError: If the dataset is sharded prior to calling split.
|
||||
ValueError: If sizes is list of float and not all floats are between 0 and 1, or if the
|
||||
floats don’t sum to 1.
|
||||
|
||||
Returns
|
||||
tuple(Dataset), a tuple of datasets that have been split.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>>
|
||||
>>> dataset_dir = "/path/to/imagefolder_directory"
|
||||
>>>
|
||||
>>> # many datasets have shuffle on by default, set shuffle to False if split will be called!
|
||||
>>> data = ds.ImageFolderDatasetV2(dataset_dir, shuffle=False)
|
||||
>>>
|
||||
>>> # sets the seed, and tells split to use this seed when randomizing. This
|
||||
>>> # is needed because we are sharding later
|
||||
>>> ds.config.set_seed(58)
|
||||
>>> train, test = data.split([0.9, 0.1])
|
||||
>>>
|
||||
>>> # if we want to shard the train dataset, we can use a DistributedSampler
|
||||
>>> train_sampler = ds.DistributedSampler(10, 2)
|
||||
>>> train.use_sampler(train_sampler)
|
||||
"""
|
||||
if self.is_shuffled():
|
||||
logger.warning("dataset is shuffled before split.")
|
||||
|
||||
if self.is_sharded():
|
||||
raise RuntimeError("dataset should not be sharded before split.")
|
||||
|
||||
absolute_sizes = self._get_absolute_split_sizes(sizes)
|
||||
splits = []
|
||||
current_split_start_index = 0
|
||||
for size in absolute_sizes:
|
||||
ds = copy.deepcopy(self)
|
||||
if randomize:
|
||||
# want to shuffle the same way every epoch before split, we are assuming
|
||||
# that the user will call set_seed
|
||||
random_sampler = samplers.RandomSampler()
|
||||
random_sampler.reshuffle_each_epoch = False
|
||||
ds.add_sampler(random_sampler)
|
||||
|
||||
subset_sampler = samplers.SubsetSampler(current_split_start_index, size)
|
||||
ds.add_sampler(subset_sampler)
|
||||
|
||||
# add sequential sampler, so that if user calls use_sampler, we will
|
||||
# get rid of the sequential sampler instead of something we need
|
||||
ds.add_sampler(samplers.SequentialSampler())
|
||||
|
||||
splits.append(ds)
|
||||
|
||||
current_split_start_index += size
|
||||
|
||||
return tuple(splits)
|
||||
|
||||
|
||||
class DatasetOp(Dataset):
|
||||
"""
|
||||
|
@ -1336,6 +1605,7 @@ class SyncWaitDataset(DatasetOp):
|
|||
flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset)
|
||||
return flag
|
||||
|
||||
|
||||
class ShuffleDataset(DatasetOp):
|
||||
"""
|
||||
The result of applying Shuffle operator to the input Dataset.
|
||||
|
@ -1352,6 +1622,7 @@ class ShuffleDataset(DatasetOp):
|
|||
super().__init__()
|
||||
self.buffer_size = buffer_size
|
||||
self.input.append(input_dataset)
|
||||
self.reshuffle_each_epoch = None
|
||||
input_dataset.output.append(self)
|
||||
self._input_indexs = input_dataset.input_indexs
|
||||
if self.is_sync():
|
||||
|
@ -1360,8 +1631,14 @@ class ShuffleDataset(DatasetOp):
|
|||
def get_args(self):
|
||||
args = super().get_args()
|
||||
args["buffer_size"] = self.buffer_size
|
||||
if self.reshuffle_each_epoch is not None:
|
||||
args["reshuffle_each_epoch"] = self.reshuffle_each_epoch
|
||||
|
||||
return args
|
||||
|
||||
def is_shuffled(self):
|
||||
return True
|
||||
|
||||
|
||||
# Pyfunc collection for multiprocess pyfunc
|
||||
# This global variable will only be used within subprocesses
|
||||
|
@ -1991,8 +2268,14 @@ class StorageDataset(SourceDataset):
|
|||
self._get_pipeline_info()
|
||||
return self._num_classes
|
||||
|
||||
def is_shuffled(self):
|
||||
return False
|
||||
|
||||
class RangeDataset(SourceDataset):
|
||||
def is_sharded(self):
|
||||
return False
|
||||
|
||||
|
||||
class RangeDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset that reads and parses datasets stored on disk in a range.
|
||||
|
||||
|
@ -2015,6 +2298,12 @@ class RangeDataset(SourceDataset):
|
|||
args["step"] = self.step
|
||||
return args
|
||||
|
||||
def is_shuffled(self):
|
||||
return False
|
||||
|
||||
def is_sharded(self):
|
||||
return False
|
||||
|
||||
|
||||
def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
||||
"""
|
||||
|
@ -2054,7 +2343,7 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
|||
return samplers.SequentialSampler()
|
||||
|
||||
|
||||
class ImageFolderDatasetV2(SourceDataset):
|
||||
class ImageFolderDatasetV2(MappableDataset):
|
||||
"""
|
||||
A source dataset that reads images from a tree of directories.
|
||||
|
||||
|
@ -2192,8 +2481,20 @@ class ImageFolderDatasetV2(SourceDataset):
|
|||
num_samples = self.num_samples
|
||||
return ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[1]
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
return True
|
||||
|
||||
class MnistDataset(SourceDataset):
|
||||
return self.shuffle_level or self.sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.num_shards is not None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return self.sampler.is_sharded()
|
||||
|
||||
|
||||
class MnistDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing the Mnist dataset.
|
||||
|
||||
|
@ -2296,6 +2597,18 @@ class MnistDataset(SourceDataset):
|
|||
|
||||
return get_num_rows(num_rows, self.num_shards)
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
return True
|
||||
|
||||
return self.shuffle_level or self.sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.num_shards is not None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return self.sampler.is_sharded()
|
||||
|
||||
|
||||
class MindDataset(SourceDataset):
|
||||
"""
|
||||
|
@ -2402,6 +2715,18 @@ class MindDataset(SourceDataset):
|
|||
num_rows = num_rows // self.partitions[0] + 1
|
||||
return num_rows
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.global_shuffle is None:
|
||||
return True
|
||||
|
||||
return self.global_shuffle or self.sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.num_shards is not None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return self.sampler.is_sharded()
|
||||
|
||||
|
||||
def _iter_fn(dataset, num_samples):
|
||||
"""
|
||||
|
@ -2611,7 +2936,7 @@ class _GeneratorWorker(multiprocessing.Process):
|
|||
self.terminate()
|
||||
|
||||
|
||||
class GeneratorDataset(SourceDataset):
|
||||
class GeneratorDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset that generate data from python by invoking python data source each epoch.
|
||||
|
||||
|
@ -2796,6 +3121,12 @@ class GeneratorDataset(SourceDataset):
|
|||
|
||||
return new_op
|
||||
|
||||
def is_shuffled(self):
|
||||
return self.sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
return self.sampler.is_sharded()
|
||||
|
||||
|
||||
class TFRecordDataset(SourceDataset):
|
||||
"""
|
||||
|
@ -2922,8 +3253,17 @@ class TFRecordDataset(SourceDataset):
|
|||
else:
|
||||
raise ValueError('set dataset_size with negative value {}'.format(value))
|
||||
|
||||
def is_shuffled(self):
|
||||
return self.shuffle_files
|
||||
|
||||
class ManifestDataset(SourceDataset):
|
||||
def is_sharded(self):
|
||||
if self.num_shards is not None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class ManifestDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset that reads images from a manifest file.
|
||||
|
||||
|
@ -3090,8 +3430,20 @@ class ManifestDataset(SourceDataset):
|
|||
|
||||
return ManifestOp.get_class_indexing(self.dataset_file, num_samples, class_indexing, self.usage)
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
return True
|
||||
|
||||
class Cifar10Dataset(SourceDataset):
|
||||
return self.shuffle_level or self.sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.num_shards is not None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return self.sampler.is_sharded()
|
||||
|
||||
|
||||
class Cifar10Dataset(MappableDataset):
|
||||
"""
|
||||
A source dataset that reads cifar10 data.
|
||||
|
||||
|
@ -3199,8 +3551,20 @@ class Cifar10Dataset(SourceDataset):
|
|||
|
||||
return get_num_rows(num_rows, self.num_shards)
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
return True
|
||||
|
||||
class Cifar100Dataset(SourceDataset):
|
||||
return self.shuffle_level or self.sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.num_shards is not None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return self.sampler.is_sharded()
|
||||
|
||||
|
||||
class Cifar100Dataset(MappableDataset):
|
||||
"""
|
||||
A source dataset that reads cifar100 data.
|
||||
|
||||
|
@ -3306,6 +3670,18 @@ class Cifar100Dataset(SourceDataset):
|
|||
|
||||
return get_num_rows(num_rows, self.num_shards)
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
return True
|
||||
|
||||
return self.shuffle_level or self.sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.num_shards is not None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return self.sampler.is_sharded()
|
||||
|
||||
|
||||
class RandomDataset(SourceDataset):
|
||||
"""
|
||||
|
@ -3357,6 +3733,11 @@ class RandomDataset(SourceDataset):
|
|||
"""
|
||||
return num_samples
|
||||
|
||||
def is_shuffled(self):
|
||||
return True
|
||||
|
||||
def is_sharded(self):
|
||||
return False
|
||||
|
||||
class Schema:
|
||||
"""
|
||||
|
@ -3536,7 +3917,7 @@ class Schema:
|
|||
return self.to_json()
|
||||
|
||||
|
||||
class VOCDataset(SourceDataset):
|
||||
class VOCDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing VOC dataset.
|
||||
|
||||
|
@ -3683,8 +4064,20 @@ class VOCDataset(SourceDataset):
|
|||
|
||||
return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.mode, class_indexing, num_samples)
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
return True
|
||||
|
||||
class CelebADataset(SourceDataset):
|
||||
return self.shuffle_level or self.sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.num_shards is not None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return self.sampler.is_sharded()
|
||||
|
||||
|
||||
class CelebADataset(MappableDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing CelebA dataset.Only support list_attr_celeba.txt currently.
|
||||
|
||||
|
@ -3737,6 +4130,18 @@ class CelebADataset(SourceDataset):
|
|||
args["shard_id"] = self.shard_id
|
||||
return args
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
return True
|
||||
|
||||
return self.shuffle_level or self.sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.num_shards is not None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return self.sampler.is_sharded()
|
||||
|
||||
|
||||
class TextFileDataset(SourceDataset):
|
||||
"""
|
||||
|
@ -3816,3 +4221,12 @@ class TextFileDataset(SourceDataset):
|
|||
return num_rows
|
||||
return min(self.num_samples, num_rows)
|
||||
return self._dataset_size
|
||||
|
||||
def is_shuffled(self):
|
||||
return self.shuffle_files
|
||||
|
||||
def is_sharded(self):
|
||||
if self.num_shards is not None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return False
|
||||
|
|
|
@ -47,6 +47,7 @@ class Sampler:
|
|||
def __init__(self):
|
||||
self.dataset_size = 0
|
||||
self.num_samples = 0
|
||||
self.child_sampler = None
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
|
@ -83,7 +84,35 @@ class Sampler:
|
|||
# Instance fetcher
|
||||
# Do not override this method!
|
||||
def create(self):
|
||||
return cde.PythonSampler(self)
|
||||
c_sampler = cde.PythonSampler(self)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def add_child(self, sampler):
|
||||
self.child_sampler = sampler
|
||||
|
||||
def get_child(self):
|
||||
return self.child_sampler
|
||||
|
||||
def create_child(self):
|
||||
c_child_sampler = None
|
||||
if self.child_sampler is not None:
|
||||
c_child_sampler = self.child_sampler.create()
|
||||
|
||||
return c_child_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
|
||||
class BuiltinSampler:
|
||||
|
@ -93,11 +122,30 @@ class BuiltinSampler:
|
|||
User should not extend this class.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
self.child_sampler = None
|
||||
|
||||
def create(self):
|
||||
pass
|
||||
|
||||
def add_child(self, sampler):
|
||||
self.child_sampler = sampler
|
||||
|
||||
def get_child(self):
|
||||
return self.child_sampler
|
||||
|
||||
def create_child(self):
|
||||
c_child_sampler = None
|
||||
if self.child_sampler is not None:
|
||||
c_child_sampler = self.child_sampler.create()
|
||||
|
||||
return c_child_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
raise NotImplementedError("Sampler must implement is_shuffled.")
|
||||
|
||||
def is_sharded(self):
|
||||
raise NotImplementedError("Sampler must implement is_sharded.")
|
||||
|
||||
|
||||
class DistributedSampler(BuiltinSampler):
|
||||
"""
|
||||
|
@ -142,7 +190,22 @@ class DistributedSampler(BuiltinSampler):
|
|||
def create(self):
|
||||
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
|
||||
self.seed += 1
|
||||
return cde.DistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed)
|
||||
c_sampler = cde.DistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.child_sampler is None:
|
||||
return self.shuffle
|
||||
|
||||
return self.child_sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.child_sampler is None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
|
||||
class PKSampler(BuiltinSampler):
|
||||
|
@ -186,7 +249,22 @@ class PKSampler(BuiltinSampler):
|
|||
super().__init__()
|
||||
|
||||
def create(self):
|
||||
return cde.PKSampler(self.num_val, self.shuffle)
|
||||
c_sampler = cde.PKSampler(self.num_val, self.shuffle)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.child_sampler is None:
|
||||
return self.shuffle
|
||||
|
||||
return self.child_sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
def _create_for_minddataset(self):
|
||||
if not self.class_column or not isinstance(self.class_column, str):
|
||||
|
@ -226,15 +304,31 @@ class RandomSampler(BuiltinSampler):
|
|||
raise ValueError("num_samples should be a positive integer "
|
||||
"value, but got num_samples={}".format(num_samples))
|
||||
|
||||
self.deterministic = False
|
||||
self.replacement = replacement
|
||||
self.num_samples = num_samples
|
||||
self.reshuffle_each_epoch = True
|
||||
super().__init__()
|
||||
|
||||
def create(self):
|
||||
# If num_samples is not specified, then call constructor #2
|
||||
c_sampler = None
|
||||
if self.num_samples is None:
|
||||
return cde.RandomSampler(self.replacement)
|
||||
return cde.RandomSampler(self.replacement, self.num_samples)
|
||||
c_sampler = cde.RandomSampler(self.replacement, self.reshuffle_each_epoch)
|
||||
else:
|
||||
c_sampler = cde.RandomSampler(self.replacement, self.reshuffle_each_epoch, self.num_samples)
|
||||
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
return True
|
||||
|
||||
def is_sharded(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
|
||||
class SequentialSampler(BuiltinSampler):
|
||||
|
@ -252,7 +346,80 @@ class SequentialSampler(BuiltinSampler):
|
|||
"""
|
||||
|
||||
def create(self):
|
||||
return cde.SequentialSampler()
|
||||
c_sampler = cde.SequentialSampler()
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
|
||||
class SubsetSampler(BuiltinSampler):
|
||||
"""
|
||||
Samples a subset of elements consecutively from a given index.
|
||||
|
||||
Args:
|
||||
start_index (int): Index to start sampling at.
|
||||
subset_size (int): How many samples to include in this subset.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>>
|
||||
>>> dataset_dir = "path/to/imagefolder_directory"
|
||||
>>>
|
||||
>>> # creates a SubsetSampler, will sample the next 5 images from the 100th image.
|
||||
>>> sampler = ds.SubsetSampler(100, 5)
|
||||
>>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler)
|
||||
|
||||
Raises:
|
||||
ValueError: If start_index is not a positive int.
|
||||
ValueError: If subset_size is not a positive int.
|
||||
"""
|
||||
|
||||
def __init__(self, start_index, subset_size):
|
||||
if not isinstance(start_index, int):
|
||||
raise ValueError("start_index should be an int.")
|
||||
|
||||
if start_index < 0:
|
||||
raise ValueError("start_index should not be negative.")
|
||||
|
||||
if not isinstance(subset_size, int):
|
||||
raise ValueError("start_index should be an int")
|
||||
|
||||
if subset_size < 0:
|
||||
raise ValueError("subset_size should not be negative.")
|
||||
|
||||
self.start_index = start_index
|
||||
self.subset_size = subset_size
|
||||
super().__init__()
|
||||
|
||||
def create(self):
|
||||
c_sampler = cde.SubsetSampler(self.start_index, self.subset_size)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
|
||||
class SubsetRandomSampler(BuiltinSampler):
|
||||
|
@ -282,7 +449,19 @@ class SubsetRandomSampler(BuiltinSampler):
|
|||
super().__init__()
|
||||
|
||||
def create(self):
|
||||
return cde.SubsetRandomSampler(self.indices)
|
||||
c_sampler = cde.SubsetRandomSampler(self.indices)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
return True
|
||||
|
||||
def is_sharded(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
def _create_for_minddataset(self):
|
||||
return cde.MindrecordSubsetRandomSampler(self.indices)
|
||||
|
@ -330,4 +509,16 @@ class WeightedRandomSampler(BuiltinSampler):
|
|||
super().__init__()
|
||||
|
||||
def create(self):
|
||||
return cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement)
|
||||
c_sampler = cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
return True
|
||||
|
||||
def is_sharded(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
|
|
@ -1031,3 +1031,44 @@ def check_textfiledataset(method):
|
|||
return method(*args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
def check_split(method):
|
||||
"""check the input arguments of split."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
||||
nreq_param_list = ['sizes']
|
||||
nreq_param_bool = ['randomize']
|
||||
check_param_type(nreq_param_list, param_dict, list)
|
||||
check_param_type(nreq_param_bool, param_dict, bool)
|
||||
|
||||
# check sizes: must be list of float or list of int
|
||||
sizes = param_dict.get('sizes')
|
||||
|
||||
if not sizes:
|
||||
raise ValueError("sizes cannot be empty.")
|
||||
all_int = all(isinstance(item, int) for item in sizes)
|
||||
all_float = all(isinstance(item, float) for item in sizes)
|
||||
|
||||
if not (all_int or all_float):
|
||||
raise ValueError("sizes should be list of int or list of float.")
|
||||
|
||||
if all_int:
|
||||
all_positive = all(item > 0 for item in sizes)
|
||||
if not all_positive:
|
||||
raise ValueError("sizes is a list of int, but there should be no negative numbers.")
|
||||
|
||||
if all_float:
|
||||
all_valid_percentages = all(0 < item <= 1 for item in sizes)
|
||||
if not all_valid_percentages:
|
||||
raise ValueError("sizes is a list of float, but there should be no numbers outside the range [0, 1].")
|
||||
|
||||
epsilon = 0.00001
|
||||
if not abs(sum(sizes) - 1) < epsilon:
|
||||
raise ValueError("sizes is a list of float, but the percentages do not sum up to 1.")
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -92,7 +92,7 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) {
|
|||
TEST_F(MindDataTestCifarOp, TestRandomSamplerCifar10) {
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
GlobalContext::config_manager()->set_seed(0);
|
||||
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, 12);
|
||||
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, true, 12);
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
auto tree = Build({Cifarop(16, 2, 32, folder_path, std::move(sampler), 100)});
|
||||
tree->Prepare();
|
||||
|
|
|
@ -138,7 +138,7 @@ TEST_F(MindDataTestImageFolderSampler, TestRandomImageFolder) {
|
|||
TEST_F(MindDataTestImageFolderSampler, TestRandomSamplerImageFolder) {
|
||||
int32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
GlobalContext::config_manager()->set_seed(0);
|
||||
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, 12);
|
||||
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, true, 12);
|
||||
int32_t res[] = {2, 2, 2, 3, 2, 3, 2, 3, 1, 2, 2, 1}; // ground truth label
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data";
|
||||
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))});
|
||||
|
|
|
@ -164,9 +164,36 @@ def test_python_sampler():
|
|||
assert list(sp1.get_indices()) == [0, 1, 2, 3, 4]
|
||||
|
||||
|
||||
def test_sampler_chain():
|
||||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
|
||||
def test_config(num_shards, shard_id):
|
||||
sampler = ds.DistributedSampler(num_shards, shard_id, False)
|
||||
child_sampler = ds.SequentialSampler()
|
||||
sampler.add_child(child_sampler)
|
||||
|
||||
data1 = ds.ManifestDataset(manifest_file, num_samples=5, sampler=sampler)
|
||||
|
||||
res = []
|
||||
for item in data1.create_dict_iterator():
|
||||
logger.info("item[image].shape[0]: {}, item[label].item(): {}"
|
||||
.format(item["image"].shape[0], item["label"].item()))
|
||||
res.append(map[(item["image"].shape[0], item["label"].item())])
|
||||
return res
|
||||
|
||||
assert test_config(2, 0) == [0, 2, 4]
|
||||
assert test_config(2, 1) == [1, 3, 0]
|
||||
assert test_config(5, 0) == [0]
|
||||
assert test_config(5, 1) == [1]
|
||||
assert test_config(5, 2) == [2]
|
||||
assert test_config(5, 3) == [3]
|
||||
assert test_config(5, 4) == [4]
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sequential_sampler(True)
|
||||
test_random_sampler(True)
|
||||
test_random_sampler_multi_iter(True)
|
||||
test_sampler_py_api()
|
||||
test_python_sampler()
|
||||
test_sampler_chain()
|
||||
|
|
|
@ -0,0 +1,342 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
|
||||
# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
|
||||
# the label of each image is [0,0,0,1,1] each image can be uniquely identified
|
||||
# via the following lookup table (dict){(83554, 0): 0, (54214, 0): 1, (54214, 1): 2, (65512, 0): 3, (64631, 1): 4}
|
||||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
manifest_map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
|
||||
def split_with_invalid_inputs(d):
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([])
|
||||
assert "sizes cannot be empty" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([5, 0.6])
|
||||
assert "sizes should be list of int or list of float" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([-1, 6])
|
||||
assert "there should be no negative numbers" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
s1, s2 = d.split([3, 1])
|
||||
assert "sum of split sizes 4 is not equal to dataset size 5" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
s1, s2 = d.split([5, 1])
|
||||
assert "sum of split sizes 6 is not equal to dataset size 5" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
s1, s2 = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25])
|
||||
assert "sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([-0.5, 0.5])
|
||||
assert "there should be no numbers outside the range [0, 1]" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([1.5, 0.5])
|
||||
assert "there should be no numbers outside the range [0, 1]" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([0.5, 0.6])
|
||||
assert "percentages do not sum up to 1" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
s1, s2 = d.split([0.3, 0.6])
|
||||
assert "percentages do not sum up to 1" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
s1, s2 = d.split([0.05, 0.95])
|
||||
assert "percentage 0.05 is too small" in str(info.value)
|
||||
|
||||
def test_unmappable_invalid_input():
|
||||
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
|
||||
d = ds.TextFileDataset(text_file_dataset_path)
|
||||
split_with_invalid_inputs(d)
|
||||
|
||||
d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0)
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
s1, s2 = d.split([4, 1])
|
||||
assert "dataset should not be sharded before split" in str(info.value)
|
||||
|
||||
def test_unmappable_split():
|
||||
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
|
||||
text_file_data = ["This is a text file.", "Another file.", "Be happy every day.",
|
||||
"End of file.", "Good luck to everyone."]
|
||||
ds.config.set_num_parallel_workers(4)
|
||||
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
|
||||
s1, s2 = d.split([4, 1], randomize=False)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
assert s1_output == text_file_data[0:4]
|
||||
assert s2_output == text_file_data[4:]
|
||||
|
||||
# exact percentages
|
||||
s1, s2 = d.split([0.8, 0.2], randomize=False)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
assert s1_output == text_file_data[0:4]
|
||||
assert s2_output == text_file_data[4:]
|
||||
|
||||
# fuzzy percentages
|
||||
s1, s2 = d.split([0.33, 0.67], randomize=False)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(item["text"].item().decode("utf8"))
|
||||
|
||||
assert s1_output == text_file_data[0:2]
|
||||
assert s2_output == text_file_data[2:]
|
||||
|
||||
def test_mappable_invalid_input():
|
||||
d = ds.ManifestDataset(manifest_file)
|
||||
split_with_invalid_inputs(d)
|
||||
|
||||
d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0)
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
s1, s2 = d.split([4, 1])
|
||||
assert "dataset should not be sharded before split" in str(info.value)
|
||||
|
||||
def test_mappable_split_general():
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
d = d.take(5)
|
||||
|
||||
# absolute rows
|
||||
s1, s2 = d.split([4, 1], randomize=False)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
assert s1_output == [0, 1, 2, 3]
|
||||
assert s2_output == [4]
|
||||
|
||||
# exact percentages
|
||||
s1, s2 = d.split([0.8, 0.2], randomize=False)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
assert s1_output == [0, 1, 2, 3]
|
||||
assert s2_output == [4]
|
||||
|
||||
# fuzzy percentages
|
||||
s1, s2 = d.split([0.33, 0.67], randomize=False)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
assert s1_output == [0, 1]
|
||||
assert s2_output == [2, 3, 4]
|
||||
|
||||
def test_mappable_split_optimized():
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
|
||||
# absolute rows
|
||||
s1, s2 = d.split([4, 1], randomize=False)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
assert s1_output == [0, 1, 2, 3]
|
||||
assert s2_output == [4]
|
||||
|
||||
# exact percentages
|
||||
s1, s2 = d.split([0.8, 0.2], randomize=False)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
assert s1_output == [0, 1, 2, 3]
|
||||
assert s2_output == [4]
|
||||
|
||||
# fuzzy percentages
|
||||
s1, s2 = d.split([0.33, 0.67], randomize=False)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
assert s1_output == [0, 1]
|
||||
assert s2_output == [2, 3, 4]
|
||||
|
||||
def test_mappable_randomize_deterministic():
|
||||
# set arbitrary seed for shard after split
|
||||
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
|
||||
ds.config.set_seed(53)
|
||||
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
s1, s2 = d.split([0.8, 0.2])
|
||||
|
||||
for _ in range(10):
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
# note no overlap
|
||||
assert s1_output == [0, 1, 3, 4]
|
||||
assert s2_output == [2]
|
||||
|
||||
def test_mappable_randomize_repeatable():
|
||||
# set arbitrary seed for shard after split
|
||||
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
|
||||
ds.config.set_seed(53)
|
||||
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
s1, s2 = d.split([0.8, 0.2])
|
||||
|
||||
num_epochs = 5
|
||||
s1 = s1.repeat(num_epochs)
|
||||
s2 = s2.repeat(num_epochs)
|
||||
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
# note no overlap
|
||||
assert s1_output == [0, 1, 3, 4] * num_epochs
|
||||
assert s2_output == [2] * num_epochs
|
||||
|
||||
def test_mappable_sharding():
|
||||
# set arbitrary seed for repeatability for shard after split
|
||||
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
|
||||
ds.config.set_seed(53)
|
||||
|
||||
num_epochs = 5
|
||||
first_split_num_rows = 4
|
||||
|
||||
d = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
s1, s2 = d.split([first_split_num_rows, 1])
|
||||
|
||||
distributed_sampler = ds.DistributedSampler(2, 0)
|
||||
s1.use_sampler(distributed_sampler)
|
||||
|
||||
s1 = s1.repeat(num_epochs)
|
||||
|
||||
# testing sharding, second dataset to simulate another instance
|
||||
d2 = ds.ManifestDataset(manifest_file, shuffle=False)
|
||||
d2s1, d2s2 = d2.split([first_split_num_rows, 1])
|
||||
|
||||
distributed_sampler = ds.DistributedSampler(2, 1)
|
||||
d2s1.use_sampler(distributed_sampler)
|
||||
|
||||
d2s1 = d2s1.repeat(num_epochs)
|
||||
|
||||
# shard 0
|
||||
s1_output = []
|
||||
for item in s1.create_dict_iterator():
|
||||
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
# shard 1
|
||||
d2s1_output = []
|
||||
for item in d2s1.create_dict_iterator():
|
||||
d2s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
rows_per_shard_per_epoch = 2
|
||||
assert len(s1_output) == rows_per_shard_per_epoch * num_epochs
|
||||
assert len(d2s1_output) == rows_per_shard_per_epoch * num_epochs
|
||||
|
||||
# verify each epoch that
|
||||
# 1. shards contain no common elements
|
||||
# 2. the data was split the same way, and that the union of shards equal the split
|
||||
correct_sorted_split_result = [0, 1, 3, 4]
|
||||
for i in range(num_epochs):
|
||||
combined_data = []
|
||||
for j in range(rows_per_shard_per_epoch):
|
||||
combined_data.append(s1_output[i * rows_per_shard_per_epoch + j])
|
||||
combined_data.append(d2s1_output[i * rows_per_shard_per_epoch + j])
|
||||
|
||||
assert sorted(combined_data) == correct_sorted_split_result
|
||||
|
||||
# test other split
|
||||
s2_output = []
|
||||
for item in s2.create_dict_iterator():
|
||||
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
d2s2_output = []
|
||||
for item in d2s2.create_dict_iterator():
|
||||
d2s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
|
||||
|
||||
assert s2_output == [2]
|
||||
assert d2s2_output == [2]
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_unmappable_invalid_input()
|
||||
test_unmappable_split()
|
||||
test_mappable_invalid_input()
|
||||
test_mappable_split_general()
|
||||
test_mappable_split_optimized()
|
||||
test_mappable_randomize_deterministic()
|
||||
test_mappable_randomize_repeatable()
|
||||
test_mappable_sharding()
|
Loading…
Reference in New Issue