!9687 Adding AddChild functionality to SamplerObj

From: @mahdirahmanihanzaki
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-01-06 01:46:08 +08:00 committed by Gitee
commit 5fd77356c9
12 changed files with 211 additions and 73 deletions

View File

@ -48,6 +48,34 @@ namespace dataset {
// Constructor
SamplerObj::SamplerObj() {}
void SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> sampler) {
for (auto child : children_) {
auto sampler_rt = child->Build();
sampler->AddChild(sampler_rt);
}
}
Status SamplerObj::AddChild(std::shared_ptr<SamplerObj> child) {
if (child == nullptr) {
return Status::OK();
}
// Only samplers can be added, not any other DatasetOp.
std::shared_ptr<SamplerObj> sampler = std::dynamic_pointer_cast<SamplerObj>(child);
if (!sampler) {
RETURN_STATUS_UNEXPECTED("Cannot add child, child is not a sampler object.");
}
// Samplers can have at most 1 child.
if (!children_.empty()) {
RETURN_STATUS_UNEXPECTED("Cannot add child sampler, this sampler already has a child.");
}
children_.push_back(child);
return Status::OK();
}
/// Function to create a Distributed Sampler.
std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle,
int64_t num_samples, uint32_t seed, int64_t offset,
@ -55,7 +83,7 @@ std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, in
auto sampler =
std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist);
// Input validation
if (!sampler->ValidateParams()) {
if (sampler->ValidateParams().IsError()) {
return nullptr;
}
return sampler;
@ -65,7 +93,7 @@ std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, in
std::shared_ptr<PKSamplerObj> PKSampler(int64_t num_val, bool shuffle, int64_t num_samples) {
auto sampler = std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples);
// Input validation
if (!sampler->ValidateParams()) {
if (sampler->ValidateParams().IsError()) {
return nullptr;
}
return sampler;
@ -75,7 +103,7 @@ std::shared_ptr<PKSamplerObj> PKSampler(int64_t num_val, bool shuffle, int64_t n
std::shared_ptr<RandomSamplerObj> RandomSampler(bool replacement, int64_t num_samples) {
auto sampler = std::make_shared<RandomSamplerObj>(replacement, num_samples);
// Input validation
if (!sampler->ValidateParams()) {
if (sampler->ValidateParams().IsError()) {
return nullptr;
}
return sampler;
@ -85,7 +113,7 @@ std::shared_ptr<RandomSamplerObj> RandomSampler(bool replacement, int64_t num_sa
std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index, int64_t num_samples) {
auto sampler = std::make_shared<SequentialSamplerObj>(start_index, num_samples);
// Input validation
if (!sampler->ValidateParams()) {
if (sampler->ValidateParams().IsError()) {
return nullptr;
}
return sampler;
@ -95,7 +123,7 @@ std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index, int
std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples) {
auto sampler = std::make_shared<SubsetRandomSamplerObj>(std::move(indices), num_samples);
// Input validation
if (!sampler->ValidateParams()) {
if (sampler->ValidateParams().IsError()) {
return nullptr;
}
return sampler;
@ -106,7 +134,7 @@ std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<doub
bool replacement) {
auto sampler = std::make_shared<WeightedRandomSamplerObj>(std::move(weights), num_samples, replacement);
// Input validation
if (!sampler->ValidateParams()) {
if (sampler->ValidateParams().IsError()) {
return nullptr;
}
return sampler;
@ -131,35 +159,33 @@ DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_i
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
}
bool DistributedSamplerObj::ValidateParams() {
Status DistributedSamplerObj::ValidateParams() {
if (num_shards_ <= 0) {
MS_LOG(ERROR) << "DistributedSampler: invalid num_shards: " << num_shards_;
return false;
RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid num_shards: " + std::to_string(num_shards_));
}
if (shard_id_ < 0 || shard_id_ >= num_shards_) {
MS_LOG(ERROR) << "DistributedSampler: invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_;
return false;
RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid input, shard_id: " + std::to_string(shard_id_) +
", num_shards: " + std::to_string(num_shards_));
}
if (num_samples_ < 0) {
MS_LOG(ERROR) << "DistributedSampler: invalid num_samples: " << num_samples_;
return false;
RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid num_samples: " + std::to_string(num_samples_));
}
if (offset_ > num_shards_) {
MS_LOG(ERROR) << "DistributedSampler: invalid offset: " << offset_
<< ", which should be no more than num_shards: " << num_shards_;
return false;
RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid offset: " + std::to_string(offset_) +
", which should be no more than num_shards: " + std::to_string(num_shards_));
}
return true;
return Status::OK();
}
std::shared_ptr<SamplerRT> DistributedSamplerObj::Build() {
// runtime sampler object
auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
offset_, even_dist_);
BuildChildren(sampler);
return sampler;
}
@ -176,23 +202,21 @@ std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDa
PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples)
: num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {}
bool PKSamplerObj::ValidateParams() {
Status PKSamplerObj::ValidateParams() {
if (num_val_ <= 0) {
MS_LOG(ERROR) << "PKSampler: invalid num_val: " << num_val_;
return false;
RETURN_STATUS_UNEXPECTED("PKSampler: invalid num_val: " + std::to_string(num_val_));
}
if (num_samples_ < 0) {
MS_LOG(ERROR) << "PKSampler: invalid num_samples: " << num_samples_;
return false;
RETURN_STATUS_UNEXPECTED("PKSampler: invalid num_samples: " + std::to_string(num_samples_));
}
return true;
return Status::OK();
}
std::shared_ptr<SamplerRT> PKSamplerObj::Build() {
// runtime sampler object
auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);
BuildChildren(sampler);
return sampler;
}
@ -204,9 +228,12 @@ PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator
: sp_minddataset_(std::move(sampler)) {}
#endif
bool PreBuiltSamplerObj::ValidateParams() { return true; }
Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); }
std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() { return sp_; }
std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() {
BuildChildren(sp_);
return sp_;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; }
@ -214,9 +241,19 @@ std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDatas
std::shared_ptr<SamplerObj> PreBuiltSamplerObj::Copy() {
#ifndef ENABLE_ANDROID
if (sp_minddataset_ != nullptr) return std::make_shared<PreBuiltSamplerObj>(sp_minddataset_);
if (sp_minddataset_ != nullptr) {
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_minddataset_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
}
#endif
return std::make_shared<PreBuiltSamplerObj>(sp_);
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
@ -238,19 +275,18 @@ std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples)
: replacement_(replacement), num_samples_(num_samples) {}
bool RandomSamplerObj::ValidateParams() {
Status RandomSamplerObj::ValidateParams() {
if (num_samples_ < 0) {
MS_LOG(ERROR) << "RandomSampler: invalid num_samples: " << num_samples_;
return false;
RETURN_STATUS_UNEXPECTED("RandomSampler: invalid num_samples: " + std::to_string(num_samples_));
}
return true;
return Status::OK();
}
std::shared_ptr<SamplerRT> RandomSamplerObj::Build() {
// runtime sampler object
bool reshuffle_each_epoch = true;
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch);
BuildChildren(sampler);
return sampler;
}
@ -269,24 +305,22 @@ std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset
SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples)
: start_index_(start_index), num_samples_(num_samples) {}
bool SequentialSamplerObj::ValidateParams() {
Status SequentialSamplerObj::ValidateParams() {
if (num_samples_ < 0) {
MS_LOG(ERROR) << "SequentialSampler: invalid num_samples: " << num_samples_;
return false;
RETURN_STATUS_UNEXPECTED("SequentialSampler: invalid num_samples: " + std::to_string(num_samples_));
}
if (start_index_ < 0) {
MS_LOG(ERROR) << "SequentialSampler: invalid start_index: " << start_index_;
return false;
RETURN_STATUS_UNEXPECTED("SequentialSampler: invalid start_index: " + std::to_string(start_index_));
}
return true;
return Status::OK();
}
std::shared_ptr<SamplerRT> SequentialSamplerObj::Build() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);
BuildChildren(sampler);
return sampler;
}
@ -303,19 +337,18 @@ std::shared_ptr<mindrecord::ShardOperator> SequentialSamplerObj::BuildForMindDat
SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
: indices_(std::move(indices)), num_samples_(num_samples) {}
bool SubsetRandomSamplerObj::ValidateParams() {
Status SubsetRandomSamplerObj::ValidateParams() {
if (num_samples_ < 0) {
MS_LOG(ERROR) << "SubsetRandomSampler: invalid num_samples: " << num_samples_;
return false;
RETURN_STATUS_UNEXPECTED("SubsetRandomSampler: invalid num_samples: " + std::to_string(num_samples_));
}
return true;
return Status::OK();
}
std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::Build() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_);
BuildChildren(sampler);
return sampler;
}
@ -332,34 +365,32 @@ std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindD
WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
: weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}
bool WeightedRandomSamplerObj::ValidateParams() {
Status WeightedRandomSamplerObj::ValidateParams() {
if (weights_.empty()) {
MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not be empty";
return false;
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not be empty");
}
int32_t zero_elem = 0;
for (int32_t i = 0; i < weights_.size(); ++i) {
if (weights_[i] < 0) {
MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not contain negative number, got: " << weights_[i];
return false;
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not contain negative number, got: " +
std::to_string(weights_[i]));
}
if (weights_[i] == 0.0) {
zero_elem++;
}
}
if (zero_elem == weights_.size()) {
MS_LOG(ERROR) << "WeightedRandomSampler: elements of weights vector must not be all zero";
return false;
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: elements of weights vector must not be all zero");
}
if (num_samples_ < 0) {
MS_LOG(ERROR) << "WeightedRandomSampler: invalid num_samples: " << num_samples_;
return false;
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: invalid num_samples: " + std::to_string(num_samples_));
}
return true;
return Status::OK();
}
std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::Build() {
auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
BuildChildren(sampler);
return sampler;
}

