!12353 Move sampler IR code to engine/ir

From: @mhmotallebi
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-13 06:44:47 +08:00 committed by Gitee
commit ced5575387
11 changed files with 852 additions and 792 deletions

View File

@ -97,6 +97,7 @@ add_dependencies(text-ir-kernels core)
add_dependencies(cpp-API core)
add_dependencies(engine-ir-datasetops core)
add_dependencies(engine-ir-datasetops-source core)
add_dependencies(engine-ir-datasetops-source-samplers core)
add_dependencies(engine-ir-cache core)
add_dependencies(kernels-ir core)
add_dependencies(kernels-ir-data core)
@ -135,6 +136,7 @@ set(submodules
$<TARGET_OBJECTS:cpp-API>
$<TARGET_OBJECTS:engine-ir-datasetops>
$<TARGET_OBJECTS:engine-ir-datasetops-source>
$<TARGET_OBJECTS:engine-ir-datasetops-source-samplers>
$<TARGET_OBJECTS:engine-ir-cache>
$<TARGET_OBJECTS:kernels-soft-dvpp-image>
$<TARGET_OBJECTS:soft-dvpp-utils>

View File

@ -42,7 +42,7 @@
#endif
// Sampler headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

View File

@ -23,7 +23,7 @@
#include "minddata/dataset/callback/py_ds_callback.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
namespace mindspore {
namespace dataset {

View File

@ -15,69 +15,11 @@
*/
#include "minddata/dataset/include/samplers.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
#ifndef ENABLE_ANDROID
#include "minddata/mindrecord/include/shard_distributed_sample.h"
#include "minddata/mindrecord/include/shard_operator.h"
#include "minddata/mindrecord/include/shard_pk_sample.h"
#include "minddata/mindrecord/include/shard_sample.h"
#include "minddata/mindrecord/include/shard_sequential_sample.h"
#include "minddata/mindrecord/include/shard_shuffle.h"
#include "minddata/dataset/util/random.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
namespace mindspore {
namespace dataset {
#define RETURN_NULL_IF_ERROR(_s) \
do { \
Status __rc = (_s); \
if (__rc.IsError()) { \
MS_LOG(ERROR) << __rc; \
return nullptr; \
} \
} while (false)
// Constructor
SamplerObj::SamplerObj() {}
void SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> sampler) {
for (auto child : children_) {
auto sampler_rt = child->SamplerBuild();
sampler->AddChild(sampler_rt);
}
}
Status SamplerObj::AddChildSampler(std::shared_ptr<SamplerObj> child) {
if (child == nullptr) {
return Status::OK();
}
// Only samplers can be added, not any other DatasetOp.
std::shared_ptr<SamplerObj> sampler = std::dynamic_pointer_cast<SamplerObj>(child);
if (!sampler) {
RETURN_STATUS_UNEXPECTED("Cannot add child, child is not a sampler object.");
}
// Samplers can have at most 1 child.
if (!children_.empty()) {
RETURN_STATUS_UNEXPECTED("Cannot add child sampler, this sampler already has a child.");
}
children_.push_back(child);
return Status::OK();
}
/// Function to create a Distributed Sampler.
std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle,
int64_t num_samples, uint32_t seed, int64_t offset,
@ -152,421 +94,5 @@ std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<doub
return sampler;
}
/* ####################################### Derived Sampler classes ################################# */
// DistributedSampler
DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
uint32_t seed, int64_t offset, bool even_dist)
: num_shards_(num_shards),
shard_id_(shard_id),
shuffle_(shuffle),
num_samples_(num_samples),
seed_(seed),
offset_(offset),
even_dist_(even_dist) {
// Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
// is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
// 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
// PreBuildSampler is phased out, this can be cleaned up.
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
}
Status DistributedSamplerObj::ValidateParams() {
if (num_shards_ <= 0) {
RETURN_STATUS_UNEXPECTED("DistributedSampler: num_shards must be greater than 0, but got: " +
std::to_string(num_shards_));
}
if (shard_id_ < 0 || shard_id_ >= num_shards_) {
RETURN_STATUS_UNEXPECTED("DistributedSampler: shard_id must be in range [0, " + std::to_string(num_shards_) +
"), but got: " + std::to_string(shard_id_));
}
if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("DistributedSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}
if (offset_ > num_shards_) {
RETURN_STATUS_UNEXPECTED("DistributedSampler: offset must be no more than num_shards(" +
std::to_string(num_shards_) + "), but got: " + std::to_string(offset_));
}
return Status::OK();
}
std::shared_ptr<SamplerRT> DistributedSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
offset_, even_dist_);
BuildChildren(sampler);
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardDistributedSample>(num_shards_, shard_id_, shuffle_, seed_,
num_samples_, offset_);
return mind_sampler;
}
#endif
Status DistributedSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "DistributedSampler";
args["num_shards"] = num_shards_;
args["shard_id"] = shard_id_;
args["shuffle"] = shuffle_;
args["num_samples"] = num_samples_;
args["offset"] = offset_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
// PKSampler
PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples)
: num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {}
Status PKSamplerObj::ValidateParams() {
if (num_val_ <= 0) {
RETURN_STATUS_UNEXPECTED("PKSampler: num_val must be greater than 0, but got: " + std::to_string(num_val_));
}
if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("PKSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}
return Status::OK();
}
Status PKSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "PKSampler";
args["num_val"] = num_val_;
args["shuffle"] = shuffle_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);
BuildChildren(sampler);
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
std::shared_ptr<mindrecord::ShardOperator> mind_sampler;
if (shuffle_ == true) {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(),
GetSeed(), num_samples_);
} else {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_);
}
return mind_sampler;
}
#endif
// PreBuiltOperation
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {}
#ifndef ENABLE_ANDROID
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler)
: sp_minddataset_(std::move(sampler)) {}
#endif
Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); }
std::shared_ptr<SamplerRT> PreBuiltSamplerObj::SamplerBuild() {
BuildChildren(sp_);
return sp_;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; }
#endif
std::shared_ptr<SamplerObj> PreBuiltSamplerObj::SamplerCopy() {
#ifndef ENABLE_ANDROID
if (sp_minddataset_ != nullptr) {
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_minddataset_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}
#endif
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}
Status PreBuiltSamplerObj::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sp_->to_json(out_json));
return Status::OK();
}
// RandomSampler
RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch)
: replacement_(replacement), num_samples_(num_samples), reshuffle_each_epoch_(reshuffle_each_epoch) {}
Status RandomSamplerObj::ValidateParams() {
if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("RandomSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}
return Status::OK();
}
Status RandomSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "RandomSampler";
args["replacement"] = replacement_;
args["num_samples"] = num_samples_;
args["reshuffle_each_epoch"] = reshuffle_each_epoch_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch_);
BuildChildren(sampler);
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler =
std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples_, replacement_, reshuffle_each_epoch_);
return mind_sampler;
}
#endif
// SequentialSampler
SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples)
: start_index_(start_index), num_samples_(num_samples) {}
Status SequentialSamplerObj::ValidateParams() {
if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("SequentialSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}
if (start_index_ < 0) {
RETURN_STATUS_UNEXPECTED("SequentialSampler: start_index_ must be greater than or equal to 0, but got: " +
std::to_string(start_index_));
}
return Status::OK();
}
Status SequentialSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "SequentialSampler";
args["start_index"] = start_index_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
std::shared_ptr<SamplerRT> SequentialSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);
BuildChildren(sampler);
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> SequentialSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardSequentialSample>(num_samples_, start_index_);
return mind_sampler;
}
#endif
// SubsetSampler
SubsetSamplerObj::SubsetSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
: indices_(std::move(indices)), num_samples_(num_samples) {}
Status SubsetSamplerObj::ValidateParams() {
if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("SubsetRandomSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}
return Status::OK();
}
std::shared_ptr<SamplerRT> SubsetSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SubsetSamplerRT>(num_samples_, indices_);
BuildChildren(sampler);
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> SubsetSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardSample>(indices_);
return mind_sampler;
}
#endif
Status SubsetSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "SubsetSampler";
args["indices"] = indices_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
// SubsetRandomSampler
SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
: SubsetSamplerObj(std::move(indices), num_samples) {}
std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_);
BuildChildren(sampler);
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardSample>(indices_, GetSeed());
return mind_sampler;
}
#endif
Status SubsetRandomSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "SubsetRandomSampler";
args["indices"] = indices_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
// WeightedRandomSampler
WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
: weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}
Status WeightedRandomSamplerObj::ValidateParams() {
if (weights_.empty()) {
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not be empty");
}
int32_t zero_elem = 0;
for (int32_t i = 0; i < weights_.size(); ++i) {
if (weights_[i] < 0) {
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not contain negative number, got: " +
std::to_string(weights_[i]));
}
if (weights_[i] == 0.0) {
zero_elem++;
}
}
if (zero_elem == weights_.size()) {
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: elements of weights vector must not be all zero");
}
if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}
return Status::OK();
}
Status WeightedRandomSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "WeightedRandomSampler";
args["weights"] = weights_;
args["num_samples"] = num_samples_;
args["replacement"] = replacement_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::SamplerBuild() {
auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
BuildChildren(sampler);
return sampler;
}
} // namespace dataset
} // namespace mindspore

