diff --git a/mindspore/ccsrc/minddata/dataset/api/samplers.cc b/mindspore/ccsrc/minddata/dataset/api/samplers.cc index ee2f9f17444..26e6c2807b3 100644 --- a/mindspore/ccsrc/minddata/dataset/api/samplers.cc +++ b/mindspore/ccsrc/minddata/dataset/api/samplers.cc @@ -52,12 +52,12 @@ SamplerObj::SamplerObj() {} void SamplerObj::BuildChildren(std::shared_ptr sampler) { for (auto child : children_) { - auto sampler_rt = child->Build(); + auto sampler_rt = child->SamplerBuild(); sampler->AddChild(sampler_rt); } } -Status SamplerObj::AddChild(std::shared_ptr child) { +Status SamplerObj::AddChildSampler(std::shared_ptr child) { if (child == nullptr) { return Status::OK(); } @@ -183,7 +183,7 @@ Status DistributedSamplerObj::ValidateParams() { return Status::OK(); } -std::shared_ptr DistributedSamplerObj::Build() { +std::shared_ptr DistributedSamplerObj::SamplerBuild() { // runtime sampler object auto sampler = std::make_shared(num_samples_, num_shards_, shard_id_, shuffle_, seed_, offset_, even_dist_); @@ -215,7 +215,7 @@ Status PKSamplerObj::ValidateParams() { return Status::OK(); } -std::shared_ptr PKSamplerObj::Build() { +std::shared_ptr PKSamplerObj::SamplerBuild() { // runtime sampler object auto sampler = std::make_shared(num_samples_, num_val_, shuffle_); BuildChildren(sampler); @@ -232,7 +232,7 @@ PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr PreBuiltSamplerObj::Build() { +std::shared_ptr PreBuiltSamplerObj::SamplerBuild() { BuildChildren(sp_); return sp_; } @@ -241,19 +241,19 @@ std::shared_ptr PreBuiltSamplerObj::Build() { std::shared_ptr PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; } #endif -std::shared_ptr PreBuiltSamplerObj::Copy() { +std::shared_ptr PreBuiltSamplerObj::SamplerCopy() { #ifndef ENABLE_ANDROID if (sp_minddataset_ != nullptr) { auto sampler = std::make_shared(sp_minddataset_); for (auto child : children_) { - sampler->AddChild(child); + sampler->AddChildSampler(child); } return sampler; } #endif auto sampler = std::make_shared(sp_); for (auto child : children_) { - sampler->AddChild(child); + sampler->AddChildSampler(child); } return sampler; } @@ -289,7 +289,7 @@ Status RandomSamplerObj::ValidateParams() { return Status::OK(); } -std::shared_ptr RandomSamplerObj::Build() { +std::shared_ptr RandomSamplerObj::SamplerBuild() { // runtime sampler object bool reshuffle_each_epoch = true; auto sampler = std::make_shared(num_samples_, replacement_, reshuffle_each_epoch); @@ -324,7 +324,7 @@ Status SequentialSamplerObj::ValidateParams() { return Status::OK(); } -std::shared_ptr SequentialSamplerObj::Build() { +std::shared_ptr SequentialSamplerObj::SamplerBuild() { // runtime sampler object auto sampler = std::make_shared(num_samples_, start_index_); BuildChildren(sampler); @@ -352,7 +352,7 @@ Status SubsetRandomSamplerObj::ValidateParams() { return Status::OK(); } -std::shared_ptr SubsetRandomSamplerObj::Build() { +std::shared_ptr SubsetRandomSamplerObj::SamplerBuild() { // runtime sampler object auto sampler = std::make_shared(num_samples_, indices_); BuildChildren(sampler); @@ -395,7 +395,7 @@ Status WeightedRandomSamplerObj::ValidateParams() { return Status::OK(); } -std::shared_ptr WeightedRandomSamplerObj::Build() { +std::shared_ptr WeightedRandomSamplerObj::SamplerBuild() { auto sampler = std::make_shared(num_samples_, weights_, replacement_); BuildChildren(sampler); return sampler; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc index 997c13fd6e9..569e752cbc8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc @@ -40,7 +40,7 @@ ConcatNode::ConcatNode(const std::vector> &datasets } std::shared_ptr ConcatNode::Copy() { - std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); // create an empty vector to copy a concat auto node = std::make_shared(std::vector>(), sampler, children_flag_and_nums_, children_start_end_index_); @@ -77,8 +77,8 @@ Status ConcatNode::Build(std::vector> *const node_ops if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) { node_ops->push_back(std::make_shared(connector_que_size_)); } else { - node_ops->push_back(std::make_shared(connector_que_size_, sampler_->Build(), children_flag_and_nums_, - children_start_end_index_)); + node_ops->push_back(std::make_shared(connector_que_size_, sampler_->SamplerBuild(), + children_flag_and_nums_, children_start_end_index_)); } return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc index c33e6f0be42..2315ca2efc0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc @@ -594,9 +594,14 @@ Status DatasetNode::AcceptAfter(IRNodePass *const p, bool *const modified) { } Status DatasetNode::GetShardId(int32_t *const shard_id) { - if (!Children().empty()) { + if (children_.size() == 1) { // Get shard id from the child node - return Children()[0]->GetShardId(shard_id); + return children_[0]->GetShardId(shard_id); + } else if (children_.size() > 1) { + // It is okay for dataset to have more than 1 child, GetShardId shouldn't fail in this case. + // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will + // always be in front of the child_ structure, so we get the dataset size from the last child. + return children_.back()->GetShardId(shard_id); } else { RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n"); } @@ -648,7 +653,7 @@ Status MappableSourceNode::Accept(IRNodePass *const p, bool *const modified) { } Status NonMappableSourceNode::Accept(IRNodePass *const p, bool *const modified) { - return p->Visit(shared_from_base(), modified); + return p->Visit(shared_from_base(), modified); } } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h index 2997427e0fe..f573bda9720 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h @@ -330,6 +330,13 @@ class MappableSourceNode : public DatasetNode { /// \brief Node name getter /// \return Name of the current node virtual std::string Name() const = 0; + + /// \brief Sampler getter + /// \return SamplerObj of the current node + virtual std::shared_ptr Sampler() = 0; + + /// \brief Sampler setter + virtual void SetSampler(std::shared_ptr sampler) = 0; }; // NonMappableSourceNode represents the leaf nodes that can not be randomly accessed. diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc index ac527e6d261..8cfbb9ba598 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc @@ -40,7 +40,7 @@ AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_sch sampler_(sampler) {} std::shared_ptr AlbumNode::Copy() { - std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); auto node = std::make_shared(dataset_dir_, schema_path_, column_names_, decode_, sampler, cache_); return node; } @@ -75,7 +75,8 @@ Status AlbumNode::Build(std::vector> *const node_ops) RETURN_IF_NOT_OK(AddCacheOp(node_ops)); node_ops->push_back(std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, - decode_, extensions, std::move(schema), std::move(sampler_->Build()))); + decode_, extensions, std::move(schema), + std::move(sampler_->SamplerBuild()))); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h index 3cba8ca785f..7f84fbe5c75 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h @@ -66,6 +66,12 @@ class AlbumNode : public MappableSourceNode { const std::string &SchemaPath() const { return schema_path_; } const std::vector &ColumnNames() const { return column_names_; } bool Decode() const { return decode_; } + /// \brief Sampler getter + /// \return SamplerObj of the current node + std::shared_ptr Sampler() override { return sampler_; } + + /// \brief Sampler setter + void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } private: std::string dataset_dir_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc index 408f877b552..87668c7490a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc @@ -40,7 +40,7 @@ CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, extensions_(extensions) {} std::shared_ptr CelebANode::Copy() { - std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); auto node = std::make_shared(dataset_dir_, usage_, sampler, decode_, extensions_, cache_); return node; } @@ -71,7 +71,7 @@ Status CelebANode::Build(std::vector> *const node_ops node_ops->push_back(std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, decode_, usage_, extensions_, std::move(schema), - std::move(sampler_->Build()))); + std::move(sampler_->SamplerBuild()))); return Status::OK(); } @@ -139,7 +139,7 @@ Status CelebANode::GetDatasetSize(const std::shared_ptr &size num_rows = std::min(num_rows, partition_num); } - sample_size = sampler_->Build()->CalculateNumSamples(num_rows); + sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); *dataset_size = sample_size; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h index 413584f3074..ef9c3b06734 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h @@ -82,6 +82,13 @@ class CelebANode : public MappableSourceNode { /// \return Status of the function Status to_json(nlohmann::json *out_json) override; + /// \brief Sampler getter + /// \return SamplerObj of the current node + std::shared_ptr Sampler() override { return sampler_; } + + /// \brief Sampler setter + void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } + private: std::string dataset_dir_; std::string usage_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc index d48cc28a286..5eced16efa9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc @@ -33,7 +33,7 @@ Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &us : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} std::shared_ptr Cifar100Node::Copy() { - std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); auto node = std::make_shared(dataset_dir_, usage_, sampler, cache_); return node; } @@ -68,7 +68,7 @@ Status Cifar100Node::Build(std::vector> *const node_o node_ops->push_back(std::make_shared(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, std::move(schema), - std::move(sampler_->Build()))); + std::move(sampler_->SamplerBuild()))); return Status::OK(); } @@ -89,7 +89,7 @@ Status Cifar100Node::GetDatasetSize(const std::shared_ptr &si } int64_t num_rows, sample_size; RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, false, &num_rows)); - sample_size = sampler_->Build()->CalculateNumSamples(num_rows); + sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); *dataset_size = sample_size; dataset_size_ = *dataset_size; return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h index 705e3dd5c4a..17bdfb39e9c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h @@ -78,6 +78,13 @@ class Cifar100Node : public MappableSourceNode { /// \return Status of the function Status to_json(nlohmann::json *out_json) override; + /// \brief Sampler getter + /// \return SamplerObj of the current node + std::shared_ptr Sampler() override { return sampler_; } + + /// \brief Sampler setter + void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } + private: std::string dataset_dir_; std::string usage_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc index 7b0e103c99b..6d99c6c79f3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc @@ -33,7 +33,7 @@ Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usag : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} std::shared_ptr Cifar10Node::Copy() { - std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); auto node = std::make_shared(dataset_dir_, usage_, sampler, cache_); return node; } @@ -66,7 +66,7 @@ Status Cifar10Node::Build(std::vector> *const node_op node_ops->push_back(std::make_shared(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, std::move(schema), - std::move(sampler_->Build()))); + std::move(sampler_->SamplerBuild()))); return Status::OK(); } @@ -87,7 +87,7 @@ Status Cifar10Node::GetDatasetSize(const std::shared_ptr &siz } int64_t num_rows, sample_size; RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, true, &num_rows)); - sample_size = sampler_->Build()->CalculateNumSamples(num_rows); + sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); *dataset_size = sample_size; dataset_size_ = *dataset_size; return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h index 7b4c4161daf..a77eac9b4d7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h @@ -78,6 +78,13 @@ class Cifar10Node : public MappableSourceNode { /// \return Status of the function Status to_json(nlohmann::json *out_json) override; + /// \brief Sampler getter + /// \return SamplerObj of the current node + std::shared_ptr Sampler() override { return sampler_; } + + /// \brief Sampler setter + void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } + private: std::string dataset_dir_; std::string usage_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc index d07ac6984ac..ea3a2f25649 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc @@ -205,7 +205,7 @@ Status CLUENode::Build(std::vector> *const node_ops) std::shared_ptr clue_op = std::make_shared( num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, sorted_dataset_files, - connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build())); + connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->SamplerBuild())); RETURN_IF_NOT_OK(clue_op->Init()); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc index bb03ce14a67..bde444f1f2f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc @@ -38,7 +38,7 @@ CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation sampler_(sampler) {} std::shared_ptr CocoNode::Copy() { - std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); auto node = std::make_shared(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_); return node; } @@ -121,7 +121,7 @@ Status CocoNode::Build(std::vector> *const node_ops) } std::shared_ptr op = std::make_shared(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_, - connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); + connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild())); RETURN_IF_NOT_OK(AddCacheOp(node_ops)); node_ops->push_back(op); @@ -145,7 +145,7 @@ Status CocoNode::GetDatasetSize(const std::shared_ptr &size_g } int64_t num_rows = 0, sample_size; RETURN_IF_NOT_OK(CocoOp::CountTotalRows(dataset_dir_, annotation_file_, task_, &num_rows)); - sample_size = sampler_->Build()->CalculateNumSamples(num_rows); + sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); *dataset_size = sample_size; dataset_size_ = *dataset_size; return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h index f0660c9d8c1..87ecc8b0c2d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h @@ -80,6 +80,13 @@ class CocoNode : public MappableSourceNode { /// \return Status of the function Status to_json(nlohmann::json *out_json) override; + /// \brief Sampler getter + /// \return SamplerObj of the current node + std::shared_ptr Sampler() override { return sampler_; } + + /// \brief Sampler setter + void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } + private: std::string dataset_dir_; std::string annotation_file_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc index da1d3e66b01..d2ac94b958b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc @@ -122,7 +122,7 @@ Status CSVNode::Build(std::vector> *const node_ops) { std::shared_ptr csv_op = std::make_shared(sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, - num_shards_, shard_id_, std::move(sampler_->Build())); + num_shards_, shard_id_, std::move(sampler_->SamplerBuild())); RETURN_IF_NOT_OK(csv_op->Init()); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h index cd74569a058..5321419cbfd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h @@ -89,6 +89,13 @@ class GeneratorNode : public MappableSourceNode { const std::vector &ColumnTypes() const { return column_types_; } const std::shared_ptr &Schema() const { return schema_; } + /// \brief Sampler getter + /// \return SamplerObj of the current node + std::shared_ptr Sampler() override { return nullptr; } + + /// \brief Sampler setter + void SetSampler(std::shared_ptr sampler) override {} + private: py::function generator_function_; std::vector column_names_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc index d64fa201d1a..9bfbf852bfb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc @@ -42,7 +42,7 @@ ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shar exts_(extensions) {} std::shared_ptr ImageFolderNode::Copy() { - std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); auto node = std::make_shared(dataset_dir_, decode_, sampler, recursive_, exts_, class_indexing_, cache_); return node; @@ -74,7 +74,7 @@ Status ImageFolderNode::Build(std::vector> *const nod node_ops->push_back(std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, recursive_, decode_, exts_, class_indexing_, std::move(schema), - std::move(sampler_->Build()))); + std::move(sampler_->SamplerBuild()))); return Status::OK(); } @@ -94,7 +94,7 @@ Status ImageFolderNode::GetDatasetSize(const std::shared_ptr } int64_t sample_size, num_rows; RETURN_IF_NOT_OK(ImageFolderOp::CountRowsAndClasses(dataset_dir_, exts_, &num_rows, nullptr, {})); - sample_size = sampler_->Build()->CalculateNumSamples(num_rows); + sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); *dataset_size = sample_size; dataset_size_ = *dataset_size; return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h index f9a7ef77046..47688ae43ed 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h @@ -79,7 +79,6 @@ class ImageFolderNode : public MappableSourceNode { const std::string &DatasetDir() const { return dataset_dir_; } bool Decode() const { return decode_; } bool Recursive() const { return recursive_; } - const std::shared_ptr &Sampler() const { return sampler_; } const std::map &ClassIndexing() const { return class_indexing_; } const std::set &Exts() const { return exts_; } @@ -88,6 +87,13 @@ class ImageFolderNode : public MappableSourceNode { /// \return Status of the function Status to_json(nlohmann::json *out_json) override; + /// \brief Sampler getter + /// \return SamplerObj of the current node + std::shared_ptr Sampler() override { return sampler_; } + + /// \brief Sampler setter + void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } + private: std::string dataset_dir_; bool decode_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc index 413cb4dc37f..bd8f0622a75 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc @@ -40,7 +40,7 @@ ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &u sampler_(sampler) {} std::shared_ptr ManifestNode::Copy() { - std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); auto node = std::make_shared(dataset_file_, usage_, sampler, class_index_, decode_, cache_); return node; } @@ -93,7 +93,7 @@ Status ManifestNode::Build(std::vector> *const node_o std::shared_ptr manifest_op; manifest_op = std::make_shared(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_, - class_index_, std::move(schema), std::move(sampler_->Build()), usage_); + class_index_, std::move(schema), std::move(sampler_->SamplerBuild()), usage_); RETURN_IF_NOT_OK(AddCacheOp(node_ops)); node_ops->push_back(manifest_op); @@ -118,7 +118,7 @@ Status ManifestNode::GetDatasetSize(const std::shared_ptr &si int64_t num_rows, sample_size; int64_t num_classes; // dummy variable RETURN_IF_NOT_OK(ManifestOp::CountTotalRows(dataset_file_, class_index_, usage_, &num_rows, &num_classes)); - sample_size = sampler_->Build()->CalculateNumSamples(num_rows); + sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); *dataset_size = sample_size; dataset_size_ = *dataset_size; return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h index 0f4cb9ecdc4..ee7012eded2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h @@ -81,6 +81,13 @@ class ManifestNode : public MappableSourceNode { /// \return Status of the function Status to_json(nlohmann::json *out_json) override; + /// \brief Sampler getter + /// \return SamplerObj of the current node + std::shared_ptr Sampler() override { return sampler_; } + + /// \brief Sampler setter + void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } + private: std::string dataset_file_; std::string usage_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc index f4ca3503aa8..b76155ec1ee 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc @@ -54,7 +54,7 @@ MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector MindDataNode::Copy() { std::shared_ptr node; - std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); if (dataset_files_.empty()) { node = std::make_shared(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_); } else { diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h index 031439c3cf8..abec37ed34a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h @@ -85,6 +85,13 @@ class MindDataNode : public MappableSourceNode { Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, int64_t *dataset_size) override; + /// \brief Sampler getter + /// \return SamplerObj of the current node + std::shared_ptr Sampler() override { return sampler_; } + + /// \brief Sampler setter + void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } + private: std::string dataset_file_; // search_for_pattern_ will be true in this mode std::vector dataset_files_; // search_for_pattern_ will be false in this mode diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc index 2702d674d94..2c32a7fe555 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc @@ -32,7 +32,7 @@ MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} std::shared_ptr MnistNode::Copy() { - std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); auto node = std::make_shared(dataset_dir_, usage_, sampler, cache_); return node; } @@ -60,7 +60,8 @@ Status MnistNode::Build(std::vector> *const node_ops) RETURN_IF_NOT_OK(AddCacheOp(node_ops)); node_ops->push_back(std::make_shared(usage_, num_workers_, rows_per_buffer_, dataset_dir_, - connector_que_size_, std::move(schema), std::move(sampler_->Build()))); + connector_que_size_, std::move(schema), + std::move(sampler_->SamplerBuild()))); return Status::OK(); } @@ -81,7 +82,7 @@ Status MnistNode::GetDatasetSize(const std::shared_ptr &size_ } int64_t num_rows, sample_size; RETURN_IF_NOT_OK(MnistOp::CountTotalRows(dataset_dir_, usage_, &num_rows)); - sample_size = sampler_->Build()->CalculateNumSamples(num_rows); + sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); *dataset_size = sample_size; dataset_size_ = *dataset_size; return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h index 6b21b2d148b..6c1c37a91d1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h @@ -72,13 +72,19 @@ class MnistNode : public MappableSourceNode { /// \brief Getter functions const std::string &DatasetDir() const { return dataset_dir_; } const std::string &Usage() const { return usage_; } - const std::shared_ptr &Sampler() const { return sampler_; } /// \brief Get the arguments of node /// \param[out] out_json JSON string of all attributes /// \return Status of the function Status to_json(nlohmann::json *out_json) override; + /// \brief Sampler getter + /// \return SamplerObj of the current node + std::shared_ptr Sampler() override { return sampler_; } + + /// \brief Sampler setter + void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } + private: std::string dataset_dir_; std::string usage_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc index 6b8a75e9b94..af24433064f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc @@ -114,7 +114,7 @@ Status RandomNode::Build(std::vector> *const node_ops std::shared_ptr op; op = std::make_shared(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_, - std::move(data_schema_), std::move(sampler_->Build())); + std::move(data_schema_), std::move(sampler_->SamplerBuild())); RETURN_IF_NOT_OK(AddCacheOp(node_ops)); node_ops->push_back(op); @@ -124,8 +124,8 @@ Status RandomNode::Build(std::vector> *const node_ops // Get the shard id of node Status RandomNode::GetShardId(int32_t *shard_id) { - *shard_id = sampler_->ShardId(); - + // RandomDataset doesn't support multiple shards + *shard_id = 0; return Status::OK(); } @@ -138,13 +138,7 @@ Status RandomNode::GetDatasetSize(const std::shared_ptr &size } int64_t num_rows; num_rows = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows(); - if (sampler_ != nullptr) { - int64_t sample_size; - sample_size = sampler_->Build()->CalculateNumSamples(num_rows); - *dataset_size = sample_size; - } else { - *dataset_size = num_rows; - } + *dataset_size = num_rows; dataset_size_ = *dataset_size; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h index 90881392354..3bfbc4ac156 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h @@ -110,7 +110,6 @@ class RandomNode : public NonMappableSourceNode { std::string schema_path_; std::shared_ptr schema_; std::vector columns_list_; - std::shared_ptr sampler_; std::mt19937 rand_gen_; std::unique_ptr data_schema_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc index 6fcb4ebab59..e89adac546b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc @@ -90,7 +90,7 @@ Status TextFileNode::Build(std::vector> *const node_o // Create and initalize TextFileOp std::shared_ptr text_file_op = std::make_shared( num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files, - connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build())); + connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->SamplerBuild())); RETURN_IF_NOT_OK(text_file_op->Init()); if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc index 2ce98b4086f..901a13ee2bb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc @@ -131,7 +131,7 @@ Status TFRecordNode::Build(std::vector> *const node_o std::shared_ptr tf_reader_op = std::make_shared(num_workers_, worker_connector_size_, rows_per_buffer_, num_samples_, sorted_dir_files, std::move(data_schema), connector_que_size_, columns_list_, shuffle_files, num_shards_, - shard_id_, shard_equal_rows_, std::move(sampler_->Build())); + shard_id_, shard_equal_rows_, std::move(sampler_->SamplerBuild())); RETURN_IF_NOT_OK(tf_reader_op->Init()); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc index 25f96236401..b2e799747ac 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc @@ -41,7 +41,7 @@ VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const sampler_(sampler) {} std::shared_ptr VOCNode::Copy() { - std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); auto node = std::make_shared(dataset_dir_, task_, usage_, class_index_, decode_, sampler, cache_); return node; } @@ -110,8 +110,9 @@ Status VOCNode::Build(std::vector> *const node_ops) { } std::shared_ptr voc_op; - voc_op = std::make_shared(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_, - connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); + voc_op = + std::make_shared(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_, + connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild())); RETURN_IF_NOT_OK(AddCacheOp(node_ops)); node_ops->push_back(voc_op); @@ -134,7 +135,7 @@ Status VOCNode::GetDatasetSize(const std::shared_ptr &size_ge } int64_t num_rows = 0, sample_size; RETURN_IF_NOT_OK(VOCOp::CountTotalRows(dataset_dir_, task_, usage_, class_index_, &num_rows)); - sample_size = sampler_->Build()->CalculateNumSamples(num_rows); + sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); *dataset_size = sample_size; dataset_size_ = *dataset_size; return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h index 195ba34ddf3..318e3c29baf 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h @@ -83,6 +83,13 @@ class VOCNode : public MappableSourceNode { /// \return Status of the function Status to_json(nlohmann::json *out_json) override; + /// \brief Sampler getter + /// \return SamplerObj of the current node + std::shared_ptr Sampler() override { return sampler_; } + + /// \brief Sampler setter + void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } + private: const std::string kColumnImage = "image"; const std::string kColumnTarget = "target"; diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/post/auto_worker_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/post/auto_worker_pass.cc index f195686458a..3c70bc7f6c0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/post/auto_worker_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/post/auto_worker_pass.cc @@ -101,7 +101,7 @@ Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr n } Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr node, bool *const modified) { - auto itr = weight_profile_.find("NonMappableSourceNode"); + auto itr = weight_profile_.find("NonMappableSource"); CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(), "NonLeafSource::" + node->Name() + "'s weight doesn't exist."); int32_t weight = itr->second; diff --git a/mindspore/ccsrc/minddata/dataset/include/samplers.h b/mindspore/ccsrc/minddata/dataset/include/samplers.h index 217659f30a7..90f975e44f4 100644 --- a/mindspore/ccsrc/minddata/dataset/include/samplers.h +++ b/mindspore/ccsrc/minddata/dataset/include/samplers.h @@ -49,11 +49,11 @@ class SamplerObj : public std::enable_shared_from_this { /// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object /// \return Shared pointers to the newly created Sampler - virtual std::shared_ptr Build() = 0; + virtual std::shared_ptr SamplerBuild() = 0; /// \brief Pure virtual function to copy a SamplerObj class /// \return Shared pointers to the newly copied SamplerObj - virtual std::shared_ptr Copy() = 0; + virtual std::shared_ptr SamplerCopy() = 0; /// \brief Function for derived class to get the shard id of sampler /// \return The shard id of the derived sampler @@ -62,7 +62,7 @@ class SamplerObj : public std::enable_shared_from_this { /// \brief Adds a child to the sampler /// \param[in] child The sampler to be added as child /// \return the Status code returned - Status AddChild(std::shared_ptr child); + Status AddChildSampler(std::shared_ptr child); virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); } @@ -152,13 +152,13 @@ class DistributedSamplerObj : public SamplerObj { ~DistributedSamplerObj() = default; - std::shared_ptr Build() override; + std::shared_ptr SamplerBuild() override; - std::shared_ptr Copy() override { + std::shared_ptr SamplerCopy() override { auto sampler = std::make_shared(num_shards_, shard_id_, shuffle_, num_samples_, seed_, offset_, even_dist_); for (auto child : children_) { - sampler->AddChild(child); + sampler->AddChildSampler(child); } return sampler; } @@ -189,12 +189,12 @@ class PKSamplerObj : public SamplerObj { ~PKSamplerObj() = default; - std::shared_ptr Build() override; + std::shared_ptr SamplerBuild() override; - std::shared_ptr Copy() override { + std::shared_ptr SamplerCopy() override { auto sampler = std::make_shared(num_val_, shuffle_, num_samples_); for (auto child : children_) { - sampler->AddChild(child); + sampler->AddChildSampler(child); } return sampler; } @@ -220,13 +220,13 @@ class PreBuiltSamplerObj : public SamplerObj { ~PreBuiltSamplerObj() = default; - std::shared_ptr Build() override; + std::shared_ptr SamplerBuild() override; #ifndef ENABLE_ANDROID std::shared_ptr BuildForMindDataset() override; #endif - std::shared_ptr Copy() override; + std::shared_ptr SamplerCopy() override; Status ValidateParams() override; @@ -245,12 +245,12 @@ class RandomSamplerObj : public SamplerObj { ~RandomSamplerObj() = default; - std::shared_ptr Build() override; + std::shared_ptr SamplerBuild() override; - std::shared_ptr Copy() override { + std::shared_ptr SamplerCopy() override { auto sampler = std::make_shared(replacement_, num_samples_); for (auto child : children_) { - sampler->AddChild(child); + sampler->AddChildSampler(child); } return sampler; } @@ -272,12 +272,12 @@ class SequentialSamplerObj : public SamplerObj { ~SequentialSamplerObj() = default; - std::shared_ptr Build() override; + std::shared_ptr SamplerBuild() override; - std::shared_ptr Copy() override { + std::shared_ptr SamplerCopy() override { auto sampler = std::make_shared(start_index_, num_samples_); for (auto child : children_) { - sampler->AddChild(child); + sampler->AddChildSampler(child); } return sampler; } @@ -299,12 +299,12 @@ class SubsetRandomSamplerObj : public SamplerObj { ~SubsetRandomSamplerObj() = default; - std::shared_ptr Build() override; + std::shared_ptr SamplerBuild() override; - std::shared_ptr Copy() override { + std::shared_ptr SamplerCopy() override { auto sampler = std::make_shared(indices_, num_samples_); for (auto child : children_) { - sampler->AddChild(child); + sampler->AddChildSampler(child); } return sampler; } @@ -326,12 +326,12 @@ class WeightedRandomSamplerObj : public SamplerObj { ~WeightedRandomSamplerObj() = default; - std::shared_ptr Build() override; + std::shared_ptr SamplerBuild() override; - std::shared_ptr Copy() override { + std::shared_ptr SamplerCopy() override { auto sampler = std::make_shared(weights_, num_samples_, replacement_); for (auto child : children_) { - sampler->AddChild(child); + sampler->AddChildSampler(child); } return sampler; } diff --git a/tests/ut/cpp/dataset/c_api_samplers_test.cc b/tests/ut/cpp/dataset/c_api_samplers_test.cc index 800e567e5c1..3b520e5a0dc 100644 --- a/tests/ut/cpp/dataset/c_api_samplers_test.cc +++ b/tests/ut/cpp/dataset/c_api_samplers_test.cc @@ -87,67 +87,67 @@ TEST_F(MindDataTestPipeline, TestCalculateNumSamples) { int64_t num_rows = 30; // dummy variable for number of rows in the dataset std::shared_ptr sampl = DistributedSampler(2, 1, false, 6); EXPECT_NE(sampl, nullptr); - std::shared_ptr sampler_rt = sampl->Build(); + std::shared_ptr sampler_rt = sampl->SamplerBuild(); EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 6); sampl = PKSampler(3, false); EXPECT_NE(sampl, nullptr); - sampler_rt = sampl->Build(); + sampler_rt = sampl->SamplerBuild(); EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 30); sampl = RandomSampler(false, 12); EXPECT_NE(sampl, nullptr); - sampler_rt = sampl->Build(); + sampler_rt = sampl->SamplerBuild(); EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12); sampl = SequentialSampler(0, 10); EXPECT_NE(sampl, nullptr); - sampler_rt = sampl->Build(); + sampler_rt = sampl->SamplerBuild(); EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 10); std::vector weights = {0.9, 0.8, 0.68, 0.7, 0.71, 0.6, 0.5, 0.4, 0.3, 0.5, 0.2, 0.1}; sampl = WeightedRandomSampler(weights, 12); EXPECT_NE(sampl, nullptr); - sampler_rt = sampl->Build(); + sampler_rt = sampl->SamplerBuild(); EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12); std::vector indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21}; sampl = SubsetRandomSampler(indices, 11); EXPECT_NE(sampl, nullptr); - sampler_rt = sampl->Build(); + sampler_rt = sampl->SamplerBuild(); EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 11); // Testing chains // Parent and child have num_samples std::shared_ptr sampl1 = WeightedRandomSampler(weights, 12); EXPECT_NE(sampl1, nullptr); - std::shared_ptr sampler_rt1 = sampl1->Build(); + std::shared_ptr sampler_rt1 = sampl1->SamplerBuild(); std::shared_ptr sampl2 = SequentialSampler(0, 10); EXPECT_NE(sampl2, nullptr); - std::shared_ptr sampler_rt2 = sampl2->Build(); + std::shared_ptr sampler_rt2 = sampl2->SamplerBuild(); sampler_rt2->AddChild(sampler_rt1); EXPECT_EQ(sampler_rt2->CalculateNumSamples(num_rows), 10); // Parent doesn't have num_samples std::shared_ptr sampl3 = WeightedRandomSampler(weights, 12); EXPECT_NE(sampl3, nullptr); - std::shared_ptr sampler_rt3 = sampl3->Build(); + std::shared_ptr sampler_rt3 = sampl3->SamplerBuild(); std::shared_ptr sampl4 = SubsetRandomSampler(indices); EXPECT_NE(sampl4, nullptr); - std::shared_ptr sampler_rt4 = sampl4->Build(); + std::shared_ptr sampler_rt4 = sampl4->SamplerBuild(); sampler_rt4->AddChild(sampler_rt3); EXPECT_EQ(sampler_rt4->CalculateNumSamples(num_rows), 12); // Child doesn't have num_samples std::shared_ptr sampl5 = RandomSampler(false); EXPECT_NE(sampl5, nullptr); - std::shared_ptr sampler_rt5 = sampl5->Build(); + std::shared_ptr sampler_rt5 = sampl5->SamplerBuild(); std::shared_ptr sampl6 = PKSampler(3, false, 7); EXPECT_NE(sampl6, nullptr); - std::shared_ptr sampler_rt6 = sampl6->Build(); + std::shared_ptr sampler_rt6 = sampl6->SamplerBuild(); sampler_rt6->AddChild(sampler_rt5); EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), 7); } @@ -156,10 +156,10 @@ TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) { std::vector indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23}; std::shared_ptr sampl1 = SubsetRandomSampler(indices); EXPECT_FALSE(indices.empty()); - EXPECT_NE(sampl1->Build(), nullptr); + EXPECT_NE(sampl1->SamplerBuild(), nullptr); std::shared_ptr sampl2 = SubsetRandomSampler(std::move(indices)); EXPECT_TRUE(indices.empty()); - EXPECT_NE(sampl2->Build(), nullptr); + EXPECT_NE(sampl2->SamplerBuild(), nullptr); } TEST_F(MindDataTestPipeline, TestWeightedRandomSamplerFail) { @@ -216,7 +216,7 @@ TEST_F(MindDataTestPipeline, TestSamplerAddChild) { EXPECT_NE(sampler, nullptr); auto child_sampler = SequentialSampler(); - sampler->AddChild(child_sampler); + sampler->AddChildSampler(child_sampler); EXPECT_NE(child_sampler, nullptr); // Create an ImageFolder Dataset diff --git a/tests/ut/python/dataset/test_cache_map.py b/tests/ut/python/dataset/test_cache_map.py index 8db7fc5923b..b885bc203c7 100644 --- a/tests/ut/python/dataset/test_cache_map.py +++ b/tests/ut/python/dataset/test_cache_map.py @@ -406,7 +406,7 @@ def test_cache_map_failure5(): num_iter = 0 for _ in data.create_dict_iterator(): num_iter += 1 - assert "MapOp with non-deterministic TensorOps is currently not supported as a descendant" in str(e.value) + assert "MapNode with non-deterministic operations is not supported as a descendant of cache" in str(e.value) assert num_iter == 0 logger.info('test_cache_failure5 Ended.\n')