diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/CMakeLists.txt index 1854d3da101..b6eb835538d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/CMakeLists.txt @@ -7,6 +7,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES random_sampler.cc sampler.cc sequential_sampler.cc + skip_first_epoch_sampler.cc subset_random_sampler.cc subset_sampler.cc weighted_random_sampler.cc diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h index 71e45614bf5..6dd6828b6e0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -66,7 +66,7 @@ class SequentialSamplerRT : public SamplerRT { /// \return Status of the function Status to_json(nlohmann::json *out_json) override; - private: + protected: int64_t current_id_; // The id sequencer. Each new id increments from this int64_t start_index_; // The starting id. current_id_ begins from here. int64_t id_count_; // An internal counter that tracks how many ids have been produced diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/skip_first_epoch_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/skip_first_epoch_sampler.cc new file mode 100644 index 00000000000..56e19d0c51b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/skip_first_epoch_sampler.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/datasetops/source/sampler/skip_first_epoch_sampler.h" +#include + +namespace mindspore { +namespace dataset { +Status SkipFirstEpochSamplerRT::ResetSampler() { + if (id_count_ != num_samples_) { + std::string err_msg = + "[Internal ERROR] ResetSampler() called early or late. id_count_: " + std::to_string(id_count_) + + " num_samples_: " + std::to_string(num_samples_); + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_UNEXPECTED(err_msg); + } + current_id_ = 0; + id_count_ = 0; + + if (!first_epoch_done_) { + num_samples_ += start_index_; + start_index_ = 0; + samples_per_tensor_ = num_samples_; + first_epoch_done_ = true; + } + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + } + + return Status::OK(); +} + +int64_t SkipFirstEpochSamplerRT::CalculateNumSamples(const int64_t num_rows) { return -1; } + +void SkipFirstEpochSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const { + out << "\nSampler: SkipFirstEpochSampler"; + if (show_all) { + // Call the super class for displaying any common detailed info + SamplerRT::SamplerPrint(out, show_all); + // Then add our own info + out << "\nStart index: " << start_index_; + out << "\nFirst epoch done: " << first_epoch_done_; + out << "\nCurrent id: " << current_id_; + out << "\nid count:" << id_count_; + } +} + +Status SkipFirstEpochSamplerRT::to_json(nlohmann::json *out_json) { + RETURN_UNEXPECTED_IF_NULL(out_json); + nlohmann::json args; + RETURN_IF_NOT_OK(SamplerRT::to_json(&args)); + args["sampler_name"] = "SkipFirstEpochSampler"; + args["start_index"] = start_index_; + *out_json = args; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/skip_first_epoch_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/skip_first_epoch_sampler.h new file mode 100644 index 00000000000..c4c330ba8b5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/skip_first_epoch_sampler.h @@ -0,0 +1,60 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SKIP_FIRST_EPOCH_SAMPLER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SKIP_FIRST_EPOCH_SAMPLER_H_ + +#include + +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" + +namespace mindspore { +namespace dataset { +class SkipFirstEpochSamplerRT : public SequentialSamplerRT { + public: + // Constructor + using SequentialSamplerRT::SequentialSamplerRT; + + // Destructor. + ~SkipFirstEpochSamplerRT() = default; + + // for next epoch of sampleIds + // @return Status The status code returned + Status ResetSampler() override; + + /// \brief Gets the number of samples available + /// \note Since this sampler returns different number of samples in the first epoch (compared to other epochs), this + /// function always returns -1 + /// \param[in] num_rows The total number of rows in the dataset + /// \return int64_t Calculated number of samples (always -1) + int64_t CalculateNumSamples(int64_t num_rows) override; + + // Printer for debugging purposes. + // @param out - output stream to write to + // @param show_all - bool to show detailed vs summary + void SamplerPrint(std::ostream &out, bool show_all) const override; + + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + + private: + bool first_epoch_done_ = false; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SKIP_FIRST_EPOCH_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc index 228a90aee49..7d435311b79 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,8 +21,9 @@ #include #include "minddata/dataset/engine/datasetops/project_op.h" - +#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/util/status.h" + namespace mindspore { namespace dataset { @@ -73,5 +74,18 @@ Status ProjectNode::from_json(nlohmann::json json_obj, std::shared_ptr(ds, columns); return Status::OK(); } + +// Visitor accepting method for IRNodePass +Status ProjectNode::Accept(IRNodePass *const p, bool *const modified) { + // Downcast shared pointer then call visitor + return p->Visit(shared_from_base(), modified); +} + +// Visitor accepting method for IRNodePass +Status ProjectNode::AcceptAfter(IRNodePass *const p, bool *const modified) { + // Downcast shared pointer then call visitor + return p->VisitAfter(shared_from_base(), modified); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h index 866c511a211..1d72d3522ce 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -70,6 +70,18 @@ class ProjectNode : public DatasetNode { static Status from_json(nlohmann::json json_obj, std::shared_ptr ds, std::shared_ptr *result); + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(IRNodePass *const p, bool *const modified) override; + + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status AcceptAfter(IRNodePass *const p, bool *const modified) override; + private: std::vector columns_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc index 736bf6a932a..1514f68c702 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,8 +21,9 @@ #include #include "minddata/dataset/engine/datasetops/rename_op.h" - +#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/util/status.h" + namespace mindspore { namespace dataset { @@ -82,5 +83,17 @@ Status RenameNode::from_json(nlohmann::json json_obj, std::shared_ptrVisit(shared_from_base(), modified); +} + +// Visitor accepting method for IRNodePass +Status RenameNode::AcceptAfter(IRNodePass *const p, bool *const modified) { + // Downcast shared pointer then call visitor + return p->VisitAfter(shared_from_base(), modified); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h index b2248442bb0..fecf788b76d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,6 +38,18 @@ class RenameNode : public DatasetNode { /// \return Name of the current node std::string Name() const override { return kRenameNode; } + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(IRNodePass *const p, bool *const modified) override; + + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status AcceptAfter(IRNodePass *const p, bool *const modified) override; + /// \brief Print the description /// \param out - The output stream to write output to void Print(std::ostream &out) const override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h index 817d2484aa9..cde837bd498 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -99,6 +99,9 @@ class SkipNode : public DatasetNode { void SetFirstEpochOnly(bool flag) { first_epoch_only_ = flag; } + /// \brief Getter functions + const bool FirstEpochOnly() const { return first_epoch_only_; } + private: int32_t skip_count_; bool first_epoch_only_ = false; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/CMakeLists.txt index b5d10fce188..66a84ea1b64 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/CMakeLists.txt @@ -8,6 +8,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SRC_FILES random_sampler_ir.cc samplers_ir.cc sequential_sampler_ir.cc + skip_first_epoch_sampler_ir.cc subset_random_sampler_ir.cc subset_sampler_ir.cc weighted_random_sampler_ir.cc diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/sequential_sampler_ir.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/sequential_sampler_ir.h index b88be53c1f0..727a8eee48d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/sequential_sampler_ir.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/sequential_sampler_ir.h @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -65,7 +65,7 @@ class SequentialSamplerObj : public SamplerObj { Status ValidateParams() override; - private: + protected: int64_t start_index_; int64_t num_samples_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/skip_first_epoch_sampler_ir.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/skip_first_epoch_sampler_ir.cc new file mode 100644 index 00000000000..f77244a2235 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/skip_first_epoch_sampler_ir.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/ir/datasetops/source/samplers/skip_first_epoch_sampler_ir.h" +#include "minddata/dataset/engine/datasetops/source/sampler/skip_first_epoch_sampler.h" +#include "minddata/dataset/util/validators.h" + +namespace mindspore { +namespace dataset { +// Constructor +SkipFirstEpochSamplerObj::SkipFirstEpochSamplerObj(int64_t start_index) : SequentialSamplerObj(start_index, 0) {} + +// Destructor +SkipFirstEpochSamplerObj::~SkipFirstEpochSamplerObj() = default; + +Status SkipFirstEpochSamplerObj::to_json(nlohmann::json *const out_json) { + nlohmann::json args; + RETURN_IF_NOT_OK(SamplerObj::to_json(&args)); + args["sampler_name"] = "SkipFirstEpochSamplerObj"; + args["start_index"] = start_index_; + *out_json = args; + return Status::OK(); +} + +#ifndef ENABLE_ANDROID +Status SkipFirstEpochSamplerObj::from_json(nlohmann::json json_obj, std::shared_ptr *sampler) { + RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "start_index", "SkipFirstEpochSamplerObj")); + int64_t start_index = json_obj["start_index"]; + *sampler = std::make_shared(start_index); + // Run common code in super class to add children samplers + RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler)); + return Status::OK(); +} +#endif + +Status SkipFirstEpochSamplerObj::SamplerBuild(std::shared_ptr *sampler) { + // runtime sampler object + *sampler = std::make_shared(start_index_, 0); + Status s = BuildChildren(sampler); + sampler = s.IsOk() ? sampler : nullptr; + return s; +} + +std::shared_ptr SkipFirstEpochSamplerObj::SamplerCopy() { + auto sampler = std::make_shared(start_index_); + for (const auto &child : children_) { + Status rc = sampler->AddChildSampler(child); + if (rc.IsError()) { + MS_LOG(ERROR) << "[Internal ERROR] Error in copying the sampler. Message: " << rc; + } + } + return sampler; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/skip_first_epoch_sampler_ir.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/skip_first_epoch_sampler_ir.h new file mode 100644 index 00000000000..b538ca97e7f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/skip_first_epoch_sampler_ir.h @@ -0,0 +1,56 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SKIP_FIRST_EPOCH_SAMPLER_IR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SKIP_FIRST_EPOCH_SAMPLER_IR_H_ + +#include +#include + +#include "minddata/dataset/engine/ir/datasetops/source/samplers/sequential_sampler_ir.h" +#include "include/api/status.h" + +namespace mindspore { +namespace dataset { +// Internal Sampler class forward declaration +class SamplerRT; + +class SkipFirstEpochSamplerObj : public SequentialSamplerObj { + public: + explicit SkipFirstEpochSamplerObj(int64_t start_index); + + ~SkipFirstEpochSamplerObj() override; + + Status SamplerBuild(std::shared_ptr *sampler) override; + + std::shared_ptr SamplerCopy() override; + + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *const out_json) override; + +#ifndef ENABLE_ANDROID + /// \brief Function for read sampler from JSON object + /// \param[in] json_obj JSON object to be read + /// \param[out] sampler Sampler constructed from parameters in JSON object + /// \return Status of the function + static Status from_json(nlohmann::json json_obj, std::shared_ptr *sampler); +#endif +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SAMPLERS_SKIP_FIRST_EPOCH_SAMPLER_IR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt index 407aa889e41..4c654b918ca 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt @@ -15,6 +15,7 @@ set(DATASET_ENGINE_OPT_SRC_FILES pre/input_validation_pass.cc pre/node_offload_pass.cc pre/node_removal_pass.cc + pre/skip_pushdown_pass.cc ) if(ENABLE_PYTHON) diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/skip_pushdown_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/skip_pushdown_pass.cc new file mode 100644 index 00000000000..88520df26ac --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/skip_pushdown_pass.cc @@ -0,0 +1,177 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/opt/pre/skip_pushdown_pass.h" +#include "minddata/dataset/engine/ir/datasetops/batch_node.h" +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" +#include "minddata/dataset/engine/ir/datasetops/map_node.h" +#include "minddata/dataset/engine/ir/datasetops/project_node.h" +#include "minddata/dataset/engine/ir/datasetops/rename_node.h" +#include "minddata/dataset/engine/ir/datasetops/skip_node.h" +#ifndef ENABLE_ANDROID +#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" +#endif +#include "minddata/dataset/engine/ir/datasetops/source/samplers/skip_first_epoch_sampler_ir.h" + +namespace mindspore { +namespace dataset { +SkipPushdownPass::SkipNodes::SkipNodes() : skip_count_(0) {} + +// activate the optimization steps, and increase skip_count_ (if not the first skip node in the pipeline) +Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr node, bool *const modified) { + if (node->FirstEpochOnly() == false) { + return Visit(std::static_pointer_cast(node), modified); + } + skip_count_ += node->Count(); + nodes_to_remove_.push_back(node); + return Status::OK(); +} + +Status SkipPushdownPass::SkipNodes::VisitAfter(std::shared_ptr node, bool *const modified) { + if (node->FirstEpochOnly() == false) { + return VisitAfter(std::static_pointer_cast(node), modified); + } + CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ == 0, "The skip_count_ cannot be non-zero here."); + return Status::OK(); +} + +Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr node, bool *const modified) { +#ifdef ENABLE_PYTHON + if (node->BatchSizeFunc()) { + return Visit(std::static_pointer_cast(node), modified); + } +#endif + CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative."); + if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow + + // we have an active skip node above. + skip_count_ *= node->BatchSize(); + + return Status::OK(); +} + +Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr node, bool *const modified) { + CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative."); + if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow + + return Status::OK(); +} + +Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr node, bool *const modified) { + CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative."); + if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow + + return Status::OK(); +} + +Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr node, bool *const modified) { + CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative."); + if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow + + // we have an active skip node above. + auto new_sampler = std::make_shared(skip_count_); + MS_LOG(INFO) << "Adding SkipFirstEpochSampler(" << skip_count_ << ")"; + auto sampler = node->Sampler(); + if (sampler != nullptr) { + new_sampler->AddChildSampler(sampler); + } + node->SetSampler(new_sampler); + skip_count_ = 0; + + return Status::OK(); +} + +Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr node, bool *const modified) { + CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative."); + if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow + + // we have an active skip node above. + MS_LOG(WARNING) + << "Pushing down skip node below a map node will result in slightly different outputs for random transformations."; + return Status::OK(); +} + +Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr node, bool *const modified) { + CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative."); + if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow + + // insert a skip node above + insert_skip_above_.emplace_back(node, skip_count_); + skip_count_ = 0; + return Status::OK(); +} + +#ifndef ENABLE_ANDROID +// Since MindDataset requires its own SkipFirstEpochSampler (which is not implemented) we insert the skip node above it. +Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr node, bool *const modified) { + CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative."); + if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow + + // insert a skip node above + insert_skip_above_.emplace_back(node, skip_count_); + skip_count_ = 0; + return Status::OK(); +} +#endif + +// This functions is used for Ops that are random, and the ones in which Visit is Not Implemented yet; +Status SkipPushdownPass::SkipNodes::Visit(std::shared_ptr node, bool *const modified) { + CHECK_FAIL_RETURN_UNEXPECTED(skip_count_ >= 0, "The skip size cannot be negative."); + if (skip_count_ == 0) return Status::OK(); // no active skip node above. normal flow + + // insert a skip node above + insert_skip_above_.emplace_back(node, skip_count_); + skip_count_ = 0; + return Status::OK(); +} + +// constructor +SkipPushdownPass::SkipPushdownPass() {} + +// Walk the tree to push down the skip node inserted when Reset is called. +Status SkipPushdownPass::RunOnTree(std::shared_ptr root_ir, bool *const modified) { + MS_LOG(INFO) << "Pre pass: skip node pushdown pass started."; + // Assumption: The total skip counts in the first_epoch_only skip node is less than the size of the dataset. This + // assumption is not validated here. + // Create the skip node pass which can identify which nodes need to be removed and which ones added. + std::unique_ptr skip_nodes = std::make_unique(); + if (root_ir->IsSizeDefined()) { + RETURN_IF_NOT_OK(skip_nodes->Run(root_ir, modified)); + } + + // Update modified flag if there were any nodes identified to be removed + if (skip_nodes->nodes_to_remove().empty() == false || skip_nodes->insert_skip_above().empty() == false) { + *modified = true; + } + + // Add skip node(s) to the tree (if any) + for (auto iter : skip_nodes->insert_skip_above()) { + MS_LOG(INFO) << "Inserting a Skip(" << iter.second << ") node above this node: " << iter.first->Name(); + auto new_skip_node = std::make_shared(iter.second); + new_skip_node->SetFirstEpochOnly(true); + RETURN_IF_NOT_OK(iter.first->InsertAbove(new_skip_node)); + } + + // Then, execute the removal of any nodes that were set up for removal + for (auto node : skip_nodes->nodes_to_remove()) { + RETURN_IF_NOT_OK(node->Drop()); + } + + MS_LOG(INFO) << "Pre pass: skip node pushdown pass is complete."; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/skip_pushdown_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/skip_pushdown_pass.h new file mode 100644 index 00000000000..229bf2f4d7b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/skip_pushdown_pass.h @@ -0,0 +1,156 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_SKIP_PUSHDOWN_PASS_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_SKIP_PUSHDOWN_PASS_H_ + +#include +#include +#include +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { +class BatchNode; +class DatasetNode; +class DatasetOp; +class MappableSourceNode; +class MapNode; +#ifndef ENABLE_ANDROID +class MindDataNode; +#endif +class NonMappableSourceNode; +class ProjectNode; +class RenameNode; +class SkipNode; + +/// \class SkipPushdownPass skip_pushdown_pass.h +/// \brief This is a tree pass that will push down a skip node. It uses SkipNodes to first identify if we have a skip +/// node, and then based on the node types we observe in the tree, decide where to place the skip node (or use a +/// SequentialSampler for MappableSource nodes). +class SkipPushdownPass : public IRTreePass { + /// \class SkipNodes + /// \brief This is a NodePass whose job is to handle different nodes accordingly. + /// It works in conjunction with the SkipPushdownPass. + class SkipNodes : public IRNodePass { + public: + /// \brief Constructor + SkipNodes(); + + /// \brief Destructor + ~SkipNodes() = default; + + /// \brief Perform skip node pushdown initiation check on a SkipNode + /// \param[in] node The node being visited + /// \param[in, out] modified Indicator if the node was changed at all + /// \return Status The status code returned + Status Visit(std::shared_ptr node, bool *const modified) override; + + /// \brief Perform skip node pushdown completion check on a SkipNode + /// \param[in] node The node being visited + /// \param[in, out] modified Indicator if the node was changed at all + /// \return Status The status code returned + Status VisitAfter(std::shared_ptr node, bool *const modified) override; + + /// \brief Perform skip node pushdown check on a BatchNode + /// \param[in] node The node being visited + /// \param[in, out] modified Indicator if the node was changed at all + /// \return Status The status code returned + Status Visit(std::shared_ptr node, bool *const modified) override; + + /// \brief Perform skip node pushdown check on a ProjectNode + /// \param[in] node The node being visited + /// \param[in, out] modified Indicator if the node was changed at all + /// \return Status The status code returned + Status Visit(std::shared_ptr node, bool *const modified) override; + + /// \brief Perform skip node pushdown check on a RenameNode + /// \param[in] node The node being visited + /// \param[in, out] modified Indicator if the node was changed at all + /// \return Status The status code returned + Status Visit(std::shared_ptr node, bool *const modified) override; + + /// \brief Perform skip node pushdown check on a MappableSourceNode + /// \param[in] node The node being visited + /// \param[in, out] modified Indicator if the node was changed at all + /// \return Status The status code returned + Status Visit(std::shared_ptr node, bool *const modified) override; + + /// \brief Perform skip node pushdown check on a MapNode + /// \param[in] node The node being visited + /// \param[in, out] modified Indicator if the node was changed at all + /// \return Status The status code returned + Status Visit(std::shared_ptr node, bool *const modified) override; + + /// \brief Perform skip node pushdown check on a NonMappableSourceNode + /// \param[in] node The node being visited + /// \param[in, out] modified Indicator if the node was changed at all + /// \return Status The status code returned + Status Visit(std::shared_ptr node, bool *const modified) override; + +#ifndef ENABLE_ANDROID + /// \brief Perform skip node pushdown check on a MindDataNode + /// \param[in] node The node being visited + /// \param[in, out] modified Indicator if the node was changed at all + /// \return Status The status code returned + Status Visit(std::shared_ptr node, bool *const modified) override; +#endif + + /// \brief Perform skip node pushdown check on a DatasetNode + /// \param[in] node The node being visited + /// \param[in, out] modified Indicator if the node was changed at all + /// \return Status The status code returned + Status Visit(std::shared_ptr node, bool *const modified) override; + + /// \brief Perform skip node pushdown completion check on a DatasetNode + /// \param[in] node The node being visited + /// \param[in, out] modified Indicator if the node was changed at all + /// \return Status The status code returned + Status VisitAfter(std::shared_ptr node, bool *const modified) override { return Status::OK(); }; + + /// \brief Getter + /// \return All the nodes where a skip node needs to be inserted above (and the skip count). + const std::vector, int64_t>> &insert_skip_above() const { + return insert_skip_above_; + } + + /// \brief Getter + /// \return All the nodes to be removed + const std::vector> &nodes_to_remove() const { return nodes_to_remove_; } + + private: + std::vector, int64_t>> insert_skip_above_; + std::vector> nodes_to_remove_; + int64_t skip_count_; + }; + + public: + /// \brief Constructor + SkipPushdownPass(); + + /// \brief Destructor + ~SkipPushdownPass() = default; + + /// \brief Runs a skip_pushdown pass to push down the skip node found in the tree (for Reset scenario). + /// \param[in, out] tree The tree to operate on. + /// \param[in, out] Indicate of the tree was modified. + /// \return Status The status code returned + Status RunOnTree(std::shared_ptr root_ir, bool *const modified) override; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_SKIP_PUSHDOWN_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc index e9b249d20df..dbd9eefb618 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc @@ -36,6 +36,7 @@ #include "minddata/dataset/engine/opt/pre/getter_pass.h" #include "minddata/dataset/engine/opt/pre/input_validation_pass.h" #include "minddata/dataset/engine/opt/pre/node_removal_pass.h" +#include "minddata/dataset/engine/opt/pre/skip_pushdown_pass.h" namespace mindspore { namespace dataset { @@ -57,8 +58,11 @@ Status TreeAdapter::PrePass(std::shared_ptr ir) { MS_LOG(INFO) << "Running pre pass loops."; actions.emplace_back(std::make_unique()); actions.emplace_back(std::make_unique()); + if (usage_ == kDeReset) { + actions.emplace_back(std::make_unique()); + actions.emplace_back(std::make_unique()); + } actions.emplace_back(std::make_unique()); - if (usage_ == kDeReset) actions.emplace_back(std::make_unique()); actions.emplace_back(std::make_unique()); if (usage_ == kDeGetter) actions.emplace_back(std::make_unique()); #ifndef ENABLE_ANDROID diff --git a/mindspore/lite/minddata/CMakeLists.txt b/mindspore/lite/minddata/CMakeLists.txt index c4351546eaf..4c74683fc59 100644 --- a/mindspore/lite/minddata/CMakeLists.txt +++ b/mindspore/lite/minddata/CMakeLists.txt @@ -153,6 +153,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full") ${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/random_sampler_ir.cc ${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/samplers_ir.cc ${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/sequential_sampler_ir.cc + ${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/skip_first_epoch_sampler_ir.cc ${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/subset_random_sampler_ir.cc ${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/subset_sampler_ir.cc ${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/weighted_random_sampler_ir.cc @@ -179,6 +180,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full") ${MINDDATA_DIR}/engine/opt/pre/node_removal_pass.cc ${MINDDATA_DIR}/engine/opt/pre/epoch_ctrl_pass.cc ${MINDDATA_DIR}/engine/opt/pre/deep_copy_pass.cc + ${MINDDATA_DIR}/engine/opt/pre/skip_pushdown_pass.cc ${MINDDATA_DIR}/engine/opt/post/auto_worker_pass.cc ${MINDDATA_DIR}/engine/opt/pass.cc ${MINDDATA_DIR}/engine/perf/auto_tune.cc @@ -193,6 +195,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full") ${MINDDATA_DIR}/engine/datasetops/source/sampler/pk_sampler.cc ${MINDDATA_DIR}/engine/datasetops/source/sampler/random_sampler.cc ${MINDDATA_DIR}/engine/datasetops/source/sampler/sequential_sampler.cc + ${MINDDATA_DIR}/engine/datasetops/source/sampler/skip_first_epoch_sampler.cc ${MINDDATA_DIR}/engine/datasetops/source/sampler/subset_random_sampler.cc ${MINDDATA_DIR}/engine/datasetops/source/sampler/weighted_random_sampler.cc ${MINDDATA_DIR}/engine/runtime_context.cc diff --git a/mindspore/python/mindspore/dataset/engine/datasets.py b/mindspore/python/mindspore/dataset/engine/datasets.py index e3299c785bc..3299de0d001 100644 --- a/mindspore/python/mindspore/dataset/engine/datasets.py +++ b/mindspore/python/mindspore/dataset/engine/datasets.py @@ -80,7 +80,6 @@ except ModuleNotFoundError: if platform.system().lower() == "darwin" and multiprocessing.get_start_method() != "fork": multiprocessing.set_start_method("fork", True) - OffloadToManualOffloadMode = { None: cde.ManualOffloadMode.UNSPECIFIED, False: cde.ManualOffloadMode.DISABLED, @@ -120,7 +119,7 @@ def _reset_training_dataset(step): """ dataset = _get_training_dataset() if dataset is not None: - dataset.reset(step) + dataset._reset(step) # pylint: disable=W0212 else: raise RuntimeError("Training dataset is not set.") @@ -3528,7 +3527,7 @@ class _ToDevice: def send(self): self._to_device.Send() - def reset(self, step): + def _reset(self, step): self._to_device.Reset(step) def stop_send(self): @@ -3638,10 +3637,10 @@ class TransferDataset(Dataset): if self._to_device is not None: self._to_device.continue_send() - def reset(self, step): + def _reset(self, step): if self._to_device is not None: logger.info("Reset the dataset pipeline to step " + str(step)) - self._to_device.reset(step) + self._to_device._reset(step) # pylint: disable=W0212 def get_data_info(self): """ diff --git a/mindspore/python/mindspore/dataset/engine/iterators.py b/mindspore/python/mindspore/dataset/engine/iterators.py index 665ba64789d..a0245f829b4 100644 --- a/mindspore/python/mindspore/dataset/engine/iterators.py +++ b/mindspore/python/mindspore/dataset/engine/iterators.py @@ -178,7 +178,7 @@ class Iterator: self._getters() return self._col_names - def reset(self, step): + def _reset(self, step): """ Reset the iterator to the given step number. diff --git a/mindspore/python/mindspore/train/dataset_helper.py b/mindspore/python/mindspore/train/dataset_helper.py index b4349f14a6d..a98c224f70f 100644 --- a/mindspore/python/mindspore/train/dataset_helper.py +++ b/mindspore/python/mindspore/train/dataset_helper.py @@ -330,9 +330,9 @@ class DatasetHelper: """Continue to send data to device at the beginning of epoch.""" self.iter.continue_send() - def reset(self, step): + def _reset(self, step): """Reset the dataset to the provided step.""" - self.iter.reset(step) + self.iter._reset(step) # pylint: disable=W0212 def get_data_info(self): """ @@ -387,10 +387,11 @@ class _DatasetIter: self.stop_send = dataset.__transfer_dataset__.stop_send self.release = dataset.__transfer_dataset__.release self.continue_send = dataset.__transfer_dataset__.continue_send - self.reset = dataset.__transfer_dataset__.reset self.get_data_info = dataset.__transfer_dataset__.get_data_info self.dynamic_min_max_shapes = dataset.dynamic_min_max_shapes self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) + if hasattr(dataset.__transfer_dataset__, "_reset"): + self._reset = dataset.__transfer_dataset__._reset # pylint: disable=W0212 def __iter__(self): self.index = 0 diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index e9763df5ee3..3af08f27a00 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -162,6 +162,7 @@ SET(DE_UT_SRCS rgba_to_bgr_op_test.cc rgba_to_rgb_op_test.cc schema_test.cc + skip_pushdown_optimization_pass_test.cc slice_op_test.cc sliding_window_op_test.cc solarize_op_test.cc diff --git a/tests/ut/cpp/dataset/skip_pushdown_optimization_pass_test.cc b/tests/ut/cpp/dataset/skip_pushdown_optimization_pass_test.cc new file mode 100644 index 00000000000..e589565f391 --- /dev/null +++ b/tests/ut/cpp/dataset/skip_pushdown_optimization_pass_test.cc @@ -0,0 +1,676 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "common/common.h" +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" +#include "minddata/dataset/engine/opt/pre/skip_pushdown_pass.h" +#include "minddata/dataset/include/dataset/samplers.h" +#include "minddata/dataset/include/dataset/vision.h" + +using namespace mindspore::dataset; + +class MindDataSkipPushdownTestOptimizationPass : public UT::DatasetOpTesting { + protected: + MindDataSkipPushdownTestOptimizationPass() {} + + /// \brief Compile and compare two datasets + /// \param[in] root_original Original dataset to be added the skip step + /// \param[in] root_target Target dataset for compare + /// \param[in] step Skip step + /// \return Status of the function + Status prepare_trees(std::shared_ptr root_original, std::shared_ptr root_target, int64_t step = 0) { + auto ir_tree = std::make_shared(TreeAdapter::UsageFlag::kDeReset); + // Compile adds a new RootNode to the top of the tree + RETURN_IF_NOT_OK(ir_tree->Compile(root_original->IRNode(), 1, step)); + + auto ir_tree_target = std::make_shared(); + // Compile adds a new RootNode to the top of the tree + RETURN_IF_NOT_OK(ir_tree_target->Compile(root_target->IRNode(), 1, + 0)); // Step is 0 for target node tree + + if (step != 0) { + RETURN_IF_NOT_OK(compare_pass(ir_tree_target->RootIRNode(), ir_tree->RootIRNode())); + } + RETURN_IF_NOT_OK(compare_pass_row(ir_tree_target, ir_tree)); + return Status::OK(); + } + + /// \brief Compare two dataset node trees + /// \param[in] expect Expected node tree for compare + /// \param[in] root Root node tree for compare + /// \return Status of the function + Status compare_pass(std::shared_ptr expect, std::shared_ptr root) { + if (expect->Children().size() == root->Children().size() && expect->Children().size() == 0) { + return Status::OK(); + } + if (expect->Children().size() == root->Children().size() && expect->Children().size() != 0) { + for (int i = 0; i < expect->Children().size(); i++) { + std::string expect_name = expect->Children()[i]->Name(); + std::string root_name = root->Children()[i]->Name(); + CHECK_FAIL_RETURN_UNEXPECTED(expect_name == root_name, + "Expect child is " + expect_name + ", but got " + root_name); + RETURN_IF_NOT_OK(compare_pass(expect->Children()[i], root->Children()[i])); + } + } else { + return Status(StatusCode::kMDUnexpectedError, "Skip Optimization is not working as expected, expect to have " + + std::to_string(expect->Children().size()) + + " operation, but got " + std::to_string(root->Children().size())); + } + return Status::OK(); + } + + /// \brief Compare each row of two dataset node trees + /// \param[in] expect Expected tree for compare + /// \param[in] root Root tree for compare + /// \return Status of the function + Status compare_pass_row(std::shared_ptr expect, std::shared_ptr root) { + TensorRow row_expect; + TensorRow row_root; + RETURN_IF_NOT_OK(expect->GetNext(&row_expect)); + RETURN_IF_NOT_OK(root->GetNext(&row_root)); + while (row_expect.size() != 0 && row_root.size() != 0) { + std::vector> e = row_expect.getRow(); + std::vector> r = row_root.getRow(); + for (int i = 0; i < e.size(); i++) { + nlohmann::json out_json; + RETURN_IF_NOT_OK(e[i]->to_json(&out_json)); + std::stringstream json_ss; + json_ss << out_json; + + nlohmann::json out_json1; + RETURN_IF_NOT_OK(r[i]->to_json(&out_json1)); + std::stringstream json_ss1; + json_ss1 << out_json1; + EXPECT_EQ(json_ss.str(), json_ss1.str()); + } + RETURN_IF_NOT_OK(expect->GetNext(&row_expect)); + RETURN_IF_NOT_OK(root->GetNext(&row_root)); + } + EXPECT_EQ(row_expect.size(), row_root.size()); + return Status::OK(); + } +}; + +TensorRow VecToRow(const MSTensorVec &v); + +MSTensorVec RowToVec(const TensorRow &v); + +MSTensorVec Predicate1(MSTensorVec in); + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with Sampler in MappableSourceNode +/// Expectation: Skip node is pushed down and removed after optimization pass +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownMappableSourceNode) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownMappableSourceNode."; + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + + auto root = ImageFolder(folder_path, false, std::make_shared()); + auto root_target = ImageFolder(folder_path, false, std::make_shared(2)); + EXPECT_OK(prepare_trees(root, root_target, 2)); +} + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with Batch Operation +/// Expectation: Skip node is pushed down and removed after optimization pass +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownBatch) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownBatch."; + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + + std::shared_ptr root; + std::shared_ptr root_target; + + root = ImageFolder(folder_path, false, std::make_shared())->Batch(5)->Skip(5); + root_target = ImageFolder(folder_path, false, std::make_shared(25))->Batch(5); + EXPECT_OK(prepare_trees(root, root_target, 0)); + + root = ImageFolder(folder_path, false, std::make_shared())->Batch(5); + root_target = ImageFolder(folder_path, false, std::make_shared(25))->Batch(5); + EXPECT_OK(prepare_trees(root, root_target, 5)); +} + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with Rename Operation +/// Expectation: Skip node is pushed down and removed after optimization pass +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownRename) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownRename."; + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + + std::shared_ptr root; + std::shared_ptr root_target; + + root = + ImageFolder(folder_path, false, std::make_shared())->Rename({"label"}, {"fake_label"})->Skip(5); + root_target = + ImageFolder(folder_path, false, std::make_shared(5))->Rename({"label"}, {"fake_label"}); + EXPECT_OK(prepare_trees(root, root_target, 0)); + + root = ImageFolder(folder_path, false, std::make_shared())->Rename({"label"}, {"fake_label"}); + root_target = + ImageFolder(folder_path, false, std::make_shared(5))->Rename({"label"}, {"fake_label"}); + EXPECT_OK(prepare_trees(root, root_target, 5)); +} + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with Project Operation +/// Expectation: Skip node is pushed down and removed for Project after optimization pass +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownProject) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownProject."; + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + + std::shared_ptr root; + std::shared_ptr root_target; + + root = ImageFolder(folder_path, false, std::make_shared())->Project({"label", "image"})->Skip(10); + root_target = ImageFolder(folder_path, false, std::make_shared(10))->Project({"label", "image"}); + EXPECT_OK(prepare_trees(root, root_target, 0)); + + root = ImageFolder(folder_path, false, std::make_shared())->Project({"label", "image"}); + root_target = ImageFolder(folder_path, false, std::make_shared(10))->Project({"label", "image"}); + EXPECT_OK(prepare_trees(root, root_target, 10)); + + root = ImageFolder(folder_path, false, std::make_shared()) + ->Skip(1) + ->Project({"label", "image"}) + ->Skip(10); + root_target = ImageFolder(folder_path, false, std::make_shared(11))->Project({"label", "image"}); + EXPECT_OK(prepare_trees(root, root_target, 0)); +} + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with Concat Operation +/// Expectation: Skip node cannot be pushed down for Concat after optimization pass +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownConcat) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownConcat."; + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + + std::shared_ptr root; + std::shared_ptr root_target; + + std::vector> datasets = { + ImageFolder(folder_path, false, std::make_shared())}; + root = ImageFolder(folder_path, false, std::make_shared())->Concat(datasets)->Skip(10); + root_target = ImageFolder(folder_path, false, std::make_shared())->Concat(datasets)->Skip(10); + EXPECT_OK(prepare_trees(root, root_target, 0)); + + root = ImageFolder(folder_path, false, std::make_shared())->Concat(datasets); + root_target = ImageFolder(folder_path, false, std::make_shared())->Concat(datasets)->Skip(10); + EXPECT_OK(prepare_trees(root, root_target, 10)); +} + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with Zip Operation +/// Expectation: Skip node cannot be pushed down for Zip after optimization pass +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownZip) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownZip."; + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + + std::shared_ptr root; + std::shared_ptr root_target; + + std::vector> datasets = { + ImageFolder(folder_path, false, std::make_shared())->Project({"label"})}; + root = ImageFolder(folder_path, false, std::make_shared())->Project({"image"})->Zip(datasets); + root_target = + ImageFolder(folder_path, false, std::make_shared())->Project({"image"})->Zip(datasets)->Skip(10); + EXPECT_OK(prepare_trees(root, root_target, 10)); +} + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with Repeat Operation +/// Expectation: Skip operation cannot be pushed down for Repeat after optimization pass +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownRepeat) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownRepeat."; + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + + std::shared_ptr root; + std::shared_ptr root_target; + + root = ImageFolder(folder_path, false, std::make_shared())->Repeat(5); + root_target = ImageFolder(folder_path, false, std::make_shared())->Repeat(5)->Skip(11); + EXPECT_OK(prepare_trees(root, root_target, 11)); +} + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with Take Operation +/// Expectation: Skip operation cannot be pushed down for Take after optimization pass +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownTake) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownTake."; + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + + std::shared_ptr root; + std::shared_ptr root_target; + + root = ImageFolder(folder_path, false, std::make_shared())->Take(20); + root_target = ImageFolder(folder_path, false, std::make_shared())->Take(20)->Skip(10); + EXPECT_OK(prepare_trees(root, root_target, 10)); +} + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with Skip Operation +/// Expectation: Skip node cannot be pushed down after optimization pass +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownSkip) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownSkip."; + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + + std::shared_ptr root; + std::shared_ptr root_target; + + root = ImageFolder(folder_path, false, std::make_shared())->Skip(2)->Skip(3); + root_target = ImageFolder(folder_path, false, std::make_shared())->Skip(2)->Skip(3)->Skip(5); + EXPECT_OK(prepare_trees(root, root_target, 5)); +} + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with NonMappableSourceNode(CSV) +/// Expectation: Skip node is pushed down after optimization pass, but cannot be removed +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownNonMappableSourceNode) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownNonMappableSourceNode."; + std::string folder_path = datasets_root_path_ + "/testCSV/append.csv"; + + std::shared_ptr root; + std::shared_ptr root_target; + + std::vector column_names = {"col1", "col2", "col3", "col4"}; + root = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)->Rename({"col1"}, {"fake_label"}); + root_target = + CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)->Skip(1)->Rename({"col1"}, {"fake_label"}); + EXPECT_OK(prepare_trees(root, root_target, 1)); + + root = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)->Rename({"col1"}, {"fake_label"})->Skip(1); + root_target = + CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)->Skip(1)->Rename({"col1"}, {"fake_label"}); + EXPECT_OK(prepare_trees(root, root_target, 0)); +} + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with Batch and Rename Operations +/// Expectation: Skip node is pushed down after optimization pass +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownCombineOperations1) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownCombineOperations1."; + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + + std::shared_ptr root; + std::shared_ptr root_target; + + root = ImageFolder(folder_path, false, std::make_shared()) + ->Batch(5) + ->Skip(2) + ->Rename({"label"}, {"fake_label"}); + root_target = ImageFolder(folder_path, false, std::make_shared()) + ->Batch(5) + ->Skip(2) + ->Skip(2) + ->Rename({"label"}, {"fake_label"}); + EXPECT_OK(prepare_trees(root, root_target, 2)); +} + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with Batch and Skip Operations and Sampler +/// Expectation: Skip node is pushed down after optimization pass +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownCombineOperations2) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownCombineOperations2."; + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + + std::shared_ptr root; + std::shared_ptr root_target; + + root = ImageFolder(folder_path, false, std::make_shared())->Skip(2)->Batch(5); + root_target = ImageFolder(folder_path, false, std::make_shared())->Skip(2)->Skip(10)->Batch(5); + EXPECT_OK(prepare_trees(root, root_target, 2)); +} + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with Take and Project Operations +/// Expectation: Skip node is pushed down for Project but not for Take +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownCombineOperations3) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownCombineOperations3."; + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + + std::shared_ptr root; + std::shared_ptr root_target; + + root = ImageFolder(folder_path, false, std::make_shared())->Take(20)->Project({"label", "image"}); + root_target = ImageFolder(folder_path, false, std::make_shared()) + ->Take(20) + ->Skip(2) + ->Project({"label", "image"}); + EXPECT_OK(prepare_trees(root, root_target, 2)); +} + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with no Skip/ Skip(0) Operation +/// Expectation: Skip(0) shows the same result as no Skip operation +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownSkip0) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownSkip0."; + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + + std::shared_ptr root; + std::shared_ptr root_target; + + root = ImageFolder(folder_path, false, std::make_shared())->Project({"label", "image"})->Take(5); + root_target = + ImageFolder(folder_path, false, std::make_shared())->Project({"label", "image"})->Take(5); + EXPECT_OK(prepare_trees(root, root_target, 0)); + + root = ImageFolder(folder_path, false, std::make_shared())->Project({"label", "image"})->Skip(0); + root_target = ImageFolder(folder_path, false, std::make_shared())->Project({"label", "image"}); + EXPECT_OK(prepare_trees(root, root_target, 0)); + + root = ImageFolder(folder_path, false, std::make_shared()) + ->Skip(0) + ->Project({"label", "image"}) + ->Skip(0); + root_target = ImageFolder(folder_path, false, std::make_shared())->Project({"label", "image"}); + EXPECT_OK(prepare_trees(root, root_target, 0)); + + root = ImageFolder(folder_path, false, std::make_shared())->Skip(0)->Project({"label", "image"}); + root_target = + ImageFolder(folder_path, false, std::make_shared())->Skip(1)->Project({"label", "image"}); + EXPECT_OK(prepare_trees(root, root_target, 1)); + + root = ImageFolder(folder_path, false, std::make_shared()) + ->Skip(2) + ->Skip(1) + ->Project({"label", "image"}); + root_target = ImageFolder(folder_path, false, std::make_shared()) + ->Skip(2) + ->Skip(1) + ->Skip(1) + ->Project({"label", "image"}); + EXPECT_OK(prepare_trees(root, root_target, 1)); +} + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with Repeat Operation +/// Expectation: Skip operation cannot be pushed down for Repeat operation +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownRepeat2) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownRepeat2."; + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + + std::shared_ptr root; + std::shared_ptr root_target; + + root = ImageFolder(folder_path, false, std::make_shared())->Repeat(3)->Skip(1); + root_target = ImageFolder(folder_path, false, std::make_shared())->Repeat(3)->Skip(1); + EXPECT_OK(prepare_trees(root, root_target, 0)); + + root = ImageFolder(folder_path, false, std::make_shared())->Repeat(3); + root_target = ImageFolder(folder_path, false, std::make_shared())->Repeat(3)->Skip(50); + EXPECT_OK(prepare_trees(root, root_target, 50)); + + root = ImageFolder(folder_path, false, std::make_shared())->Skip(1)->Repeat(3); + root_target = ImageFolder(folder_path, false, std::make_shared())->Skip(1)->Repeat(3)->Skip(50); + EXPECT_OK(prepare_trees(root, root_target, 50)); +} + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with Take Operation +/// Expectation: Skip operation cannot be pushed down for Take operation +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownTake2) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownTake2."; + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + + std::shared_ptr root; + std::shared_ptr root_target; + + root = ImageFolder(folder_path, false, std::make_shared())->Skip(1)->Take(3); + root_target = ImageFolder(folder_path, false, std::make_shared())->Skip(1)->Take(3); + EXPECT_OK(prepare_trees(root, root_target, 0)); + + root = ImageFolder(folder_path, false, std::make_shared())->Take(3); + root_target = ImageFolder(folder_path, false, std::make_shared())->Take(3)->Skip(1); + EXPECT_OK(prepare_trees(root, root_target, 1)); + + root = ImageFolder(folder_path, false, std::make_shared())->Skip(1)->Take(5); + root_target = ImageFolder(folder_path, false, std::make_shared())->Skip(1)->Take(5)->Skip(1); + EXPECT_OK(prepare_trees(root, root_target, 1)); +} + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with Concat/Zip Operation +/// Expectation: Skip node cannot be removed for Concat/Zip after optimization pass +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownUnsupported) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownUnsupported."; + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + + std::shared_ptr root; + std::shared_ptr root_target; + + std::vector> datasets = { + ImageFolder(folder_path, false, std::make_shared())->Project({"label"})}; + root = ImageFolder(folder_path, false, std::make_shared())->Project({"label"})->Concat(datasets); + root_target = ImageFolder(folder_path, false, std::make_shared()) + ->Project({"label"}) + ->Concat(datasets) + ->Skip(2); + EXPECT_OK(prepare_trees(root, root_target, 2)); + + root = + ImageFolder(folder_path, false, std::make_shared())->Project({"image"})->Zip(datasets)->Skip(1); + root_target = + ImageFolder(folder_path, false, std::make_shared())->Project({"image"})->Zip(datasets)->Skip(1); + EXPECT_OK(prepare_trees(root, root_target, 0)); + + root = ImageFolder(folder_path, false, std::make_shared())->Project({"image"})->Zip(datasets); + root_target = + ImageFolder(folder_path, false, std::make_shared())->Project({"image"})->Zip(datasets)->Skip(1); + EXPECT_OK(prepare_trees(root, root_target, 1)); +} + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with Filter Operation +/// Expectation: Skip node cannot be pushed down for Filter after optimization pass +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownUnsupported_Filter) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownUnsupported_Filter."; + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + + std::shared_ptr root; + std::shared_ptr root_target; + + root = ImageFolder(folder_path, false, std::make_shared())->Skip(1)->Filter(Predicate1, {"label"}); + root_target = + ImageFolder(folder_path, false, std::make_shared())->Skip(1)->Filter(Predicate1, {"label"}); + EXPECT_OK(prepare_trees(root, root_target, 0)); + + root = ImageFolder(folder_path, false, std::make_shared())->Filter(Predicate1, {"label"})->Skip(1); + root_target = + ImageFolder(folder_path, false, std::make_shared())->Filter(Predicate1, {"label"})->Skip(1); + EXPECT_OK(prepare_trees(root, root_target, 0)); + + root = ImageFolder(folder_path, false, std::make_shared())->Skip(1)->Filter(Predicate1, {"label"}); + root_target = ImageFolder(folder_path, false, std::make_shared()) + ->Skip(1) + ->Filter(Predicate1, {"label"}) + ->Skip(1); + EXPECT_OK(prepare_trees(root, root_target, 1)); +} + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with SubsetSampler as child +/// Expectation: Skip node is removed for Rename/Project/Map after optimization pass +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownSubsetSampler) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownSubsetSampler."; + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + + std::shared_ptr root; + std::shared_ptr root_target; + std::vector indices = {0, 1, 2, 3, 4, 5}; + root = ImageFolder(folder_path, false, std::make_shared(indices, 3))->Skip(1); + auto sampler = std::make_shared(1); + sampler->AddChild(std::make_shared(indices, 3)); + root_target = ImageFolder(folder_path, false, sampler); + EXPECT_OK(prepare_trees(root, root_target, 0)); + + root = ImageFolder(folder_path, false, std::make_shared(indices, 4)) + ->Rename({"label"}, {"fake_label"}); + sampler = std::make_shared(1); + sampler->AddChild(std::make_shared(indices, 4)); + root_target = ImageFolder(folder_path, false, sampler)->Rename({"label"}, {"fake_label"}); + EXPECT_OK(prepare_trees(root, root_target, 1)); + + root = + ImageFolder(folder_path, false, std::make_shared(indices, 10))->Project({"label", "image"}); + sampler = std::make_shared(1); + sampler->AddChild(std::make_shared(indices, 10)); + root_target = ImageFolder(folder_path, false, sampler)->Project({"label", "image"}); + EXPECT_OK(prepare_trees(root, root_target, 1)); + + std::vector> transforms; + std::vector size = {80, 80}; + std::vector ignore = {20, 20, 20, 20}; + std::shared_ptr operation1 = std::make_shared(0.5, ignore); + std::shared_ptr operation2 = std::make_shared(size); + transforms.push_back(operation1); + transforms.push_back(operation2); + root = ImageFolder(folder_path, true, std::make_shared(indices, 3))->Map(transforms); + sampler = std::make_shared(1); + sampler->AddChild(std::make_shared(indices, 3)); + root_target = ImageFolder(folder_path, true, sampler)->Map(transforms); + EXPECT_OK(prepare_trees(root, root_target, 1)); +} + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with Batch Operation +/// Expectation: Skip node is pushed down for Batch after optimization pass +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownBatch2) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownBatch2."; + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + + std::shared_ptr root; + std::shared_ptr root_target; + + root = ImageFolder(folder_path, false, std::make_shared())->Skip(1)->Batch(5); + root_target = ImageFolder(folder_path, false, std::make_shared())->Skip(1)->Batch(5); + EXPECT_OK(prepare_trees(root, root_target, 0)); + + root = ImageFolder(folder_path, false, std::make_shared())->Skip(3)->Batch(5); + root_target = ImageFolder(folder_path, false, std::make_shared())->Skip(3)->Skip(15)->Batch(5); + EXPECT_OK(prepare_trees(root, root_target, 3)); + + root = ImageFolder(folder_path, false, std::make_shared())->Batch(5); + root_target = ImageFolder(folder_path, false, std::make_shared(15))->Batch(5); + EXPECT_OK(prepare_trees(root, root_target, 3)); +} + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with SkipPushdownNonMappableSourceNode(CSV) +/// Expectation: Skip node is pushed down for Rename/Batch/Project after optimization pass, but cannot be removed for +/// NonMappableSourceNode(CSV) +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownNonMappableSourceNode2) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownNonMappableSourceNode2."; + std::string folder_path = datasets_root_path_ + "/testCSV/append.csv"; + std::vector column_names = {"col1", "col2", "col3", "col4"}; + + std::shared_ptr root; + std::shared_ptr root_target; + + root = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)->Rename({"col1"}, {"fake_label"}); + root_target = + CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)->Skip(1)->Rename({"col1"}, {"fake_label"}); + EXPECT_OK(prepare_trees(root, root_target, 1)); + + root = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse)->Skip(1)->Rename({"col1"}, {"fake_label"}); + root_target = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse) + ->Skip(1) + ->Skip(1) + ->Rename({"col1"}, {"fake_label"}); + EXPECT_OK(prepare_trees(root, root_target, 1)); + + root = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse) + ->Repeat(5) + ->Skip(1) + ->Batch(2) + ->Project({"col1", "col2"}) + ->Rename({"col1"}, {"fake_label"}); + root_target = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse) + ->Repeat(5) + ->Skip(1) + ->Skip(2) + ->Batch(2) + ->Project({"col1", "col2"}) + ->Rename({"col1"}, {"fake_label"}); + EXPECT_OK(prepare_trees(root, root_target, 1)); +} + +/// Feature: MindData Skip Pushdown Optimization Pass Test +/// Description: Test MindData Skip Pushdown Optimization Pass with combined Operations +/// Expectation: Skip node is removed/reduced for Rename/Batch after optimization pass +TEST_F(MindDataSkipPushdownTestOptimizationPass, SkipPushdownCombineOperations4) { + MS_LOG(INFO) << "Doing MindDataSkipPushdownTestOptimizationPass-SkipPushdownCombineOperations4."; + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + + std::shared_ptr root; + std::shared_ptr root_target; + + std::vector indices = {0, 1, 2, 3, 4, 5}; + root = ImageFolder(folder_path, false, std::make_shared(indices, 6)) + ->Repeat(4) + ->Skip(4) + ->Take(10) + ->Skip(2) + ->Take(2) + ->Rename({"label"}, {"fake_label"}); + auto sampler = std::make_shared(1); + sampler->AddChild(std::make_shared(indices, 6)); + root_target = ImageFolder(folder_path, false, sampler) + ->Repeat(4) + ->Skip(4) + ->Take(10) + ->Skip(2) + ->Take(2) + ->Skip(1) + ->Rename({"label"}, {"fake_label"}); + EXPECT_OK(prepare_trees(root, root_target, 1)); + + root = ImageFolder(folder_path, false, std::make_shared(indices, 5)) + ->Repeat(8) + ->Skip(4) + ->Skip(2) + ->Take(10) + ->Rename({"label"}, {"fake_label"}) + ->Batch(4); + sampler = std::make_shared(1); + sampler->AddChild(std::make_shared(indices, 5)); + root_target = ImageFolder(folder_path, false, std::make_shared(indices, 5)) + ->Repeat(8) + ->Skip(4) + ->Skip(2) + ->Take(10) + ->Skip(8) + ->Rename({"label"}, {"fake_label"}) + ->Batch(4); + EXPECT_OK(prepare_trees(root, root_target, 2)); + + folder_path = datasets_root_path_ + "/testCSV/append.csv"; + std::vector column_names = {"col1", "col2", "col3", "col4"}; + + root = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse) + ->Skip(1) + ->Repeat(4) + ->Batch(3) + ->Skip(1) + ->Rename({"col1"}, {"fake_label"}); + root_target = CSV({folder_path}, ',', {}, column_names, 0, ShuffleMode::kFalse) + ->Skip(1) + ->Repeat(4) + ->Skip(3) + ->Batch(3) + ->Rename({"col1"}, {"fake_label"}); + EXPECT_OK(prepare_trees(root, root_target, 0)); +}