View File

@ -43,6 +43,9 @@ DistributedSamplerRT::DistributedSamplerRT(int64_t num_samples, int64_t num_dev,
}
Status DistributedSamplerRT::InitSampler() {
if (is_initialized) {
return Status::OK();
}
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) {
@ -78,6 +81,7 @@ Status DistributedSamplerRT::InitSampler() {
}
if (!samples_per_buffer_) non_empty_ = false;
is_initialized = true;
return Status::OK();
}

View File

@ -28,6 +28,9 @@ PKSamplerRT::PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle, int64_t
samples_per_class_(val) {}
Status PKSamplerRT::InitSampler() {
if (is_initialized) {
return Status::OK();
}
labels_.reserve(label_to_ids_.size());
for (const auto &pair : label_to_ids_) {
if (!pair.second.empty()) {
@ -58,6 +61,7 @@ Status PKSamplerRT::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(
num_samples_ > 0, "Invalid parameter, num_class or K (num samples per class) must be greater than 0, but got " +
std::to_string(num_samples_));
is_initialized = true;
return Status::OK();
}

View File

@ -65,6 +65,9 @@ Status PythonSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
}
Status PythonSamplerRT::InitSampler() {
if (is_initialized) {
return Status::OK();
}
CHECK_FAIL_RETURN_UNEXPECTED(
num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_));
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
@ -83,6 +86,8 @@ Status PythonSamplerRT::InitSampler() {
return Status(StatusCode::kPyFuncException, e.what());
}
}
is_initialized = true;
return Status::OK();
}

