From ff5999fc2f24dae663bc75fb803dadeefb89777c Mon Sep 17 00:00:00 2001 From: Zirui Wu Date: Mon, 16 Nov 2020 16:16:56 -0500 Subject: [PATCH] add removal pass for getters fix CI round I fix ci round II address review cmts fix ci round II --- .../dataset/callback/callback_manager.h | 3 + .../dataset/engine/datasetops/dataset_op.h | 3 + .../minddata/dataset/engine/execution_tree.cc | 19 ++- .../minddata/dataset/engine/execution_tree.h | 9 +- .../dataset/engine/opt/CMakeLists.txt | 3 +- .../dataset/engine/opt/pre/getter_pass.cc | 87 +++++++++++ .../dataset/engine/opt/pre/getter_pass.h | 76 +++++++++ .../minddata/dataset/engine/tree_adapter.cc | 2 +- tests/ut/cpp/dataset/CMakeLists.txt | 145 +++++++++--------- .../ut/cpp/dataset/optimization_pass_test.cc | 137 +++++++++++++++++ 10 files changed, 404 insertions(+), 80 deletions(-) create mode 100644 mindspore/ccsrc/minddata/dataset/engine/opt/pre/getter_pass.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/opt/pre/getter_pass.h create mode 100644 tests/ut/cpp/dataset/optimization_pass_test.cc diff --git a/mindspore/ccsrc/minddata/dataset/callback/callback_manager.h b/mindspore/ccsrc/minddata/dataset/callback/callback_manager.h index 0fedb97115c..5f36dad1cc6 100644 --- a/mindspore/ccsrc/minddata/dataset/callback/callback_manager.h +++ b/mindspore/ccsrc/minddata/dataset/callback/callback_manager.h @@ -41,6 +41,9 @@ class CallbackManager { /// \param [in] callbacks list of callbacks to perform void AddCallbacks(std::vector> callbacks); + /// \brief set callbacks to empty + void ClearCallbacks() { callbacks_.clear(); } + /// \brief DatasetOp needs to call Init if it wishes to use callback, Init will set enabled_ to true /// \param[in] op, this pointer is used for Callback Manager to Pause Worker threads /// \return Status diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h index d310b580218..4aa2118c4e1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h @@ -393,6 +393,9 @@ class DatasetOp : public std::enable_shared_from_this { /// \brief Add callback to DatasetOp, only MapOp supports Callback at the moment void AddCallbacks(std::vector> callbacks) { callback_manager_.AddCallbacks(callbacks); } + /// \brief Remove all callbacks from DatasetOp + void ClearCallbacks() { callback_manager_.ClearCallbacks(); } + protected: /// \brief Removes a parent operator from this operator /// \notes External callers do not have access to this function diff --git a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc index 6819dce3893..e3519baf17e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc @@ -16,6 +16,7 @@ #include "minddata/dataset/engine/execution_tree.h" #include #include +#include #include "minddata/dataset/engine/datasetops/dataset_op.h" #include "minddata/dataset/engine/datasetops/shuffle_op.h" #include "minddata/dataset/engine/datasetops/device_queue_op.h" @@ -35,7 +36,7 @@ namespace mindspore { namespace dataset { // Constructor -ExecutionTree::ExecutionTree() : id_count_(0) { +ExecutionTree::ExecutionTree() : id_count_(0), pre_pass_override_(nullptr) { tg_ = std::make_unique(); tree_state_ = kDeTStateInit; prepare_flags_ = kDePrepNone; @@ -234,7 +235,6 @@ Status ExecutionTree::PrepareTreePreAction() { bool modified = false; std::vector> pre_actions; // Construct pre actions - MS_LOG(INFO) << "Running pre pass loops."; #ifndef ENABLE_ANDROID pre_actions.push_back(std::make_unique()); #endif @@ -243,6 +243,17 @@ Status ExecutionTree::PrepareTreePreAction() { #ifndef ENABLE_ANDROID pre_actions.push_back(std::make_unique()); #endif + + // this offers a way to override the preset optimization pass with customized ones + // this is used when certain nodes are removed for tree getters + if (pre_pass_override_) { + MS_LOG(INFO) << "Default pre optimization passes is being overridden," + << " number of passes before the override:" << pre_actions.size() << "."; + pre_actions = pre_pass_override_(std::move(pre_actions)); + } + + MS_LOG(INFO) << "Running " << pre_actions.size() << " pre pass loops."; + // Apply pre action passes for (auto &pass : pre_actions) { RETURN_IF_NOT_OK(pass->Run(this, &modified)); @@ -256,7 +267,7 @@ Status ExecutionTree::PrepareTreePostAction() { tree_state_ = kDeTStatePrepare; bool modified = false; - std::vector> post_actions; + OptPass post_actions; // Construct pre actions MS_LOG(INFO) << "Running post pass loops."; #ifndef ENABLE_ANDROID @@ -274,7 +285,7 @@ Status ExecutionTree::PrepareTreePostAction() { Status ExecutionTree::Optimize() { // Vector of optimizations, currently only 1, add more as necessary - std::vector> optimizations; + OptPass optimizations; #ifndef ENABLE_ANDROID optimizations.push_back(std::make_unique()); #endif diff --git a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h index ed58b79a846..aaa47279613 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h +++ b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h @@ -24,13 +24,13 @@ #include "minddata/dataset/engine/datasetops/dataset_op.h" #include "minddata/dataset/util/status.h" #include "mindspore/ccsrc/minddata/dataset/engine/perf/profiling.h" - namespace mindspore { namespace dataset { // Forward declares class TaskGroup; class DatasetOp; - +class Pass; +using OptPass = std::vector>; class ExecutionTree { public: // Prepare flags used during tree prepare phase @@ -253,6 +253,10 @@ class ExecutionTree { // @return total number of epochs int32_t num_epochs() { return num_epochs_; } + // set the function ptr that overrides the pre-pass which allows caller to adjust the existing pre_pass and + // introduce new passes. E.g. caller can override the num_epoch in EpochInjectionPass + void SetPrePassOverride(std::function pre_pass_override) { pre_pass_override_ = pre_pass_override; } + private: // A helper functions for doing the recursive printing // @param dataset_op - The dataset op to print @@ -270,6 +274,7 @@ class ExecutionTree { int32_t num_epochs_; // Total number of epochs to run for this tree std::unique_ptr profiling_manager_; // Profiling manager bool optimize_; // Flag to enable optional optimizations + std::function pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare() }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt index c8c591a4968..7ad4da248f8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt @@ -1,13 +1,14 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) add_library(engine-opt OBJECT + optional/tensor_op_fusion_pass.cc pass.cc post/repeat_pass.cc pre/cache_error_pass.cc pre/cache_transform_pass.cc pre/epoch_injection_pass.cc + pre/getter_pass.cc pre/input_validation_pass.cc pre/removal_pass.cc - optional/tensor_op_fusion_pass.cc util/printer_pass.cc ) diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/getter_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/getter_pass.cc new file mode 100644 index 00000000000..c5bab40f952 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/getter_pass.cc @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/opt/pre/getter_pass.h" +#include "minddata/dataset/engine/execution_tree.h" + +namespace mindspore { +namespace dataset { +Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr node, bool *modified) { + nodes_to_remove_.push_back(node); + return Status::OK(); +} + +Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr node, bool *modified) { + if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node); + return Status::OK(); +} + +Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr node, bool *modified) { + if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node); + return Status::OK(); +} + +Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr node, bool *modified) { + if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node); + return Status::OK(); +} + +Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr node, bool *modified) { + if (type_ == kOutputShapeAndType) { + nodes_to_clear_callback_.push_back(node); + } else if (type_ == kDatasetSize) { + nodes_to_remove_.push_back(node); + } + return Status::OK(); +} + +Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr node, bool *modified) { + if (type_ == kDatasetSize) nodes_to_remove_.push_back(node); + return Status::OK(); +} + +Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr node, bool *modified) { + if (type_ == kDatasetSize) nodes_to_remove_.push_back(node); + return Status::OK(); +} + +Status GetterPass::GetterNodes::PreRunOnNode(std::shared_ptr node, bool *modified) { + if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node); + return Status::OK(); +} + +#ifdef ENABLE_PYTHON +Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr node, bool *modified) { + if (type_ == kOutputShapeAndType) nodes_to_remove_.push_back(node); + return Status::OK(); +} +#endif + +Status GetterPass::RunOnTree(ExecutionTree *tree, bool *modified) { + RETURN_IF_NOT_OK(pass_.Run(tree, modified)); + + // nested private class variables can be directly accessed by its outer class + for (auto node : pass_.nodes_to_remove_) { + RETURN_IF_NOT_OK(node->Remove()); + } + + // clear the callback for selected ops (map when its GetOutputType/Shape) + for (auto node : pass_.nodes_to_clear_callback_) node->ClearCallbacks(); + + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/getter_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/getter_pass.h new file mode 100644 index 00000000000..ce4d3ea1904 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/getter_pass.h @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_GETTER_PASS_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_GETTER_PASS_H_ + +#include +#include +#include "minddata/dataset/engine/datasetops/dataset_op.h" +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +class DatasetOp; + +/// \class GetterPass +/// \brief This is a tree pass that will remove nodes or clears the callback in MapOp +class GetterPass : public TreePass { + public: + enum GetterType { kDatasetSize = 1, kOutputShapeAndType = 2 }; + /// \brief Constructor + explicit GetterPass(GetterType tp) : pass_(tp) {} + /// \brief Destructor + ~GetterPass() = default; + + Status RunOnTree(ExecutionTree *tree, bool *modified) override; + + private: + /// \class GetterNodes, this is a nested class which is owned via composition by the outter class to identify nodes + /// \brief This is a NodePass who's job is to identify which nodes should be removed. + class GetterNodes : public NodePass { + public: + /// \brief Constructor + explicit GetterNodes(GetterType tp) : type_(tp) {} + + ~GetterNodes() = default; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + Status RunOnNode(std::shared_ptr node, bool *modified) override; + Status RunOnNode(std::shared_ptr node, bool *modified) override; + Status RunOnNode(std::shared_ptr node, bool *modified) override; + Status RunOnNode(std::shared_ptr node, bool *modified) override; + Status RunOnNode(std::shared_ptr node, bool *modified) override; + Status RunOnNode(std::shared_ptr node, bool *modified) override; + // whether this is Run or PreRun does not matter here, however, Only Accept() is defined in ConcatOp + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + +#ifdef ENABLE_PYTHON + Status RunOnNode(std::shared_ptr node, bool *modified) override; +#endif + + GetterType type_; + std::list> nodes_to_clear_callback_; + std::list> nodes_to_remove_; + }; + // outter class needs only to own the inner class object since it automatically has access to its private variables + GetterNodes pass_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_GETTER_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc index 04898d61e33..c94be7e2c7c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc @@ -95,7 +95,7 @@ Status TreeAdapter::PostPass(std::shared_ptr ir) { } Status TreeAdapter::BuildExecutionTree(std::shared_ptr ir, std::shared_ptr *op) { - // Build the DatasetOp ExecutionTree from the optmized IR tree + // Build the DatasetOp ExecutionTree from the optimized IR tree std::vector> ops = ir->Build(); CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "Unable to build node."); diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index 051bea37882..57be1f4457e 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -1,57 +1,98 @@ include(GoogleTest) SET(DE_UT_SRCS - common/common.cc - common/cvop_common.cc - common/bboxop_common.cc - auto_contrast_op_test.cc album_op_test.cc + arena_test.cc + auto_contrast_op_test.cc batch_op_test.cc bit_functions_test.cc - storage_container_test.cc - treap_test.cc - interrupt_test.cc - image_folder_op_test.cc - buddy_test.cc bounding_box_augment_op_test.cc - arena_test.cc btree_test.cc + buddy_test.cc + build_vocab_test.cc + c_api_cache_test.cc + c_api_dataset_album_test.cc + c_api_dataset_cifar_test.cc + c_api_dataset_clue_test.cc + c_api_dataset_coco_test.cc + c_api_dataset_config_test.cc + c_api_dataset_csv_test.cc + c_api_dataset_iterator_test.cc + c_api_dataset_manifest_test.cc + c_api_dataset_minddata_test.cc + c_api_dataset_ops_test.cc + c_api_dataset_randomdata_test.cc + c_api_dataset_save.cc + c_api_dataset_textfile_test.cc + c_api_dataset_tfrecord_test.cc + c_api_dataset_voc_test.cc + c_api_datasets_test.cc + c_api_samplers_test.cc + c_api_text_sentence_piece_vocab_test.cc + c_api_text_vocab_test.cc + c_api_transforms_test.cc + c_api_vision_test.cc callback_test.cc + celeba_op_test.cc center_crop_op_test.cc channel_swap_test.cc + cifar_op_test.cc circular_pool_test.cc client_config_test.cc + clue_op_test.cc + coco_op_test.cc + common/bboxop_common.cc + common/common.cc + common/cvop_common.cc + concat_op_test.cc + concatenate_op_test.cc connector_test.cc - cutmix_batch_op_test.cc + csv_op_test.cc cut_out_op_test.cc + cutmix_batch_op_test.cc + cyclic_array_test.cc + data_helper_test.cc datatype_test.cc decode_op_test.cc + distributed_sampler_test.cc + epoch_ctrl_op_test.cc equalize_op_test.cc execution_tree_test.cc + fill_op_test.cc global_context_test.cc + gnn_graph_test.cc + image_folder_op_test.cc + image_process_test.cc + interrupt_test.cc + jieba_tokenizer_op_test.cc main_test.cc map_op_test.cc + mask_test.cc + memory_pool_test.cc mind_record_op_test.cc mixup_batch_op_test.cc - memory_pool_test.cc + mnist_op_test.cc normalize_op_test.cc one_hot_op_test.cc + optimization_pass_test.cc pad_end_op_test.cc pad_op_test.cc path_test.cc + perf_data_test.cc project_op_test.cc queue_test.cc random_affine_op_test.cc + random_color_adjust_op_test.cc random_color_op_test.cc - random_crop_op_test.cc - random_crop_with_bbox_op_test.cc - random_crop_decode_resize_op_test.cc random_crop_and_resize_op_test.cc random_crop_and_resize_with_bbox_op_test.cc - random_color_adjust_op_test.cc + random_crop_decode_resize_op_test.cc + random_crop_op_test.cc + random_crop_with_bbox_op_test.cc random_horizontal_flip_op_test.cc random_horizontal_flip_with_bbox_test.cc random_resize_op_test.cc + random_resize_op_test.cc random_resize_with_bbox_op_test.cc random_rotation_op_test.cc random_solarize_op_test.cc @@ -65,74 +106,34 @@ SET(DE_UT_SRCS rgba_to_bgr_op_test.cc rgba_to_rgb_op_test.cc schema_test.cc - skip_op_test.cc + sentence_piece_vocab_op_test.cc shuffle_op_test.cc + skip_op_test.cc + slice_op_test.cc + sliding_window_op_test.cc + solarize_op_test.cc stand_alone_samplers_test.cc status_test.cc + storage_container_test.cc + subset_random_sampler_test.cc + swap_red_blue_test.cc + take_op_test.cc task_manager_test.cc + tensor_op_fusion_pass_test.cc tensor_row_test.cc tensor_string_test.cc tensor_test.cc tensorshape_test.cc + text_file_op_test.cc tfReader_op_test.cc to_float16_op_test.cc - tree_adapter_test.cc - type_cast_op_test.cc - zip_op_test.cc - random_resize_op_test.cc - subset_random_sampler_test.cc - weighted_random_sampler_test.cc - mnist_op_test.cc - cifar_op_test.cc - celeba_op_test.cc - take_op_test.cc - clue_op_test.cc - csv_op_test.cc - text_file_op_test.cc - concat_op_test.cc - jieba_tokenizer_op_test.cc tokenizer_op_test.cc - gnn_graph_test.cc - coco_op_test.cc - fill_op_test.cc - mask_test.cc + treap_test.cc + tree_adapter_test.cc trucate_pair_test.cc - concatenate_op_test.cc - cyclic_array_test.cc - perf_data_test.cc - build_vocab_test.cc - c_api_samplers_test.cc - c_api_transforms_test.cc - c_api_vision_test.cc - c_api_dataset_ops_test.cc - c_api_dataset_album_test.cc - c_api_dataset_cifar_test.cc - c_api_dataset_clue_test.cc - c_api_dataset_coco_test.cc - c_api_dataset_config_test.cc - c_api_dataset_csv_test.cc - c_api_dataset_manifest_test.cc - c_api_dataset_minddata_test.cc - c_api_dataset_randomdata_test.cc - c_api_dataset_save.cc - c_api_dataset_textfile_test.cc - c_api_dataset_tfrecord_test.cc - c_api_dataset_voc_test.cc - c_api_datasets_test.cc - c_api_dataset_iterator_test.cc - c_api_text_sentence_piece_vocab_test.cc - c_api_text_vocab_test.cc - c_api_cache_test.cc - tensor_op_fusion_pass_test.cc - sliding_window_op_test.cc - epoch_ctrl_op_test.cc - sentence_piece_vocab_op_test.cc - solarize_op_test.cc - swap_red_blue_test.cc - distributed_sampler_test.cc - data_helper_test.cc - image_process_test.cc - slice_op_test.cc + type_cast_op_test.cc + weighted_random_sampler_test.cc + zip_op_test.cc ) if (ENABLE_PYTHON) diff --git a/tests/ut/cpp/dataset/optimization_pass_test.cc b/tests/ut/cpp/dataset/optimization_pass_test.cc new file mode 100644 index 00000000000..50b3e31190d --- /dev/null +++ b/tests/ut/cpp/dataset/optimization_pass_test.cc @@ -0,0 +1,137 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "minddata/dataset/core/client.h" +#include "common/common.h" +#include "gtest/gtest.h" + +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" +#include "minddata/dataset/engine/opt/pre/getter_pass.h" + +using namespace mindspore::dataset; +using mindspore::LogStream; +using mindspore::MsLogLevel::INFO; + +class MindDataTestOptimizationPass : public UT::DatasetOpTesting { + public: + MindDataTestOptimizationPass() = default; + void SetUp() override { GlobalInit(); } + + // this recursive function helps build a ExecutionTree from a IR node, it is copied from TreeAdapter + Status DFSBuild(std::shared_ptr ir, std::shared_ptr *op, ExecutionTree *tree) { + std::vector> ops = ir->Build(); + CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty() && tree != nullptr && op != nullptr, "Fail To Build Tree."); + (*op) = ops.front(); + RETURN_IF_NOT_OK(tree->AssociateNode(*op)); + for (size_t i = 1; i < ops.size(); i++) { + RETURN_IF_NOT_OK(tree->AssociateNode(ops[i])); + RETURN_IF_NOT_OK(ops[i - 1]->AddChild(ops[i])); + } + for (std::shared_ptr child_ir : ir->Children()) { + std::shared_ptr child_op; + RETURN_IF_NOT_OK(DFSBuild(child_ir, &child_op, tree)); + RETURN_IF_NOT_OK(ops.back()->AddChild(child_op)); // append children to the last of ops + } + return Status::OK(); + } + + // this function will build an execution_tree from a root ir node. nullptr will be returned if error occurs + std::unique_ptr BuildTree(std::shared_ptr ir) { + std::unique_ptr tree = std::make_unique(); + std::shared_ptr root; + if (DFSBuild(ir, &root, tree.get()).IsError()) return nullptr; + if (tree->AssignRoot(root).IsError()) return nullptr; + return tree; + } +}; + +TEST_F(MindDataTestOptimizationPass, MindDataTestOutputShapeAndTypePass) { + MS_LOG(INFO) << "Doing MindDataTestOptimizationPass-MindDataTestOutputShapeAndTypePass."; + // config leaf_op, use random_data to avoid I/O + std::shared_ptr schema = std::make_shared(); + ASSERT_TRUE(schema->add_column("label", "uint32", {})); + std::shared_ptr ds = RandomData(44, schema)->Repeat(2)->Project({"label"})->Shuffle(10)->Batch(2); + + std::unique_ptr exe_tree = BuildTree(ds->IRNode()); + + ASSERT_NE(exe_tree, nullptr); + + // test the optimization pass + // OptPass is supposed to remove concat, filter repeat, shuffle skip, take and set the callback of map to empty + std::function pass = [](OptPass pre) { + // return a new pass, this will override all the existing pre-pass es + pre.clear(); + pre.push_back(std::make_unique(GetterPass::kOutputShapeAndType)); + return pre; + }; + + exe_tree->SetPrePassOverride(pass); + ASSERT_OK(exe_tree->PrepareTreePreAction()); + std::stringstream ss; + + // print the tree in std::string as a way to verify that nodes are indeed removed + exe_tree->Print(ss); + std::string ss_str = ss.str(); + + // ss_str would look like this + // +- ( 0) : [workers: 4] [batch size: 2] + // +- ( 2) : [workers: 0 (inlined)] + // +- ( 4) : [workers: 4] [total rows: 44] + // + + // verify that Shuffle and RepeatOp are removed, but Batch and ProjectOp are not + EXPECT_EQ(ss_str.find("ShuffleOp"), ss_str.npos); + EXPECT_EQ(ss_str.find("RepeatOp"), ss_str.npos); + EXPECT_NE(ss_str.find("ProjectOp"), ss_str.npos); + EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos); +} + +TEST_F(MindDataTestOptimizationPass, MindDataTestDatasetSizePass) { + MS_LOG(INFO) << "Doing MindDataTestOptimizationPass-MindDataTestDatasetSizePass."; + // config leaf_op, use random_data to avoid I/O + std::shared_ptr schema = std::make_shared(); + ASSERT_TRUE(schema->add_column("label", "uint32", {})); + std::shared_ptr ds = RandomData(44, schema)->Repeat(2)->Project({"label"})->Shuffle(10)->Batch(2); + + std::unique_ptr exe_tree = BuildTree(ds->IRNode()); + + ASSERT_NE(exe_tree, nullptr); + + // test the optimization pass + // OptPass is supposed to remove concat, filter repeat, shuffle skip, take and set the callback of map to empty + std::function pass = [](OptPass pre) { + // return a new pass, this will override all the existing pre-pass es + pre.clear(); // remove all existing pre pass + pre.push_back(std::make_unique(GetterPass::kDatasetSize)); + return pre; + }; + + exe_tree->SetPrePassOverride(pass); + ASSERT_OK(exe_tree->PrepareTreePreAction()); + std::stringstream ss; + // print the tree in std::string as a way to verify that nodes are indeed removed + exe_tree->Print(ss); + std::string ss_str = ss.str(); + + // verify that Shuffle and RepeatOp are removed, but Batch and ProjectOp are not + EXPECT_EQ(ss_str.find("ShuffleOp"), ss_str.npos); + EXPECT_NE(ss_str.find("RepeatOp"), ss_str.npos); + EXPECT_EQ(ss_str.find("ProjectOp"), ss_str.npos); + EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos); +}