View File

@ -1,5 +1,6 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_subdirectory(samplers)
set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
album_node.cc

View File

@ -0,0 +1,8 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SRC_FILES
samplers_ir.cc
)
add_library(engine-ir-datasetops-source-samplers OBJECT ${DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SRC_FILES})

View File

@ -0,0 +1,490 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/util/random.h"
#include "minddata/mindrecord/include/shard_distributed_sample.h"
#include "minddata/mindrecord/include/shard_operator.h"
#include "minddata/mindrecord/include/shard_pk_sample.h"
#include "minddata/mindrecord/include/shard_sample.h"
#include "minddata/mindrecord/include/shard_sequential_sample.h"
#include "minddata/mindrecord/include/shard_shuffle.h"
#endif
namespace mindspore {
namespace dataset {
// Constructor
SamplerObj::SamplerObj() {}
void SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> sampler) {
for (auto child : children_) {
auto sampler_rt = child->SamplerBuild();
sampler->AddChild(sampler_rt);
}
}
Status SamplerObj::AddChildSampler(std::shared_ptr<SamplerObj> child) {
if (child == nullptr) {
return Status::OK();
}
// Only samplers can be added, not any other DatasetOp.
std::shared_ptr<SamplerObj> sampler = std::dynamic_pointer_cast<SamplerObj>(child);
if (!sampler) {
RETURN_STATUS_UNEXPECTED("Cannot add child, child is not a sampler object.");
}
// Samplers can have at most 1 child.
if (!children_.empty()) {
RETURN_STATUS_UNEXPECTED("Cannot add child sampler, this sampler already has a child.");
}
children_.push_back(child);
return Status::OK();
}
/* ####################################### Derived Sampler classes ################################# */
// DistributedSampler
DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
uint32_t seed, int64_t offset, bool even_dist)
: num_shards_(num_shards),
shard_id_(shard_id),
shuffle_(shuffle),
num_samples_(num_samples),
seed_(seed),
offset_(offset),
even_dist_(even_dist) {
// Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
// is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
// 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
// PreBuildSampler is phased out, this can be cleaned up.
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
}
Status DistributedSamplerObj::ValidateParams() {
if (num_shards_ <= 0) {
RETURN_STATUS_UNEXPECTED("DistributedSampler: num_shards must be greater than 0, but got: " +
std::to_string(num_shards_));
}
if (shard_id_ < 0 || shard_id_ >= num_shards_) {
RETURN_STATUS_UNEXPECTED("DistributedSampler: shard_id must be in range [0, " + std::to_string(num_shards_) +
"), but got: " + std::to_string(shard_id_));
}
if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("DistributedSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}
if (offset_ > num_shards_) {
RETURN_STATUS_UNEXPECTED("DistributedSampler: offset must be no more than num_shards(" +
std::to_string(num_shards_) + "), but got: " + std::to_string(offset_));
}
return Status::OK();
}
std::shared_ptr<SamplerRT> DistributedSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
offset_, even_dist_);
BuildChildren(sampler);
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardDistributedSample>(num_shards_, shard_id_, shuffle_, seed_,
num_samples_, offset_);
return mind_sampler;
}
#endif
Status DistributedSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "DistributedSampler";
args["num_shards"] = num_shards_;
args["shard_id"] = shard_id_;
args["shuffle"] = shuffle_;
args["num_samples"] = num_samples_;
args["offset"] = offset_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
// PKSampler
PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples)
: num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {}
Status PKSamplerObj::ValidateParams() {
if (num_val_ <= 0) {
RETURN_STATUS_UNEXPECTED("PKSampler: num_val must be greater than 0, but got: " + std::to_string(num_val_));
}
if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("PKSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}
return Status::OK();
}
Status PKSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "PKSampler";
args["num_val"] = num_val_;
args["shuffle"] = shuffle_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);
BuildChildren(sampler);
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
std::shared_ptr<mindrecord::ShardOperator> mind_sampler;
if (shuffle_ == true) {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(),
GetSeed(), num_samples_);
} else {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_);
}
return mind_sampler;
}
#endif
// PreBuiltOperation
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {}
#ifndef ENABLE_ANDROID
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler)
: sp_minddataset_(std::move(sampler)) {}
#endif
Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); }
std::shared_ptr<SamplerRT> PreBuiltSamplerObj::SamplerBuild() {
BuildChildren(sp_);
return sp_;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; }
#endif
std::shared_ptr<SamplerObj> PreBuiltSamplerObj::SamplerCopy() {
#ifndef ENABLE_ANDROID
if (sp_minddataset_ != nullptr) {
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_minddataset_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}
#endif
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}
Status PreBuiltSamplerObj::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sp_->to_json(out_json));
return Status::OK();
}
// RandomSampler
RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch)
: replacement_(replacement), num_samples_(num_samples), reshuffle_each_epoch_(reshuffle_each_epoch) {}
Status RandomSamplerObj::ValidateParams() {
if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("RandomSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}
return Status::OK();
}
Status RandomSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "RandomSampler";
args["replacement"] = replacement_;
args["num_samples"] = num_samples_;
args["reshuffle_each_epoch"] = reshuffle_each_epoch_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch_);
BuildChildren(sampler);
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler =
std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples_, replacement_, reshuffle_each_epoch_);
return mind_sampler;
}
#endif
// SequentialSampler
SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples)
: start_index_(start_index), num_samples_(num_samples) {}
Status SequentialSamplerObj::ValidateParams() {
if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("SequentialSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}
if (start_index_ < 0) {
RETURN_STATUS_UNEXPECTED("SequentialSampler: start_index_ must be greater than or equal to 0, but got: " +
std::to_string(start_index_));
}
return Status::OK();
}
Status SequentialSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "SequentialSampler";
args["start_index"] = start_index_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
std::shared_ptr<SamplerRT> SequentialSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);
BuildChildren(sampler);
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> SequentialSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardSequentialSample>(num_samples_, start_index_);
return mind_sampler;
}
#endif
// SubsetSampler
SubsetSamplerObj::SubsetSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
: indices_(std::move(indices)), num_samples_(num_samples) {}
Status SubsetSamplerObj::ValidateParams() {
if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("SubsetRandomSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}
return Status::OK();
}
std::shared_ptr<SamplerRT> SubsetSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SubsetSamplerRT>(num_samples_, indices_);
BuildChildren(sampler);
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> SubsetSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardSample>(indices_);
return mind_sampler;
}
#endif
Status SubsetSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "SubsetSampler";
args["indices"] = indices_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
// SubsetRandomSampler
SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
: SubsetSamplerObj(std::move(indices), num_samples) {}
std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::SamplerBuild() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_);
BuildChildren(sampler);
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardSample>(indices_, GetSeed());
return mind_sampler;
}
#endif
Status SubsetRandomSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "SubsetRandomSampler";
args["indices"] = indices_;
args["num_samples"] = num_samples_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
// WeightedRandomSampler
WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
: weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}
Status WeightedRandomSamplerObj::ValidateParams() {
if (weights_.empty()) {
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not be empty");
}
int32_t zero_elem = 0;
for (int32_t i = 0; i < weights_.size(); ++i) {
if (weights_[i] < 0) {
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not contain negative number, got: " +
std::to_string(weights_[i]));
}
if (weights_[i] == 0.0) {
zero_elem++;
}
}
if (zero_elem == weights_.size()) {
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: elements of weights vector must not be all zero");
}
if (num_samples_ < 0) {
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: num_samples must be greater than or equal to 0, but got: " +
std::to_string(num_samples_));
}
return Status::OK();
}
Status WeightedRandomSamplerObj::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "WeightedRandomSampler";
args["weights"] = weights_;
args["num_samples"] = num_samples_;
args["replacement"] = replacement_;
if (!children_.empty()) {
std::vector<nlohmann::json> children_args;
for (auto child : children_) {
nlohmann::json child_arg;
RETURN_IF_NOT_OK(child->to_json(&child_arg));
children_args.push_back(child_arg);
}
args["child_sampler"] = children_args;
}
*out_json = args;
return Status::OK();
}
std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::SamplerBuild() {
auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
BuildChildren(sampler);
return sampler;
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,344 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SAMPLERS_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SAMPLERS_IR_H_
#include <limits>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <nlohmann/json.hpp>
#include "include/api/status.h"
#ifndef ENABLE_ANDROID
#include "minddata/mindrecord/include/shard_operator.h"
#endif
namespace mindspore {
namespace dataset {
// Internal Sampler class forward declaration
class SamplerRT;
class SamplerObj {
public:
/// \brief Constructor
SamplerObj();
/// \brief Destructor
~SamplerObj() = default;
/// \brief Pure virtual function for derived class to implement parameters validation
/// \return The Status code of the function. It returns OK status if parameters are valid.
virtual Status ValidateParams() = 0;
/// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object
/// \return Shared pointers to the newly created Sampler
virtual std::shared_ptr<SamplerRT> SamplerBuild() = 0;
/// \brief Pure virtual function to copy a SamplerObj class
/// \return Shared pointers to the newly copied SamplerObj
virtual std::shared_ptr<SamplerObj> SamplerCopy() = 0;
/// \brief Function for derived class to get the shard id of sampler
/// \return The shard id of the derived sampler
virtual int64_t ShardId() { return 0; }
/// \brief Adds a child to the sampler
/// \param[in] child The sampler to be added as child
/// \return the Status code returned
Status AddChildSampler(std::shared_ptr<SamplerObj> child);
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; }
#ifndef ENABLE_ANDROID
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
/// \return Shared pointers to the newly created Sampler
virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; }
#endif
protected:
/// \brief A function that calls build on the children of this sampler
/// \param[in] sampler The samplerRT object built from this sampler
void BuildChildren(std::shared_ptr<SamplerRT> sampler);
std::vector<std::shared_ptr<SamplerObj>> children_;
};
/* ####################################### Derived Sampler classes ################################# */
class DistributedSamplerObj : public SamplerObj {
public:
DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed,
int64_t offset, bool even_dist);
~DistributedSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_,
offset_, even_dist_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
/// \brief Function to get the shard id of sampler
/// \return The shard id of sampler
int64_t ShardId() override { return shard_id_; }
private:
int64_t num_shards_;
int64_t shard_id_;
bool shuffle_;
int64_t num_samples_;
uint32_t seed_;
int64_t offset_;
bool even_dist_;
};
class PKSamplerObj : public SamplerObj {
public:
PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples);
~PKSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
private:
int64_t num_val_;
bool shuffle_;
int64_t num_samples_;
};
class PreBuiltSamplerObj : public SamplerObj {
public:
explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler);
#ifndef ENABLE_ANDROID
explicit PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler);
#endif
~PreBuiltSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
std::shared_ptr<SamplerObj> SamplerCopy() override;
Status ValidateParams() override;
Status to_json(nlohmann::json *out_json) override;
private:
std::shared_ptr<SamplerRT> sp_;
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> sp_minddataset_;
#endif
};
class RandomSamplerObj : public SamplerObj {
public:
RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch = true);
~RandomSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_, reshuffle_each_epoch_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
private:
bool replacement_;
int64_t num_samples_;
bool reshuffle_each_epoch_;
};
class SequentialSamplerObj : public SamplerObj {
public:
SequentialSamplerObj(int64_t start_index, int64_t num_samples);
~SequentialSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
private:
int64_t start_index_;
int64_t num_samples_;
};
class SubsetSamplerObj : public SamplerObj {
public:
SubsetSamplerObj(std::vector<int64_t> indices, int64_t num_samples);
~SubsetSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<SubsetSamplerObj>(indices_, num_samples_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
protected:
const std::vector<int64_t> indices_;
int64_t num_samples_;
};
class SubsetRandomSamplerObj : public SubsetSamplerObj {
public:
SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples);
~SubsetRandomSamplerObj() = default;
Status to_json(nlohmann::json *out_json) override;
std::shared_ptr<SamplerRT> SamplerBuild() override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
private:
};
class WeightedRandomSamplerObj : public SamplerObj {
public:
explicit WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true);
~WeightedRandomSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
private:
const std::vector<double> weights_;
int64_t num_samples_;
bool replacement_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SAMPLERS_IR_H_

