forked from mindspore-Ecosystem/mindspore
!9687 Adding AddChild functionality to SamplerObj
From: @mahdirahmanihanzaki Reviewed-by: Signed-off-by:
This commit is contained in:
commit
5fd77356c9
|
@ -48,6 +48,34 @@ namespace dataset {
|
|||
// Constructor
|
||||
SamplerObj::SamplerObj() {}
|
||||
|
||||
void SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> sampler) {
|
||||
for (auto child : children_) {
|
||||
auto sampler_rt = child->Build();
|
||||
sampler->AddChild(sampler_rt);
|
||||
}
|
||||
}
|
||||
|
||||
Status SamplerObj::AddChild(std::shared_ptr<SamplerObj> child) {
|
||||
if (child == nullptr) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Only samplers can be added, not any other DatasetOp.
|
||||
std::shared_ptr<SamplerObj> sampler = std::dynamic_pointer_cast<SamplerObj>(child);
|
||||
if (!sampler) {
|
||||
RETURN_STATUS_UNEXPECTED("Cannot add child, child is not a sampler object.");
|
||||
}
|
||||
|
||||
// Samplers can have at most 1 child.
|
||||
if (!children_.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("Cannot add child sampler, this sampler already has a child.");
|
||||
}
|
||||
|
||||
children_.push_back(child);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// Function to create a Distributed Sampler.
|
||||
std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle,
|
||||
int64_t num_samples, uint32_t seed, int64_t offset,
|
||||
|
@ -55,7 +83,7 @@ std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, in
|
|||
auto sampler =
|
||||
std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist);
|
||||
// Input validation
|
||||
if (!sampler->ValidateParams()) {
|
||||
if (sampler->ValidateParams().IsError()) {
|
||||
return nullptr;
|
||||
}
|
||||
return sampler;
|
||||
|
@ -65,7 +93,7 @@ std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, in
|
|||
std::shared_ptr<PKSamplerObj> PKSampler(int64_t num_val, bool shuffle, int64_t num_samples) {
|
||||
auto sampler = std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples);
|
||||
// Input validation
|
||||
if (!sampler->ValidateParams()) {
|
||||
if (sampler->ValidateParams().IsError()) {
|
||||
return nullptr;
|
||||
}
|
||||
return sampler;
|
||||
|
@ -75,7 +103,7 @@ std::shared_ptr<PKSamplerObj> PKSampler(int64_t num_val, bool shuffle, int64_t n
|
|||
std::shared_ptr<RandomSamplerObj> RandomSampler(bool replacement, int64_t num_samples) {
|
||||
auto sampler = std::make_shared<RandomSamplerObj>(replacement, num_samples);
|
||||
// Input validation
|
||||
if (!sampler->ValidateParams()) {
|
||||
if (sampler->ValidateParams().IsError()) {
|
||||
return nullptr;
|
||||
}
|
||||
return sampler;
|
||||
|
@ -85,7 +113,7 @@ std::shared_ptr<RandomSamplerObj> RandomSampler(bool replacement, int64_t num_sa
|
|||
std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index, int64_t num_samples) {
|
||||
auto sampler = std::make_shared<SequentialSamplerObj>(start_index, num_samples);
|
||||
// Input validation
|
||||
if (!sampler->ValidateParams()) {
|
||||
if (sampler->ValidateParams().IsError()) {
|
||||
return nullptr;
|
||||
}
|
||||
return sampler;
|
||||
|
@ -95,7 +123,7 @@ std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index, int
|
|||
std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples) {
|
||||
auto sampler = std::make_shared<SubsetRandomSamplerObj>(std::move(indices), num_samples);
|
||||
// Input validation
|
||||
if (!sampler->ValidateParams()) {
|
||||
if (sampler->ValidateParams().IsError()) {
|
||||
return nullptr;
|
||||
}
|
||||
return sampler;
|
||||
|
@ -106,7 +134,7 @@ std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<doub
|
|||
bool replacement) {
|
||||
auto sampler = std::make_shared<WeightedRandomSamplerObj>(std::move(weights), num_samples, replacement);
|
||||
// Input validation
|
||||
if (!sampler->ValidateParams()) {
|
||||
if (sampler->ValidateParams().IsError()) {
|
||||
return nullptr;
|
||||
}
|
||||
return sampler;
|
||||
|
@ -131,35 +159,33 @@ DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_i
|
|||
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
|
||||
}
|
||||
|
||||
bool DistributedSamplerObj::ValidateParams() {
|
||||
Status DistributedSamplerObj::ValidateParams() {
|
||||
if (num_shards_ <= 0) {
|
||||
MS_LOG(ERROR) << "DistributedSampler: invalid num_shards: " << num_shards_;
|
||||
return false;
|
||||
RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid num_shards: " + std::to_string(num_shards_));
|
||||
}
|
||||
|
||||
if (shard_id_ < 0 || shard_id_ >= num_shards_) {
|
||||
MS_LOG(ERROR) << "DistributedSampler: invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_;
|
||||
return false;
|
||||
RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid input, shard_id: " + std::to_string(shard_id_) +
|
||||
", num_shards: " + std::to_string(num_shards_));
|
||||
}
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
MS_LOG(ERROR) << "DistributedSampler: invalid num_samples: " << num_samples_;
|
||||
return false;
|
||||
RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid num_samples: " + std::to_string(num_samples_));
|
||||
}
|
||||
|
||||
if (offset_ > num_shards_) {
|
||||
MS_LOG(ERROR) << "DistributedSampler: invalid offset: " << offset_
|
||||
<< ", which should be no more than num_shards: " << num_shards_;
|
||||
return false;
|
||||
RETURN_STATUS_UNEXPECTED("DistributedSampler: invalid offset: " + std::to_string(offset_) +
|
||||
", which should be no more than num_shards: " + std::to_string(num_shards_));
|
||||
}
|
||||
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> DistributedSamplerObj::Build() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
|
||||
offset_, even_dist_);
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
|
@ -176,23 +202,21 @@ std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDa
|
|||
PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples)
|
||||
: num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {}
|
||||
|
||||
bool PKSamplerObj::ValidateParams() {
|
||||
Status PKSamplerObj::ValidateParams() {
|
||||
if (num_val_ <= 0) {
|
||||
MS_LOG(ERROR) << "PKSampler: invalid num_val: " << num_val_;
|
||||
return false;
|
||||
RETURN_STATUS_UNEXPECTED("PKSampler: invalid num_val: " + std::to_string(num_val_));
|
||||
}
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
MS_LOG(ERROR) << "PKSampler: invalid num_samples: " << num_samples_;
|
||||
return false;
|
||||
RETURN_STATUS_UNEXPECTED("PKSampler: invalid num_samples: " + std::to_string(num_samples_));
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> PKSamplerObj::Build() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);
|
||||
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
|
@ -204,9 +228,12 @@ PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator
|
|||
: sp_minddataset_(std::move(sampler)) {}
|
||||
#endif
|
||||
|
||||
bool PreBuiltSamplerObj::ValidateParams() { return true; }
|
||||
Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); }
|
||||
|
||||
std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() { return sp_; }
|
||||
std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() {
|
||||
BuildChildren(sp_);
|
||||
return sp_;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; }
|
||||
|
@ -214,9 +241,19 @@ std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDatas
|
|||
|
||||
std::shared_ptr<SamplerObj> PreBuiltSamplerObj::Copy() {
|
||||
#ifndef ENABLE_ANDROID
|
||||
if (sp_minddataset_ != nullptr) return std::make_shared<PreBuiltSamplerObj>(sp_minddataset_);
|
||||
if (sp_minddataset_ != nullptr) {
|
||||
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_minddataset_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChild(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
#endif
|
||||
return std::make_shared<PreBuiltSamplerObj>(sp_);
|
||||
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChild(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
|
@ -238,19 +275,18 @@ std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
|
|||
RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples)
|
||||
: replacement_(replacement), num_samples_(num_samples) {}
|
||||
|
||||
bool RandomSamplerObj::ValidateParams() {
|
||||
Status RandomSamplerObj::ValidateParams() {
|
||||
if (num_samples_ < 0) {
|
||||
MS_LOG(ERROR) << "RandomSampler: invalid num_samples: " << num_samples_;
|
||||
return false;
|
||||
RETURN_STATUS_UNEXPECTED("RandomSampler: invalid num_samples: " + std::to_string(num_samples_));
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> RandomSamplerObj::Build() {
|
||||
// runtime sampler object
|
||||
bool reshuffle_each_epoch = true;
|
||||
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch);
|
||||
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
|
@ -269,24 +305,22 @@ std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset
|
|||
SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples)
|
||||
: start_index_(start_index), num_samples_(num_samples) {}
|
||||
|
||||
bool SequentialSamplerObj::ValidateParams() {
|
||||
Status SequentialSamplerObj::ValidateParams() {
|
||||
if (num_samples_ < 0) {
|
||||
MS_LOG(ERROR) << "SequentialSampler: invalid num_samples: " << num_samples_;
|
||||
return false;
|
||||
RETURN_STATUS_UNEXPECTED("SequentialSampler: invalid num_samples: " + std::to_string(num_samples_));
|
||||
}
|
||||
|
||||
if (start_index_ < 0) {
|
||||
MS_LOG(ERROR) << "SequentialSampler: invalid start_index: " << start_index_;
|
||||
return false;
|
||||
RETURN_STATUS_UNEXPECTED("SequentialSampler: invalid start_index: " + std::to_string(start_index_));
|
||||
}
|
||||
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> SequentialSamplerObj::Build() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);
|
||||
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
|
@ -303,19 +337,18 @@ std::shared_ptr<mindrecord::ShardOperator> SequentialSamplerObj::BuildForMindDat
|
|||
SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
|
||||
: indices_(std::move(indices)), num_samples_(num_samples) {}
|
||||
|
||||
bool SubsetRandomSamplerObj::ValidateParams() {
|
||||
Status SubsetRandomSamplerObj::ValidateParams() {
|
||||
if (num_samples_ < 0) {
|
||||
MS_LOG(ERROR) << "SubsetRandomSampler: invalid num_samples: " << num_samples_;
|
||||
return false;
|
||||
RETURN_STATUS_UNEXPECTED("SubsetRandomSampler: invalid num_samples: " + std::to_string(num_samples_));
|
||||
}
|
||||
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::Build() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_);
|
||||
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
|
@ -332,34 +365,32 @@ std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindD
|
|||
WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
|
||||
: weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}
|
||||
|
||||
bool WeightedRandomSamplerObj::ValidateParams() {
|
||||
Status WeightedRandomSamplerObj::ValidateParams() {
|
||||
if (weights_.empty()) {
|
||||
MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not be empty";
|
||||
return false;
|
||||
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not be empty");
|
||||
}
|
||||
int32_t zero_elem = 0;
|
||||
for (int32_t i = 0; i < weights_.size(); ++i) {
|
||||
if (weights_[i] < 0) {
|
||||
MS_LOG(ERROR) << "WeightedRandomSampler: weights vector must not contain negative number, got: " << weights_[i];
|
||||
return false;
|
||||
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not contain negative number, got: " +
|
||||
std::to_string(weights_[i]));
|
||||
}
|
||||
if (weights_[i] == 0.0) {
|
||||
zero_elem++;
|
||||
}
|
||||
}
|
||||
if (zero_elem == weights_.size()) {
|
||||
MS_LOG(ERROR) << "WeightedRandomSampler: elements of weights vector must not be all zero";
|
||||
return false;
|
||||
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: elements of weights vector must not be all zero");
|
||||
}
|
||||
if (num_samples_ < 0) {
|
||||
MS_LOG(ERROR) << "WeightedRandomSampler: invalid num_samples: " << num_samples_;
|
||||
return false;
|
||||
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: invalid num_samples: " + std::to_string(num_samples_));
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::Build() {
|
||||
auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
|
|
|
@ -43,6 +43,9 @@ DistributedSamplerRT::DistributedSamplerRT(int64_t num_samples, int64_t num_dev,
|
|||
}
|
||||
|
||||
Status DistributedSamplerRT::InitSampler() {
|
||||
if (is_initialized) {
|
||||
return Status::OK();
|
||||
}
|
||||
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
|
||||
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
|
||||
if (num_samples_ == 0 || num_samples_ > num_rows_) {
|
||||
|
@ -78,6 +81,7 @@ Status DistributedSamplerRT::InitSampler() {
|
|||
}
|
||||
if (!samples_per_buffer_) non_empty_ = false;
|
||||
|
||||
is_initialized = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -28,6 +28,9 @@ PKSamplerRT::PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle, int64_t
|
|||
samples_per_class_(val) {}
|
||||
|
||||
Status PKSamplerRT::InitSampler() {
|
||||
if (is_initialized) {
|
||||
return Status::OK();
|
||||
}
|
||||
labels_.reserve(label_to_ids_.size());
|
||||
for (const auto &pair : label_to_ids_) {
|
||||
if (!pair.second.empty()) {
|
||||
|
@ -58,6 +61,7 @@ Status PKSamplerRT::InitSampler() {
|
|||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
num_samples_ > 0, "Invalid parameter, num_class or K (num samples per class) must be greater than 0, but got " +
|
||||
std::to_string(num_samples_));
|
||||
is_initialized = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -65,6 +65,9 @@ Status PythonSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
|
|||
}
|
||||
|
||||
Status PythonSamplerRT::InitSampler() {
|
||||
if (is_initialized) {
|
||||
return Status::OK();
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_));
|
||||
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
|
||||
|
@ -83,6 +86,8 @@ Status PythonSamplerRT::InitSampler() {
|
|||
return Status(StatusCode::kPyFuncException, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
is_initialized = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -69,6 +69,9 @@ Status RandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
|
|||
}
|
||||
|
||||
Status RandomSamplerRT::InitSampler() {
|
||||
if (is_initialized) {
|
||||
return Status::OK();
|
||||
}
|
||||
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
|
||||
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
|
||||
if (num_samples_ == 0 || num_samples_ > num_rows_) {
|
||||
|
@ -91,6 +94,7 @@ Status RandomSamplerRT::InitSampler() {
|
|||
dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1);
|
||||
}
|
||||
|
||||
is_initialized = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -34,7 +34,11 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const {
|
|||
}
|
||||
|
||||
SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_buffer)
|
||||
: num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {}
|
||||
: num_rows_(0),
|
||||
num_samples_(num_samples),
|
||||
samples_per_buffer_(samples_per_buffer),
|
||||
col_desc_(nullptr),
|
||||
is_initialized(false) {}
|
||||
|
||||
Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) {
|
||||
std::shared_ptr<SamplerRT> child_sampler;
|
||||
|
|
|
@ -160,6 +160,7 @@ class SamplerRT {
|
|||
// amount.
|
||||
int64_t num_samples_;
|
||||
|
||||
bool is_initialized;
|
||||
int64_t samples_per_buffer_;
|
||||
std::unique_ptr<ColDescriptor> col_desc_;
|
||||
std::vector<std::shared_ptr<SamplerRT>> child_; // Child nodes
|
||||
|
|
|
@ -63,6 +63,9 @@ Status SequentialSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffe
|
|||
}
|
||||
|
||||
Status SequentialSamplerRT::InitSampler() {
|
||||
if (is_initialized) {
|
||||
return Status::OK();
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0,
|
||||
"Invalid parameter, start_index must be greater than or equal to 0, but got " +
|
||||
std::to_string(start_index_) + ".\n");
|
||||
|
@ -82,6 +85,8 @@ Status SequentialSamplerRT::InitSampler() {
|
|||
num_samples_ > 0 && samples_per_buffer_ > 0,
|
||||
"Invalid parameter, samples_per_buffer must be greater than 0, but got " + std::to_string(samples_per_buffer_));
|
||||
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
|
||||
|
||||
is_initialized = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -32,6 +32,9 @@ SubsetRandomSamplerRT::SubsetRandomSamplerRT(int64_t num_samples, const std::vec
|
|||
|
||||
// Initialized this Sampler.
|
||||
Status SubsetRandomSamplerRT::InitSampler() {
|
||||
if (is_initialized) {
|
||||
return Status::OK();
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
num_rows_ > 0, "Invalid parameter, num_rows must be greater than 0, but got " + std::to_string(num_rows_) + ".\n");
|
||||
|
||||
|
@ -51,6 +54,7 @@ Status SubsetRandomSamplerRT::InitSampler() {
|
|||
// We will shuffle the full set of id's, but only select the first num_samples_ of them later.
|
||||
std::shuffle(indices_.begin(), indices_.end(), rand_gen_);
|
||||
|
||||
is_initialized = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -37,6 +37,9 @@ WeightedRandomSamplerRT::WeightedRandomSamplerRT(int64_t num_samples, const std:
|
|||
|
||||
// Initialized this Sampler.
|
||||
Status WeightedRandomSamplerRT::InitSampler() {
|
||||
if (is_initialized) {
|
||||
return Status::OK();
|
||||
}
|
||||
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
|
||||
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
|
||||
if (num_samples_ == 0 || num_samples_ > num_rows_) {
|
||||
|
@ -75,6 +78,7 @@ Status WeightedRandomSamplerRT::InitSampler() {
|
|||
discrete_dist_ = std::make_unique<std::discrete_distribution<int64_t>>(weights_.begin(), weights_.end());
|
||||
}
|
||||
|
||||
is_initialized = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -22,7 +22,10 @@
|
|||
#include <vector>
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/mindrecord/include/shard_column.h"
|
||||
#include "minddata/mindrecord/include/shard_error.h"
|
||||
#include "minddata/mindrecord/include/shard_reader.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -40,8 +43,8 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
|
|||
~SamplerObj() = default;
|
||||
|
||||
/// \brief Pure virtual function for derived class to implement parameters validation
|
||||
/// \return bool true if all the parameters are valid
|
||||
virtual bool ValidateParams() = 0;
|
||||
/// \return The Status code of the function. It returns OK status if parameters are valid.
|
||||
virtual Status ValidateParams() = 0;
|
||||
|
||||
/// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object
|
||||
/// \return Shared pointers to the newly created Sampler
|
||||
|
@ -55,12 +58,24 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
|
|||
/// \return The shard id of the derived sampler
|
||||
virtual int64_t ShardId() { return 0; }
|
||||
|
||||
/// \brief Adds a child to the sampler
|
||||
/// \param[in] child The sampler to be added as child
|
||||
/// \return the Status code returned
|
||||
Status AddChild(std::shared_ptr<SamplerObj> child);
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
|
||||
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
|
||||
/// \return Shared pointers to the newly created Sampler
|
||||
virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; }
|
||||
#endif
|
||||
|
||||
protected:
|
||||
/// \brief A function that calls build on the children of this sampler
|
||||
/// \param[in] sampler The samplerRT object built from this sampler
|
||||
void BuildChildren(std::shared_ptr<SamplerRT> sampler);
|
||||
|
||||
std::vector<std::shared_ptr<SamplerObj>> children_;
|
||||
};
|
||||
|
||||
class DistributedSamplerObj;
|
||||
|
@ -137,15 +152,19 @@ class DistributedSamplerObj : public SamplerObj {
|
|||
std::shared_ptr<SamplerRT> Build() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> Copy() override {
|
||||
return std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, offset_,
|
||||
even_dist_);
|
||||
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_,
|
||||
offset_, even_dist_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChild(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Function to get the shard id of sampler
|
||||
/// \return The shard id of sampler
|
||||
|
@ -170,14 +189,18 @@ class PKSamplerObj : public SamplerObj {
|
|||
std::shared_ptr<SamplerRT> Build() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> Copy() override {
|
||||
return std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
|
||||
auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChild(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
int64_t num_val_;
|
||||
|
@ -202,7 +225,7 @@ class PreBuiltSamplerObj : public SamplerObj {
|
|||
|
||||
std::shared_ptr<SamplerObj> Copy() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<SamplerRT> sp_;
|
||||
|
@ -219,13 +242,19 @@ class RandomSamplerObj : public SamplerObj {
|
|||
|
||||
std::shared_ptr<SamplerRT> Build() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> Copy() override { return std::make_shared<RandomSamplerObj>(replacement_, num_samples_); }
|
||||
std::shared_ptr<SamplerObj> Copy() override {
|
||||
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChild(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
bool replacement_;
|
||||
|
@ -241,14 +270,18 @@ class SequentialSamplerObj : public SamplerObj {
|
|||
std::shared_ptr<SamplerRT> Build() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> Copy() override {
|
||||
return std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
|
||||
auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChild(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
int64_t start_index_;
|
||||
|
@ -264,14 +297,18 @@ class SubsetRandomSamplerObj : public SamplerObj {
|
|||
std::shared_ptr<SamplerRT> Build() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> Copy() override {
|
||||
return std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
|
||||
auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChild(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
const std::vector<int64_t> indices_;
|
||||
|
@ -287,10 +324,14 @@ class WeightedRandomSamplerObj : public SamplerObj {
|
|||
std::shared_ptr<SamplerRT> Build() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> Copy() override {
|
||||
return std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
|
||||
auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChild(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
const std::vector<double> weights_;
|
||||
|
|
|
@ -208,6 +208,37 @@ TEST_F(MindDataTestPipeline, TestDistributedSamplerSuccess) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestSamplerAddChild) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSamplerAddChild.";
|
||||
|
||||
auto sampler = DistributedSampler(1, 0, false, 5, 0, -1, true);
|
||||
EXPECT_NE(sampler, nullptr);
|
||||
|
||||
auto child_sampler = SequentialSampler();
|
||||
sampler->AddChild(child_sampler);
|
||||
EXPECT_NE(child_sampler, nullptr);
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, sampler);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 5);
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestDistributedSamplerFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDistributedSamplerFail.";
|
||||
// Test invalid offset setting of distributed_sampler
|
||||
|
|
Loading…
Reference in New Issue