forked from mindspore-Ecosystem/mindspore
!12353 Move sampler IR code to engine/ir
From: @mhmotallebi Reviewed-by: Signed-off-by:
This commit is contained in:
commit
ced5575387
|
@ -97,6 +97,7 @@ add_dependencies(text-ir-kernels core)
|
|||
add_dependencies(cpp-API core)
|
||||
add_dependencies(engine-ir-datasetops core)
|
||||
add_dependencies(engine-ir-datasetops-source core)
|
||||
add_dependencies(engine-ir-datasetops-source-samplers core)
|
||||
add_dependencies(engine-ir-cache core)
|
||||
add_dependencies(kernels-ir core)
|
||||
add_dependencies(kernels-ir-data core)
|
||||
|
@ -135,6 +136,7 @@ set(submodules
|
|||
$<TARGET_OBJECTS:cpp-API>
|
||||
$<TARGET_OBJECTS:engine-ir-datasetops>
|
||||
$<TARGET_OBJECTS:engine-ir-datasetops-source>
|
||||
$<TARGET_OBJECTS:engine-ir-datasetops-source-samplers>
|
||||
$<TARGET_OBJECTS:engine-ir-cache>
|
||||
$<TARGET_OBJECTS:kernels-soft-dvpp-image>
|
||||
$<TARGET_OBJECTS:soft-dvpp-utils>
|
||||
|
|
|
@ -42,7 +42,7 @@
|
|||
#endif
|
||||
|
||||
// Sampler headers (in alphabetical order)
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
#include "minddata/dataset/callback/py_ds_callback.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
|
|
@ -15,69 +15,11 @@
|
|||
*/
|
||||
|
||||
#include "minddata/dataset/include/samplers.h"
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/mindrecord/include/shard_distributed_sample.h"
|
||||
#include "minddata/mindrecord/include/shard_operator.h"
|
||||
#include "minddata/mindrecord/include/shard_pk_sample.h"
|
||||
#include "minddata/mindrecord/include/shard_sample.h"
|
||||
#include "minddata/mindrecord/include/shard_sequential_sample.h"
|
||||
#include "minddata/mindrecord/include/shard_shuffle.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
#define RETURN_NULL_IF_ERROR(_s) \
|
||||
do { \
|
||||
Status __rc = (_s); \
|
||||
if (__rc.IsError()) { \
|
||||
MS_LOG(ERROR) << __rc; \
|
||||
return nullptr; \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
// Constructor
|
||||
SamplerObj::SamplerObj() {}
|
||||
|
||||
void SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> sampler) {
|
||||
for (auto child : children_) {
|
||||
auto sampler_rt = child->SamplerBuild();
|
||||
sampler->AddChild(sampler_rt);
|
||||
}
|
||||
}
|
||||
|
||||
Status SamplerObj::AddChildSampler(std::shared_ptr<SamplerObj> child) {
|
||||
if (child == nullptr) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Only samplers can be added, not any other DatasetOp.
|
||||
std::shared_ptr<SamplerObj> sampler = std::dynamic_pointer_cast<SamplerObj>(child);
|
||||
if (!sampler) {
|
||||
RETURN_STATUS_UNEXPECTED("Cannot add child, child is not a sampler object.");
|
||||
}
|
||||
|
||||
// Samplers can have at most 1 child.
|
||||
if (!children_.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("Cannot add child sampler, this sampler already has a child.");
|
||||
}
|
||||
|
||||
children_.push_back(child);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// Function to create a Distributed Sampler.
|
||||
std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle,
|
||||
int64_t num_samples, uint32_t seed, int64_t offset,
|
||||
|
@ -152,421 +94,5 @@ std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<doub
|
|||
return sampler;
|
||||
}
|
||||
|
||||
/* ####################################### Derived Sampler classes ################################# */
|
||||
|
||||
// DistributedSampler
|
||||
DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
|
||||
uint32_t seed, int64_t offset, bool even_dist)
|
||||
: num_shards_(num_shards),
|
||||
shard_id_(shard_id),
|
||||
shuffle_(shuffle),
|
||||
num_samples_(num_samples),
|
||||
seed_(seed),
|
||||
offset_(offset),
|
||||
even_dist_(even_dist) {
|
||||
// Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
|
||||
// is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
|
||||
// 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
|
||||
// PreBuildSampler is phased out, this can be cleaned up.
|
||||
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
|
||||
}
|
||||
|
||||
Status DistributedSamplerObj::ValidateParams() {
|
||||
if (num_shards_ <= 0) {
|
||||
RETURN_STATUS_UNEXPECTED("DistributedSampler: num_shards must be greater than 0, but got: " +
|
||||
std::to_string(num_shards_));
|
||||
}
|
||||
|
||||
if (shard_id_ < 0 || shard_id_ >= num_shards_) {
|
||||
RETURN_STATUS_UNEXPECTED("DistributedSampler: shard_id must be in range [0, " + std::to_string(num_shards_) +
|
||||
"), but got: " + std::to_string(shard_id_));
|
||||
}
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
RETURN_STATUS_UNEXPECTED("DistributedSampler: num_samples must be greater than or equal to 0, but got: " +
|
||||
std::to_string(num_samples_));
|
||||
}
|
||||
|
||||
if (offset_ > num_shards_) {
|
||||
RETURN_STATUS_UNEXPECTED("DistributedSampler: offset must be no more than num_shards(" +
|
||||
std::to_string(num_shards_) + "), but got: " + std::to_string(offset_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> DistributedSamplerObj::SamplerBuild() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
|
||||
offset_, even_dist_);
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
auto mind_sampler = std::make_shared<mindrecord::ShardDistributedSample>(num_shards_, shard_id_, shuffle_, seed_,
|
||||
num_samples_, offset_);
|
||||
return mind_sampler;
|
||||
}
|
||||
#endif
|
||||
|
||||
Status DistributedSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "DistributedSampler";
|
||||
args["num_shards"] = num_shards_;
|
||||
args["shard_id"] = shard_id_;
|
||||
args["shuffle"] = shuffle_;
|
||||
args["num_samples"] = num_samples_;
|
||||
args["offset"] = offset_;
|
||||
if (!children_.empty()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : children_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// PKSampler
|
||||
PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples)
|
||||
: num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {}
|
||||
|
||||
Status PKSamplerObj::ValidateParams() {
|
||||
if (num_val_ <= 0) {
|
||||
RETURN_STATUS_UNEXPECTED("PKSampler: num_val must be greater than 0, but got: " + std::to_string(num_val_));
|
||||
}
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
RETURN_STATUS_UNEXPECTED("PKSampler: num_samples must be greater than or equal to 0, but got: " +
|
||||
std::to_string(num_samples_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PKSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "PKSampler";
|
||||
args["num_val"] = num_val_;
|
||||
args["shuffle"] = shuffle_;
|
||||
args["num_samples"] = num_samples_;
|
||||
if (!children_.empty()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : children_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
std::shared_ptr<mindrecord::ShardOperator> mind_sampler;
|
||||
if (shuffle_ == true) {
|
||||
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(),
|
||||
GetSeed(), num_samples_);
|
||||
} else {
|
||||
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_);
|
||||
}
|
||||
|
||||
return mind_sampler;
|
||||
}
|
||||
#endif
|
||||
|
||||
// PreBuiltOperation
|
||||
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler)
|
||||
: sp_minddataset_(std::move(sampler)) {}
|
||||
#endif
|
||||
|
||||
Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); }
|
||||
|
||||
std::shared_ptr<SamplerRT> PreBuiltSamplerObj::SamplerBuild() {
|
||||
BuildChildren(sp_);
|
||||
return sp_;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; }
|
||||
#endif
|
||||
|
||||
std::shared_ptr<SamplerObj> PreBuiltSamplerObj::SamplerCopy() {
|
||||
#ifndef ENABLE_ANDROID
|
||||
if (sp_minddataset_ != nullptr) {
|
||||
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_minddataset_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
#endif
|
||||
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
Status PreBuiltSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
RETURN_IF_NOT_OK(sp_->to_json(out_json));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// RandomSampler
|
||||
RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch)
|
||||
: replacement_(replacement), num_samples_(num_samples), reshuffle_each_epoch_(reshuffle_each_epoch) {}
|
||||
|
||||
Status RandomSamplerObj::ValidateParams() {
|
||||
if (num_samples_ < 0) {
|
||||
RETURN_STATUS_UNEXPECTED("RandomSampler: num_samples must be greater than or equal to 0, but got: " +
|
||||
std::to_string(num_samples_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RandomSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "RandomSampler";
|
||||
args["replacement"] = replacement_;
|
||||
args["num_samples"] = num_samples_;
|
||||
args["reshuffle_each_epoch"] = reshuffle_each_epoch_;
|
||||
if (!children_.empty()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : children_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch_);
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
auto mind_sampler =
|
||||
std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples_, replacement_, reshuffle_each_epoch_);
|
||||
|
||||
return mind_sampler;
|
||||
}
|
||||
#endif
|
||||
|
||||
// SequentialSampler
|
||||
SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples)
|
||||
: start_index_(start_index), num_samples_(num_samples) {}
|
||||
|
||||
Status SequentialSamplerObj::ValidateParams() {
|
||||
if (num_samples_ < 0) {
|
||||
RETURN_STATUS_UNEXPECTED("SequentialSampler: num_samples must be greater than or equal to 0, but got: " +
|
||||
std::to_string(num_samples_));
|
||||
}
|
||||
|
||||
if (start_index_ < 0) {
|
||||
RETURN_STATUS_UNEXPECTED("SequentialSampler: start_index_ must be greater than or equal to 0, but got: " +
|
||||
std::to_string(start_index_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SequentialSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "SequentialSampler";
|
||||
args["start_index"] = start_index_;
|
||||
args["num_samples"] = num_samples_;
|
||||
if (!children_.empty()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : children_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> SequentialSamplerObj::SamplerBuild() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> SequentialSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
auto mind_sampler = std::make_shared<mindrecord::ShardSequentialSample>(num_samples_, start_index_);
|
||||
|
||||
return mind_sampler;
|
||||
}
|
||||
#endif
|
||||
|
||||
// SubsetSampler
|
||||
SubsetSamplerObj::SubsetSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
|
||||
: indices_(std::move(indices)), num_samples_(num_samples) {}
|
||||
|
||||
Status SubsetSamplerObj::ValidateParams() {
|
||||
if (num_samples_ < 0) {
|
||||
RETURN_STATUS_UNEXPECTED("SubsetRandomSampler: num_samples must be greater than or equal to 0, but got: " +
|
||||
std::to_string(num_samples_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> SubsetSamplerObj::SamplerBuild() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::SubsetSamplerRT>(num_samples_, indices_);
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> SubsetSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
auto mind_sampler = std::make_shared<mindrecord::ShardSample>(indices_);
|
||||
|
||||
return mind_sampler;
|
||||
}
|
||||
#endif
|
||||
Status SubsetSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "SubsetSampler";
|
||||
args["indices"] = indices_;
|
||||
args["num_samples"] = num_samples_;
|
||||
if (!children_.empty()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : children_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// SubsetRandomSampler
|
||||
SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
|
||||
: SubsetSamplerObj(std::move(indices), num_samples) {}
|
||||
|
||||
std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::SamplerBuild() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_);
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
auto mind_sampler = std::make_shared<mindrecord::ShardSample>(indices_, GetSeed());
|
||||
|
||||
return mind_sampler;
|
||||
}
|
||||
#endif
|
||||
|
||||
Status SubsetRandomSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "SubsetRandomSampler";
|
||||
args["indices"] = indices_;
|
||||
args["num_samples"] = num_samples_;
|
||||
if (!children_.empty()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : children_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// WeightedRandomSampler
|
||||
WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
|
||||
: weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}
|
||||
|
||||
Status WeightedRandomSamplerObj::ValidateParams() {
|
||||
if (weights_.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not be empty");
|
||||
}
|
||||
int32_t zero_elem = 0;
|
||||
for (int32_t i = 0; i < weights_.size(); ++i) {
|
||||
if (weights_[i] < 0) {
|
||||
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not contain negative number, got: " +
|
||||
std::to_string(weights_[i]));
|
||||
}
|
||||
if (weights_[i] == 0.0) {
|
||||
zero_elem++;
|
||||
}
|
||||
}
|
||||
if (zero_elem == weights_.size()) {
|
||||
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: elements of weights vector must not be all zero");
|
||||
}
|
||||
if (num_samples_ < 0) {
|
||||
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: num_samples must be greater than or equal to 0, but got: " +
|
||||
std::to_string(num_samples_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status WeightedRandomSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "WeightedRandomSampler";
|
||||
args["weights"] = weights_;
|
||||
args["num_samples"] = num_samples_;
|
||||
args["replacement"] = replacement_;
|
||||
if (!children_.empty()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : children_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::SamplerBuild() {
|
||||
auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_subdirectory(samplers)
|
||||
|
||||
set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
||||
album_node.cc
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
|
||||
set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SRC_FILES
|
||||
samplers_ir.cc
|
||||
)
|
||||
|
||||
add_library(engine-ir-datasetops-source-samplers OBJECT ${DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SRC_FILES})
|
|
@ -0,0 +1,490 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
|
||||
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/mindrecord/include/shard_distributed_sample.h"
|
||||
#include "minddata/mindrecord/include/shard_operator.h"
|
||||
#include "minddata/mindrecord/include/shard_pk_sample.h"
|
||||
#include "minddata/mindrecord/include/shard_sample.h"
|
||||
#include "minddata/mindrecord/include/shard_sequential_sample.h"
|
||||
#include "minddata/mindrecord/include/shard_shuffle.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// Constructor
|
||||
SamplerObj::SamplerObj() {}
|
||||
|
||||
void SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> sampler) {
|
||||
for (auto child : children_) {
|
||||
auto sampler_rt = child->SamplerBuild();
|
||||
sampler->AddChild(sampler_rt);
|
||||
}
|
||||
}
|
||||
|
||||
Status SamplerObj::AddChildSampler(std::shared_ptr<SamplerObj> child) {
|
||||
if (child == nullptr) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Only samplers can be added, not any other DatasetOp.
|
||||
std::shared_ptr<SamplerObj> sampler = std::dynamic_pointer_cast<SamplerObj>(child);
|
||||
if (!sampler) {
|
||||
RETURN_STATUS_UNEXPECTED("Cannot add child, child is not a sampler object.");
|
||||
}
|
||||
|
||||
// Samplers can have at most 1 child.
|
||||
if (!children_.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("Cannot add child sampler, this sampler already has a child.");
|
||||
}
|
||||
|
||||
children_.push_back(child);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/* ####################################### Derived Sampler classes ################################# */
|
||||
|
||||
// DistributedSampler
|
||||
DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
|
||||
uint32_t seed, int64_t offset, bool even_dist)
|
||||
: num_shards_(num_shards),
|
||||
shard_id_(shard_id),
|
||||
shuffle_(shuffle),
|
||||
num_samples_(num_samples),
|
||||
seed_(seed),
|
||||
offset_(offset),
|
||||
even_dist_(even_dist) {
|
||||
// Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
|
||||
// is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
|
||||
// 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
|
||||
// PreBuildSampler is phased out, this can be cleaned up.
|
||||
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
|
||||
}
|
||||
|
||||
Status DistributedSamplerObj::ValidateParams() {
|
||||
if (num_shards_ <= 0) {
|
||||
RETURN_STATUS_UNEXPECTED("DistributedSampler: num_shards must be greater than 0, but got: " +
|
||||
std::to_string(num_shards_));
|
||||
}
|
||||
|
||||
if (shard_id_ < 0 || shard_id_ >= num_shards_) {
|
||||
RETURN_STATUS_UNEXPECTED("DistributedSampler: shard_id must be in range [0, " + std::to_string(num_shards_) +
|
||||
"), but got: " + std::to_string(shard_id_));
|
||||
}
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
RETURN_STATUS_UNEXPECTED("DistributedSampler: num_samples must be greater than or equal to 0, but got: " +
|
||||
std::to_string(num_samples_));
|
||||
}
|
||||
|
||||
if (offset_ > num_shards_) {
|
||||
RETURN_STATUS_UNEXPECTED("DistributedSampler: offset must be no more than num_shards(" +
|
||||
std::to_string(num_shards_) + "), but got: " + std::to_string(offset_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> DistributedSamplerObj::SamplerBuild() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
|
||||
offset_, even_dist_);
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
auto mind_sampler = std::make_shared<mindrecord::ShardDistributedSample>(num_shards_, shard_id_, shuffle_, seed_,
|
||||
num_samples_, offset_);
|
||||
return mind_sampler;
|
||||
}
|
||||
#endif
|
||||
|
||||
Status DistributedSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "DistributedSampler";
|
||||
args["num_shards"] = num_shards_;
|
||||
args["shard_id"] = shard_id_;
|
||||
args["shuffle"] = shuffle_;
|
||||
args["num_samples"] = num_samples_;
|
||||
args["offset"] = offset_;
|
||||
if (!children_.empty()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : children_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// PKSampler
|
||||
PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples)
|
||||
: num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {}
|
||||
|
||||
Status PKSamplerObj::ValidateParams() {
|
||||
if (num_val_ <= 0) {
|
||||
RETURN_STATUS_UNEXPECTED("PKSampler: num_val must be greater than 0, but got: " + std::to_string(num_val_));
|
||||
}
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
RETURN_STATUS_UNEXPECTED("PKSampler: num_samples must be greater than or equal to 0, but got: " +
|
||||
std::to_string(num_samples_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PKSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "PKSampler";
|
||||
args["num_val"] = num_val_;
|
||||
args["shuffle"] = shuffle_;
|
||||
args["num_samples"] = num_samples_;
|
||||
if (!children_.empty()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : children_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
std::shared_ptr<mindrecord::ShardOperator> mind_sampler;
|
||||
if (shuffle_ == true) {
|
||||
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(),
|
||||
GetSeed(), num_samples_);
|
||||
} else {
|
||||
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_);
|
||||
}
|
||||
|
||||
return mind_sampler;
|
||||
}
|
||||
#endif
|
||||
|
||||
// PreBuiltOperation
|
||||
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler)
|
||||
: sp_minddataset_(std::move(sampler)) {}
|
||||
#endif
|
||||
|
||||
Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); }
|
||||
|
||||
std::shared_ptr<SamplerRT> PreBuiltSamplerObj::SamplerBuild() {
|
||||
BuildChildren(sp_);
|
||||
return sp_;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; }
|
||||
#endif
|
||||
|
||||
std::shared_ptr<SamplerObj> PreBuiltSamplerObj::SamplerCopy() {
|
||||
#ifndef ENABLE_ANDROID
|
||||
if (sp_minddataset_ != nullptr) {
|
||||
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_minddataset_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
#endif
|
||||
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
Status PreBuiltSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
RETURN_IF_NOT_OK(sp_->to_json(out_json));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// RandomSampler
|
||||
RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch)
|
||||
: replacement_(replacement), num_samples_(num_samples), reshuffle_each_epoch_(reshuffle_each_epoch) {}
|
||||
|
||||
Status RandomSamplerObj::ValidateParams() {
|
||||
if (num_samples_ < 0) {
|
||||
RETURN_STATUS_UNEXPECTED("RandomSampler: num_samples must be greater than or equal to 0, but got: " +
|
||||
std::to_string(num_samples_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RandomSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "RandomSampler";
|
||||
args["replacement"] = replacement_;
|
||||
args["num_samples"] = num_samples_;
|
||||
args["reshuffle_each_epoch"] = reshuffle_each_epoch_;
|
||||
if (!children_.empty()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : children_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch_);
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
auto mind_sampler =
|
||||
std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples_, replacement_, reshuffle_each_epoch_);
|
||||
|
||||
return mind_sampler;
|
||||
}
|
||||
#endif
|
||||
|
||||
// SequentialSampler
|
||||
SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples)
|
||||
: start_index_(start_index), num_samples_(num_samples) {}
|
||||
|
||||
Status SequentialSamplerObj::ValidateParams() {
|
||||
if (num_samples_ < 0) {
|
||||
RETURN_STATUS_UNEXPECTED("SequentialSampler: num_samples must be greater than or equal to 0, but got: " +
|
||||
std::to_string(num_samples_));
|
||||
}
|
||||
|
||||
if (start_index_ < 0) {
|
||||
RETURN_STATUS_UNEXPECTED("SequentialSampler: start_index_ must be greater than or equal to 0, but got: " +
|
||||
std::to_string(start_index_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SequentialSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "SequentialSampler";
|
||||
args["start_index"] = start_index_;
|
||||
args["num_samples"] = num_samples_;
|
||||
if (!children_.empty()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : children_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> SequentialSamplerObj::SamplerBuild() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> SequentialSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
auto mind_sampler = std::make_shared<mindrecord::ShardSequentialSample>(num_samples_, start_index_);
|
||||
|
||||
return mind_sampler;
|
||||
}
|
||||
#endif
|
||||
|
||||
// SubsetSampler
|
||||
SubsetSamplerObj::SubsetSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
|
||||
: indices_(std::move(indices)), num_samples_(num_samples) {}
|
||||
|
||||
Status SubsetSamplerObj::ValidateParams() {
|
||||
if (num_samples_ < 0) {
|
||||
RETURN_STATUS_UNEXPECTED("SubsetRandomSampler: num_samples must be greater than or equal to 0, but got: " +
|
||||
std::to_string(num_samples_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> SubsetSamplerObj::SamplerBuild() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::SubsetSamplerRT>(num_samples_, indices_);
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> SubsetSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
auto mind_sampler = std::make_shared<mindrecord::ShardSample>(indices_);
|
||||
|
||||
return mind_sampler;
|
||||
}
|
||||
#endif
|
||||
Status SubsetSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "SubsetSampler";
|
||||
args["indices"] = indices_;
|
||||
args["num_samples"] = num_samples_;
|
||||
if (!children_.empty()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : children_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// SubsetRandomSampler
|
||||
SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
|
||||
: SubsetSamplerObj(std::move(indices), num_samples) {}
|
||||
|
||||
std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::SamplerBuild() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_);
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
auto mind_sampler = std::make_shared<mindrecord::ShardSample>(indices_, GetSeed());
|
||||
|
||||
return mind_sampler;
|
||||
}
|
||||
#endif
|
||||
|
||||
Status SubsetRandomSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "SubsetRandomSampler";
|
||||
args["indices"] = indices_;
|
||||
args["num_samples"] = num_samples_;
|
||||
if (!children_.empty()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : children_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// WeightedRandomSampler
|
||||
WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
|
||||
: weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}
|
||||
|
||||
Status WeightedRandomSamplerObj::ValidateParams() {
|
||||
if (weights_.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not be empty");
|
||||
}
|
||||
int32_t zero_elem = 0;
|
||||
for (int32_t i = 0; i < weights_.size(); ++i) {
|
||||
if (weights_[i] < 0) {
|
||||
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not contain negative number, got: " +
|
||||
std::to_string(weights_[i]));
|
||||
}
|
||||
if (weights_[i] == 0.0) {
|
||||
zero_elem++;
|
||||
}
|
||||
}
|
||||
if (zero_elem == weights_.size()) {
|
||||
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: elements of weights vector must not be all zero");
|
||||
}
|
||||
if (num_samples_ < 0) {
|
||||
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: num_samples must be greater than or equal to 0, but got: " +
|
||||
std::to_string(num_samples_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status WeightedRandomSamplerObj::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "WeightedRandomSampler";
|
||||
args["weights"] = weights_;
|
||||
args["num_samples"] = num_samples_;
|
||||
args["replacement"] = replacement_;
|
||||
if (!children_.empty()) {
|
||||
std::vector<nlohmann::json> children_args;
|
||||
for (auto child : children_) {
|
||||
nlohmann::json child_arg;
|
||||
RETURN_IF_NOT_OK(child->to_json(&child_arg));
|
||||
children_args.push_back(child_arg);
|
||||
}
|
||||
args["child_sampler"] = children_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::SamplerBuild() {
|
||||
auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,344 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SAMPLERS_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SAMPLERS_IR_H_
|
||||
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/mindrecord/include/shard_operator.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// Internal Sampler class forward declaration
|
||||
class SamplerRT;
|
||||
|
||||
class SamplerObj {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
SamplerObj();
|
||||
|
||||
/// \brief Destructor
|
||||
~SamplerObj() = default;
|
||||
|
||||
/// \brief Pure virtual function for derived class to implement parameters validation
|
||||
/// \return The Status code of the function. It returns OK status if parameters are valid.
|
||||
virtual Status ValidateParams() = 0;
|
||||
|
||||
/// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object
|
||||
/// \return Shared pointers to the newly created Sampler
|
||||
virtual std::shared_ptr<SamplerRT> SamplerBuild() = 0;
|
||||
|
||||
/// \brief Pure virtual function to copy a SamplerObj class
|
||||
/// \return Shared pointers to the newly copied SamplerObj
|
||||
virtual std::shared_ptr<SamplerObj> SamplerCopy() = 0;
|
||||
|
||||
/// \brief Function for derived class to get the shard id of sampler
|
||||
/// \return The shard id of the derived sampler
|
||||
virtual int64_t ShardId() { return 0; }
|
||||
|
||||
/// \brief Adds a child to the sampler
|
||||
/// \param[in] child The sampler to be added as child
|
||||
/// \return the Status code returned
|
||||
Status AddChildSampler(std::shared_ptr<SamplerObj> child);
|
||||
|
||||
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
|
||||
|
||||
std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; }
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
|
||||
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
|
||||
/// \return Shared pointers to the newly created Sampler
|
||||
virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; }
|
||||
#endif
|
||||
|
||||
protected:
|
||||
/// \brief A function that calls build on the children of this sampler
|
||||
/// \param[in] sampler The samplerRT object built from this sampler
|
||||
void BuildChildren(std::shared_ptr<SamplerRT> sampler);
|
||||
|
||||
std::vector<std::shared_ptr<SamplerObj>> children_;
|
||||
};
|
||||
|
||||
/* ####################################### Derived Sampler classes ################################# */
|
||||
class DistributedSamplerObj : public SamplerObj {
|
||||
public:
|
||||
DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed,
|
||||
int64_t offset, bool even_dist);
|
||||
|
||||
~DistributedSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_,
|
||||
offset_, even_dist_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Function to get the shard id of sampler
|
||||
/// \return The shard id of sampler
|
||||
int64_t ShardId() override { return shard_id_; }
|
||||
|
||||
private:
|
||||
int64_t num_shards_;
|
||||
int64_t shard_id_;
|
||||
bool shuffle_;
|
||||
int64_t num_samples_;
|
||||
uint32_t seed_;
|
||||
int64_t offset_;
|
||||
bool even_dist_;
|
||||
};
|
||||
|
||||
class PKSamplerObj : public SamplerObj {
|
||||
public:
|
||||
PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples);
|
||||
|
||||
~PKSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
int64_t num_val_;
|
||||
bool shuffle_;
|
||||
int64_t num_samples_;
|
||||
};
|
||||
|
||||
class PreBuiltSamplerObj : public SamplerObj {
|
||||
public:
|
||||
explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler);
|
||||
#ifndef ENABLE_ANDROID
|
||||
explicit PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler);
|
||||
#endif
|
||||
|
||||
~PreBuiltSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<SamplerRT> sp_;
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> sp_minddataset_;
|
||||
#endif
|
||||
};
|
||||
|
||||
class RandomSamplerObj : public SamplerObj {
|
||||
public:
|
||||
RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch = true);
|
||||
|
||||
~RandomSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_, reshuffle_each_epoch_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
bool replacement_;
|
||||
int64_t num_samples_;
|
||||
bool reshuffle_each_epoch_;
|
||||
};
|
||||
|
||||
class SequentialSamplerObj : public SamplerObj {
|
||||
public:
|
||||
SequentialSamplerObj(int64_t start_index, int64_t num_samples);
|
||||
|
||||
~SequentialSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
int64_t start_index_;
|
||||
int64_t num_samples_;
|
||||
};
|
||||
|
||||
class SubsetSamplerObj : public SamplerObj {
|
||||
public:
|
||||
SubsetSamplerObj(std::vector<int64_t> indices, int64_t num_samples);
|
||||
|
||||
~SubsetSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
auto sampler = std::make_shared<SubsetSamplerObj>(indices_, num_samples_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
protected:
|
||||
const std::vector<int64_t> indices_;
|
||||
int64_t num_samples_;
|
||||
};
|
||||
|
||||
class SubsetRandomSamplerObj : public SubsetSamplerObj {
|
||||
public:
|
||||
SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples);
|
||||
|
||||
~SubsetRandomSamplerObj() = default;
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
class WeightedRandomSamplerObj : public SamplerObj {
|
||||
public:
|
||||
explicit WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true);
|
||||
|
||||
~WeightedRandomSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
const std::vector<double> weights_;
|
||||
int64_t num_samples_;
|
||||
bool replacement_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SAMPLERS_IR_H_
|
|
@ -16,6 +16,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_SERDES_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_SERDES_H_
|
||||
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
|
|
@ -18,72 +18,14 @@
|
|||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/mindrecord/include/shard_column.h"
|
||||
#include "minddata/mindrecord/include/shard_error.h"
|
||||
#include "minddata/mindrecord/include/shard_operator.h"
|
||||
#include "minddata/mindrecord/include/shard_reader.h"
|
||||
#endif
|
||||
// FIXME - This internal IR header will be removed when external API classes are provided
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// Internal Sampler class forward declaration
|
||||
class SamplerRT;
|
||||
|
||||
class SamplerObj {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
SamplerObj();
|
||||
|
||||
/// \brief Destructor
|
||||
~SamplerObj() = default;
|
||||
|
||||
/// \brief Pure virtual function for derived class to implement parameters validation
|
||||
/// \return The Status code of the function. It returns OK status if parameters are valid.
|
||||
virtual Status ValidateParams() = 0;
|
||||
|
||||
/// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object
|
||||
/// \return Shared pointers to the newly created Sampler
|
||||
virtual std::shared_ptr<SamplerRT> SamplerBuild() = 0;
|
||||
|
||||
/// \brief Pure virtual function to copy a SamplerObj class
|
||||
/// \return Shared pointers to the newly copied SamplerObj
|
||||
virtual std::shared_ptr<SamplerObj> SamplerCopy() = 0;
|
||||
|
||||
/// \brief Function for derived class to get the shard id of sampler
|
||||
/// \return The shard id of the derived sampler
|
||||
virtual int64_t ShardId() { return 0; }
|
||||
|
||||
/// \brief Adds a child to the sampler
|
||||
/// \param[in] child The sampler to be added as child
|
||||
/// \return the Status code returned
|
||||
Status AddChildSampler(std::shared_ptr<SamplerObj> child);
|
||||
|
||||
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
|
||||
|
||||
std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; }
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
|
||||
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
|
||||
/// \return Shared pointers to the newly created Sampler
|
||||
virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; }
|
||||
#endif
|
||||
|
||||
protected:
|
||||
/// \brief A function that calls build on the children of this sampler
|
||||
/// \param[in] sampler The samplerRT object built from this sampler
|
||||
void BuildChildren(std::shared_ptr<SamplerRT> sampler);
|
||||
|
||||
std::vector<std::shared_ptr<SamplerObj>> children_;
|
||||
};
|
||||
|
||||
class DistributedSamplerObj;
|
||||
class PKSamplerObj;
|
||||
class PreBuiltSamplerObj;
|
||||
|
@ -155,261 +97,6 @@ std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t>
|
|||
std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<double> weights, int64_t num_samples = 0,
|
||||
bool replacement = true);
|
||||
|
||||
/* ####################################### Derived Sampler classes ################################# */
|
||||
class DistributedSamplerObj : public SamplerObj {
|
||||
public:
|
||||
DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed,
|
||||
int64_t offset, bool even_dist);
|
||||
|
||||
~DistributedSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_,
|
||||
offset_, even_dist_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Function to get the shard id of sampler
|
||||
/// \return The shard id of sampler
|
||||
int64_t ShardId() override { return shard_id_; }
|
||||
|
||||
private:
|
||||
int64_t num_shards_;
|
||||
int64_t shard_id_;
|
||||
bool shuffle_;
|
||||
int64_t num_samples_;
|
||||
uint32_t seed_;
|
||||
int64_t offset_;
|
||||
bool even_dist_;
|
||||
};
|
||||
|
||||
class PKSamplerObj : public SamplerObj {
|
||||
public:
|
||||
PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples);
|
||||
|
||||
~PKSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
int64_t num_val_;
|
||||
bool shuffle_;
|
||||
int64_t num_samples_;
|
||||
};
|
||||
|
||||
class PreBuiltSamplerObj : public SamplerObj {
|
||||
public:
|
||||
explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler);
|
||||
#ifndef ENABLE_ANDROID
|
||||
explicit PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler);
|
||||
#endif
|
||||
|
||||
~PreBuiltSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<SamplerRT> sp_;
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> sp_minddataset_;
|
||||
#endif
|
||||
};
|
||||
|
||||
class RandomSamplerObj : public SamplerObj {
|
||||
public:
|
||||
RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch = true);
|
||||
|
||||
~RandomSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_, reshuffle_each_epoch_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
bool replacement_;
|
||||
int64_t num_samples_;
|
||||
bool reshuffle_each_epoch_;
|
||||
};
|
||||
|
||||
class SequentialSamplerObj : public SamplerObj {
|
||||
public:
|
||||
SequentialSamplerObj(int64_t start_index, int64_t num_samples);
|
||||
|
||||
~SequentialSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
int64_t start_index_;
|
||||
int64_t num_samples_;
|
||||
};
|
||||
|
||||
class SubsetSamplerObj : public SamplerObj {
|
||||
public:
|
||||
SubsetSamplerObj(std::vector<int64_t> indices, int64_t num_samples);
|
||||
|
||||
~SubsetSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
auto sampler = std::make_shared<SubsetSamplerObj>(indices_, num_samples_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
protected:
|
||||
const std::vector<int64_t> indices_;
|
||||
int64_t num_samples_;
|
||||
};
|
||||
|
||||
class SubsetRandomSamplerObj : public SubsetSamplerObj {
|
||||
public:
|
||||
SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples);
|
||||
|
||||
~SubsetRandomSamplerObj() = default;
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
class WeightedRandomSamplerObj : public SamplerObj {
|
||||
public:
|
||||
explicit WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true);
|
||||
|
||||
~WeightedRandomSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
const std::vector<double> weights_;
|
||||
int64_t num_samples_;
|
||||
bool replacement_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_
|
||||
|
|
|
@ -136,6 +136,7 @@ if(BUILD_MINDDATA STREQUAL "full")
|
|||
${MINDDATA_DIR}/engine/ir/datasetops/shuffle_node.cc
|
||||
${MINDDATA_DIR}/engine/ir/datasetops/source/album_node.cc
|
||||
${MINDDATA_DIR}/engine/ir/datasetops/source/mnist_node.cc
|
||||
${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/samplers_ir.cc
|
||||
${MINDDATA_DIR}/engine/datasetops/dataset_op.cc
|
||||
${MINDDATA_DIR}/engine/datasetops/repeat_op.cc
|
||||
${MINDDATA_DIR}/engine/datasetops/epoch_ctrl_op.cc
|
||||
|
|
Loading…
Reference in New Issue