forked from mindspore-Ecosystem/mindspore
Rename SamplerObj::Build() to SamplerBuild()
This commit is contained in:
parent
91842af2ad
commit
273865a72b
|
@ -52,12 +52,12 @@ SamplerObj::SamplerObj() {}
|
|||
|
||||
void SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> sampler) {
|
||||
for (auto child : children_) {
|
||||
auto sampler_rt = child->Build();
|
||||
auto sampler_rt = child->SamplerBuild();
|
||||
sampler->AddChild(sampler_rt);
|
||||
}
|
||||
}
|
||||
|
||||
Status SamplerObj::AddChild(std::shared_ptr<SamplerObj> child) {
|
||||
Status SamplerObj::AddChildSampler(std::shared_ptr<SamplerObj> child) {
|
||||
if (child == nullptr) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -183,7 +183,7 @@ Status DistributedSamplerObj::ValidateParams() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> DistributedSamplerObj::Build() {
|
||||
std::shared_ptr<SamplerRT> DistributedSamplerObj::SamplerBuild() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
|
||||
offset_, even_dist_);
|
||||
|
@ -215,7 +215,7 @@ Status PKSamplerObj::ValidateParams() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> PKSamplerObj::Build() {
|
||||
std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);
|
||||
BuildChildren(sampler);
|
||||
|
@ -232,7 +232,7 @@ PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator
|
|||
|
||||
Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); }
|
||||
|
||||
std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() {
|
||||
std::shared_ptr<SamplerRT> PreBuiltSamplerObj::SamplerBuild() {
|
||||
BuildChildren(sp_);
|
||||
return sp_;
|
||||
}
|
||||
|
@ -241,19 +241,19 @@ std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() {
|
|||
std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; }
|
||||
#endif
|
||||
|
||||
std::shared_ptr<SamplerObj> PreBuiltSamplerObj::Copy() {
|
||||
std::shared_ptr<SamplerObj> PreBuiltSamplerObj::SamplerCopy() {
|
||||
#ifndef ENABLE_ANDROID
|
||||
if (sp_minddataset_ != nullptr) {
|
||||
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_minddataset_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChild(child);
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
#endif
|
||||
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChild(child);
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
@ -289,7 +289,7 @@ Status RandomSamplerObj::ValidateParams() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> RandomSamplerObj::Build() {
|
||||
std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() {
|
||||
// runtime sampler object
|
||||
bool reshuffle_each_epoch = true;
|
||||
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch);
|
||||
|
@ -324,7 +324,7 @@ Status SequentialSamplerObj::ValidateParams() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> SequentialSamplerObj::Build() {
|
||||
std::shared_ptr<SamplerRT> SequentialSamplerObj::SamplerBuild() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);
|
||||
BuildChildren(sampler);
|
||||
|
@ -352,7 +352,7 @@ Status SubsetRandomSamplerObj::ValidateParams() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::Build() {
|
||||
std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::SamplerBuild() {
|
||||
// runtime sampler object
|
||||
auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_);
|
||||
BuildChildren(sampler);
|
||||
|
@ -395,7 +395,7 @@ Status WeightedRandomSamplerObj::ValidateParams() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::Build() {
|
||||
std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::SamplerBuild() {
|
||||
auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
|
||||
BuildChildren(sampler);
|
||||
return sampler;
|
||||
|
|
|
@ -40,7 +40,7 @@ ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets
|
|||
}
|
||||
|
||||
std::shared_ptr<DatasetNode> ConcatNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
// create an empty vector to copy a concat
|
||||
auto node = std::make_shared<ConcatNode>(std::vector<std::shared_ptr<DatasetNode>>(), sampler,
|
||||
children_flag_and_nums_, children_start_end_index_);
|
||||
|
@ -77,8 +77,8 @@ Status ConcatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
|
|||
if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) {
|
||||
node_ops->push_back(std::make_shared<ConcatOp>(connector_que_size_));
|
||||
} else {
|
||||
node_ops->push_back(std::make_shared<ConcatOp>(connector_que_size_, sampler_->Build(), children_flag_and_nums_,
|
||||
children_start_end_index_));
|
||||
node_ops->push_back(std::make_shared<ConcatOp>(connector_que_size_, sampler_->SamplerBuild(),
|
||||
children_flag_and_nums_, children_start_end_index_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -567,9 +567,14 @@ Status DatasetNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
|
|||
}
|
||||
|
||||
Status DatasetNode::GetShardId(int32_t *const shard_id) {
|
||||
if (!Children().empty()) {
|
||||
if (children_.size() == 1) {
|
||||
// Get shard id from the child node
|
||||
return Children()[0]->GetShardId(shard_id);
|
||||
return children_[0]->GetShardId(shard_id);
|
||||
} else if (children_.size() > 1) {
|
||||
// It is okay for dataset to have more than 1 child, GetShardId shouldn't fail in this case.
|
||||
// This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
|
||||
// always be in front of the child_ structure, so we get the dataset size from the last child.
|
||||
return children_.back()->GetShardId(shard_id);
|
||||
} else {
|
||||
RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n");
|
||||
}
|
||||
|
@ -621,7 +626,7 @@ Status MappableSourceNode::Accept(IRNodePass *const p, bool *const modified) {
|
|||
}
|
||||
|
||||
Status NonMappableSourceNode::Accept(IRNodePass *const p, bool *const modified) {
|
||||
return p->Visit(shared_from_base<MappableSourceNode>(), modified);
|
||||
return p->Visit(shared_from_base<NonMappableSourceNode>(), modified);
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -324,6 +324,13 @@ class MappableSourceNode : public DatasetNode {
|
|||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
virtual std::string Name() const = 0;
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
virtual std::shared_ptr<SamplerObj> Sampler() = 0;
|
||||
|
||||
/// \brief Sampler setter
|
||||
virtual void SetSampler(std::shared_ptr<SamplerObj> sampler) = 0;
|
||||
};
|
||||
|
||||
// NonMappableSourceNode represents the leaf nodes that can not be randomly accessed.
|
||||
|
|
|
@ -40,7 +40,7 @@ AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_sch
|
|||
sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> AlbumNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<AlbumNode>(dataset_dir_, schema_path_, column_names_, decode_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
@ -75,7 +75,8 @@ Status AlbumNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
|
||||
|
||||
node_ops->push_back(std::make_shared<AlbumOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
decode_, extensions, std::move(schema), std::move(sampler_->Build())));
|
||||
decode_, extensions, std::move(schema),
|
||||
std::move(sampler_->SamplerBuild())));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -66,6 +66,12 @@ class AlbumNode : public MappableSourceNode {
|
|||
const std::string &SchemaPath() const { return schema_path_; }
|
||||
const std::vector<std::string> &ColumnNames() const { return column_names_; }
|
||||
bool Decode() const { return decode_; }
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
|
|
|
@ -40,7 +40,7 @@ CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage,
|
|||
extensions_(extensions) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> CelebANode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<CelebANode>(dataset_dir_, usage_, sampler, decode_, extensions_, cache_);
|
||||
return node;
|
||||
}
|
||||
|
@ -71,7 +71,7 @@ Status CelebANode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
|
|||
|
||||
node_ops->push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
decode_, usage_, extensions_, std::move(schema),
|
||||
std::move(sampler_->Build())));
|
||||
std::move(sampler_->SamplerBuild())));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -139,7 +139,7 @@ Status CelebANode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
|
|||
num_rows = std::min(num_rows, partition_num);
|
||||
}
|
||||
|
||||
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
|
||||
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -82,6 +82,13 @@ class CelebANode : public MappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
|
|
|
@ -33,7 +33,7 @@ Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &us
|
|||
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> Cifar100Node::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<Cifar100Node>(dataset_dir_, usage_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
@ -68,7 +68,7 @@ Status Cifar100Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
|
||||
node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_,
|
||||
dataset_dir_, connector_que_size_, std::move(schema),
|
||||
std::move(sampler_->Build())));
|
||||
std::move(sampler_->SamplerBuild())));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -89,7 +89,7 @@ Status Cifar100Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
|
|||
}
|
||||
int64_t num_rows, sample_size;
|
||||
RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, false, &num_rows));
|
||||
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
|
||||
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
|
|
|
@ -78,6 +78,13 @@ class Cifar100Node : public MappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
|
|
|
@ -33,7 +33,7 @@ Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usag
|
|||
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> Cifar10Node::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<Cifar10Node>(dataset_dir_, usage_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ Status Cifar10Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_op
|
|||
|
||||
node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_,
|
||||
dataset_dir_, connector_que_size_, std::move(schema),
|
||||
std::move(sampler_->Build())));
|
||||
std::move(sampler_->SamplerBuild())));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -87,7 +87,7 @@ Status Cifar10Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz
|
|||
}
|
||||
int64_t num_rows, sample_size;
|
||||
RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, true, &num_rows));
|
||||
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
|
||||
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
|
|
|
@ -78,6 +78,13 @@ class Cifar10Node : public MappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
|
|
|
@ -205,7 +205,7 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
|
||||
std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>(
|
||||
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, sorted_dataset_files,
|
||||
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build()));
|
||||
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->SamplerBuild()));
|
||||
|
||||
RETURN_IF_NOT_OK(clue_op->Init());
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation
|
|||
sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> CocoNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<CocoNode>(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
@ -121,7 +121,7 @@ Status CocoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
}
|
||||
std::shared_ptr<CocoOp> op =
|
||||
std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_,
|
||||
connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
|
||||
connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild()));
|
||||
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
|
||||
|
||||
node_ops->push_back(op);
|
||||
|
@ -145,7 +145,7 @@ Status CocoNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g
|
|||
}
|
||||
int64_t num_rows = 0, sample_size;
|
||||
RETURN_IF_NOT_OK(CocoOp::CountTotalRows(dataset_dir_, annotation_file_, task_, &num_rows));
|
||||
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
|
||||
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
|
|
|
@ -80,6 +80,13 @@ class CocoNode : public MappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string annotation_file_;
|
||||
|
|
|
@ -122,7 +122,7 @@ Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
|||
std::shared_ptr<CsvOp> csv_op =
|
||||
std::make_shared<CsvOp>(sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_,
|
||||
rows_per_buffer_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files,
|
||||
num_shards_, shard_id_, std::move(sampler_->Build()));
|
||||
num_shards_, shard_id_, std::move(sampler_->SamplerBuild()));
|
||||
|
||||
RETURN_IF_NOT_OK(csv_op->Init());
|
||||
|
||||
|
|
|
@ -89,6 +89,13 @@ class GeneratorNode : public MappableSourceNode {
|
|||
const std::vector<DataType> &ColumnTypes() const { return column_types_; }
|
||||
const std::shared_ptr<SchemaObj> &Schema() const { return schema_; }
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return nullptr; }
|
||||
|
||||
/// \brief Sampler setter
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override {}
|
||||
|
||||
private:
|
||||
py::function generator_function_;
|
||||
std::vector<std::string> column_names_;
|
||||
|
|
|
@ -42,7 +42,7 @@ ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shar
|
|||
exts_(extensions) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> ImageFolderNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node =
|
||||
std::make_shared<ImageFolderNode>(dataset_dir_, decode_, sampler, recursive_, exts_, class_indexing_, cache_);
|
||||
return node;
|
||||
|
@ -74,7 +74,7 @@ Status ImageFolderNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const nod
|
|||
|
||||
node_ops->push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
recursive_, decode_, exts_, class_indexing_, std::move(schema),
|
||||
std::move(sampler_->Build())));
|
||||
std::move(sampler_->SamplerBuild())));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -94,7 +94,7 @@ Status ImageFolderNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter>
|
|||
}
|
||||
int64_t sample_size, num_rows;
|
||||
RETURN_IF_NOT_OK(ImageFolderOp::CountRowsAndClasses(dataset_dir_, exts_, &num_rows, nullptr, {}));
|
||||
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
|
||||
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
|
|
|
@ -79,7 +79,6 @@ class ImageFolderNode : public MappableSourceNode {
|
|||
const std::string &DatasetDir() const { return dataset_dir_; }
|
||||
bool Decode() const { return decode_; }
|
||||
bool Recursive() const { return recursive_; }
|
||||
const std::shared_ptr<SamplerObj> &Sampler() const { return sampler_; }
|
||||
const std::map<std::string, int32_t> &ClassIndexing() const { return class_indexing_; }
|
||||
const std::set<std::string> &Exts() const { return exts_; }
|
||||
|
||||
|
@ -88,6 +87,13 @@ class ImageFolderNode : public MappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
bool decode_;
|
||||
|
|
|
@ -40,7 +40,7 @@ ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &u
|
|||
sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> ManifestNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<ManifestNode>(dataset_file_, usage_, sampler, class_index_, decode_, cache_);
|
||||
return node;
|
||||
}
|
||||
|
@ -93,7 +93,7 @@ Status ManifestNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
std::shared_ptr<ManifestOp> manifest_op;
|
||||
manifest_op =
|
||||
std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_,
|
||||
class_index_, std::move(schema), std::move(sampler_->Build()), usage_);
|
||||
class_index_, std::move(schema), std::move(sampler_->SamplerBuild()), usage_);
|
||||
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
|
||||
|
||||
node_ops->push_back(manifest_op);
|
||||
|
@ -118,7 +118,7 @@ Status ManifestNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
|
|||
int64_t num_rows, sample_size;
|
||||
int64_t num_classes; // dummy variable
|
||||
RETURN_IF_NOT_OK(ManifestOp::CountTotalRows(dataset_file_, class_index_, usage_, &num_rows, &num_classes));
|
||||
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
|
||||
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
|
|
|
@ -81,6 +81,13 @@ class ManifestNode : public MappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
std::string dataset_file_;
|
||||
std::string usage_;
|
||||
|
|
|
@ -54,7 +54,7 @@ MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<st
|
|||
|
||||
std::shared_ptr<DatasetNode> MindDataNode::Copy() {
|
||||
std::shared_ptr<MindDataNode> node;
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
if (dataset_files_.empty()) {
|
||||
node = std::make_shared<MindDataNode>(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_);
|
||||
} else {
|
||||
|
|
|
@ -85,6 +85,13 @@ class MindDataNode : public MappableSourceNode {
|
|||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
std::string dataset_file_; // search_for_pattern_ will be true in this mode
|
||||
std::vector<std::string> dataset_files_; // search_for_pattern_ will be false in this mode
|
||||
|
|
|
@ -32,7 +32,7 @@ MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr
|
|||
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> MnistNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<MnistNode>(dataset_dir_, usage_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
@ -60,7 +60,8 @@ Status MnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
|
||||
|
||||
node_ops->push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_,
|
||||
connector_que_size_, std::move(schema), std::move(sampler_->Build())));
|
||||
connector_que_size_, std::move(schema),
|
||||
std::move(sampler_->SamplerBuild())));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -81,7 +82,7 @@ Status MnistNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_
|
|||
}
|
||||
int64_t num_rows, sample_size;
|
||||
RETURN_IF_NOT_OK(MnistOp::CountTotalRows(dataset_dir_, usage_, &num_rows));
|
||||
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
|
||||
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
|
|
|
@ -72,13 +72,19 @@ class MnistNode : public MappableSourceNode {
|
|||
/// \brief Getter functions
|
||||
const std::string &DatasetDir() const { return dataset_dir_; }
|
||||
const std::string &Usage() const { return usage_; }
|
||||
const std::shared_ptr<SamplerObj> &Sampler() const { return sampler_; }
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
|
|
|
@ -114,7 +114,7 @@ Status RandomNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
|
|||
|
||||
std::shared_ptr<RandomDataOp> op;
|
||||
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_,
|
||||
std::move(data_schema_), std::move(sampler_->Build()));
|
||||
std::move(data_schema_), std::move(sampler_->SamplerBuild()));
|
||||
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
|
||||
|
||||
node_ops->push_back(op);
|
||||
|
@ -124,8 +124,8 @@ Status RandomNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
|
|||
|
||||
// Get the shard id of node
|
||||
Status RandomNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
|
||||
// RandomDataset doesn't support multiple shards
|
||||
*shard_id = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -138,13 +138,7 @@ Status RandomNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
|
|||
}
|
||||
int64_t num_rows;
|
||||
num_rows = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows();
|
||||
if (sampler_ != nullptr) {
|
||||
int64_t sample_size;
|
||||
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
} else {
|
||||
*dataset_size = num_rows;
|
||||
}
|
||||
*dataset_size = num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -110,7 +110,6 @@ class RandomNode : public NonMappableSourceNode {
|
|||
std::string schema_path_;
|
||||
std::shared_ptr<SchemaObj> schema_;
|
||||
std::vector<std::string> columns_list_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
std::mt19937 rand_gen_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
};
|
||||
|
|
|
@ -90,7 +90,7 @@ Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
// Create and initalize TextFileOp
|
||||
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
|
||||
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files,
|
||||
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build()));
|
||||
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->SamplerBuild()));
|
||||
RETURN_IF_NOT_OK(text_file_op->Init());
|
||||
|
||||
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) {
|
||||
|
|
|
@ -131,7 +131,7 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
std::shared_ptr<TFReaderOp> tf_reader_op =
|
||||
std::make_shared<TFReaderOp>(num_workers_, worker_connector_size_, rows_per_buffer_, num_samples_, sorted_dir_files,
|
||||
std::move(data_schema), connector_que_size_, columns_list_, shuffle_files, num_shards_,
|
||||
shard_id_, shard_equal_rows_, std::move(sampler_->Build()));
|
||||
shard_id_, shard_equal_rows_, std::move(sampler_->SamplerBuild()));
|
||||
|
||||
RETURN_IF_NOT_OK(tf_reader_op->Init());
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const
|
|||
sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> VOCNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<VOCNode>(dataset_dir_, task_, usage_, class_index_, decode_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
@ -110,8 +110,9 @@ Status VOCNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
|||
}
|
||||
|
||||
std::shared_ptr<VOCOp> voc_op;
|
||||
voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
|
||||
connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
|
||||
voc_op =
|
||||
std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
|
||||
connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild()));
|
||||
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
|
||||
|
||||
node_ops->push_back(voc_op);
|
||||
|
@ -134,7 +135,7 @@ Status VOCNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ge
|
|||
}
|
||||
int64_t num_rows = 0, sample_size;
|
||||
RETURN_IF_NOT_OK(VOCOp::CountTotalRows(dataset_dir_, task_, usage_, class_index_, &num_rows));
|
||||
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
|
||||
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
|
|
|
@ -83,6 +83,13 @@ class VOCNode : public MappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
const std::string kColumnImage = "image";
|
||||
const std::string kColumnTarget = "target";
|
||||
|
|
|
@ -101,7 +101,7 @@ Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<MappableSourceNode> n
|
|||
}
|
||||
|
||||
Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) {
|
||||
auto itr = weight_profile_.find("NonMappableSourceNode");
|
||||
auto itr = weight_profile_.find("NonMappableSource");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(),
|
||||
"NonLeafSource::" + node->Name() + "'s weight doesn't exist.");
|
||||
int32_t weight = itr->second;
|
||||
|
|
|
@ -49,11 +49,11 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
|
|||
|
||||
/// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object
|
||||
/// \return Shared pointers to the newly created Sampler
|
||||
virtual std::shared_ptr<SamplerRT> Build() = 0;
|
||||
virtual std::shared_ptr<SamplerRT> SamplerBuild() = 0;
|
||||
|
||||
/// \brief Pure virtual function to copy a SamplerObj class
|
||||
/// \return Shared pointers to the newly copied SamplerObj
|
||||
virtual std::shared_ptr<SamplerObj> Copy() = 0;
|
||||
virtual std::shared_ptr<SamplerObj> SamplerCopy() = 0;
|
||||
|
||||
/// \brief Function for derived class to get the shard id of sampler
|
||||
/// \return The shard id of the derived sampler
|
||||
|
@ -62,7 +62,7 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
|
|||
/// \brief Adds a child to the sampler
|
||||
/// \param[in] child The sampler to be added as child
|
||||
/// \return the Status code returned
|
||||
Status AddChild(std::shared_ptr<SamplerObj> child);
|
||||
Status AddChildSampler(std::shared_ptr<SamplerObj> child);
|
||||
|
||||
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
|
||||
|
||||
|
@ -152,13 +152,13 @@ class DistributedSamplerObj : public SamplerObj {
|
|||
|
||||
~DistributedSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> Build() override;
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> Copy() override {
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_,
|
||||
offset_, even_dist_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChild(child);
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
@ -189,12 +189,12 @@ class PKSamplerObj : public SamplerObj {
|
|||
|
||||
~PKSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> Build() override;
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> Copy() override {
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChild(child);
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
@ -220,13 +220,13 @@ class PreBuiltSamplerObj : public SamplerObj {
|
|||
|
||||
~PreBuiltSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> Build() override;
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
std::shared_ptr<SamplerObj> Copy() override;
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
|
@ -245,12 +245,12 @@ class RandomSamplerObj : public SamplerObj {
|
|||
|
||||
~RandomSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> Build() override;
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> Copy() override {
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChild(child);
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
@ -272,12 +272,12 @@ class SequentialSamplerObj : public SamplerObj {
|
|||
|
||||
~SequentialSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> Build() override;
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> Copy() override {
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChild(child);
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
@ -299,12 +299,12 @@ class SubsetRandomSamplerObj : public SamplerObj {
|
|||
|
||||
~SubsetRandomSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> Build() override;
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> Copy() override {
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChild(child);
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
@ -326,12 +326,12 @@ class WeightedRandomSamplerObj : public SamplerObj {
|
|||
|
||||
~WeightedRandomSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> Build() override;
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> Copy() override {
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override {
|
||||
auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
|
||||
for (auto child : children_) {
|
||||
sampler->AddChild(child);
|
||||
sampler->AddChildSampler(child);
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
|
|
@ -87,67 +87,67 @@ TEST_F(MindDataTestPipeline, TestCalculateNumSamples) {
|
|||
int64_t num_rows = 30; // dummy variable for number of rows in the dataset
|
||||
std::shared_ptr<SamplerObj> sampl = DistributedSampler(2, 1, false, 6);
|
||||
EXPECT_NE(sampl, nullptr);
|
||||
std::shared_ptr<SamplerRT> sampler_rt = sampl->Build();
|
||||
std::shared_ptr<SamplerRT> sampler_rt = sampl->SamplerBuild();
|
||||
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 6);
|
||||
|
||||
sampl = PKSampler(3, false);
|
||||
EXPECT_NE(sampl, nullptr);
|
||||
sampler_rt = sampl->Build();
|
||||
sampler_rt = sampl->SamplerBuild();
|
||||
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 30);
|
||||
|
||||
sampl = RandomSampler(false, 12);
|
||||
EXPECT_NE(sampl, nullptr);
|
||||
sampler_rt = sampl->Build();
|
||||
sampler_rt = sampl->SamplerBuild();
|
||||
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12);
|
||||
|
||||
sampl = SequentialSampler(0, 10);
|
||||
EXPECT_NE(sampl, nullptr);
|
||||
sampler_rt = sampl->Build();
|
||||
sampler_rt = sampl->SamplerBuild();
|
||||
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 10);
|
||||
|
||||
std::vector<double> weights = {0.9, 0.8, 0.68, 0.7, 0.71, 0.6, 0.5, 0.4, 0.3, 0.5, 0.2, 0.1};
|
||||
sampl = WeightedRandomSampler(weights, 12);
|
||||
EXPECT_NE(sampl, nullptr);
|
||||
sampler_rt = sampl->Build();
|
||||
sampler_rt = sampl->SamplerBuild();
|
||||
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12);
|
||||
|
||||
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21};
|
||||
sampl = SubsetRandomSampler(indices, 11);
|
||||
EXPECT_NE(sampl, nullptr);
|
||||
sampler_rt = sampl->Build();
|
||||
sampler_rt = sampl->SamplerBuild();
|
||||
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 11);
|
||||
|
||||
// Testing chains
|
||||
// Parent and child have num_samples
|
||||
std::shared_ptr<SamplerObj> sampl1 = WeightedRandomSampler(weights, 12);
|
||||
EXPECT_NE(sampl1, nullptr);
|
||||
std::shared_ptr<SamplerRT> sampler_rt1 = sampl1->Build();
|
||||
std::shared_ptr<SamplerRT> sampler_rt1 = sampl1->SamplerBuild();
|
||||
|
||||
std::shared_ptr<SamplerObj> sampl2 = SequentialSampler(0, 10);
|
||||
EXPECT_NE(sampl2, nullptr);
|
||||
std::shared_ptr<SamplerRT> sampler_rt2 = sampl2->Build();
|
||||
std::shared_ptr<SamplerRT> sampler_rt2 = sampl2->SamplerBuild();
|
||||
sampler_rt2->AddChild(sampler_rt1);
|
||||
EXPECT_EQ(sampler_rt2->CalculateNumSamples(num_rows), 10);
|
||||
|
||||
// Parent doesn't have num_samples
|
||||
std::shared_ptr<SamplerObj> sampl3 = WeightedRandomSampler(weights, 12);
|
||||
EXPECT_NE(sampl3, nullptr);
|
||||
std::shared_ptr<SamplerRT> sampler_rt3 = sampl3->Build();
|
||||
std::shared_ptr<SamplerRT> sampler_rt3 = sampl3->SamplerBuild();
|
||||
|
||||
std::shared_ptr<SamplerObj> sampl4 = SubsetRandomSampler(indices);
|
||||
EXPECT_NE(sampl4, nullptr);
|
||||
std::shared_ptr<SamplerRT> sampler_rt4 = sampl4->Build();
|
||||
std::shared_ptr<SamplerRT> sampler_rt4 = sampl4->SamplerBuild();
|
||||
sampler_rt4->AddChild(sampler_rt3);
|
||||
EXPECT_EQ(sampler_rt4->CalculateNumSamples(num_rows), 12);
|
||||
|
||||
// Child doesn't have num_samples
|
||||
std::shared_ptr<SamplerObj> sampl5 = RandomSampler(false);
|
||||
EXPECT_NE(sampl5, nullptr);
|
||||
std::shared_ptr<SamplerRT> sampler_rt5 = sampl5->Build();
|
||||
std::shared_ptr<SamplerRT> sampler_rt5 = sampl5->SamplerBuild();
|
||||
|
||||
std::shared_ptr<SamplerObj> sampl6 = PKSampler(3, false, 7);
|
||||
EXPECT_NE(sampl6, nullptr);
|
||||
std::shared_ptr<SamplerRT> sampler_rt6 = sampl6->Build();
|
||||
std::shared_ptr<SamplerRT> sampler_rt6 = sampl6->SamplerBuild();
|
||||
sampler_rt6->AddChild(sampler_rt5);
|
||||
EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), 7);
|
||||
}
|
||||
|
@ -156,10 +156,10 @@ TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) {
|
|||
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23};
|
||||
std::shared_ptr<SamplerObj> sampl1 = SubsetRandomSampler(indices);
|
||||
EXPECT_FALSE(indices.empty());
|
||||
EXPECT_NE(sampl1->Build(), nullptr);
|
||||
EXPECT_NE(sampl1->SamplerBuild(), nullptr);
|
||||
std::shared_ptr<SamplerObj> sampl2 = SubsetRandomSampler(std::move(indices));
|
||||
EXPECT_TRUE(indices.empty());
|
||||
EXPECT_NE(sampl2->Build(), nullptr);
|
||||
EXPECT_NE(sampl2->SamplerBuild(), nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestWeightedRandomSamplerFail) {
|
||||
|
@ -216,7 +216,7 @@ TEST_F(MindDataTestPipeline, TestSamplerAddChild) {
|
|||
EXPECT_NE(sampler, nullptr);
|
||||
|
||||
auto child_sampler = SequentialSampler();
|
||||
sampler->AddChild(child_sampler);
|
||||
sampler->AddChildSampler(child_sampler);
|
||||
EXPECT_NE(child_sampler, nullptr);
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
|
|
|
@ -406,7 +406,7 @@ def test_cache_map_failure5():
|
|||
num_iter = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert "MapOp with non-deterministic TensorOps is currently not supported as a descendant" in str(e.value)
|
||||
assert "MapNode with non-deterministic operations is not supported as a descendant of cache" in str(e.value)
|
||||
|
||||
assert num_iter == 0
|
||||
logger.info('test_cache_failure5 Ended.\n')
|
||||
|
|
Loading…
Reference in New Issue