forked from mindspore-Ecosystem/mindspore
!1983 Remove inheritance of Sampler from DatasetOp
Merge pull request !1983 from JesseKLee/sampler
This commit is contained in:
commit
beefb20c01
|
@ -263,7 +263,7 @@ std::vector<std::string> CelebAOp::Split(const std::string &line) {
|
|||
Status CelebAOp::operator()() {
|
||||
RETURN_IF_NOT_OK(LaunchThreadsAndInitOp());
|
||||
std::unique_ptr<DataBuffer> data_buffer;
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&data_buffer));
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&data_buffer));
|
||||
RETURN_IF_NOT_OK(AddIOBlock(&data_buffer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -291,7 +291,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) {
|
|||
keys.clear();
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(data_buffer));
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer));
|
||||
}
|
||||
|
||||
if (!keys.empty()) {
|
||||
|
@ -313,7 +313,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) {
|
|||
io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
|
||||
RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks
|
||||
wp_.Clear();
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(data_buffer));
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -100,7 +100,7 @@ CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const
|
|||
Status CifarOp::operator()() {
|
||||
RETURN_IF_NOT_OK(LaunchThreadsAndInitOp());
|
||||
std::unique_ptr<DataBuffer> sampler_buffer;
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
while (true) { // each iterator is 1 epoch
|
||||
std::vector<int64_t> keys;
|
||||
keys.reserve(rows_per_buffer_);
|
||||
|
@ -118,7 +118,7 @@ Status CifarOp::operator()() {
|
|||
keys.clear();
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
}
|
||||
if (keys.empty() == false) {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(
|
||||
|
@ -139,7 +139,7 @@ Status CifarOp::operator()() {
|
|||
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
|
||||
RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks
|
||||
wp_.Clear();
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -126,7 +126,7 @@ Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) {
|
|||
Status ImageFolderOp::operator()() {
|
||||
RETURN_IF_NOT_OK(LaunchThreadsAndInitOp());
|
||||
std::unique_ptr<DataBuffer> sampler_buffer;
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
while (true) { // each iterator is 1 epoch
|
||||
std::vector<int64_t> keys;
|
||||
keys.reserve(rows_per_buffer_);
|
||||
|
@ -145,7 +145,7 @@ Status ImageFolderOp::operator()() {
|
|||
keys.clear();
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
}
|
||||
if (keys.empty() == false) {
|
||||
RETURN_IF_NOT_OK(
|
||||
|
@ -166,7 +166,7 @@ Status ImageFolderOp::operator()() {
|
|||
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
|
||||
RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks
|
||||
wp_.Clear();
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -88,7 +88,7 @@ ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string f
|
|||
Status ManifestOp::operator()() {
|
||||
RETURN_IF_NOT_OK(LaunchThreadsAndInitOp());
|
||||
std::unique_ptr<DataBuffer> sampler_buffer;
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
return AddIoBlock(&sampler_buffer);
|
||||
}
|
||||
|
||||
|
@ -110,7 +110,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) {
|
|||
keys.clear();
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(sampler_buffer));
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer));
|
||||
}
|
||||
if (keys.empty() == false) {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(
|
||||
|
@ -131,7 +131,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) {
|
|||
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
|
||||
RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks
|
||||
wp_.Clear();
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(sampler_buffer));
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -98,7 +98,7 @@ Status MnistOp::TraversalSampleIds(const std::shared_ptr<Tensor> &sample_ids, st
|
|||
Status MnistOp::operator()() {
|
||||
RETURN_IF_NOT_OK(LaunchThreadsAndInitOp());
|
||||
std::unique_ptr<DataBuffer> sampler_buffer;
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
while (true) { // each iterator is 1 epoch
|
||||
std::vector<int64_t> keys;
|
||||
keys.reserve(rows_per_buffer_);
|
||||
|
@ -109,7 +109,7 @@ Status MnistOp::operator()() {
|
|||
RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't UINT64");
|
||||
}
|
||||
RETURN_IF_NOT_OK(TraversalSampleIds(sample_ids, &keys));
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
}
|
||||
if (keys.empty() == false) {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(
|
||||
|
@ -130,7 +130,7 @@ Status MnistOp::operator()() {
|
|||
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
|
||||
RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks
|
||||
wp_.Clear();
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -55,14 +55,14 @@ Status DistributedSampler::InitSampler() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DistributedSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
if (cnt_ > samples_per_buffer_) {
|
||||
RETURN_STATUS_UNEXPECTED("Distributed Sampler Error");
|
||||
} else if (cnt_ == samples_per_buffer_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
|
||||
}
|
||||
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(cnt_, DataBuffer::kDeBFlagNone);
|
||||
|
|
|
@ -40,7 +40,7 @@ class DistributedSampler : public Sampler {
|
|||
// @param std::unique_ptr<DataBuffer> * pBuffer
|
||||
// @param int32_t workerId
|
||||
// @return - The error code return
|
||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
// Init sampler, called by base class or python
|
||||
Status InitSampler() override;
|
||||
|
|
|
@ -59,14 +59,14 @@ Status PKSampler::InitSampler() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
Status PKSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
if (next_id_ > num_samples_ || num_samples_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED("Index out of bound in PKSampler");
|
||||
} else if (next_id_ == num_samples_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
|
||||
}
|
||||
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
|
||||
|
|
|
@ -41,7 +41,7 @@ class PKSampler : public Sampler { // NOT YET FINISHED
|
|||
// @param std::unique_ptr<DataBuffer pBuffer
|
||||
// @param int32_t workerId
|
||||
// @return - The error code return
|
||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
// first handshake between leaf source op and Sampler. This func will determine the amount of data
|
||||
// in the dataset that we can sample from.
|
||||
|
|
|
@ -23,12 +23,12 @@ namespace dataset {
|
|||
PythonSampler::PythonSampler(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer)
|
||||
: Sampler(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {}
|
||||
|
||||
Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
Status PythonSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
if (need_to_reset_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
|
||||
}
|
||||
|
||||
std::shared_ptr<Tensor> sample_ids;
|
||||
|
|
|
@ -48,7 +48,7 @@ class PythonSampler : public Sampler {
|
|||
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
|
||||
// @param int32_t workerId - not meant to be used
|
||||
// @return - The error code return
|
||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
private:
|
||||
bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer()
|
||||
|
|
|
@ -31,14 +31,14 @@ RandomSampler::RandomSampler(int64_t num_samples, bool replacement, bool reshuff
|
|||
reshuffle_each_epoch_(reshuffle_each_epoch),
|
||||
dist(nullptr) {}
|
||||
|
||||
Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
Status RandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
if (next_id_ > num_samples_) {
|
||||
RETURN_STATUS_UNEXPECTED("RandomSampler Internal Error");
|
||||
} else if (next_id_ == num_samples_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
|
||||
}
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ class RandomSampler : public Sampler {
|
|||
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
|
||||
// @param int32_t workerId - not meant to be used
|
||||
// @return - The error code return
|
||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
// meant to be called by base class or python
|
||||
Status InitSampler() override;
|
||||
|
|
|
@ -33,11 +33,7 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const {
|
|||
}
|
||||
|
||||
Sampler::Sampler(int64_t num_samples, int64_t samples_per_buffer)
|
||||
: DatasetOp(0),
|
||||
num_rows_(0),
|
||||
num_samples_(num_samples),
|
||||
samples_per_buffer_(samples_per_buffer),
|
||||
col_desc_(nullptr) {}
|
||||
: num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {}
|
||||
|
||||
Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
|
||||
std::shared_ptr<Sampler> child_sampler;
|
||||
|
@ -97,7 +93,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) {
|
|||
std::shared_ptr<Tensor> sample_ids;
|
||||
|
||||
// A call to derived class to get sample ids wrapped inside a buffer
|
||||
RETURN_IF_NOT_OK(GetNextBuffer(&db));
|
||||
RETURN_IF_NOT_OK(GetNextSample(&db));
|
||||
// Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch
|
||||
RETURN_IF_NOT_OK(db->GetTensor(&sample_ids, 0, 0));
|
||||
// check this buffer is not a ctrl buffer
|
||||
|
@ -114,7 +110,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) {
|
|||
}
|
||||
}
|
||||
// perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch
|
||||
RETURN_IF_NOT_OK(GetNextBuffer(&db));
|
||||
RETURN_IF_NOT_OK(GetNextSample(&db));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received");
|
||||
// Reset Sampler since this is the end of the epoch
|
||||
RETURN_IF_NOT_OK(Reset());
|
||||
|
@ -133,17 +129,7 @@ Status Sampler::SetNumRowsInDataset(int64_t num_rows) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// inline op doesn't have it's own consumer, it's assigned from parent
|
||||
int32_t Sampler::num_consumers() const {
|
||||
if (parent_.empty() || parent_[0] == nullptr) {
|
||||
MS_LOG(WARNING) << "Sampler with no parent. num_consumers is 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return parent_[0]->num_consumers();
|
||||
}
|
||||
}
|
||||
|
||||
Status Sampler::AddChild(std::shared_ptr<DatasetOp> child) {
|
||||
Status Sampler::AddChild(std::shared_ptr<Sampler> child) {
|
||||
if (child == nullptr) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -182,14 +168,5 @@ Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// inline op doesn't have it's own producers, it's assigned from child
|
||||
int32_t Sampler::num_producers() const {
|
||||
if (child_.empty() || child_[0] == nullptr) {
|
||||
MS_LOG(WARNING) << "Sampler with no child, num_producers is 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return child_[0]->num_producers();
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -54,7 +54,7 @@ class RandomAccessOp {
|
|||
int64_t num_rows_;
|
||||
};
|
||||
|
||||
class Sampler : public DatasetOp {
|
||||
class Sampler {
|
||||
public:
|
||||
// Constructor
|
||||
// @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0
|
||||
|
@ -70,14 +70,14 @@ class Sampler : public DatasetOp {
|
|||
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
|
||||
// @param int32_t workerId - not meant to be used
|
||||
// @return - The error code return
|
||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override = 0;
|
||||
virtual Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) = 0;
|
||||
|
||||
// return all ids in one epoch as a numpy array, then call reset
|
||||
Status GetAllIdsThenReset(py::array *data);
|
||||
|
||||
// for next epoch of sampleIds
|
||||
// @return - The error code return
|
||||
Status Reset() override = 0;
|
||||
virtual Status Reset() = 0;
|
||||
|
||||
// first handshake between leaf source op and Sampler. This func will determine the amount of data
|
||||
// in the dataset that we can sample from.
|
||||
|
@ -98,26 +98,10 @@ class Sampler : public DatasetOp {
|
|||
// @return status error code
|
||||
Status SetNumRowsInDataset(int64_t num_rows);
|
||||
|
||||
// Sampler is an inlined op and has no workers. Producers and consumers are computed.
|
||||
// @return
|
||||
int32_t num_workers() const final { return 0; }
|
||||
|
||||
// Identify num consumers (inlined op)
|
||||
// @return
|
||||
int32_t num_consumers() const final;
|
||||
|
||||
// Identify num producers (inlined op)
|
||||
// @return
|
||||
int32_t num_producers() const final;
|
||||
|
||||
// Not meant to be called!
|
||||
// @return - The error code return
|
||||
Status operator()() final { RETURN_STATUS_UNEXPECTED("Functor not supported in Sampler"); }
|
||||
|
||||
// Adds a sampler to become our child.
|
||||
// @param std::shared_ptr<DatasetOp> - The sampler to add as a child.
|
||||
// @return - The error code returned.
|
||||
Status AddChild(std::shared_ptr<DatasetOp> child);
|
||||
Status AddChild(std::shared_ptr<Sampler> child);
|
||||
|
||||
// A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler
|
||||
// @param std::shared_ptr<Tensor>* sampleIds
|
||||
|
@ -125,7 +109,7 @@ class Sampler : public DatasetOp {
|
|||
// @return - The error code returned.
|
||||
Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements);
|
||||
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
virtual void Print(std::ostream &out, bool show_all) const;
|
||||
|
||||
friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) {
|
||||
sampler.Print(out, false);
|
||||
|
@ -156,6 +140,7 @@ class Sampler : public DatasetOp {
|
|||
|
||||
int64_t samples_per_buffer_;
|
||||
std::unique_ptr<ColDescriptor> col_desc_;
|
||||
std::vector<std::shared_ptr<Sampler>> child_; // Child nodes
|
||||
std::unique_ptr<DataBuffer> child_ids_;
|
||||
};
|
||||
} // namespace dataset
|
||||
|
|
|
@ -23,14 +23,14 @@ namespace dataset {
|
|||
SequentialSampler::SequentialSampler(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer)
|
||||
: Sampler(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {}
|
||||
|
||||
Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
Status SequentialSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
if (id_count_ > num_samples_) {
|
||||
RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error");
|
||||
} else if (id_count_ == num_samples_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
|
||||
}
|
||||
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(current_id_, DataBuffer::kDeBFlagNone);
|
||||
|
|
|
@ -47,7 +47,7 @@ class SequentialSampler : public Sampler {
|
|||
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
|
||||
// @param int32_t workerId - not meant to be used
|
||||
// @return - The error code return
|
||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
|
|
|
@ -72,13 +72,13 @@ Status SubsetRandomSampler::Reset() {
|
|||
}
|
||||
|
||||
// Get the sample ids.
|
||||
Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
Status SubsetRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
// All samples have been drawn
|
||||
if (sample_id_ == num_samples_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
|
||||
}
|
||||
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone);
|
||||
|
|
|
@ -49,7 +49,7 @@ class SubsetRandomSampler : public Sampler {
|
|||
// Get the sample ids.
|
||||
// @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.
|
||||
// @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer.
|
||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
private:
|
||||
// A list of indices (already randomized in constructor).
|
||||
|
|
|
@ -95,7 +95,7 @@ Status WeightedRandomSampler::Reset() {
|
|||
}
|
||||
|
||||
// Get the sample ids.
|
||||
Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
Status WeightedRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
if (weights_.size() > static_cast<size_t>(num_rows_)) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
"number of samples weights is more than num of rows. Might generate id out of bound OR other errors");
|
||||
|
@ -109,7 +109,7 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
|
|||
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
|
||||
}
|
||||
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone);
|
||||
|
|
|
@ -51,7 +51,7 @@ class WeightedRandomSampler : public Sampler {
|
|||
// Get the sample ids.
|
||||
// @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.
|
||||
// @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer.
|
||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
private:
|
||||
// A list of weights for each sample.
|
||||
|
|
|
@ -123,7 +123,7 @@ Status VOCOp::TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::
|
|||
Status VOCOp::operator()() {
|
||||
RETURN_IF_NOT_OK(LaunchThreadsAndInitOp());
|
||||
std::unique_ptr<DataBuffer> sampler_buffer;
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
while (true) {
|
||||
std::vector<int64_t> keys;
|
||||
keys.reserve(rows_per_buffer_);
|
||||
|
@ -134,7 +134,7 @@ Status VOCOp::operator()() {
|
|||
RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64");
|
||||
}
|
||||
RETURN_IF_NOT_OK(TraverseSampleIds(sample_ids, &keys));
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
}
|
||||
if (keys.empty() == false) {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(
|
||||
|
@ -155,7 +155,7 @@ Status VOCOp::operator()() {
|
|||
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
|
||||
RETURN_IF_NOT_OK(wp_.Wait());
|
||||
wp_.Clear();
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -68,7 +68,7 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) {
|
|||
for (int i = 0; i < 6; i++) {
|
||||
std::shared_ptr<Sampler> sampler = std::make_shared<DistributedSampler>(num_samples, 3, i % 3, (i < 3 ? false : true));
|
||||
sampler->HandshakeRandomAccessOp(&mock);
|
||||
sampler->GetNextBuffer(&db);
|
||||
sampler->GetNextSample(&db);
|
||||
db->GetTensor(&tensor, 0, 0);
|
||||
MS_LOG(DEBUG) << (*tensor);
|
||||
if(i < 3) { // This is added due to std::shuffle()
|
||||
|
@ -90,17 +90,17 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) {
|
|||
std::unique_ptr<DataBuffer> db;
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
sampler->HandshakeRandomAccessOp(&mock);
|
||||
sampler->GetNextBuffer(&db);
|
||||
sampler->GetNextSample(&db);
|
||||
db->GetTensor(&tensor, 0, 0);
|
||||
EXPECT_TRUE((*tensor) == (*label1));
|
||||
sampler->GetNextBuffer(&db);
|
||||
sampler->GetNextSample(&db);
|
||||
db->GetTensor(&tensor, 0, 0);
|
||||
EXPECT_TRUE((*tensor) == (*label2));
|
||||
sampler->Reset();
|
||||
sampler->GetNextBuffer(&db);
|
||||
sampler->GetNextSample(&db);
|
||||
db->GetTensor(&tensor, 0, 0);
|
||||
EXPECT_TRUE((*tensor) == (*label1));
|
||||
sampler->GetNextBuffer(&db);
|
||||
sampler->GetNextSample(&db);
|
||||
db->GetTensor(&tensor, 0, 0);
|
||||
EXPECT_TRUE((*tensor) == (*label2));
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) {
|
|||
std::unique_ptr<DataBuffer> db;
|
||||
TensorRow row;
|
||||
std::vector<int64_t> out;
|
||||
ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
|
||||
db->PopRow(&row);
|
||||
for (const auto &t : row) {
|
||||
for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
|
||||
|
@ -61,7 +61,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) {
|
|||
ASSERT_NE(in_set.find(out[i]), in_set.end());
|
||||
}
|
||||
|
||||
ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
|
||||
ASSERT_EQ(db->eoe(), true);
|
||||
}
|
||||
|
||||
|
@ -79,7 +79,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) {
|
|||
TensorRow row;
|
||||
std::vector<int64_t> out;
|
||||
|
||||
ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
|
||||
int epoch = 0;
|
||||
while (!db->eoe()) {
|
||||
epoch++;
|
||||
|
@ -91,7 +91,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) {
|
|||
}
|
||||
db.reset();
|
||||
|
||||
ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
|
||||
}
|
||||
|
||||
ASSERT_EQ(epoch, (total_samples + samples_per_buffer - 1) / samples_per_buffer);
|
||||
|
@ -111,7 +111,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
|
|||
TensorRow row;
|
||||
std::vector<int64_t> out;
|
||||
|
||||
ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
|
||||
db->PopRow(&row);
|
||||
for (const auto &t : row) {
|
||||
for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
|
||||
|
@ -125,7 +125,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
|
|||
|
||||
sampler.Reset();
|
||||
|
||||
ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
|
||||
ASSERT_EQ(db->eoe(), false);
|
||||
db->PopRow(&row);
|
||||
out.clear();
|
||||
|
@ -139,6 +139,6 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
|
|||
ASSERT_NE(in_set.find(out[i]), in_set.end());
|
||||
}
|
||||
|
||||
ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
|
||||
ASSERT_EQ(db->eoe(), true);
|
||||
}
|
||||
|
|
|
@ -58,7 +58,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
|
|||
std::unique_ptr<DataBuffer> db;
|
||||
TensorRow row;
|
||||
std::vector<uint64_t> out;
|
||||
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
db->PopRow(&row);
|
||||
for (const auto &t : row) {
|
||||
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
||||
|
@ -69,7 +69,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
|
|||
|
||||
ASSERT_EQ(num_samples, out.size());
|
||||
|
||||
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
ASSERT_EQ(db->eoe(), true);
|
||||
}
|
||||
|
||||
|
@ -88,7 +88,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
|
|||
std::unique_ptr<DataBuffer> db;
|
||||
TensorRow row;
|
||||
std::vector<uint64_t> out;
|
||||
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
db->PopRow(&row);
|
||||
for (const auto &t : row) {
|
||||
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
||||
|
@ -105,7 +105,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
|
|||
}
|
||||
}
|
||||
|
||||
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
ASSERT_EQ(db->eoe(), true);
|
||||
}
|
||||
|
||||
|
@ -124,7 +124,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) {
|
|||
std::unique_ptr<DataBuffer> db;
|
||||
TensorRow row;
|
||||
std::vector<uint64_t> out;
|
||||
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
int epoch = 0;
|
||||
while (!db->eoe()) {
|
||||
epoch++;
|
||||
|
@ -135,7 +135,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) {
|
|||
}
|
||||
}
|
||||
db.reset();
|
||||
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
}
|
||||
|
||||
ASSERT_EQ(epoch, (num_samples + samples_per_buffer - 1) / samples_per_buffer);
|
||||
|
@ -160,7 +160,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) {
|
|||
std::unique_ptr<DataBuffer> db;
|
||||
TensorRow row;
|
||||
std::vector<uint64_t> out;
|
||||
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
int epoch = 0;
|
||||
while (!db->eoe()) {
|
||||
epoch++;
|
||||
|
@ -172,7 +172,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) {
|
|||
}
|
||||
}
|
||||
db.reset();
|
||||
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
}
|
||||
|
||||
// Without replacement, each sample only drawn once.
|
||||
|
@ -201,7 +201,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
|
|||
std::unique_ptr<DataBuffer> db;
|
||||
TensorRow row;
|
||||
std::vector<uint64_t> out;
|
||||
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
db->PopRow(&row);
|
||||
for (const auto &t : row) {
|
||||
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
||||
|
@ -211,13 +211,13 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
|
|||
}
|
||||
ASSERT_EQ(num_samples, out.size());
|
||||
|
||||
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
ASSERT_EQ(db->eoe(), true);
|
||||
|
||||
m_sampler.Reset();
|
||||
out.clear();
|
||||
|
||||
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
db->PopRow(&row);
|
||||
for (const auto &t : row) {
|
||||
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
||||
|
@ -227,7 +227,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
|
|||
}
|
||||
ASSERT_EQ(num_samples, out.size());
|
||||
|
||||
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
ASSERT_EQ(db->eoe(), true);
|
||||
}
|
||||
|
||||
|
@ -246,7 +246,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
|
|||
std::unique_ptr<DataBuffer> db;
|
||||
TensorRow row;
|
||||
std::vector<uint64_t> out;
|
||||
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
db->PopRow(&row);
|
||||
for (const auto &t : row) {
|
||||
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
||||
|
@ -256,7 +256,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
|
|||
}
|
||||
ASSERT_EQ(num_samples, out.size());
|
||||
|
||||
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
ASSERT_EQ(db->eoe(), true);
|
||||
|
||||
m_sampler.Reset();
|
||||
|
@ -265,7 +265,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
|
|||
freq.resize(total_samples, 0);
|
||||
MS_LOG(INFO) << "Resetting sampler";
|
||||
|
||||
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
db->PopRow(&row);
|
||||
for (const auto &t : row) {
|
||||
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
|
||||
|
@ -282,6 +282,6 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
|
|||
}
|
||||
}
|
||||
|
||||
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
|
||||
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
|
||||
ASSERT_EQ(db->eoe(), true);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue