!1983 Remove inheritance of Sampler from DatasetOp

Merge pull request !1983 from JesseKLee/sampler
This commit is contained in:
mindspore-ci-bot 2020-06-11 23:14:56 +08:00 committed by Gitee
commit beefb20c01
25 changed files with 77 additions and 115 deletions

View File

@ -263,7 +263,7 @@ std::vector<std::string> CelebAOp::Split(const std::string &line) {
Status CelebAOp::operator()() {
RETURN_IF_NOT_OK(LaunchThreadsAndInitOp());
std::unique_ptr<DataBuffer> data_buffer;
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&data_buffer));
RETURN_IF_NOT_OK(sampler_->GetNextSample(&data_buffer));
RETURN_IF_NOT_OK(AddIOBlock(&data_buffer));
return Status::OK();
}
@ -291,7 +291,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) {
keys.clear();
}
}
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(data_buffer));
RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer));
}
if (!keys.empty()) {
@ -313,7 +313,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) {
io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks
wp_.Clear();
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(data_buffer));
RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer));
}
}
}

View File

@ -100,7 +100,7 @@ CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const
Status CifarOp::operator()() {
RETURN_IF_NOT_OK(LaunchThreadsAndInitOp());
std::unique_ptr<DataBuffer> sampler_buffer;
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
while (true) { // each iterator is 1 epoch
std::vector<int64_t> keys;
keys.reserve(rows_per_buffer_);
@ -118,7 +118,7 @@ Status CifarOp::operator()() {
keys.clear();
}
}
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
}
if (keys.empty() == false) {
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(
@ -139,7 +139,7 @@ Status CifarOp::operator()() {
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks
wp_.Clear();
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
}
}
}

View File

@ -126,7 +126,7 @@ Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) {
Status ImageFolderOp::operator()() {
RETURN_IF_NOT_OK(LaunchThreadsAndInitOp());
std::unique_ptr<DataBuffer> sampler_buffer;
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
while (true) { // each iterator is 1 epoch
std::vector<int64_t> keys;
keys.reserve(rows_per_buffer_);
@ -145,7 +145,7 @@ Status ImageFolderOp::operator()() {
keys.clear();
}
}
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
}
if (keys.empty() == false) {
RETURN_IF_NOT_OK(
@ -166,7 +166,7 @@ Status ImageFolderOp::operator()() {
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks
wp_.Clear();
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
}
}
}

View File

@ -88,7 +88,7 @@ ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string f
Status ManifestOp::operator()() {
RETURN_IF_NOT_OK(LaunchThreadsAndInitOp());
std::unique_ptr<DataBuffer> sampler_buffer;
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
return AddIoBlock(&sampler_buffer);
}
@ -110,7 +110,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) {
keys.clear();
}
}
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(sampler_buffer));
RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer));
}
if (keys.empty() == false) {
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(
@ -131,7 +131,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) {
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks
wp_.Clear();
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(sampler_buffer));
RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer));
}
}
}

View File

