diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc index 4849eeac630..69fda9b66cf 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc @@ -263,7 +263,7 @@ std::vector CelebAOp::Split(const std::string &line) { Status CelebAOp::operator()() { RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); std::unique_ptr data_buffer; - RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&data_buffer)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&data_buffer)); RETURN_IF_NOT_OK(AddIOBlock(&data_buffer)); return Status::OK(); } @@ -291,7 +291,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr *data_buffer) { keys.clear(); } } - RETURN_IF_NOT_OK(sampler_->GetNextBuffer(data_buffer)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); } if (!keys.empty()) { @@ -313,7 +313,7 @@ Status CelebAOp::AddIOBlock(std::unique_ptr *data_buffer) { io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks wp_.Clear(); - RETURN_IF_NOT_OK(sampler_->GetNextBuffer(data_buffer)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); } } } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc index e2ac8dd31e2..0893ad9f821 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc @@ -100,7 +100,7 @@ CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const Status CifarOp::operator()() { RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); std::unique_ptr sampler_buffer; - RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); while (true) { // each iterator is 1 epoch std::vector keys; keys.reserve(rows_per_buffer_); @@ -118,7 +118,7 @@ Status CifarOp::operator()() { keys.clear(); } } - RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } if (keys.empty() == false) { RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( @@ -139,7 +139,7 @@ Status CifarOp::operator()() { io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks wp_.Clear(); - RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } } } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc index 7ffb63a6ba9..bd7da566b6e 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc @@ -126,7 +126,7 @@ Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) { Status ImageFolderOp::operator()() { RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); std::unique_ptr sampler_buffer; - RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); while (true) { // each iterator is 1 epoch std::vector keys; keys.reserve(rows_per_buffer_); @@ -145,7 +145,7 @@ Status ImageFolderOp::operator()() { keys.clear(); } } - RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } if (keys.empty() == false) { RETURN_IF_NOT_OK( @@ -166,7 +166,7 @@ Status ImageFolderOp::operator()() { io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks wp_.Clear(); - RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } } } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc index 8b22f9dcfa4..14cc9b22d95 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc @@ -88,7 +88,7 @@ ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string f Status ManifestOp::operator()() { RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); std::unique_ptr sampler_buffer; - RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); return AddIoBlock(&sampler_buffer); } @@ -110,7 +110,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr *sampler_buffer) { keys.clear(); } } - RETURN_IF_NOT_OK(sampler_->GetNextBuffer(sampler_buffer)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer)); } if (keys.empty() == false) { RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( @@ -131,7 +131,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr *sampler_buffer) { io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks wp_.Clear(); - RETURN_IF_NOT_OK(sampler_->GetNextBuffer(sampler_buffer)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer)); } } } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc index 07726f5033e..a937d8e5aef 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc @@ -98,7 +98,7 @@ Status MnistOp::TraversalSampleIds(const std::shared_ptr &sample_ids, st Status MnistOp::operator()() { RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); std::unique_ptr sampler_buffer; - RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); while (true) { // each iterator is 1 epoch std::vector keys; keys.reserve(rows_per_buffer_); @@ -109,7 +109,7 @@ Status MnistOp::operator()() { RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't UINT64"); } RETURN_IF_NOT_OK(TraversalSampleIds(sample_ids, &keys)); - RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } if (keys.empty() == false) { RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( @@ -130,7 +130,7 @@ Status MnistOp::operator()() { io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks wp_.Clear(); - RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } } } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc index 77207e9a6cc..e1f3ed7214f 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc @@ -55,14 +55,14 @@ Status DistributedSampler::InitSampler() { return Status::OK(); } -Status DistributedSampler::GetNextBuffer(std::unique_ptr *out_buffer) { +Status DistributedSampler::GetNextSample(std::unique_ptr *out_buffer) { if (cnt_ > samples_per_buffer_) { RETURN_STATUS_UNEXPECTED("Distributed Sampler Error"); } else if (cnt_ == samples_per_buffer_) { (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); } else { if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); + RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); } (*out_buffer) = std::make_unique(cnt_, DataBuffer::kDeBFlagNone); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h index aeea2bfe5dd..b4d68362ee2 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h @@ -40,7 +40,7 @@ class DistributedSampler : public Sampler { // @param std::unique_ptr * pBuffer // @param int32_t workerId // @return - The error code return - Status GetNextBuffer(std::unique_ptr *out_buffer) override; + Status GetNextSample(std::unique_ptr *out_buffer) override; // Init sampler, called by base class or python Status InitSampler() override; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc index 48c59c45032..0c49c1b3144 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc @@ -59,14 +59,14 @@ Status PKSampler::InitSampler() { return Status::OK(); } -Status PKSampler::GetNextBuffer(std::unique_ptr *out_buffer) { +Status PKSampler::GetNextSample(std::unique_ptr *out_buffer) { if (next_id_ > num_samples_ || num_samples_ == 0) { RETURN_STATUS_UNEXPECTED("Index out of bound in PKSampler"); } else if (next_id_ == num_samples_) { (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); } else { if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); + RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); } (*out_buffer) = std::make_unique(next_id_, DataBuffer::kDeBFlagNone); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h index a8538874ec4..990242c8e9a 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h @@ -41,7 +41,7 @@ class PKSampler : public Sampler { // NOT YET FINISHED // @param std::unique_ptr *out_buffer) override; + Status GetNextSample(std::unique_ptr *out_buffer) override; // first handshake between leaf source op and Sampler. This func will determine the amount of data // in the dataset that we can sample from. diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc index bff11a0b448..d8078eed9d7 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc @@ -23,12 +23,12 @@ namespace dataset { PythonSampler::PythonSampler(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer) : Sampler(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} -Status PythonSampler::GetNextBuffer(std::unique_ptr *out_buffer) { +Status PythonSampler::GetNextSample(std::unique_ptr *out_buffer) { if (need_to_reset_) { (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); } else { if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); + RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); } std::shared_ptr sample_ids; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h index bba9804952c..0c1331595f4 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h @@ -48,7 +48,7 @@ class PythonSampler : public Sampler { // @param std::unique_ptr pBuffer - Buffer to be returned to StorageOp // @param int32_t workerId - not meant to be used // @return - The error code return - Status GetNextBuffer(std::unique_ptr *out_buffer) override; + Status GetNextSample(std::unique_ptr *out_buffer) override; private: bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer() diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc index 2adf6bc8c79..c78225e010c 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc @@ -31,14 +31,14 @@ RandomSampler::RandomSampler(int64_t num_samples, bool replacement, bool reshuff reshuffle_each_epoch_(reshuffle_each_epoch), dist(nullptr) {} -Status RandomSampler::GetNextBuffer(std::unique_ptr *out_buffer) { +Status RandomSampler::GetNextSample(std::unique_ptr *out_buffer) { if (next_id_ > num_samples_) { RETURN_STATUS_UNEXPECTED("RandomSampler Internal Error"); } else if (next_id_ == num_samples_) { (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); } else { if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); + RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); } (*out_buffer) = std::make_unique(next_id_, DataBuffer::kDeBFlagNone); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h index bb8bb724289..e9961cc51ad 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h @@ -41,7 +41,7 @@ class RandomSampler : public Sampler { // @param std::unique_ptr pBuffer - Buffer to be returned to StorageOp // @param int32_t workerId - not meant to be used // @return - The error code return - Status GetNextBuffer(std::unique_ptr *out_buffer) override; + Status GetNextSample(std::unique_ptr *out_buffer) override; // meant to be called by base class or python Status InitSampler() override; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc index 7c96a2c54aa..d5e1b838b9c 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc @@ -33,11 +33,7 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const { } Sampler::Sampler(int64_t num_samples, int64_t samples_per_buffer) - : DatasetOp(0), - num_rows_(0), - num_samples_(num_samples), - samples_per_buffer_(samples_per_buffer), - col_desc_(nullptr) {} + : num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { std::shared_ptr child_sampler; @@ -97,7 +93,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) { std::shared_ptr sample_ids; // A call to derived class to get sample ids wrapped inside a buffer - RETURN_IF_NOT_OK(GetNextBuffer(&db)); + RETURN_IF_NOT_OK(GetNextSample(&db)); // Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch RETURN_IF_NOT_OK(db->GetTensor(&sample_ids, 0, 0)); // check this buffer is not a ctrl buffer @@ -114,7 +110,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) { } } // perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch - RETURN_IF_NOT_OK(GetNextBuffer(&db)); + RETURN_IF_NOT_OK(GetNextSample(&db)); CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received"); // Reset Sampler since this is the end of the epoch RETURN_IF_NOT_OK(Reset()); @@ -133,17 +129,7 @@ Status Sampler::SetNumRowsInDataset(int64_t num_rows) { return Status::OK(); } -// inline op doesn't have it's own consumer, it's assigned from parent -int32_t Sampler::num_consumers() const { - if (parent_.empty() || parent_[0] == nullptr) { - MS_LOG(WARNING) << "Sampler with no parent. num_consumers is 0."; - return 0; - } else { - return parent_[0]->num_consumers(); - } -} - -Status Sampler::AddChild(std::shared_ptr child) { +Status Sampler::AddChild(std::shared_ptr child) { if (child == nullptr) { return Status::OK(); } @@ -182,14 +168,5 @@ Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) { return Status::OK(); } -// inline op doesn't have it's own producers, it's assigned from child -int32_t Sampler::num_producers() const { - if (child_.empty() || child_[0] == nullptr) { - MS_LOG(WARNING) << "Sampler with no child, num_producers is 0."; - return 0; - } else { - return child_[0]->num_producers(); - } -} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h index 8880e5e9f8a..36be723ae8c 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h @@ -54,7 +54,7 @@ class RandomAccessOp { int64_t num_rows_; }; -class Sampler : public DatasetOp { +class Sampler { public: // Constructor // @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0 @@ -70,14 +70,14 @@ class Sampler : public DatasetOp { // @param std::unique_ptr pBuffer - Buffer to be returned to StorageOp // @param int32_t workerId - not meant to be used // @return - The error code return - Status GetNextBuffer(std::unique_ptr *out_buffer) override = 0; + virtual Status GetNextSample(std::unique_ptr *out_buffer) = 0; // return all ids in one epoch as a numpy array, then call reset Status GetAllIdsThenReset(py::array *data); // for next epoch of sampleIds // @return - The error code return - Status Reset() override = 0; + virtual Status Reset() = 0; // first handshake between leaf source op and Sampler. This func will determine the amount of data // in the dataset that we can sample from. @@ -98,26 +98,10 @@ class Sampler : public DatasetOp { // @return status error code Status SetNumRowsInDataset(int64_t num_rows); - // Sampler is an inlined op and has no workers. Producers and consumers are computed. - // @return - int32_t num_workers() const final { return 0; } - - // Identify num consumers (inlined op) - // @return - int32_t num_consumers() const final; - - // Identify num producers (inlined op) - // @return - int32_t num_producers() const final; - - // Not meant to be called! - // @return - The error code return - Status operator()() final { RETURN_STATUS_UNEXPECTED("Functor not supported in Sampler"); } - // Adds a sampler to become our child. // @param std::shared_ptr - The sampler to add as a child. // @return - The error code returned. - Status AddChild(std::shared_ptr child); + Status AddChild(std::shared_ptr child); // A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler // @param std::shared_ptr* sampleIds @@ -125,7 +109,7 @@ class Sampler : public DatasetOp { // @return - The error code returned. Status CreateSamplerTensor(std::shared_ptr *sample_ids, int64_t num_elements); - void Print(std::ostream &out, bool show_all) const override; + virtual void Print(std::ostream &out, bool show_all) const; friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) { sampler.Print(out, false); @@ -156,6 +140,7 @@ class Sampler : public DatasetOp { int64_t samples_per_buffer_; std::unique_ptr col_desc_; + std::vector> child_; // Child nodes std::unique_ptr child_ids_; }; } // namespace dataset diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc index b26fc630671..8a088ef9c33 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc @@ -23,14 +23,14 @@ namespace dataset { SequentialSampler::SequentialSampler(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer) : Sampler(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {} -Status SequentialSampler::GetNextBuffer(std::unique_ptr *out_buffer) { +Status SequentialSampler::GetNextSample(std::unique_ptr *out_buffer) { if (id_count_ > num_samples_) { RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error"); } else if (id_count_ == num_samples_) { (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); } else { if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); + RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); } (*out_buffer) = std::make_unique(current_id_, DataBuffer::kDeBFlagNone); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h index 46cfb7a3047..f95a717e3c4 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h @@ -47,7 +47,7 @@ class SequentialSampler : public Sampler { // @param std::unique_ptr pBuffer - Buffer to be returned to StorageOp // @param int32_t workerId - not meant to be used // @return - The error code return - Status GetNextBuffer(std::unique_ptr *out_buffer) override; + Status GetNextSample(std::unique_ptr *out_buffer) override; void Print(std::ostream &out, bool show_all) const override; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc index 0dfeb1a191b..21eead3a0b5 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc @@ -72,13 +72,13 @@ Status SubsetRandomSampler::Reset() { } // Get the sample ids. -Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr *out_buffer) { +Status SubsetRandomSampler::GetNextSample(std::unique_ptr *out_buffer) { // All samples have been drawn if (sample_id_ == num_samples_) { (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagEOE); } else { if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); + RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); } (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagNone); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h index d1ab13c5404..253b6ff54b5 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h @@ -49,7 +49,7 @@ class SubsetRandomSampler : public Sampler { // Get the sample ids. // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. - Status GetNextBuffer(std::unique_ptr *out_buffer) override; + Status GetNextSample(std::unique_ptr *out_buffer) override; private: // A list of indices (already randomized in constructor). diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc index 96b2571786a..cdf728cbcee 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc @@ -95,7 +95,7 @@ Status WeightedRandomSampler::Reset() { } // Get the sample ids. -Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr *out_buffer) { +Status WeightedRandomSampler::GetNextSample(std::unique_ptr *out_buffer) { if (weights_.size() > static_cast(num_rows_)) { return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "number of samples weights is more than num of rows. Might generate id out of bound OR other errors"); @@ -109,7 +109,7 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr *out_buf (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagEOE); } else { if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); + RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); } (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagNone); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h index 775176ccdac..46a92b973c2 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h @@ -51,7 +51,7 @@ class WeightedRandomSampler : public Sampler { // Get the sample ids. // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. - Status GetNextBuffer(std::unique_ptr *out_buffer) override; + Status GetNextSample(std::unique_ptr *out_buffer) override; private: // A list of weights for each sample. diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc index 35d44757106..3d52fc7373b 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc @@ -123,7 +123,7 @@ Status VOCOp::TraverseSampleIds(const std::shared_ptr &sample_ids, std:: Status VOCOp::operator()() { RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); std::unique_ptr sampler_buffer; - RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); while (true) { std::vector keys; keys.reserve(rows_per_buffer_); @@ -134,7 +134,7 @@ Status VOCOp::operator()() { RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64"); } RETURN_IF_NOT_OK(TraverseSampleIds(sample_ids, &keys)); - RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } if (keys.empty() == false) { RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( @@ -155,7 +155,7 @@ Status VOCOp::operator()() { io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); RETURN_IF_NOT_OK(wp_.Wait()); wp_.Clear(); - RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } } } diff --git a/tests/ut/cpp/dataset/stand_alone_samplers_test.cc b/tests/ut/cpp/dataset/stand_alone_samplers_test.cc index 39fe56e163a..0c11ef9dcfd 100644 --- a/tests/ut/cpp/dataset/stand_alone_samplers_test.cc +++ b/tests/ut/cpp/dataset/stand_alone_samplers_test.cc @@ -68,7 +68,7 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) { for (int i = 0; i < 6; i++) { std::shared_ptr sampler = std::make_shared(num_samples, 3, i % 3, (i < 3 ? false : true)); sampler->HandshakeRandomAccessOp(&mock); - sampler->GetNextBuffer(&db); + sampler->GetNextSample(&db); db->GetTensor(&tensor, 0, 0); MS_LOG(DEBUG) << (*tensor); if(i < 3) { // This is added due to std::shuffle() @@ -90,17 +90,17 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) { std::unique_ptr db; std::shared_ptr tensor; sampler->HandshakeRandomAccessOp(&mock); - sampler->GetNextBuffer(&db); + sampler->GetNextSample(&db); db->GetTensor(&tensor, 0, 0); EXPECT_TRUE((*tensor) == (*label1)); - sampler->GetNextBuffer(&db); + sampler->GetNextSample(&db); db->GetTensor(&tensor, 0, 0); EXPECT_TRUE((*tensor) == (*label2)); sampler->Reset(); - sampler->GetNextBuffer(&db); + sampler->GetNextSample(&db); db->GetTensor(&tensor, 0, 0); EXPECT_TRUE((*tensor) == (*label1)); - sampler->GetNextBuffer(&db); + sampler->GetNextSample(&db); db->GetTensor(&tensor, 0, 0); EXPECT_TRUE((*tensor) == (*label2)); } diff --git a/tests/ut/cpp/dataset/subset_random_sampler_test.cc b/tests/ut/cpp/dataset/subset_random_sampler_test.cc index 10050dbfb4f..55674a55167 100644 --- a/tests/ut/cpp/dataset/subset_random_sampler_test.cc +++ b/tests/ut/cpp/dataset/subset_random_sampler_test.cc @@ -49,7 +49,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) { std::unique_ptr db; TensorRow row; std::vector out; - ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); db->PopRow(&row); for (const auto &t : row) { for (auto it = t->begin(); it != t->end(); it++) { @@ -61,7 +61,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) { ASSERT_NE(in_set.find(out[i]), in_set.end()); } - ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); ASSERT_EQ(db->eoe(), true); } @@ -79,7 +79,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) { TensorRow row; std::vector out; - ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); int epoch = 0; while (!db->eoe()) { epoch++; @@ -91,7 +91,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) { } db.reset(); - ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); } ASSERT_EQ(epoch, (total_samples + samples_per_buffer - 1) / samples_per_buffer); @@ -111,7 +111,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) { TensorRow row; std::vector out; - ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); db->PopRow(&row); for (const auto &t : row) { for (auto it = t->begin(); it != t->end(); it++) { @@ -125,7 +125,7 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) { sampler.Reset(); - ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); ASSERT_EQ(db->eoe(), false); db->PopRow(&row); out.clear(); @@ -139,6 +139,6 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) { ASSERT_NE(in_set.find(out[i]), in_set.end()); } - ASSERT_EQ(sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(sampler.GetNextSample(&db), Status::OK()); ASSERT_EQ(db->eoe(), true); } diff --git a/tests/ut/cpp/dataset/weighted_random_sampler_test.cc b/tests/ut/cpp/dataset/weighted_random_sampler_test.cc index a41dae532f3..6cc69c8b8b4 100644 --- a/tests/ut/cpp/dataset/weighted_random_sampler_test.cc +++ b/tests/ut/cpp/dataset/weighted_random_sampler_test.cc @@ -58,7 +58,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) { std::unique_ptr db; TensorRow row; std::vector out; - ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); db->PopRow(&row); for (const auto &t : row) { for (auto it = t->begin(); it != t->end(); it++) { @@ -69,7 +69,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) { ASSERT_EQ(num_samples, out.size()); - ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); ASSERT_EQ(db->eoe(), true); } @@ -88,7 +88,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) { std::unique_ptr db; TensorRow row; std::vector out; - ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); db->PopRow(&row); for (const auto &t : row) { for (auto it = t->begin(); it != t->end(); it++) { @@ -105,7 +105,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) { } } - ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); ASSERT_EQ(db->eoe(), true); } @@ -124,7 +124,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) { std::unique_ptr db; TensorRow row; std::vector out; - ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); int epoch = 0; while (!db->eoe()) { epoch++; @@ -135,7 +135,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) { } } db.reset(); - ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); } ASSERT_EQ(epoch, (num_samples + samples_per_buffer - 1) / samples_per_buffer); @@ -160,7 +160,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) { std::unique_ptr db; TensorRow row; std::vector out; - ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); int epoch = 0; while (!db->eoe()) { epoch++; @@ -172,7 +172,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) { } } db.reset(); - ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); } // Without replacement, each sample only drawn once. @@ -201,7 +201,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) { std::unique_ptr db; TensorRow row; std::vector out; - ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); db->PopRow(&row); for (const auto &t : row) { for (auto it = t->begin(); it != t->end(); it++) { @@ -211,13 +211,13 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) { } ASSERT_EQ(num_samples, out.size()); - ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); ASSERT_EQ(db->eoe(), true); m_sampler.Reset(); out.clear(); - ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); db->PopRow(&row); for (const auto &t : row) { for (auto it = t->begin(); it != t->end(); it++) { @@ -227,7 +227,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) { } ASSERT_EQ(num_samples, out.size()); - ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); ASSERT_EQ(db->eoe(), true); } @@ -246,7 +246,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { std::unique_ptr db; TensorRow row; std::vector out; - ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); db->PopRow(&row); for (const auto &t : row) { for (auto it = t->begin(); it != t->end(); it++) { @@ -256,7 +256,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { } ASSERT_EQ(num_samples, out.size()); - ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); ASSERT_EQ(db->eoe(), true); m_sampler.Reset(); @@ -265,7 +265,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { freq.resize(total_samples, 0); MS_LOG(INFO) << "Resetting sampler"; - ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); db->PopRow(&row); for (const auto &t : row) { for (auto it = t->begin(); it != t->end(); it++) { @@ -282,6 +282,6 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { } } - ASSERT_EQ(m_sampler.GetNextBuffer(&db), Status::OK()); + ASSERT_EQ(m_sampler.GetNextSample(&db), Status::OK()); ASSERT_EQ(db->eoe(), true); }