View File

@ -69,6 +69,9 @@ Status RandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
}
Status RandomSamplerRT::InitSampler() {
if (is_initialized) {
return Status::OK();
}
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) {
@ -91,6 +94,7 @@ Status RandomSamplerRT::InitSampler() {
dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1);
}
is_initialized = true;
return Status::OK();
}

View File

@ -34,7 +34,11 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const {
}
SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_buffer)
: num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {}
: num_rows_(0),
num_samples_(num_samples),
samples_per_buffer_(samples_per_buffer),
col_desc_(nullptr),
is_initialized(false) {}
Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) {
std::shared_ptr<SamplerRT> child_sampler;

View File

@ -160,6 +160,7 @@ class SamplerRT {
// amount.
int64_t num_samples_;
bool is_initialized;
int64_t samples_per_buffer_;
std::unique_ptr<ColDescriptor> col_desc_;
std::vector<std::shared_ptr<SamplerRT>> child_; // Child nodes

View File

@ -63,6 +63,9 @@ Status SequentialSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffe
}
Status SequentialSamplerRT::InitSampler() {
if (is_initialized) {
return Status::OK();
}
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0,
"Invalid parameter, start_index must be greater than or equal to 0, but got " +
std::to_string(start_index_) + ".\n");
@ -82,6 +85,8 @@ Status SequentialSamplerRT::InitSampler() {
num_samples_ > 0 && samples_per_buffer_ > 0,
"Invalid parameter, samples_per_buffer must be greater than 0, but got " + std::to_string(samples_per_buffer_));
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
is_initialized = true;
return Status::OK();
}

View File

@ -32,6 +32,9 @@ SubsetRandomSamplerRT::SubsetRandomSamplerRT(int64_t num_samples, const std::vec
// Initialized this Sampler.
Status SubsetRandomSamplerRT::InitSampler() {
if (is_initialized) {
return Status::OK();
}
CHECK_FAIL_RETURN_UNEXPECTED(
num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_) + ".\n");
@ -51,6 +54,7 @@ Status SubsetRandomSamplerRT::InitSampler() {
// We will shuffle the full set of id's, but only select the first num_samples_ of them later.
std::shuffle(indices_.begin(), indices_.end(), rand_gen_);
is_initialized = true;
return Status::OK();
}

View File

@ -37,6 +37,9 @@ WeightedRandomSamplerRT::WeightedRandomSamplerRT(int64_t num_samples, const std:
// Initialized this Sampler.
Status WeightedRandomSamplerRT::InitSampler() {
if (is_initialized) {
return Status::OK();
}
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) {
@ -75,6 +78,7 @@ Status WeightedRandomSamplerRT::InitSampler() {
discrete_dist_ = std::make_unique<std::discrete_distribution<int64_t>>(weights_.begin(), weights_.end());
}
is_initialized = true;
return Status::OK();
}

View File

@ -22,7 +22,10 @@
#include <vector>
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
#include "minddata/dataset/util/status.h"
#include "minddata/mindrecord/include/shard_column.h"
#include "minddata/mindrecord/include/shard_error.h"
#include "minddata/mindrecord/include/shard_reader.h"
#endif
namespace mindspore {
@ -40,8 +43,8 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
~SamplerObj() = default;
/// \brief Pure virtual function for derived class to implement parameters validation
/// \return bool true if all the parameters are valid
virtual bool ValidateParams() = 0;
/// \return The Status code of the function. It returns OK status if parameters are valid.
virtual Status ValidateParams() = 0;
/// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object
/// \return Shared pointers to the newly created Sampler
@ -55,12 +58,24 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
/// \return The shard id of the derived sampler
virtual int64_t ShardId() { return 0; }
/// \brief Adds a child to the sampler
/// \param[in] child The sampler to be added as child
/// \return the Status code returned
Status AddChild(std::shared_ptr<SamplerObj> child);
#ifndef ENABLE_ANDROID
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
/// \return Shared pointers to the newly created Sampler
virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; }
#endif
protected:
/// \brief A function that calls build on the children of this sampler
/// \param[in] sampler The samplerRT object built from this sampler
void BuildChildren(std::shared_ptr<SamplerRT> sampler);
std::vector<std::shared_ptr<SamplerObj>> children_;
};
class DistributedSamplerObj;
@ -137,15 +152,19 @@ class DistributedSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, offset_,
even_dist_);
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_,
offset_, even_dist_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
bool ValidateParams() override;
Status ValidateParams() override;
/// \brief Function to get the shard id of sampler
/// \return The shard id of sampler
@ -170,14 +189,18 @@ class PKSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
bool ValidateParams() override;
Status ValidateParams() override;
private:
int64_t num_val_;
@ -202,7 +225,7 @@ class PreBuiltSamplerObj : public SamplerObj {
std::shared_ptr<SamplerObj> Copy() override;
bool ValidateParams() override;
Status ValidateParams() override;
private:
std::shared_ptr<SamplerRT> sp_;
@ -219,13 +242,19 @@ class RandomSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerObj> Copy() override { return std::make_shared<RandomSamplerObj>(replacement_, num_samples_); }
std::shared_ptr<SamplerObj> Copy() override {
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
bool ValidateParams() override;
Status ValidateParams() override;
private:
bool replacement_;
@ -241,14 +270,18 @@ class SequentialSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
bool ValidateParams() override;
Status ValidateParams() override;
private:
int64_t start_index_;
@ -264,14 +297,18 @@ class SubsetRandomSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
bool ValidateParams() override;
Status ValidateParams() override;
private:
const std::vector<int64_t> indices_;
@ -287,10 +324,14 @@ class WeightedRandomSamplerObj : public SamplerObj {
std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerObj> Copy() override {
return std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
for (auto child : children_) {
sampler->AddChild(child);
}
return sampler;
}
bool ValidateParams() override;
Status ValidateParams() override;
private:
const std::vector<double> weights_;

View File

@ -208,6 +208,37 @@ TEST_F(MindDataTestPipeline, TestDistributedSamplerSuccess) {
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestSamplerAddChild) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSamplerAddChild.";
auto sampler = DistributedSampler(1, 0, false, 5, 0, -1, true);
EXPECT_NE(sampler, nullptr);
auto child_sampler = SequentialSampler();
sampler->AddChild(child_sampler);
EXPECT_NE(child_sampler, nullptr);
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, sampler);
EXPECT_NE(ds, nullptr);
// Iterate the dataset and get each row
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
iter->GetNextRow(&row);
}
EXPECT_EQ(ds->GetDatasetSize(), 5);
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestDistributedSamplerFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDistributedSamplerFail.";
// Test invalid offset setting of distributed_sampler