View File

@ -16,6 +16,7 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_SERDES_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_SERDES_H_
#include <fstream>
#include <memory>
#include <string>
#include <vector>

View File

@ -18,72 +18,14 @@
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_
#include <memory>
#include <string>
#include <vector>
#include <nlohmann/json.hpp>
#include "include/api/status.h"
#ifndef ENABLE_ANDROID
#include "minddata/mindrecord/include/shard_column.h"
#include "minddata/mindrecord/include/shard_error.h"
#include "minddata/mindrecord/include/shard_operator.h"
#include "minddata/mindrecord/include/shard_reader.h"
#endif
// FIXME - This internal IR header will be removed when external API classes are provided
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
namespace mindspore {
namespace dataset {
// Internal Sampler class forward declaration
class SamplerRT;
class SamplerObj {
public:
/// \brief Constructor
SamplerObj();
/// \brief Destructor
~SamplerObj() = default;
/// \brief Pure virtual function for derived class to implement parameters validation
/// \return The Status code of the function. It returns OK status if parameters are valid.
virtual Status ValidateParams() = 0;
/// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object
/// \return Shared pointers to the newly created Sampler
virtual std::shared_ptr<SamplerRT> SamplerBuild() = 0;
/// \brief Pure virtual function to copy a SamplerObj class
/// \return Shared pointers to the newly copied SamplerObj
virtual std::shared_ptr<SamplerObj> SamplerCopy() = 0;
/// \brief Function for derived class to get the shard id of sampler
/// \return The shard id of the derived sampler
virtual int64_t ShardId() { return 0; }
/// \brief Adds a child to the sampler
/// \param[in] child The sampler to be added as child
/// \return the Status code returned
Status AddChildSampler(std::shared_ptr<SamplerObj> child);
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; }
#ifndef ENABLE_ANDROID
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
/// \return Shared pointers to the newly created Sampler
virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; }
#endif
protected:
/// \brief A function that calls build on the children of this sampler
/// \param[in] sampler The samplerRT object built from this sampler
void BuildChildren(std::shared_ptr<SamplerRT> sampler);
std::vector<std::shared_ptr<SamplerObj>> children_;
};
class DistributedSamplerObj;
class PKSamplerObj;
class PreBuiltSamplerObj;
@ -155,261 +97,6 @@ std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t>
std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<double> weights, int64_t num_samples = 0,
bool replacement = true);
/* ####################################### Derived Sampler classes ################################# */
class DistributedSamplerObj : public SamplerObj {
public:
DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed,
int64_t offset, bool even_dist);
~DistributedSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_,
offset_, even_dist_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
/// \brief Function to get the shard id of sampler
/// \return The shard id of sampler
int64_t ShardId() override { return shard_id_; }
private:
int64_t num_shards_;
int64_t shard_id_;
bool shuffle_;
int64_t num_samples_;
uint32_t seed_;
int64_t offset_;
bool even_dist_;
};
class PKSamplerObj : public SamplerObj {
public:
PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples);
~PKSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
private:
int64_t num_val_;
bool shuffle_;
int64_t num_samples_;
};
class PreBuiltSamplerObj : public SamplerObj {
public:
explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler);
#ifndef ENABLE_ANDROID
explicit PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler);
#endif
~PreBuiltSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
std::shared_ptr<SamplerObj> SamplerCopy() override;
Status ValidateParams() override;
Status to_json(nlohmann::json *out_json) override;
private:
std::shared_ptr<SamplerRT> sp_;
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> sp_minddataset_;
#endif
};
class RandomSamplerObj : public SamplerObj {
public:
RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch = true);
~RandomSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_, reshuffle_each_epoch_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
private:
bool replacement_;
int64_t num_samples_;
bool reshuffle_each_epoch_;
};
class SequentialSamplerObj : public SamplerObj {
public:
SequentialSamplerObj(int64_t start_index, int64_t num_samples);
~SequentialSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
private:
int64_t start_index_;
int64_t num_samples_;
};
class SubsetSamplerObj : public SamplerObj {
public:
SubsetSamplerObj(std::vector<int64_t> indices, int64_t num_samples);
~SubsetSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<SubsetSamplerObj>(indices_, num_samples_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
protected:
const std::vector<int64_t> indices_;
int64_t num_samples_;
};
class SubsetRandomSamplerObj : public SubsetSamplerObj {
public:
SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples);
~SubsetRandomSamplerObj() = default;
Status to_json(nlohmann::json *out_json) override;
std::shared_ptr<SamplerRT> SamplerBuild() override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
private:
};
class WeightedRandomSamplerObj : public SamplerObj {
public:
explicit WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true);
~WeightedRandomSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
for (auto child : children_) {
sampler->AddChildSampler(child);
}
return sampler;
}
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
Status ValidateParams() override;
private:
const std::vector<double> weights_;
int64_t num_samples_;
bool replacement_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_

View File

@ -136,6 +136,7 @@ if(BUILD_MINDDATA STREQUAL "full")
${MINDDATA_DIR}/engine/ir/datasetops/shuffle_node.cc
${MINDDATA_DIR}/engine/ir/datasetops/source/album_node.cc
${MINDDATA_DIR}/engine/ir/datasetops/source/mnist_node.cc
${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/samplers_ir.cc
${MINDDATA_DIR}/engine/datasetops/dataset_op.cc
${MINDDATA_DIR}/engine/datasetops/repeat_op.cc
${MINDDATA_DIR}/engine/datasetops/epoch_ctrl_op.cc