From 0f2b5d8cac49fe04bf97cc31670a1e5be678ecd5 Mon Sep 17 00:00:00 2001 From: Mahdi Date: Tue, 8 Dec 2020 15:55:17 -0500 Subject: [PATCH] Changed SamplerObj validate params to return status and added AddChild to it --- .../ccsrc/minddata/dataset/api/samplers.cc | 141 +++++++++++------- .../source/sampler/distributed_sampler.cc | 4 + .../datasetops/source/sampler/pk_sampler.cc | 4 + .../source/sampler/python_sampler.cc | 5 + .../source/sampler/random_sampler.cc | 4 + .../datasetops/source/sampler/sampler.cc | 6 +- .../datasetops/source/sampler/sampler.h | 1 + .../source/sampler/sequential_sampler.cc | 5 + .../source/sampler/subset_random_sampler.cc | 4 + .../source/sampler/weighted_random_sampler.cc | 4 + .../ccsrc/minddata/dataset/include/samplers.h | 75 +++++++--- tests/ut/cpp/dataset/c_api_samplers_test.cc | 31 ++++ 12 files changed, 211 insertions(+), 73 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/api/samplers.cc b/mindspore/ccsrc/minddata/dataset/api/samplers.cc index 98a1dc219ec..58c518f8818 100644 --- a/mindspore/ccsrc/minddata/dataset/api/samplers.cc +++ b/mindspore/ccsrc/minddata/dataset/api/samplers.cc @@ -48,6 +48,34 @@ namespace dataset { // Constructor SamplerObj::SamplerObj() {} +void SamplerObj::BuildChildren(std::shared_ptr sampler) { + for (auto child : children_) { + auto sampler_rt = child->Build(); + sampler->AddChild(sampler_rt); + } +} + +Status SamplerObj::AddChild(std::shared_ptr child) { + if (child == nullptr) { + return Status::OK(); + } + + // Only samplers can be added, not any other DatasetOp. + std::shared_ptr sampler = std::dynamic_pointer_cast(child); + if (!sampler) { + RETURN_STATUS_UNEXPECTED("Cannot add child, child is not a sampler object."); + } + + // Samplers can have at most 1 child. + if (!children_.empty()) { + RETURN_STATUS_UNEXPECTED("Cannot add child sampler, this sampler already has a child."); + } + + children_.push_back(child); + + return Status::OK(); +} + /// Function to create a Distributed Sampler. std::shared_ptr DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed, int64_t offset, @@ -55,7 +83,7 @@ std::shared_ptr DistributedSampler(int64_t num_shards, in auto sampler = std::make_shared(num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist); // Input validation - if (!sampler->ValidateParams()) { + if (sampler->ValidateParams().IsError()) { return nullptr; } return sampler; @@ -65,7 +93,7 @@ std::shared_ptr DistributedSampler(int64_t num_shards, in std::shared_ptr PKSampler(int64_t num_val, bool shuffle, int64_t num_samples) { auto sampler = std::make_shared(num_val, shuffle, num_samples); // Input validation - if (!sampler->ValidateParams()) { + if (sampler->ValidateParams().IsError()) { return nullptr; } return sampler; @@ -75,7 +103,7 @@ std::shared_ptr PKSampler(int64_t num_val, bool shuffle, int64_t n std::shared_ptr RandomSampler(bool replacement, int64_t num_samples) { auto sampler = std::make_shared(replacement, num_samples); // Input validation - if (!sampler->ValidateParams()) { + if (sampler->ValidateParams().IsError()) { return nullptr; } return sampler; @@ -85,7 +113,7 @@ std::shared_ptr RandomSampler(bool replacement, int64_t num_sa std::shared_ptr SequentialSampler(int64_t start_index, int64_t num_samples) { auto sampler = std::make_shared(start_index, num_samples); // Input validation - if (!sampler->ValidateParams()) { + if (sampler->ValidateParams().IsError()) { return nullptr; } return sampler; @@ -95,7 +123,7 @@ std::shared_ptr SequentialSampler(int64_t start_index, int std::shared_ptr SubsetRandomSampler(std::vector indices, int64_t num_samples) { auto sampler = std::make_shared(std::move(indices), num_samples); // Input validation - if (!sampler->ValidateParams()) { + if (sampler->ValidateParams().IsError()) { return nullptr; } return sampler; @@ -106,7 +134,7 @@ std::shared_ptr WeightedRandomSampler(std::vector(std::move(weights), num_samples, replacement); // Input validation - if (!sampler->ValidateParams()) { + if (sampler->ValidateParams().IsError()) { return nullptr; } return sampler; @@ -125,35 +153,33 @@ DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_i offset_(offset), even_dist_(even_dist) {} -bool DistributedSamplerObj::ValidateParams() { +Status DistributedSamplerObj::ValidateParams() { if (num_shards_ <= 0) { - MS_LOG(ERROR) << "DistributedSampler: invalid num_shards: " << num_shards_; - return false; + RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid num_shards: " + std::to_string(num_shards_)); } if (shard_id_ < 0 || shard_id_ >= num_shards_) { - MS_LOG(ERROR) << "DistributedSampler: invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_; - return false; + RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid input, shard_id: " + std::to_string(shard_id_) + + ", num_shards: " + std::to_string(num_shards_)); } if (num_samples_ < 0) { - MS_LOG(ERROR) << "DistributedSampler: invalid num_samples: " << num_samples_; - return false; + RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid num_samples: " + std::to_string(num_samples_)); } if (offset_ > num_shards_) { - MS_LOG(ERROR) << "DistributedSampler: invalid offset: " << offset_ - << ", which should be no more than num_shards: " << num_shards_; - return false; + RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid offset: " + std::to_string(offset_) + + ", which should be no more than num_shards: " + std::to_string(num_shards_)); } - return true; + return Status::OK(); } std::shared_ptr DistributedSamplerObj::Build() { // runtime sampler object auto sampler = std::make_shared(num_samples_, num_shards_, shard_id_, shuffle_, seed_, offset_, even_dist_); + BuildChildren(sampler); return sampler; } @@ -170,23 +196,21 @@ std::shared_ptr DistributedSamplerObj::BuildForMindDa PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples) : num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {} -bool PKSamplerObj::ValidateParams() { +Status PKSamplerObj::ValidateParams() { if (num_val_ <= 0) { - MS_LOG(ERROR) << "PKSampler: invalid num_val: " << num_val_; - return false; + RETURN_STATUS_UNEXPECTED("PKSampler: invalid num_val: " + std::to_string(num_val_)); } if (num_samples_ < 0) { - MS_LOG(ERROR) << "PKSampler: invalid num_samples: " << num_samples_; - return false; + RETURN_STATUS_UNEXPECTED("PKSampler: invalid num_samples: " + std::to_string(num_samples_)); } - return true; + return Status::OK(); } std::shared_ptr PKSamplerObj::Build() { // runtime sampler object auto sampler = std::make_shared(num_samples_, num_val_, shuffle_); - + BuildChildren(sampler); return sampler; } @@ -198,9 +222,12 @@ PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr PreBuiltSamplerObj::Build() { return sp_; } +std::shared_ptr PreBuiltSamplerObj::Build() { + BuildChildren(sp_); + return sp_; +} #ifndef ENABLE_ANDROID std::shared_ptr PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; } @@ -208,9 +235,19 @@ std::shared_ptr PreBuiltSamplerObj::BuildForMindDatas std::shared_ptr PreBuiltSamplerObj::Copy() { #ifndef ENABLE_ANDROID - if (sp_minddataset_ != nullptr) return std::make_shared(sp_minddataset_); + if (sp_minddataset_ != nullptr) { + auto sampler = std::make_shared(sp_minddataset_); + for (auto child : children_) { + sampler->AddChild(child); + } + return sampler; + } #endif - return std::make_shared(sp_); + auto sampler = std::make_shared(sp_); + for (auto child : children_) { + sampler->AddChild(child); + } + return sampler; } #ifndef ENABLE_ANDROID @@ -232,19 +269,18 @@ std::shared_ptr PKSamplerObj::BuildForMindDataset() { RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples) : replacement_(replacement), num_samples_(num_samples) {} -bool RandomSamplerObj::ValidateParams() { +Status RandomSamplerObj::ValidateParams() { if (num_samples_ < 0) { - MS_LOG(ERROR) << "RandomSampler: invalid num_samples: " << num_samples_; - return false; + RETURN_STATUS_UNEXPECTED("RandomSampler: invalid num_samples: " + std::to_string(num_samples_)); } - return true; + return Status::OK(); } std::shared_ptr RandomSamplerObj::Build() { // runtime sampler object bool reshuffle_each_epoch = true; auto sampler = std::make_shared(num_samples_, replacement_, reshuffle_each_epoch); - + BuildChildren(sampler); return sampler; } @@ -263,24 +299,22 @@ std::shared_ptr RandomSamplerObj::BuildForMindDataset SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples) : start_index_(start_index), num_samples_(num_samples) {} -bool SequentialSamplerObj::ValidateParams() { +Status SequentialSamplerObj::ValidateParams() { if (num_samples_ < 0) { - MS_LOG(ERROR) << "SequentialSampler: invalid num_samples: " << num_samples_; - return false; + RETURN_STATUS_UNEXPECTED("SequentialSampler: invalid num_samples: " + std::to_string(num_samples_)); } if (start_index_ < 0) { - MS_LOG(ERROR) << "SequentialSampler: invalid start_index: " << start_index_; - return false; + RETURN_STATUS_UNEXPECTED("SequentialSampler: invalid start_index: " + std::to_string(start_index_)); } - return true; + return Status::OK(); } std::shared_ptr SequentialSamplerObj::Build() { // runtime sampler object auto sampler = std::make_shared(num_samples_, start_index_); - + BuildChildren(sampler); return sampler; } @@ -297,19 +331,18 @@ std::shared_ptr SequentialSamplerObj::BuildForMindDat SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector indices, int64_t num_samples) : indices_(std::move(indices)), num_samples_(num_samples) {} -bool SubsetRandomSamplerObj::ValidateParams() { +Status SubsetRandomSamplerObj::ValidateParams() { if (num_samples_ < 0) { - MS_LOG(ERROR) << "SubsetRandomSampler: invalid num_samples: " << num_samples_; - return false; + RETURN_STATUS_UNEXPECTED("SubsetRandomSampler: invalid num_samples: " + std::to_string(num_samples_)); } - return true; + return Status::OK(); } std::shared_ptr SubsetRandomSamplerObj::Build() { // runtime sampler object auto sampler = std::make_shared(num_samples_, indices_); - + BuildChildren(sampler); return sampler; } @@ -326,34 +359,32 @@ std::shared_ptr SubsetRandomSamplerObj::BuildForMindD WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector weights, int64_t num_samples, bool replacement) : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {} -bool WeightedRandomSamplerObj::ValidateParams() { +Status WeightedRandomSamplerObj::ValidateParams() { if (weights_.empty()) { - MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not be empty"; - return false; + RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not be empty"); } int32_t zero_elem = 0; for (int32_t i = 0; i < weights_.size(); ++i) { if (weights_[i] < 0) { - MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not contain negative number, got: " << weights_[i]; - return false; + RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not contain negative number, got: " + + std::to_string(weights_[i])); } if (weights_[i] == 0.0) { zero_elem++; } } if (zero_elem == weights_.size()) { - MS_LOG(ERROR) << "WeightedRandomSampler: elements of weights vector must not be all zero"; - return false; + RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: elements of weights vector must not be all zero"); } if (num_samples_ < 0) { - MS_LOG(ERROR) << "WeightedRandomSampler: invalid num_samples: " << num_samples_; - return false; + RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: invalid num_samples: " + std::to_string(num_samples_)); } - return true; + return Status::OK(); } std::shared_ptr WeightedRandomSamplerObj::Build() { auto sampler = std::make_shared(num_samples_, weights_, replacement_); + BuildChildren(sampler); return sampler; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc index 9251d79433e..4f5fb490cd6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc @@ -37,6 +37,9 @@ DistributedSamplerRT::DistributedSamplerRT(int64_t num_samples, int64_t num_dev, non_empty_(true) {} Status DistributedSamplerRT::InitSampler() { + if (is_initialized) { + return Status::OK(); + } // Special value of 0 for num_samples means that the user wants to sample the entire set of data. // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. if (num_samples_ == 0 || num_samples_ > num_rows_) { @@ -72,6 +75,7 @@ Status DistributedSamplerRT::InitSampler() { } if (!samples_per_buffer_) non_empty_ = false; + is_initialized = true; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc index 046f19de1bc..373ccd55fc6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc @@ -28,6 +28,9 @@ PKSamplerRT::PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle, int64_t samples_per_class_(val) {} Status PKSamplerRT::InitSampler() { + if (is_initialized) { + return Status::OK(); + } labels_.reserve(label_to_ids_.size()); for (const auto &pair : label_to_ids_) { if (!pair.second.empty()) { @@ -58,6 +61,7 @@ Status PKSamplerRT::InitSampler() { CHECK_FAIL_RETURN_UNEXPECTED( num_samples_ > 0, "Invalid parameter, num_class or K (num samples per class) must be greater than 0, but got " + std::to_string(num_samples_)); + is_initialized = true; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc index 7637c840921..c4107e7bcd9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc @@ -65,6 +65,9 @@ Status PythonSamplerRT::GetNextSample(std::unique_ptr *out_buffer) { } Status PythonSamplerRT::InitSampler() { + if (is_initialized) { + return Status::OK(); + } CHECK_FAIL_RETURN_UNEXPECTED( num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_)); // Special value of 0 for num_samples means that the user wants to sample the entire set of data. @@ -83,6 +86,8 @@ Status PythonSamplerRT::InitSampler() { return Status(StatusCode::kPyFuncException, e.what()); } } + + is_initialized = true; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc index a525a3f9071..1157e41c439 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc @@ -69,6 +69,9 @@ Status RandomSamplerRT::GetNextSample(std::unique_ptr *out_buffer) { } Status RandomSamplerRT::InitSampler() { + if (is_initialized) { + return Status::OK(); + } // Special value of 0 for num_samples means that the user wants to sample the entire set of data. // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. if (num_samples_ == 0 || num_samples_ > num_rows_) { @@ -91,6 +94,7 @@ Status RandomSamplerRT::InitSampler() { dist = std::make_unique>(0, num_rows_ - 1); } + is_initialized = true; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc index 06accbd5cca..a8a779aa032 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc @@ -34,7 +34,11 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const { } SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_buffer) - : num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} + : num_rows_(0), + num_samples_(num_samples), + samples_per_buffer_(samples_per_buffer), + col_desc_(nullptr), + is_initialized(false) {} Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) { std::shared_ptr child_sampler; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h index 89bdc30665b..a526aaa1444 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h @@ -160,6 +160,7 @@ class SamplerRT { // amount. int64_t num_samples_; + bool is_initialized; int64_t samples_per_buffer_; std::unique_ptr col_desc_; std::vector> child_; // Child nodes diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc index 2752a36412e..f57e1f6b681 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc @@ -63,6 +63,9 @@ Status SequentialSamplerRT::GetNextSample(std::unique_ptr *out_buffe } Status SequentialSamplerRT::InitSampler() { + if (is_initialized) { + return Status::OK(); + } CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "Invalid parameter, start_index must be greater than or equal to 0, but got " + std::to_string(start_index_) + ".\n"); @@ -82,6 +85,8 @@ Status SequentialSamplerRT::InitSampler() { num_samples_ > 0 && samples_per_buffer_ > 0, "Invalid parameter, samples_per_buffer must be greater than 0, but got " + std::to_string(samples_per_buffer_)); samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; + + is_initialized = true; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc index d8b15d4426f..1ac5dcbd3eb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc @@ -32,6 +32,9 @@ SubsetRandomSamplerRT::SubsetRandomSamplerRT(int64_t num_samples, const std::vec // Initialized this Sampler. Status SubsetRandomSamplerRT::InitSampler() { + if (is_initialized) { + return Status::OK(); + } CHECK_FAIL_RETURN_UNEXPECTED( num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_) + ".\n"); @@ -51,6 +54,7 @@ Status SubsetRandomSamplerRT::InitSampler() { // We will shuffle the full set of id's, but only select the first num_samples_ of them later. std::shuffle(indices_.begin(), indices_.end(), rand_gen_); + is_initialized = true; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc index 2e9da033921..1c32bfe5480 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc @@ -37,6 +37,9 @@ WeightedRandomSamplerRT::WeightedRandomSamplerRT(int64_t num_samples, const std: // Initialized this Sampler. Status WeightedRandomSamplerRT::InitSampler() { + if (is_initialized) { + return Status::OK(); + } // Special value of 0 for num_samples means that the user wants to sample the entire set of data. // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. if (num_samples_ == 0 || num_samples_ > num_rows_) { @@ -75,6 +78,7 @@ Status WeightedRandomSamplerRT::InitSampler() { discrete_dist_ = std::make_unique>(weights_.begin(), weights_.end()); } + is_initialized = true; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/include/samplers.h b/mindspore/ccsrc/minddata/dataset/include/samplers.h index e3150fa7512..b52860d58d4 100644 --- a/mindspore/ccsrc/minddata/dataset/include/samplers.h +++ b/mindspore/ccsrc/minddata/dataset/include/samplers.h @@ -22,7 +22,10 @@ #include #ifndef ENABLE_ANDROID -#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" +#include "minddata/dataset/util/status.h" +#include "minddata/mindrecord/include/shard_column.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_reader.h" #endif namespace mindspore { @@ -40,8 +43,8 @@ class SamplerObj : public std::enable_shared_from_this { ~SamplerObj() = default; /// \brief Pure virtual function for derived class to implement parameters validation - /// \return bool true if all the parameters are valid - virtual bool ValidateParams() = 0; + /// \return The Status code of the function. It returns OK status if parameters are valid. + virtual Status ValidateParams() = 0; /// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object /// \return Shared pointers to the newly created Sampler @@ -55,12 +58,24 @@ class SamplerObj : public std::enable_shared_from_this { /// \return The shard id of the derived sampler virtual int64_t ShardId() { return 0; } + /// \brief Adds a child to the sampler + /// \param[in] child The sampler to be added as child + /// \return the Status code returned + Status AddChild(std::shared_ptr child); + #ifndef ENABLE_ANDROID /// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object, /// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler /// \return Shared pointers to the newly created Sampler virtual std::shared_ptr BuildForMindDataset() { return nullptr; } #endif + + protected: + /// \brief A function that calls build on the children of this sampler + /// \param[in] sampler The samplerRT object built from this sampler + void BuildChildren(std::shared_ptr sampler); + + std::vector> children_; }; class DistributedSamplerObj; @@ -137,15 +152,19 @@ class DistributedSamplerObj : public SamplerObj { std::shared_ptr Build() override; std::shared_ptr Copy() override { - return std::make_shared(num_shards_, shard_id_, shuffle_, num_samples_, seed_, offset_, - even_dist_); + auto sampler = std::make_shared(num_shards_, shard_id_, shuffle_, num_samples_, seed_, + offset_, even_dist_); + for (auto child : children_) { + sampler->AddChild(child); + } + return sampler; } #ifndef ENABLE_ANDROID std::shared_ptr BuildForMindDataset() override; #endif - bool ValidateParams() override; + Status ValidateParams() override; /// \brief Function to get the shard id of sampler /// \return The shard id of sampler @@ -170,14 +189,18 @@ class PKSamplerObj : public SamplerObj { std::shared_ptr Build() override; std::shared_ptr Copy() override { - return std::make_shared(num_val_, shuffle_, num_samples_); + auto sampler = std::make_shared(num_val_, shuffle_, num_samples_); + for (auto child : children_) { + sampler->AddChild(child); + } + return sampler; } #ifndef ENABLE_ANDROID std::shared_ptr BuildForMindDataset() override; #endif - bool ValidateParams() override; + Status ValidateParams() override; private: int64_t num_val_; @@ -202,7 +225,7 @@ class PreBuiltSamplerObj : public SamplerObj { std::shared_ptr Copy() override; - bool ValidateParams() override; + Status ValidateParams() override; private: std::shared_ptr sp_; @@ -219,13 +242,19 @@ class RandomSamplerObj : public SamplerObj { std::shared_ptr Build() override; - std::shared_ptr Copy() override { return std::make_shared(replacement_, num_samples_); } + std::shared_ptr Copy() override { + auto sampler = std::make_shared(replacement_, num_samples_); + for (auto child : children_) { + sampler->AddChild(child); + } + return sampler; + } #ifndef ENABLE_ANDROID std::shared_ptr BuildForMindDataset() override; #endif - bool ValidateParams() override; + Status ValidateParams() override; private: bool replacement_; @@ -241,14 +270,18 @@ class SequentialSamplerObj : public SamplerObj { std::shared_ptr Build() override; std::shared_ptr Copy() override { - return std::make_shared(start_index_, num_samples_); + auto sampler = std::make_shared(start_index_, num_samples_); + for (auto child : children_) { + sampler->AddChild(child); + } + return sampler; } #ifndef ENABLE_ANDROID std::shared_ptr BuildForMindDataset() override; #endif - bool ValidateParams() override; + Status ValidateParams() override; private: int64_t start_index_; @@ -264,14 +297,18 @@ class SubsetRandomSamplerObj : public SamplerObj { std::shared_ptr Build() override; std::shared_ptr Copy() override { - return std::make_shared(indices_, num_samples_); + auto sampler = std::make_shared(indices_, num_samples_); + for (auto child : children_) { + sampler->AddChild(child); + } + return sampler; } #ifndef ENABLE_ANDROID std::shared_ptr BuildForMindDataset() override; #endif - bool ValidateParams() override; + Status ValidateParams() override; private: const std::vector indices_; @@ -287,10 +324,14 @@ class WeightedRandomSamplerObj : public SamplerObj { std::shared_ptr Build() override; std::shared_ptr Copy() override { - return std::make_shared(weights_, num_samples_, replacement_); + auto sampler = std::make_shared(weights_, num_samples_, replacement_); + for (auto child : children_) { + sampler->AddChild(child); + } + return sampler; } - bool ValidateParams() override; + Status ValidateParams() override; private: const std::vector weights_; diff --git a/tests/ut/cpp/dataset/c_api_samplers_test.cc b/tests/ut/cpp/dataset/c_api_samplers_test.cc index 01a3092f22a..d771aba2325 100644 --- a/tests/ut/cpp/dataset/c_api_samplers_test.cc +++ b/tests/ut/cpp/dataset/c_api_samplers_test.cc @@ -208,6 +208,37 @@ TEST_F(MindDataTestPipeline, TestDistributedSamplerSuccess) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestSamplerAddChild) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSamplerAddChild."; + + auto sampler = DistributedSampler(1, 0, false, 5, 0, -1, true); + EXPECT_NE(sampler, nullptr); + + auto child_sampler = SequentialSampler(); + sampler->AddChild(child_sampler); + EXPECT_NE(child_sampler, nullptr); + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, false, sampler); + EXPECT_NE(ds, nullptr); + + // Iterate the dataset and get each row + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + iter->GetNextRow(&row); + } + + EXPECT_EQ(ds->GetDatasetSize(), 5); + iter->Stop(); +} + TEST_F(MindDataTestPipeline, TestDistributedSamplerFail) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDistributedSamplerFail."; // Test invalid offset setting of distributed_sampler