@ -98,7 +98,7 @@ Status MnistOp::TraversalSampleIds(const std::shared_ptr<Tensor> &sample_ids, st
Status MnistOp::operator()() {
RETURN_IF_NOT_OK(LaunchThreadsAndInitOp());
std::unique_ptr<DataBuffer> sampler_buffer;
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
while (true) { // each iterator is 1 epoch
std::vector<int64_t> keys;
keys.reserve(rows_per_buffer_);
@ -109,7 +109,7 @@ Status MnistOp::operator()() {
RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't UINT64");
}
RETURN_IF_NOT_OK(TraversalSampleIds(sample_ids, &keys));
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
}
if (keys.empty() == false) {
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(
@ -130,7 +130,7 @@ Status MnistOp::operator()() {
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks
wp_.Clear();
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
}
}
}

View File

@ -55,14 +55,14 @@ Status DistributedSampler::InitSampler() {
return Status::OK();
}
Status DistributedSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
if (cnt_ > samples_per_buffer_) {
RETURN_STATUS_UNEXPECTED("Distributed Sampler Error");
} else if (cnt_ == samples_per_buffer_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(cnt_, DataBuffer::kDeBFlagNone);

View File

@ -40,7 +40,7 @@ class DistributedSampler : public Sampler {
// @param std::unique_ptr<DataBuffer> * pBuffer
// @param int32_t workerId
// @return - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
// Init sampler, called by base class or python
Status InitSampler() override;

View File

@ -59,14 +59,14 @@ Status PKSampler::InitSampler() {
return Status::OK();
}
Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
Status PKSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
if (next_id_ > num_samples_ || num_samples_ == 0) {
RETURN_STATUS_UNEXPECTED("Index out of bound in PKSampler");
} else if (next_id_ == num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);

View File

@ -41,7 +41,7 @@ class PKSampler : public Sampler { // NOT YET FINISHED
// @param std::unique_ptr<DataBuffer pBuffer
// @param int32_t workerId
// @return - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
// first handshake between leaf source op and Sampler. This func will determine the amount of data
// in the dataset that we can sample from.

View File

@ -23,12 +23,12 @@ namespace dataset {
PythonSampler::PythonSampler(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer)
: Sampler(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {}
Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
Status PythonSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
if (need_to_reset_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
}
std::shared_ptr<Tensor> sample_ids;

View File

@ -48,7 +48,7 @@ class PythonSampler : public Sampler {
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
// @param int32_t workerId - not meant to be used
// @return - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
private:
bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer()

View File

@ -31,14 +31,14 @@ RandomSampler::RandomSampler(int64_t num_samples, bool replacement, bool reshuff
reshuffle_each_epoch_(reshuffle_each_epoch),
dist(nullptr) {}
Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
Status RandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
if (next_id_ > num_samples_) {
RETURN_STATUS_UNEXPECTED("RandomSampler Internal Error");
} else if (next_id_ == num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);

View File

@ -41,7 +41,7 @@ class RandomSampler : public Sampler {
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
// @param int32_t workerId - not meant to be used
// @return - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
// meant to be called by base class or python
Status InitSampler() override;

View File

@ -33,11 +33,7 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const {
}
Sampler::Sampler(int64_t num_samples, int64_t samples_per_buffer)
: DatasetOp(0),
num_rows_(0),
num_samples_(num_samples),
samples_per_buffer_(samples_per_buffer),
col_desc_(nullptr) {}
: num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {}
Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
std::shared_ptr<Sampler> child_sampler;
@ -97,7 +93,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) {
std::shared_ptr<Tensor> sample_ids;
// A call to derived class to get sample ids wrapped inside a buffer
RETURN_IF_NOT_OK(GetNextBuffer(&db));
RETURN_IF_NOT_OK(GetNextSample(&db));
// Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch
RETURN_IF_NOT_OK(db->GetTensor(&sample_ids, 0, 0));
// check this buffer is not a ctrl buffer
@ -114,7 +110,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) {
}
}
// perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch
RETURN_IF_NOT_OK(GetNextBuffer(&db));
RETURN_IF_NOT_OK(GetNextSample(&db));
CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received");
// Reset Sampler since this is the end of the epoch
RETURN_IF_NOT_OK(Reset());
@ -133,17 +129,7 @@ Status Sampler::SetNumRowsInDataset(int64_t num_rows) {
return Status::OK();
}
// inline op doesn't have it's own consumer, it's assigned from parent
int32_t Sampler::num_consumers() const {
if (parent_.empty() || parent_[0] == nullptr) {
MS_LOG(WARNING) << "Sampler with no parent. num_consumers is 0.";
return 0;
} else {
return parent_[0]->num_consumers();
}
}
Status Sampler::AddChild(std::shared_ptr<DatasetOp> child) {
Status Sampler::AddChild(std::shared_ptr<Sampler> child) {
if (child == nullptr) {
return Status::OK();
}
@ -182,14 +168,5 @@ Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
return Status::OK();
}
// inline op doesn't have it's own producers, it's assigned from child
int32_t Sampler::num_producers() const {
if (child_.empty() || child_[0] == nullptr) {
MS_LOG(WARNING) << "Sampler with no child, num_producers is 0.";
return 0;
} else {
return child_[0]->num_producers();
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -54,7 +54,7 @@ class RandomAccessOp {
int64_t num_rows_;
};
class Sampler : public DatasetOp {
class Sampler {
public:
// Constructor
// @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0
@ -70,14 +70,14 @@ class Sampler : public DatasetOp {
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
// @param int32_t workerId - not meant to be used
// @return - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override = 0;
virtual Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) = 0;
// return all ids in one epoch as a numpy array, then call reset
Status GetAllIdsThenReset(py::array *data);
// for next epoch of sampleIds
// @return - The error code return
Status Reset() override = 0;
virtual Status Reset() = 0;
// first handshake between leaf source op and Sampler. This func will determine the amount of data
// in the dataset that we can sample from.
@ -98,26 +98,10 @@ class Sampler : public DatasetOp {
// @return status error code
Status SetNumRowsInDataset(int64_t num_rows);
// Sampler is an inlined op and has no workers. Producers and consumers are computed.
// @return
int32_t num_workers() const final { return 0; }
// Identify num consumers (inlined op)
// @return
int32_t num_consumers() const final;
// Identify num producers (inlined op)
// @return
int32_t num_producers() const final;
// Not meant to be called!
// @return - The error code return
Status operator()() final { RETURN_STATUS_UNEXPECTED("Functor not supported in Sampler"); }
// Adds a sampler to become our child.
// @param std::shared_ptr<DatasetOp> - The sampler to add as a child.
// @return - The error code returned.
Status AddChild(std::shared_ptr<DatasetOp> child);
Status AddChild(std::shared_ptr<Sampler> child);
// A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler
// @param std::shared_ptr<Tensor>* sampleIds
@ -125,7 +109,7 @@ class Sampler : public DatasetOp {
// @return - The error code returned.
Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements);
void Print(std::ostream &out, bool show_all) const override;
virtual void Print(std::ostream &out, bool show_all) const;
friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) {
sampler.Print(out, false);
@ -156,6 +140,7 @@ class Sampler : public DatasetOp {
int64_t samples_per_buffer_;
std::unique_ptr<ColDescriptor> col_desc_;
std::vector<std::shared_ptr<Sampler>> child_; // Child nodes
std::unique_ptr<DataBuffer> child_ids_;
};
} // namespace dataset

View File

@ -23,14 +23,14 @@ namespace dataset {
SequentialSampler::SequentialSampler(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer)
: Sampler(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {}
Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
Status SequentialSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
if (id_count_ > num_samples_) {
RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error");
} else if (id_count_ == num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(current_id_, DataBuffer::kDeBFlagNone);

View File

@ -47,7 +47,7 @@ class SequentialSampler : public Sampler {
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
// @param int32_t workerId - not meant to be used
// @return - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
void Print(std::ostream &out, bool show_all) const override;

View File

@ -72,13 +72,13 @@ Status SubsetRandomSampler::Reset() {
}
// Get the sample ids.
Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
Status SubsetRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
// All samples have been drawn
if (sample_id_ == num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone);

View File

@ -49,7 +49,7 @@ class SubsetRandomSampler : public Sampler {
// Get the sample ids.
// @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.
// @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer.
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
private:
// A list of indices (already randomized in constructor).

View File

@ -95,7 +95,7 @@ Status WeightedRandomSampler::Reset() {
}
// Get the sample ids.
Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
Status WeightedRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
if (weights_.size() > static_cast<size_t>(num_rows_)) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"number of samples weights is more than num of rows. Might generate id out of bound OR other errors");
@ -109,7 +109,7 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone);

View File

@ -51,7 +51,7 @@ class WeightedRandomSampler : public Sampler {
// Get the sample ids.
// @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.
// @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer.
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
private:
// A list of weights for each sample.

View File

@ -123,7 +123,7 @@ Status VOCOp::TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::
Status VOCOp::operator()() {
RETURN_IF_NOT_OK(LaunchThreadsAndInitOp());
std::unique_ptr<DataBuffer> sampler_buffer;
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
while (true) {
std::vector<int64_t> keys;
keys.reserve(rows_per_buffer_);
@ -134,7 +134,7 @@ Status VOCOp::operator()() {
RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64");
}
RETURN_IF_NOT_OK(TraverseSampleIds(sample_ids, &keys));
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
}
if (keys.empty() == false) {
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(
@ -155,7 +155,7 @@ Status VOCOp::operator()() {
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
RETURN_IF_NOT_OK(wp_.Wait());
wp_.Clear();
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer));
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
}
}
}

View File

@ -68,7 +68,7 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) {
for (int i = 0; i < 6; i++) {
std::shared_ptr<Sampler> sampler = std::make_shared<DistributedSampler>(num_samples, 3, i % 3, (i < 3 ? false : true));
sampler->HandshakeRandomAccessOp(&mock);
sampler->GetNextBuffer(&db);
sampler->GetNextSample(&db);
db->GetTensor(&tensor, 0, 0);
MS_LOG(DEBUG) << (*tensor);
if(i < 3) { // This is added due to std::shuffle()
@ -90,17 +90,17 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) {
std::unique_ptr<DataBuffer> db;
std::shared_ptr<Tensor> tensor;
sampler->HandshakeRandomAccessOp(&mock);
sampler->GetNextBuffer(&db);
sampler->GetNextSample(&db);
db->GetTensor(&tensor, 0, 0);
EXPECT_TRUE((*tensor) == (*label1));
sampler->GetNextBuffer(&db);
sampler->GetNextSample(&db);
db->GetTensor(&tensor, 0, 0);
EXPECT_TRUE((*tensor) == (*label2));
sampler->Reset();
sampler->GetNextBuffer(&db);
sampler->GetNextSample(&db);
db->GetTensor(&tensor, 0, 0);
EXPECT_TRUE((*tensor) == (*label1));
sampler->GetNextBuffer(&db);
sampler->GetNextSample(&db);
db->GetTensor(&tensor, 0, 0);
EXPECT_TRUE((*tensor) == (*label2));
}

View File

@ -49,7 +49,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) {
std::unique_ptr<DataBuffer> db;
TensorRow row;
std::vector<int64_t> out;
ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
db->PopRow(&row);
for (const auto &t : row) {
for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
@ -61,7 +61,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) {
ASSERT_NE(in_set.find(out[i]), in_set.end());
}
ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
ASSERT_EQ(db->eoe(), true);
}
@ -79,7 +79,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) {
TensorRow row;
std::vector<int64_t> out;
ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
int epoch = 0;
while (!db->eoe()) {
epoch++;
@ -91,7 +91,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) {
}
db.reset();
ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
}
ASSERT_EQ(epoch, (total_samples + samples_per_buffer - 1) / samples_per_buffer);
@ -111,7 +111,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
TensorRow row;
std::vector<int64_t> out;
ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
db->PopRow(&row);
for (const auto &t : row) {
for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
@ -125,7 +125,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
sampler.Reset();
ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
ASSERT_EQ(db->eoe(), false);
db->PopRow(&row);
out.clear();
@ -139,6 +139,6 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
ASSERT_NE(in_set.find(out[i]), in_set.end());
}
ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(sampler.GetNextSample(&db), Status::OK());
ASSERT_EQ(db->eoe(), true);
}

View File

@ -58,7 +58,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
std::unique_ptr<DataBuffer> db;
TensorRow row;
std::vector<uint64_t> out;
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
db->PopRow(&row);
for (const auto &t : row) {
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
@ -69,7 +69,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
ASSERT_EQ(num_samples, out.size());
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
ASSERT_EQ(db->eoe(), true);
}
@ -88,7 +88,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
std::unique_ptr<DataBuffer> db;
TensorRow row;
std::vector<uint64_t> out;
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
db->PopRow(&row);
for (const auto &t : row) {
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
@ -105,7 +105,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
}
}
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
ASSERT_EQ(db->eoe(), true);
}
@ -124,7 +124,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) {
std::unique_ptr<DataBuffer> db;
TensorRow row;
std::vector<uint64_t> out;
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
int epoch = 0;
while (!db->eoe()) {
epoch++;
@ -135,7 +135,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) {
}
}
db.reset();
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
}
ASSERT_EQ(epoch, (num_samples + samples_per_buffer - 1) / samples_per_buffer);
@ -160,7 +160,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) {
std::unique_ptr<DataBuffer> db;
TensorRow row;
std::vector<uint64_t> out;
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
int epoch = 0;
while (!db->eoe()) {
epoch++;
@ -172,7 +172,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) {
}
}
db.reset();
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
}
// Without replacement, each sample only drawn once.
@ -201,7 +201,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
std::unique_ptr<DataBuffer> db;
TensorRow row;
std::vector<uint64_t> out;
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
db->PopRow(&row);
for (const auto &t : row) {
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
@ -211,13 +211,13 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
}
ASSERT_EQ(num_samples, out.size());
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
ASSERT_EQ(db->eoe(), true);
m_sampler.Reset();
out.clear();
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
db->PopRow(&row);
for (const auto &t : row) {
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
@ -227,7 +227,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
}
ASSERT_EQ(num_samples, out.size());
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
ASSERT_EQ(db->eoe(), true);
}
@ -246,7 +246,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
std::unique_ptr<DataBuffer> db;
TensorRow row;
std::vector<uint64_t> out;
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
db->PopRow(&row);
for (const auto &t : row) {
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
@ -256,7 +256,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
}
ASSERT_EQ(num_samples, out.size());
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
ASSERT_EQ(db->eoe(), true);
m_sampler.Reset();
@ -265,7 +265,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
freq.resize(total_samples, 0);
MS_LOG(INFO) << "Resetting sampler";
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
db->PopRow(&row);
for (const auto &t : row) {
for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
@ -282,6 +282,6 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
}
}
ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK());
ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK());
ASSERT_EQ(db->eoe(), true);
}