From 9fb1904ed81bab64c5ca2297913e34d4da3b4fed Mon Sep 17 00:00:00 2001 From: Nat Sutyanyong Date: Tue, 21 Jul 2020 22:55:18 -0400 Subject: [PATCH] Refactoring opt/pre --- .../minddata/dataset/engine/execution_tree.cc | 4 +- .../dataset/engine/opt/CMakeLists.txt | 4 +- .../dataset/engine/opt/pre/cache_pass.cc | 181 ------------------ .../dataset/engine/opt/pre/cache_pass.h | 141 -------------- .../engine/opt/pre/cache_transform_pass.cc | 174 ++++++++++++++++- .../engine/opt/pre/cache_transform_pass.h | 125 +++++++++++- ...ection_pass.cc => epoch_injection_pass.cc} | 61 +++--- ...njection_pass.h => epoch_injection_pass.h} | 31 +-- .../dataset/engine/opt/pre/removal_nodes.cc | 58 ------ .../dataset/engine/opt/pre/removal_nodes.h | 64 ------- .../dataset/engine/opt/pre/removal_pass.cc | 40 +++- .../dataset/engine/opt/pre/removal_pass.h | 46 ++++- .../dataset/test_minddataset_exception.py | 10 +- 13 files changed, 407 insertions(+), 532 deletions(-) delete mode 100644 mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.cc delete mode 100644 mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.h rename mindspore/ccsrc/minddata/dataset/engine/opt/pre/{injection_pass.cc => epoch_injection_pass.cc} (54%) rename mindspore/ccsrc/minddata/dataset/engine/opt/pre/{injection_pass.h => epoch_injection_pass.h} (75%) delete mode 100644 mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.cc delete mode 100644 mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.h diff --git a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc index 16039012f97..79c954595ec 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc @@ -23,7 +23,7 @@ #include "minddata/dataset/engine/opt/pre/removal_pass.h" #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" #include "minddata/dataset/engine/opt/post/repeat_pass.h" -#include "minddata/dataset/engine/opt/pre/injection_pass.h" +#include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h" #include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" #include "minddata/dataset/engine/perf/profiling.h" #include "minddata/dataset/engine/perf/monitor.h" @@ -225,7 +225,7 @@ Status ExecutionTree::PrepareTreePreAction() { std::vector> pre_actions; // Construct pre actions MS_LOG(INFO) << "Running pre pass loops."; - pre_actions.push_back(std::make_unique()); + pre_actions.push_back(std::make_unique()); pre_actions.push_back(std::make_unique()); pre_actions.push_back(std::make_unique()); // Apply pre action passes diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt index c7f861b75ff..50346ffad81 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt @@ -3,10 +3,8 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE add_library(engine-opt OBJECT pass.cc post/repeat_pass.cc - pre/cache_pass.cc pre/cache_transform_pass.cc - pre/injection_pass.cc - pre/removal_nodes.cc + pre/epoch_injection_pass.cc pre/removal_pass.cc optional/tensor_op_fusion_pass.cc util/printer_pass.cc diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.cc deleted file mode 100644 index f4453f25b1a..00000000000 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.cc +++ /dev/null @@ -1,181 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "minddata/dataset/engine/opt/pre/cache_pass.h" -#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" -#include "minddata/dataset/engine/datasetops/cache_op.h" -#include "minddata/dataset/engine/datasetops/source/celeba_op.h" -#include "minddata/dataset/engine/datasetops/source/generator_op.h" -#include "minddata/dataset/engine/datasetops/source/manifest_op.h" -#include "minddata/dataset/engine/datasetops/source/mnist_op.h" -#include "minddata/dataset/engine/datasetops/source/voc_op.h" -#include "minddata/dataset/engine/datasetops/source/cifar_op.h" -#include "minddata/dataset/engine/datasetops/source/coco_op.h" -#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" -#include "minddata/dataset/engine/datasetops/source/random_data_op.h" -#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" -#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" - -namespace mindspore { -namespace dataset { - -// Constructor -CachePass::CachePass(CacheTransformPass *transform_pass) - : transform_pass_(transform_pass), is_caching_(false), leaf_op_(nullptr) {} - -// Identifies the subtree below this node as a cached descendant tree. -Status CachePass::PreRunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; - if (is_caching_) { - RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!"); - } - is_caching_ = true; - return Status::OK(); -} - -// Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache -// transformation -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - is_caching_ = false; // We a no longer in a cache subtree. clear the flag. - if (leaf_op_) { - MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache."; - // Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op, - // using base class pointers. - transform_pass_->AddMappableCacheOperators(std::move(leaf_op_), node); - } else { - // If there was no leaf_op set, then this is a non-mappable scenario. - - if (sampler_) { - // Grab the sampler that was saved from the leaf and plug it into the cache op - node->SetSampler(std::move(sampler_)); - MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf."; - } else { - // We're a cache op but no sampler was saved from leaf, so create a default sampler - const int64_t num_samples = 0; - const int64_t start_index = 0; - sampler_ = std::make_shared(num_samples, start_index); - node->SetSampler(std::move(sampler_)); - MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op."; - } - - // Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache - uint32_t cache_crc = DatasetOp::GenerateCRC(node); - RETURN_IF_NOT_OK(node->CreateCache(cache_crc)); - } - - return Status::OK(); -} - -// Common code for mappable leaf setup. -Status CachePass::MappableCacheLeafSetup(std::shared_ptr leaf_op) { - // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. - if (is_caching_ && leaf_op_) { - RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); - } - - // If we are a leaf in the caching path, then save this leaf. - if (is_caching_) { - MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected"; - leaf_op_ = std::move(leaf_op); - } - return Status::OK(); -} - -// Common code for non mappable leaf setup. -Status CachePass::NonMappableCacheLeafSetup(std::shared_ptr leaf_op) { - // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. - if (is_caching_ && leaf_op_) { - RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); - } - - // Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf - // as save it for use by cache op in ascendant tree. - if (is_caching_) { - RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_)); - MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected"; - } else { - // If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can - // remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based) - std::shared_ptr sampler_from_leaf; - RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf)); - } - return Status::OK(); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - if (is_caching_) { - // If we are a TF Reader in a caching tree, then change our config so that it becomes a basic - // TF reader that parses all files. Selection of data will come from the sampler on the cache instead. - node->MakeSimpleProducer(); - } - return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.h deleted file mode 100644 index c08db895632..00000000000 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.h +++ /dev/null @@ -1,141 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_ - -#include -#include -#include -#include "minddata/dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { - -class CacheTransformPass; - -/// \class CachePass cache_pass.h -/// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache -/// transformation. It works in conjunction with the CacheTransformPass -class CachePass : public NodePass { - public: - /// \brief Constructor - /// \param[in] transform_pass Raw pointer back to controlling tree pass - explicit CachePass(CacheTransformPass *transform_pass); - - /// \brief Destructor - ~CachePass() = default; - - /// \brief Identifies the subtree below this node as a cached descendant tree. - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status PreRunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache - /// transformation - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - private: - /// \brief Common code for mappable leaf setup. - /// \param[in] node The leaf node performing setup work. - /// \return Status The error code return - Status MappableCacheLeafSetup(std::shared_ptr leaf_op); - - /// \brief Common code for non-mappable leaf setup. - /// \param[in] node The leaf node performing setup work. - /// \return Status The error code return - Status NonMappableCacheLeafSetup(std::shared_ptr leaf_op); - - bool is_caching_; - std::shared_ptr leaf_op_; - std::shared_ptr sampler_; - CacheTransformPass *transform_pass_; // Back pointer to the owning transform pass -}; - -} // namespace dataset -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc index 033150e8f41..f305cbedb1b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc @@ -15,17 +15,177 @@ */ #include -#include "minddata/dataset/engine/opt/pre/cache_pass.h" #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" #include "minddata/dataset/engine/execution_tree.h" #include "minddata/dataset/engine/cache/cache_client.h" #include "minddata/dataset/engine/datasetops/cache_lookup_op.h" #include "minddata/dataset/engine/datasetops/cache_merge_op.h" #include "minddata/dataset/engine/datasetops/cache_op.h" +#include "minddata/dataset/engine/datasetops/source/celeba_op.h" +#include "minddata/dataset/engine/datasetops/source/cifar_op.h" +#include "minddata/dataset/engine/datasetops/source/coco_op.h" +#include "minddata/dataset/engine/datasetops/source/generator_op.h" +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include "minddata/dataset/engine/datasetops/source/manifest_op.h" +#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" +#include "minddata/dataset/engine/datasetops/source/mnist_op.h" +#include "minddata/dataset/engine/datasetops/source/random_data_op.h" +#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" +#include "minddata/dataset/engine/datasetops/source/voc_op.h" namespace mindspore { namespace dataset { +// Constructor +CacheTransformPass::CachePass::CachePass() : is_caching_(false), leaf_op_(nullptr) {} + +// Identifies the subtree below this node as a cached descendant tree. +Status CacheTransformPass::CachePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; + if (is_caching_) { + RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!"); + } + is_caching_ = true; + return Status::OK(); +} + +// Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache +// transformation +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + is_caching_ = false; // We a no longer in a cache subtree. clear the flag. + if (leaf_op_) { + MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache."; + // Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op, + // using base class pointers. + AddMappableCacheOperators(std::move(leaf_op_), node); + } else { + // If there was no leaf_op set, then this is a non-mappable scenario. + + if (sampler_) { + // Grab the sampler that was saved from the leaf and plug it into the cache op + node->SetSampler(std::move(sampler_)); + MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf."; + } else { + // We're a cache op but no sampler was saved from leaf, so create a default sampler + int64_t num_samples = 0; + int64_t start_index = 0; + sampler_ = std::make_shared(num_samples, start_index); + node->SetSampler(std::move(sampler_)); + MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op."; + } + + // Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache + uint32_t cache_crc = DatasetOp::GenerateCRC(node); + RETURN_IF_NOT_OK(node->CreateCache(cache_crc)); + } + + return Status::OK(); +} + +// Common code for mappable leaf setup. +Status CacheTransformPass::CachePass::MappableCacheLeafSetup(std::shared_ptr leaf_op) { + // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. + if (is_caching_ && leaf_op_) { + RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); + } + + // If we are a leaf in the caching path, then save this leaf. + if (is_caching_) { + MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected"; + leaf_op_ = std::move(leaf_op); + } + return Status::OK(); +} + +// Common code for non mappable leaf setup. +Status CacheTransformPass::CachePass::NonMappableCacheLeafSetup(std::shared_ptr leaf_op) { + // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. + if (is_caching_ && leaf_op_) { + RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); + } + + // Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf + // as save it for use by cache op in ascendant tree. + if (is_caching_) { + RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_)); + MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected"; + } else { + // If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can + // remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based) + std::shared_ptr sampler_from_leaf; + RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf)); + } + return Status::OK(); +} + +// Perform leaf node cache tranform identifications +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + if (is_caching_) { + // If we are a TF Reader in a caching tree, then change our config so that it becomes a basic + // TF reader that parses all files. Selection of data will come from the sampler on the cache instead. + node->MakeSimpleProducer(); + } + return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Assigns the leaf and cache operators that are involved in a cache transformation +void CacheTransformPass::CachePass::AddMappableCacheOperators(std::shared_ptr leaf_op, + std::shared_ptr cache_op) { + cache_pairs_.push_back(std::make_pair(leaf_op, cache_op)); +} + // constructor CacheTransformPass::CacheTransformPass() {} @@ -34,11 +194,11 @@ Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *modified) { MS_LOG(INFO) << "Pre pass: Cache transform pass started."; // Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will // use to execute a transform. - std::unique_ptr cache_pass = std::make_unique(this); - RETURN_IF_NOT_OK(cache_pass->Run(tree, modified)); + CachePass cache_pass = CachePass(); + RETURN_IF_NOT_OK(cache_pass.Run(tree, modified)); // Then, execute the transform for each pair - for (auto cache_pair : cache_pairs_) { + for (auto cache_pair : cache_pass.cache_pairs()) { MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform."; ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client()); } @@ -98,11 +258,5 @@ Status CacheTransformPass::ExecuteCacheTransform(ExecutionTree *tree, std::share return Status::OK(); } - -// Assigns the leaf and cache operators that are involved in a cache transformation -void CacheTransformPass::AddMappableCacheOperators(std::shared_ptr leaf_op, - std::shared_ptr cache_op) { - cache_pairs_.push_back(std::make_pair(leaf_op, cache_op)); -} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h index 05efb5b1267..942ea5fb5c0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h @@ -33,6 +33,123 @@ class CacheClient; /// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching /// operations class CacheTransformPass : public TreePass { + /// \class CachePass + /// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache + /// transformation. It works in conjunction with the CacheTransformPass + class CachePass : public NodePass { + public: + /// \brief Constructor + /// \param[in] transform_pass Raw pointer back to controlling tree pass + CachePass(); + + /// \brief Destructor + ~CachePass() = default; + + /// \brief Identifies the subtree below this node as a cached descendant tree. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Resets the tracking of the cache within the tree and assigns the operators that + /// will be involved in a cache transformation + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Getter + std::vector, std::shared_ptr>> cache_pairs() { return cache_pairs_; } + + private: + /// \brief Common code for mappable leaf setup. + /// \param[in] node The leaf node performing setup work. + /// \return Status The error code return + Status MappableCacheLeafSetup(std::shared_ptr leaf_op); + + /// \brief Common code for non-mappable leaf setup. + /// \param[in] node The leaf node performing setup work. + /// \return Status The error code return + Status NonMappableCacheLeafSetup(std::shared_ptr leaf_op); + + /// \brief Assigns the leaf and cache operators that are involved in a cache transformation + /// \param[in] leaf_op The leaf operator involved in the cache transform + /// \param[in] cache_op The cache operator involved in the cache transform + void AddMappableCacheOperators(std::shared_ptr leaf_op, std::shared_ptr cache_op); + + bool is_caching_; + std::shared_ptr leaf_op_; + std::shared_ptr sampler_; + // The two operators that work together to establish the cache transform + std::vector, std::shared_ptr>> cache_pairs_; + }; + public: /// \brief Constructor CacheTransformPass(); @@ -46,11 +163,6 @@ class CacheTransformPass : public TreePass { /// \return Status The error code return Status RunOnTree(ExecutionTree *tree, bool *modified) override; - /// \brief Assigns the leaf and cache operators that are involved in a cache transformation - /// \param[in] leaf_op The leaf operator involved in the cache transform - /// \param[in] cache_op The cache operator involved in the cache transform - void AddMappableCacheOperators(std::shared_ptr leaf_op, std::shared_ptr cache_op); - private: /// \brief Helper function to execute the cache transformation. /// @@ -72,9 +184,6 @@ class CacheTransformPass : public TreePass { /// \return Status The error code return Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr leaf_op, std::shared_ptr cache_op, std::shared_ptr cache_client); - - // The two operators that work together to establish the cache transform - std::vector, std::shared_ptr>> cache_pairs_; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/injection_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.cc similarity index 54% rename from mindspore/ccsrc/minddata/dataset/engine/opt/pre/injection_pass.cc rename to mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.cc index 0dfe3034ff3..2cd1f740894 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/injection_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.cc @@ -16,7 +16,7 @@ #include #include -#include "minddata/dataset/engine/opt/pre/injection_pass.h" +#include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h" #include "minddata/dataset/engine/execution_tree.h" #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" #include "minddata/dataset/engine/datasetops/device_queue_op.h" @@ -25,64 +25,55 @@ namespace mindspore { namespace dataset { // constructor -InjectionPass::InjectionFinder::InjectionFinder(InjectionPass *injection_pass) : injection_pass_(injection_pass) {} +EpochInjectionPass::InjectionFinder::InjectionFinder(std::shared_ptr node) : injection_point_(node) {} // Performs finder work for BuildVocabOp that has special rules about epoch control injection -Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr node, bool *modified) { - if (injection_pass_) { - injection_pass_->epoch_ctrl_bypass_ = true; - return Status::OK(); - } else { - RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!"); - } +Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr node, bool *modified) { + injection_point_ = nullptr; + return Status::OK(); } // Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection -Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr node, bool *modified) { - if (injection_pass_) { - injection_pass_->epoch_ctrl_bypass_ = true; - return Status::OK(); - } else { - RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!"); - } +Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr node, + bool *modified) { + injection_point_ = nullptr; + return Status::OK(); } // Temporary code to prevent the injection of epoch control when cache op is present // Remove this code in cache op phase 2 -Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr node, bool *modified) { - if (injection_pass_) { - injection_pass_->epoch_ctrl_bypass_ = true; - return Status::OK(); - } else { - RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!"); - } +Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr node, bool *modified) { + injection_point_ = nullptr; + return Status::OK(); +} + +Status EpochInjectionPass::InjectionFinder::RunOnNode(std::shared_ptr node, bool *modified) { + // Assumption: There is only one DeviceQueueOp in a pipeline. This assumption is not validated here. + injection_point_ = node->child(0); + return Status::OK(); } // constructor -InjectionPass::InjectionPass() : epoch_ctrl_bypass_(false) {} +EpochInjectionPass::EpochInjectionPass() {} // Runs an injection pass to inject in operators needed at the pre pass stage -Status InjectionPass::RunOnTree(ExecutionTree *tree, bool *modified) { +Status EpochInjectionPass::RunOnTree(ExecutionTree *tree, bool *modified) { MS_LOG(INFO) << "Pre pass: Injection pass started."; // First, run the finder to perform any injection info before we can go ahead to drive the op injection work. - // The finder can make updates to the InjectionPass object. - InjectionPass::InjectionFinder finder(this); - finder.Run(tree, modified); + // The finder can make updates to the EpochInjectionPass object. + EpochInjectionPass::InjectionFinder finder(tree->root()); + RETURN_IF_NOT_OK(finder.Run(tree, modified)); // The first injection logic is to check if we should inject the epoch control op as the root node. // Do not inject the op if the number of epochs is 1. int32_t num_epochs = tree->num_epochs(); - if (num_epochs != 1 && !epoch_ctrl_bypass_) { + std::shared_ptr epoch_inject_node = finder.injection_point(); + if (num_epochs != 1 && epoch_inject_node != nullptr) { std::shared_ptr epoch_ctrl_op; RETURN_IF_NOT_OK(EpochCtrlOp::Builder(num_epochs).Build(&epoch_ctrl_op)); RETURN_IF_NOT_OK(tree->AssociateNode(epoch_ctrl_op)); - std::shared_ptr node = tree->root(); - if (std::dynamic_pointer_cast(node) == nullptr) { - tree->root()->InsertAsParent(epoch_ctrl_op); - } else { - tree->root()->child(0)->InsertAsParent(epoch_ctrl_op); - } + epoch_inject_node->InsertAsParent(epoch_ctrl_op); } MS_LOG(INFO) << "Pre pass: Injection pass complete."; diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/injection_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.h similarity index 75% rename from mindspore/ccsrc/minddata/dataset/engine/opt/pre/injection_pass.h rename to mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.h index aef29616629..292f411aff4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/injection_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_ -#define DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_ +#ifndef DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ +#define DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ #include #include @@ -26,10 +26,10 @@ namespace dataset { class DatasetOp; -/// \class InjectionPass injection_pass.h +/// \class EpochInjectionPass epoch_injection_pass.h /// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api /// parsing. -class InjectionPass : public TreePass { +class EpochInjectionPass : public TreePass { /// \class InjectionFinder /// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for /// operators that need to be injected. It is run first by the main injection pass to find out what operators @@ -37,7 +37,10 @@ class InjectionPass : public TreePass { class InjectionFinder : public NodePass { public: /// \brief Constructor - explicit InjectionFinder(InjectionPass *injection_pass); + explicit InjectionFinder(std::shared_ptr node); + + /// \brief Destructor + ~InjectionFinder() = default; /// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection. /// \param[in] node The node being visited @@ -58,24 +61,30 @@ class InjectionPass : public TreePass { /// \return Status The error code return Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + /// \brief Register the DeviceQueueOp for further action. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Getter + std::shared_ptr injection_point() { return injection_point_; } + private: - InjectionPass *injection_pass_; + std::shared_ptr injection_point_; }; public: /// \brief Constructor - InjectionPass(); + EpochInjectionPass(); /// \brief Runs an injection pass to inject in operators needed at the pre pass stage /// \param[inout] tree The tree to operate on. /// \param[inout] Indicate of the tree was modified. /// \return Status The error code return Status RunOnTree(ExecutionTree *tree, bool *modified) override; - - private: - bool epoch_ctrl_bypass_; }; } // namespace dataset } // namespace mindspore -#endif // DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_ +#endif // DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.cc deleted file mode 100644 index f04d7bc07d2..00000000000 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.cc +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "minddata/dataset/engine/opt/pre/removal_nodes.h" -#include "minddata/dataset/engine/opt/pre/removal_pass.h" -#include "minddata/dataset/engine/datasetops/shuffle_op.h" - -namespace mindspore { -namespace dataset { - -RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {} - -// Identifies the subtree below this node as a cached descendant tree. -Status RemovalNodes::PreRunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree."; - is_caching_ = true; - return Status::OK(); -} - -// Resets the tracking of the cache within the tree -Status RemovalNodes::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - MS_LOG(INFO) << "Removal pass: cache descendant tree complete."; - is_caching_ = false; - return Status::OK(); -} - -// Perform ShuffleOp removal check. -Status RemovalNodes::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - // If we are in a cache descendant tree, then this shuffle op needs to be removed - if (is_caching_) { - MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)"; - if (removal_pass_) { - removal_pass_->AddToRemovalList(std::static_pointer_cast(node)); - } else { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Back reference to removal pass is missing!"); - } - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.h deleted file mode 100644 index 16f623668c0..00000000000 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.h +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ - -#include -#include "minddata/dataset/engine/opt/pass.h" -#include "minddata/dataset/engine/opt/pre/removal_pass.h" - -namespace mindspore { -namespace dataset { -/// \class RemovalNodes removal_nodes.h -/// \brief This is a NodePass who's job is to identify which nodes should be removed. -/// It works in conjunction with the removal_pass. -class RemovalNodes : public NodePass { - public: - /// \brief Constructor - /// \param[in] removal_pass Raw pointer back to controlling tree pass - explicit RemovalNodes(RemovalPass *removal_pass); - - /// \brief Identifies the subtree below this node as a cached descendant tree. - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status PreRunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Resets the tracking of the cache within the tree - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Destructor - ~RemovalNodes() = default; - - /// \brief Perform ShuffleOp removal check - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - private: - bool is_caching_; - RemovalPass *removal_pass_; // Back pointer to the owning removal pass -}; - -} // namespace dataset -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc index 0db422a7c25..14c00b4a63b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc @@ -16,32 +16,58 @@ #include #include -#include "minddata/dataset/engine/opt/pre/removal_nodes.h" #include "minddata/dataset/engine/opt/pre/removal_pass.h" +#include "minddata/dataset/engine/datasetops/shuffle_op.h" #include "minddata/dataset/engine/execution_tree.h" namespace mindspore { namespace dataset { +RemovalPass::RemovalNodes::RemovalNodes() : is_caching_(false) {} + +// Identifies the subtree below this node as a cached descendant tree. +Status RemovalPass::RemovalNodes::PreRunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree."; + is_caching_ = true; + return Status::OK(); +} + +// Resets the tracking of the cache within the tree +Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + MS_LOG(INFO) << "Removal pass: cache descendant tree complete."; + is_caching_ = false; + return Status::OK(); +} + +// Perform ShuffleOp removal check. +Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + // If we are in a cache descendant tree, then this shuffle op needs to be removed + if (is_caching_) { + MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)"; + nodes_to_remove_.push_back(std::static_pointer_cast(node)); + } + return Status::OK(); +} + // constructor RemovalPass::RemovalPass() {} -// Runs a removal_nodes pass first to find out which nodes to remove, then removes them. +// Walk the tree to collect the nodes to remove, then removes them. Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) { MS_LOG(INFO) << "Pre pass: removal pass started."; // Create the removal node pass which can identify which nodes need to be removed. - std::unique_ptr removal_nodes = std::make_unique(this); + std::unique_ptr removal_nodes = std::make_unique(); RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified)); // Then, execute the removal of any nodes that were set up for removal - for (auto node : removal_nodes_) { + for (auto node : removal_nodes->nodes_to_remove()) { node->Remove(); } MS_LOG(INFO) << "Pre pass: removal pass complete."; return Status::OK(); } - -// Adds an operator to the list of operators to be removed -void RemovalPass::AddToRemovalList(std::shared_ptr dataset_op) { removal_nodes_.push_back(dataset_op); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h index 59414f153fa..f1e8b794955 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h @@ -30,6 +30,45 @@ class DatasetOp; /// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which /// nodes should be removed, and then removes them. class RemovalPass : public TreePass { + /// \class RemovalNodes + /// \brief This is a NodePass who's job is to identify which nodes should be removed. + /// It works in conjunction with the removal_pass. + class RemovalNodes : public NodePass { + public: + /// \brief Constructor + /// \param[in] removal_pass Raw pointer back to controlling tree pass + RemovalNodes(); + + /// \brief Destructor + ~RemovalNodes() = default; + + /// \brief Identifies the subtree below this node as a cached descendant tree. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Resets the tracking of the cache within the tree + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform ShuffleOp removal check + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Getter + /// \return All the nodes to be removed + std::vector> nodes_to_remove() { return nodes_to_remove_; } + + private: + bool is_caching_; + std::vector> nodes_to_remove_; + }; + public: /// \brief Constructor RemovalPass(); @@ -42,13 +81,6 @@ class RemovalPass : public TreePass { /// \param[inout] Indicate of the tree was modified. /// \return Status The error code return Status RunOnTree(ExecutionTree *tree, bool *modified) override; - - /// \brief Adds an operator to the list of operators to be removed - /// \param[in] dataset_op The operator to add to the removal list - void AddToRemovalList(std::shared_ptr dataset_op); - - private: - std::vector> removal_nodes_; }; } // namespace dataset } // namespace mindspore diff --git a/tests/ut/python/dataset/test_minddataset_exception.py b/tests/ut/python/dataset/test_minddataset_exception.py index 0bfb7a03427..619dff19628 100644 --- a/tests/ut/python/dataset/test_minddataset_exception.py +++ b/tests/ut/python/dataset/test_minddataset_exception.py @@ -189,7 +189,7 @@ def test_minddataset_invalidate_num_shards(): num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 - assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info) + assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info.value) os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME)) @@ -203,7 +203,7 @@ def test_minddataset_invalidate_shard_id(): num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 - assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info) + assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info.value) os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME)) @@ -217,14 +217,14 @@ def test_minddataset_shard_id_bigger_than_num_shard(): num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 - assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info) + assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info.value) with pytest.raises(Exception) as error_info: data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5) num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 - assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info) + assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info.value) os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME)) @@ -245,7 +245,7 @@ def test_cv_minddataset_partition_num_samples_equals_0(): num_iter += 1 with pytest.raises(Exception) as error_info: partitions(5) - assert 'num_samples should be a positive integer value, but got num_samples=0' in str(error_info) + assert 'num_samples should be a positive integer value, but got num_samples=0' in str(error_info.value) os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME))