From d69a29a44e7b416a29be7964b6b2a1a92966c172 Mon Sep 17 00:00:00 2001 From: Nat Sutyanyong Date: Fri, 27 Nov 2020 16:45:00 -0500 Subject: [PATCH] Migrate 3 pre passes to IR optimizer, namely, cache_error_pass, epoch_injection, and removal_pass --- .../dataset/engine/consumers/tree_consumer.cc | 2 +- .../minddata/dataset/engine/execution_tree.cc | 40 +++-- .../minddata/dataset/engine/execution_tree.h | 11 +- .../engine/ir/datasetops/batch_node.cc | 12 ++ .../dataset/engine/ir/datasetops/batch_node.h | 12 ++ .../datasetops/bucket_batch_by_length_node.cc | 32 +++- .../build_sentence_piece_vocab_node.cc | 8 +- .../build_sentence_piece_vocab_node.h | 8 +- .../engine/ir/datasetops/build_vocab_node.cc | 8 +- .../engine/ir/datasetops/build_vocab_node.h | 8 +- .../engine/ir/datasetops/concat_node.cc | 12 +- .../engine/ir/datasetops/concat_node.h | 8 +- .../engine/ir/datasetops/dataset_node.cc | 55 ++++-- .../engine/ir/datasetops/dataset_node.h | 113 ++++-------- .../engine/ir/datasetops/epoch_ctrl_node.cc | 3 +- .../engine/ir/datasetops/epoch_ctrl_node.h | 5 +- .../engine/ir/datasetops/filter_node.cc | 8 +- .../engine/ir/datasetops/filter_node.h | 8 +- .../dataset/engine/ir/datasetops/map_node.cc | 14 +- .../dataset/engine/ir/datasetops/map_node.h | 8 +- .../engine/ir/datasetops/repeat_node.cc | 8 +- .../engine/ir/datasetops/repeat_node.h | 8 +- .../dataset/engine/ir/datasetops/root_node.cc | 8 +- .../dataset/engine/ir/datasetops/root_node.h | 8 +- .../dataset/engine/ir/datasetops/skip_node.cc | 12 ++ .../dataset/engine/ir/datasetops/skip_node.h | 12 ++ .../engine/ir/datasetops/source/album_node.cc | 2 +- .../ir/datasetops/source/celeba_node.cc | 2 +- .../ir/datasetops/source/cifar100_node.cc | 2 +- .../ir/datasetops/source/cifar10_node.cc | 2 +- .../engine/ir/datasetops/source/clue_node.cc | 2 +- .../engine/ir/datasetops/source/coco_node.cc | 2 +- .../engine/ir/datasetops/source/csv_node.cc | 2 +- .../ir/datasetops/source/generator_node.cc | 13 +- .../ir/datasetops/source/image_folder_node.cc | 2 +- .../ir/datasetops/source/manifest_node.cc | 2 +- .../ir/datasetops/source/minddata_node.cc | 3 +- .../engine/ir/datasetops/source/mnist_node.cc | 2 +- .../ir/datasetops/source/text_file_node.cc | 2 +- .../ir/datasetops/source/tf_record_node.cc | 2 +- .../engine/ir/datasetops/source/voc_node.cc | 2 +- .../dataset/engine/ir/datasetops/take_node.cc | 12 ++ .../dataset/engine/ir/datasetops/take_node.h | 12 ++ .../engine/ir/datasetops/transfer_node.cc | 8 +- .../engine/ir/datasetops/transfer_node.h | 8 +- .../dataset/engine/ir/datasetops/zip_node.cc | 8 +- .../dataset/engine/ir/datasetops/zip_node.h | 23 +-- .../dataset/engine/opt/CMakeLists.txt | 3 + .../ccsrc/minddata/dataset/engine/opt/pass.cc | 84 ++++----- .../ccsrc/minddata/dataset/engine/opt/pass.h | 126 ++++++++------ .../engine/opt/pre/cache_validation_pass.cc | 163 ++++++++++++++++++ .../engine/opt/pre/cache_validation_pass.h | 105 +++++++++++ .../dataset/engine/opt/pre/epoch_ctrl_pass.cc | 85 +++++++++ .../dataset/engine/opt/pre/epoch_ctrl_pass.h | 98 +++++++++++ .../engine/opt/pre/input_validation_pass.h | 2 +- .../engine/opt/pre/node_removal_pass.cc | 81 +++++++++ .../engine/opt/pre/node_removal_pass.h | 88 ++++++++++ .../minddata/dataset/engine/tree_adapter.cc | 49 ++---- .../dataset/kernels/image/normalize_op.cc | 2 +- .../ut/cpp/dataset/optimization_pass_test.cc | 4 +- tests/ut/python/dataset/test_cache_map.py | 20 +-- 61 files changed, 1095 insertions(+), 359 deletions(-) create mode 100644 mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_validation_pass.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_validation_pass.h create mode 100644 mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_ctrl_pass.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h create mode 100644 mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.h diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc index 9b0b5fbec28..101f9d836e3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc @@ -574,7 +574,7 @@ Status DatasetSizeGetter::DryRun(std::shared_ptr ir_node, int64_t * std::make_unique(static_cast(GetterPass::GetterType::kDatasetSize))); return pre; }); - RETURN_IF_NOT_OK(tree_adapter->Compile(std::move(ir_node), 1)); + RETURN_IF_NOT_OK(tree_adapter->Compile(ir_node, 1)); TensorRow row; RETURN_IF_NOT_OK(GetRow(tree_adapter, &row)); int64_t row_cnt = 0; diff --git a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc index 7edd85c7c31..97d44941207 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc @@ -214,7 +214,7 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::functionPrepareTreePreAction()); + RETURN_IF_NOT_OK(this->PreAction()); // If optional optimizations are enabled if (optimize_) { RETURN_IF_NOT_OK(this->Optimize()); } // Post optimization compulsory transformation - RETURN_IF_NOT_OK(this->PrepareTreePostAction()); + RETURN_IF_NOT_OK(this->PostAction()); + + // The tree is ready to be prepared. + tree_state_ = kDeTStatePrepare; // Existing transformation implementation, will be removed later RETURN_IF_NOT_OK(this->PrepareDeprecated()); return Status::OK(); } -Status ExecutionTree::PrepareTreePreAction() { +Status ExecutionTree::PreAction() { bool modified = false; std::vector> pre_actions; // Construct pre actions + if (!partially_prepare_) { #ifndef ENABLE_ANDROID - pre_actions.push_back(std::make_unique()); -#endif - pre_actions.push_back(std::make_unique()); - pre_actions.push_back(std::make_unique()); -#ifndef ENABLE_ANDROID - pre_actions.push_back(std::make_unique()); + pre_actions.push_back(std::make_unique()); #endif + pre_actions.push_back(std::make_unique()); + pre_actions.push_back(std::make_unique()); + } // this offers a way to override the preset optimization pass with customized ones // this is used when certain nodes are removed for tree getters @@ -276,15 +279,17 @@ Status ExecutionTree::PrepareTreePreAction() { return Status::OK(); } -Status ExecutionTree::PrepareTreePostAction() { - // The tree is ready to be prepared. - tree_state_ = kDeTStatePrepare; - +Status ExecutionTree::PostAction() { bool modified = false; OptPass post_actions; // Construct pre actions MS_LOG(INFO) << "Running post pass loops."; #ifndef ENABLE_ANDROID + // Calling CacheErrorPass again. This is a temporary fix until the TensorOperation is properly done in Pybind. + // The IR version cannot detect an invalid case of a cache on Map with random tensor operation from Python API. + // This is because Python API binding to TensorOperation is still in progress. + post_actions.push_back(std::make_unique()); + post_actions.push_back(std::make_unique()); post_actions.push_back(std::make_unique()); #endif @@ -340,9 +345,6 @@ Status ExecutionTree::PrepareDeprecated() { // Recursive function used during prepare phase to visit a node and drive any pre- and post- // node actions during a tree walk. Status ExecutionTree::PrepareNode(const std::shared_ptr &dataset_op) { - // execute PreAction - RETURN_IF_NOT_OK(dataset_op->PrepareNodePreAction()); - // Before going down into children, make any prepare flags updates based on this operator. uint32_t op_prep_flags = dataset_op->PrepareFlags(); BitSet(&prepare_flags_, op_prep_flags); diff --git a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h index 61bddbf4d4f..0bcfdcba59b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h +++ b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h @@ -169,7 +169,7 @@ class ExecutionTree { // The driver of the prepare phase of the execution tree. // Prepare phase consists of three sub phases // - // 1. PrepareTreePreAction() + // 1. PreAction() // Compulsory transformation/action pre optimization. // For example, CacheOp Insertion // @@ -177,20 +177,20 @@ class ExecutionTree { // Optimization transformation/action, optional // For example, MapOp Fusion // - // 3. PrepareTreePostAction() + // 3. PostAction() // Compulsory transformation/action post optimization. // For example, repeatOp inlining // // @return Status - The error code return - Status Prepare(int num_epochs = -1); + Status Prepare(int num_epochs = -1, bool partial = false); // Compulsory transformation/action pre optimization. // @return Status - The error code return - Status PrepareTreePreAction(); + Status PreAction(); // Compulsory transformation/action post optimization. // @return Status - The error code return - Status PrepareTreePostAction(); + Status PostAction(); // Optimization transformation/action, optional. // @return Status - The error code return @@ -281,6 +281,7 @@ class ExecutionTree { std::unique_ptr profiling_manager_; // Profiling manager bool optimize_; // Flag to enable optional optimizations std::function pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare() + bool partially_prepare_; // Temp: during migration to IR, if true, run remaining passes. }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc index 3e493d6d597..40db6edb9fa 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.cc @@ -23,6 +23,7 @@ #include #include "minddata/dataset/engine/datasetops/batch_op.h" +#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { @@ -139,5 +140,16 @@ Status BatchNode::GetDatasetSize(const std::shared_ptr &size_ return Status::OK(); } +// Visitor accepting method for IRNodePass +Status BatchNode::Accept(IRNodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->Visit(shared_from_base(), modified); +} + +// Visitor accepting method for IRNodePass +Status BatchNode::AcceptAfter(IRNodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->VisitAfter(shared_from_base(), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h index 4b63d047ef0..d91efa3e006 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/batch_node.h @@ -74,6 +74,18 @@ class BatchNode : public DatasetNode { Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, int64_t *dataset_size) override; + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(IRNodePass *p, bool *modified) override; + + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status AcceptAfter(IRNodePass *p, bool *modified) override; + private: int32_t batch_size_; bool drop_remainder_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc index 497050763f7..4bbd4316a74 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc @@ -46,12 +46,40 @@ BucketBatchByLengthNode::BucketBatchByLengthNode( std::shared_ptr BucketBatchByLengthNode::Copy() { auto node = std::make_shared(nullptr, column_names_, bucket_boundaries_, bucket_batch_sizes_, - element_length_function_, pad_info_, pad_to_bucket_boundary_); + element_length_function_, pad_info_, pad_to_bucket_boundary_, + drop_remainder_); return node; } void BucketBatchByLengthNode::Print(std::ostream &out) const { - out << Name() + "(columns:" + PrintColumns(column_names_) + ",...)"; + out << Name() + "(columns:" + PrintColumns(column_names_); + int i = 0; + for (auto it : bucket_boundaries_) { + if (i == 0) { + out << ",bucket_boundaries:{"; + } + out << it; + if (i < bucket_boundaries_.size() - 1) { + out << ","; + } else { + out << "}"; + } + i++; + } + i = 0; + for (auto it : bucket_batch_sizes_) { + if (i == 0) { + out << ",bucket_batch_sizes:{"; + } + out << it; + if (i < bucket_batch_sizes_.size() - 1) { + out << ","; + } else { + out << "}"; + } + i++; + } + out << ")"; } Status BucketBatchByLengthNode::Build(std::vector> *node_ops) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc index ab89669e6b4..b012e722230 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.cc @@ -90,14 +90,14 @@ Status BuildSentenceVocabNode::ValidateParams() { return Status::OK(); } -// Visitor accepting method for NodePass -Status BuildSentenceVocabNode::Accept(NodePass *p, bool *modified) { +// Visitor accepting method for IRNodePass +Status BuildSentenceVocabNode::Accept(IRNodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->Visit(shared_from_base(), modified); } -// Visitor accepting method for NodePass -Status BuildSentenceVocabNode::AcceptAfter(NodePass *p, bool *modified) { +// Visitor accepting method for IRNodePass +Status BuildSentenceVocabNode::AcceptAfter(IRNodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->VisitAfter(shared_from_base(), modified); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h index d0624bd35d8..2426201fc50 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h @@ -59,17 +59,17 @@ class BuildSentenceVocabNode : public DatasetNode { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; - /// \brief Base-class override for accepting NodePass visitor + /// \brief Base-class override for accepting IRNodePass visitor /// \param[in] p The node to visit /// \param[out] modified Indicator if the node was modified /// \return Status of the node visit - Status Accept(NodePass *p, bool *modified) override; + Status Accept(IRNodePass *p, bool *modified) override; - /// \brief Base-class override for accepting NodePass visitor + /// \brief Base-class override for accepting IRNodePass visitor /// \param[in] p The node to visit /// \param[out] modified Indicator if the node was modified /// \return Status of the node visit - Status AcceptAfter(NodePass *p, bool *modified) override; + Status AcceptAfter(IRNodePass *p, bool *modified) override; private: std::shared_ptr vocab_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc index ade1ebd6c99..714f967d75a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc @@ -85,14 +85,14 @@ Status BuildVocabNode::ValidateParams() { return Status::OK(); } -// Visitor accepting method for NodePass -Status BuildVocabNode::Accept(NodePass *p, bool *modified) { +// Visitor accepting method for IRNodePass +Status BuildVocabNode::Accept(IRNodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->Visit(shared_from_base(), modified); } -// Visitor accepting method for NodePass -Status BuildVocabNode::AcceptAfter(NodePass *p, bool *modified) { +// Visitor accepting method for IRNodePass +Status BuildVocabNode::AcceptAfter(IRNodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->VisitAfter(shared_from_base(), modified); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h index bb2c9bfd4b0..77193faf73a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h @@ -58,17 +58,17 @@ class BuildVocabNode : public DatasetNode { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; - /// \brief Base-class override for accepting NodePass visitor + /// \brief Base-class override for accepting IRNodePass visitor /// \param[in] p The node to visit /// \param[out] modified Indicator if the node was modified /// \return Status of the node visit - Status Accept(NodePass *p, bool *modified) override; + Status Accept(IRNodePass *p, bool *modified) override; - /// \brief Base-class override for accepting NodePass visitor + /// \brief Base-class override for accepting IRNodePass visitor /// \param[in] p The node to visit /// \param[out] modified Indicator if the node was modified /// \return Status of the node visit - Status AcceptAfter(NodePass *p, bool *modified) override; + Status AcceptAfter(IRNodePass *p, bool *modified) override; private: std::shared_ptr vocab_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc index db594227b58..8ff227fa5d4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc @@ -39,8 +39,10 @@ ConcatNode::ConcatNode(const std::vector> &datasets } std::shared_ptr ConcatNode::Copy() { + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); // create an empty vector to copy a concat - auto node = std::make_shared(std::vector>()); + auto node = std::make_shared(std::vector>(), sampler, + children_flag_and_nums_, children_start_end_index_); return node; } @@ -80,14 +82,14 @@ Status ConcatNode::Build(std::vector> *node_ops) { return Status::OK(); } -// Visitor accepting method for NodePass -Status ConcatNode::Accept(NodePass *p, bool *modified) { +// Visitor accepting method for IRNodePass +Status ConcatNode::Accept(IRNodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->Visit(shared_from_base(), modified); } -// Visitor accepting method for NodePass -Status ConcatNode::AcceptAfter(NodePass *p, bool *modified) { +// Visitor accepting method for IRNodePass +Status ConcatNode::AcceptAfter(IRNodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->VisitAfter(shared_from_base(), modified); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h index d53e0c0ff5a..694445939de 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h @@ -66,17 +66,17 @@ class ConcatNode : public DatasetNode { std::vector> children_flag_and_nums_; std::vector> children_start_end_index_; - /// \brief Base-class override for accepting NodePass visitor + /// \brief Base-class override for accepting IRNodePass visitor /// \param[in] p The node to visit /// \param[out] modified Indicator if the node was modified /// \return Status of the node visit - Status Accept(NodePass *p, bool *modified) override; + Status Accept(IRNodePass *p, bool *modified) override; - /// \brief Base-class override for accepting NodePass visitor + /// \brief Base-class override for accepting IRNodePass visitor /// \param[in] p The node to visit /// \param[out] modified Indicator if the node was modified /// \return Status of the node visit - Status AcceptAfter(NodePass *p, bool *modified) override; + Status AcceptAfter(IRNodePass *p, bool *modified) override; }; } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc index f6b945d4053..2b44015cc6f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc @@ -242,9 +242,27 @@ DatasetNode::DatasetNode() : cache_(nullptr), parent_({}), children_({}) { worker_connector_size_ = cfg->worker_connector_size(); } +const bool DatasetNode::IsTree() const { + bool is_tree = true; + if (this->parent_.size() > 1) { + MS_LOG(WARNING) << Name() << " has more than one parent."; + return false; + } + for (const auto &child : children_) { + is_tree = child->IsTree(); + if (!is_tree) { + MS_LOG(WARNING) << Name() << " has more than one parent."; + break; + } + } + return is_tree; +} + // this function will preform a deep copy of current node (and its descendants), the parent* pointer will not be copied std::shared_ptr DatasetNode::DeepCopy() { std::shared_ptr new_node = this->Copy(); + // temporary fix to set the num_workers to the new node. + new_node->SetNumWorkers(this->num_workers_); for (const auto &child : children_) { new_node->AddChild(child->DeepCopy()); } @@ -298,12 +316,31 @@ void DatasetNode::AddChild(std::shared_ptr child) { children_.push_back(child); child->parent_.push_back(this); } else if (child != nullptr) { - MS_LOG(WARNING) << "DatasetNode::AddChild() failed: " + child->Name() + "'s parent isn't a nullptr."; + MS_LOG(WARNING) << "Adding " + child->Name() + " to " + Name() + " but it already has a parent"; children_.push_back(child); child->parent_.push_back(this); } } +// Insert a node as a child of this node. This node's children becomes the children of the inserted node. +Status DatasetNode::InsertBelow(std::shared_ptr node) { + CHECK_FAIL_RETURN_UNEXPECTED(node != nullptr, "Inserted node must not be a null pointer."); + CHECK_FAIL_RETURN_UNEXPECTED(node->children_.empty(), "Inserted node must not have any children."); + CHECK_FAIL_RETURN_UNEXPECTED(node->parent_.empty(), "Inserted node must not have a parent."); + + for (auto child : children_) { + node->children_.push_back(child); + child->parent_.clear(); + child->parent_.push_back(node.get()); + } + // Then establish the new parent-child relationship with the new parent. + children_.clear(); + children_.push_back(node); + node->parent_.clear(); + node->parent_.push_back(this); + return Status::OK(); +} + // Remove this node from its parent. Add the child of this node to its parent. // for now, this remove is limited to node with a single child or no child Status DatasetNode::Remove() { @@ -325,14 +362,14 @@ Status DatasetNode::Remove() { } // In DFS tree traversal, each node is visited twice. Accept is called on the first visit. -Status DatasetNode::Accept(NodePass *p, bool *modified) { +Status DatasetNode::Accept(IRNodePass *p, bool *modified) { // This method will only be called if its derived class does not implement one. return p->Visit(shared_from_this(), modified); } // In DFS tree traversal, each node is visited twice. AcceptAfter is called on the second visit // after all child nodes are visited. -Status DatasetNode::AcceptAfter(NodePass *p, bool *modified) { +Status DatasetNode::AcceptAfter(IRNodePass *p, bool *modified) { // This method will only be called if its derived class does not implement one. return p->VisitAfter(shared_from_this(), modified); } @@ -369,17 +406,5 @@ Status DatasetNode::GetDatasetSize(const std::shared_ptr &siz RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override"); } } - -// Visitor accepting method for NodePass -Status SourceNode::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->Visit(shared_from_base(), modified); -} - -// Visitor accepting method for NodePass -Status SourceNode::AcceptAfter(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->VisitAfter(shared_from_base(), modified); -} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h index 601e06171d7..f037a460ba8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h @@ -32,7 +32,7 @@ namespace dataset { class Dataset; class SamplerObj; -class NodePass; +class IRNodePass; class DatasetSizeGetter; // Names for non-leaf IR node @@ -182,6 +182,9 @@ class DatasetNode : public std::enable_shared_from_this { /// \brief Establish the parent-child relationship between this node and its child. void AddChild(std::shared_ptr child); + /// \brief Insert the input node below this node. This node's children becomes the children of the inserted node. + Status InsertBelow(std::shared_ptr node); + /// \brief detach this node from its parent, add its child (if any) to its parent /// \return error code, return error if node has more than 1 children Status Remove(); @@ -190,6 +193,25 @@ class DatasetNode : public std::enable_shared_from_this { /// \return True if the data of this node will be cached const bool IsCached() const { return (cache_ != nullptr); } + /// \brief Check if this node is a tree + /// \return True if the structure is indeed a tree, i.e., no node has more than one parent + const bool IsTree() const; + + /// \brief Check if this node is a leaf node. + /// \return True if this is a leaf node. + const bool IsLeaf() const { return children_.empty(); } + + /// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes + /// \return True if the dataset represented by this node is a mappable dataset + const bool IsMappable() const { return mappable_; } + + /// \brief Check if this node is a descendant of an operator with cache. Currently used in leaf nodes + /// \return True if a cache-enabled operator is an ancestor of this node + const bool IsDescendantOfCache() const { return descendant_of_cache_; } + + /// \brief Mark to indicate this node is a descendant of an operator with cache. Currently used in leaf nodes + void HasCacheAbove() { descendant_of_cache_ = true; } + /// \brief Setter function for runtime number of workers /// \param[in] num_workers The number of threads in this operator /// \return Shared pointer to the original object @@ -203,7 +225,7 @@ class DatasetNode : public std::enable_shared_from_this { return std::static_pointer_cast(shared_from_this()); } - /// \brief Base method for NodePass visit. A tree walk consists of walking down the tree and also walking back up + /// \brief Base method for IRNodePass visit. A tree walk consists of walking down the tree and also walking back up /// in a depth-first order. Accept is the node visit on the way down, whereas AcceptAfter is the node /// visit on the way back up the tree after its descendants are visited. /// \notes Subclass needs to override this if it requires special node visit access. @@ -211,15 +233,15 @@ class DatasetNode : public std::enable_shared_from_this { /// \param[in] p The node to visit /// \param[out] modified Indicator if the node was modified /// \return Status of the node visit - virtual Status Accept(NodePass *p, bool *modified); + virtual Status Accept(IRNodePass *p, bool *modified); - /// \brief Base method for NodePass visit on the way back up the tree after its descendants are visited. + /// \brief Base method for IRNodePass visit on the way back up the tree after its descendants are visited. /// \notes Subclass needs to override this if it requires special node visit access. /// Check "dataset/engine/opt/pass.h" for more details. /// \param[in] p The node to visit /// \param[out] modified Indicator if the node was modified /// \return Status of the node visit - virtual Status AcceptAfter(NodePass *p, bool *modified); + virtual Status AcceptAfter(IRNodePass *p, bool *modified); virtual bool IsSizeDefined() { return true; } @@ -235,55 +257,22 @@ class DatasetNode : public std::enable_shared_from_this { std::string PrintColumns(const std::vector &columns) const; Status AddCacheOp(std::vector> *node_ops); void PrintNode(std::ostream &out, int *level) const; -}; - -// SourceNode represents the leaf nodes of a pipeline where the data is pulled into. -class SourceNode : public DatasetNode { - public: - /// \brief Constructor - SourceNode() : DatasetNode() {} - - /// \brief Constructor that initializes the cache - /// \param dataset_cache DatasetCache - explicit SourceNode(const std::shared_ptr &dataset_cache) : DatasetNode(dataset_cache) {} - - /// \brief Destructor - ~SourceNode() = default; - - /// \brief Node name getter - /// \return Name of the current node - virtual std::string Name() const = 0; - - /// \brief Base-class override for accepting NodePass visitor - /// \param[in] p The node to visit - /// \param[out] modified Indicator if the node was modified - /// \return Status of the node visit - Status Accept(NodePass *p, bool *modified) override; - - /// \brief Base-class override for accepting NodePass visitor - /// \param[in] p The node to visit - /// \param[out] modified Indicator if the node was modified - /// \return Status of the node visit - Status AcceptAfter(NodePass *p, bool *modified) override; - - /// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes - /// \return True if the dataset represented by this node is a mappable dataset - const bool IsMappable() const { return mappable_; } - - protected: bool mappable_; + bool descendant_of_cache_; }; // MappableSourceNode represents the leaf nodes that can be randomly accessed with indexes. -class MappableSourceNode : public SourceNode { +class MappableSourceNode : public DatasetNode { public: /// \brief Constructor - MappableSourceNode() : SourceNode() { mappable_ = true; } + MappableSourceNode() : DatasetNode() { mappable_ = true; } /// \brief Constructor that initializes the cache /// \param dataset_cache DatasetCache - explicit MappableSourceNode(const std::shared_ptr &dataset_cache) : SourceNode(dataset_cache) { + explicit MappableSourceNode(const std::shared_ptr &dataset_cache) : DatasetNode(dataset_cache) { mappable_ = true; + // Initially set to false, and set to true by the optimizer when conditions are met. + descendant_of_cache_ = false; } /// \brief Destructor @@ -295,15 +284,17 @@ class MappableSourceNode : public SourceNode { }; // NonMappableSourceNode represents the leaf nodes that can not be randomly accessed. -class NonMappableSourceNode : public SourceNode { +class NonMappableSourceNode : public DatasetNode { public: /// \brief Constructor - NonMappableSourceNode() : SourceNode() { mappable_ = false; } + NonMappableSourceNode() : DatasetNode() { mappable_ = false; } /// \brief Constructor that initializes the cache /// \param dataset_cache DatasetCache - explicit NonMappableSourceNode(const std::shared_ptr &dataset_cache) : SourceNode(dataset_cache) { + explicit NonMappableSourceNode(const std::shared_ptr &dataset_cache) : DatasetNode(dataset_cache) { mappable_ = false; + // Initially set to false, and set to true by the optimizer when conditions are met. + descendant_of_cache_ = false; } /// \brief Destructor @@ -313,34 +304,6 @@ class NonMappableSourceNode : public SourceNode { /// \return Name of the current node virtual std::string Name() const = 0; }; - -// NonLeafNode represents operations over data in a pipeline. -class NonLeafNode : public DatasetNode { - public: - /// \brief Constructor - NonLeafNode() = default; - - /// \brief Destructor - ~NonLeafNode() = default; - - /// \brief Node name getter - /// \return Name of the current node - virtual std::string Name() const = 0; -}; - -// SinkNode represents the end node of a pipeline where the data is pushed out -class SinkNode : public DatasetNode { - public: - /// \brief Constructor - SinkNode() = default; - - /// \brief Destructor - ~SinkNode() = default; - - /// \brief Node name getter - /// \return Name of the current node - virtual std::string Name() const = 0; -}; } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.cc index adddad477e9..2e3f473ceae 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.cc @@ -32,8 +32,9 @@ EpochCtrlNode::EpochCtrlNode(std::shared_ptr child, int32_t num_epo // The root node's parent must set to null pointer. this->AddChild(child); } + std::shared_ptr EpochCtrlNode::Copy() { - auto node = std::make_shared(nullptr, this->num_epochs_); + auto node = std::make_shared(num_epochs_); return node; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h index e0250e668df..046561fb12c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h @@ -29,7 +29,10 @@ namespace dataset { class EpochCtrlNode : public DatasetNode { public: /// \brief Constructor - explicit EpochCtrlNode(std::shared_ptr child, int32_t num_epochs); + explicit EpochCtrlNode(int32_t num_epochs) : num_epochs_(num_epochs) {} + + /// \brief Constructor + EpochCtrlNode(std::shared_ptr child, int32_t num_epochs); /// \brief Destructor ~EpochCtrlNode() = default; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc index 8bc967fa451..449371798a8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc @@ -60,14 +60,14 @@ Status FilterNode::ValidateParams() { return Status::OK(); } -// Visitor accepting method for NodePass -Status FilterNode::Accept(NodePass *p, bool *modified) { +// Visitor accepting method for IRNodePass +Status FilterNode::Accept(IRNodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->Visit(shared_from_base(), modified); } -// Visitor accepting method for NodePass -Status FilterNode::AcceptAfter(NodePass *p, bool *modified) { +// Visitor accepting method for IRNodePass +Status FilterNode::AcceptAfter(IRNodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->VisitAfter(shared_from_base(), modified); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.h index aa04ef7ed19..88cf0323f2f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.h @@ -58,17 +58,17 @@ class FilterNode : public DatasetNode { bool IsSizeDefined() override { return false; }; - /// \brief Base-class override for accepting NodePass visitor + /// \brief Base-class override for accepting IRNodePass visitor /// \param[in] p The node to visit /// \param[out] modified Indicator if the node was modified /// \return Status of the node visit - Status Accept(NodePass *p, bool *modified) override; + Status Accept(IRNodePass *p, bool *modified) override; - /// \brief Base-class override for accepting NodePass visitor + /// \brief Base-class override for accepting IRNodePass visitor /// \param[in] p The node to visit /// \param[out] modified Indicator if the node was modified /// \return Status of the node visit - Status AcceptAfter(NodePass *p, bool *modified) override; + Status AcceptAfter(IRNodePass *p, bool *modified) override; private: std::shared_ptr predicate_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc index 18452e880a6..71532b4590b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc @@ -42,14 +42,16 @@ MapNode::MapNode(std::shared_ptr child, std::vector MapNode::Copy() { - auto node = std::make_shared(nullptr, operations_, input_columns_, output_columns_, project_columns_, cache_, + std::vector> operations = operations_; + auto node = std::make_shared(nullptr, operations, input_columns_, output_columns_, project_columns_, cache_, callbacks_); return node; } void MapNode::Print(std::ostream &out) const { out << Name() + "(" + ",input:" + PrintColumns(input_columns_) + ",output:" + PrintColumns(output_columns_) + - "," + ",...)"; + "," + ",num_tensor_ops:" + << operations_.size() << ",...)"; } Status MapNode::Build(std::vector> *node_ops) { @@ -101,14 +103,14 @@ Status MapNode::ValidateParams() { return Status::OK(); } -// Visitor accepting method for NodePass -Status MapNode::Accept(NodePass *p, bool *modified) { +// Visitor accepting method for IRNodePass +Status MapNode::Accept(IRNodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->Visit(shared_from_base(), modified); } -// Visitor accepting method for NodePass -Status MapNode::AcceptAfter(NodePass *p, bool *modified) { +// Visitor accepting method for IRNodePass +Status MapNode::AcceptAfter(IRNodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->VisitAfter(shared_from_base(), modified); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h index 07f2588dde9..9a23d583705 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h @@ -63,17 +63,17 @@ class MapNode : public DatasetNode { const auto &TensorOperations() const { return operations_; } auto &TensorOperations() { return operations_; } - /// \brief Base-class override for accepting NodePass visitor + /// \brief Base-class override for accepting IRNodePass visitor /// \param[in] p The node to visit /// \param[out] modified Indicator if the node was modified /// \return Status of the node visit - Status Accept(NodePass *p, bool *modified) override; + Status Accept(IRNodePass *p, bool *modified) override; - /// \brief Base-class override for accepting NodePass visitor + /// \brief Base-class override for accepting IRNodePass visitor /// \param[in] p The node to visit /// \param[out] modified Indicator if the node was modified /// \return Status of the node visit - Status AcceptAfter(NodePass *p, bool *modified) override; + Status AcceptAfter(IRNodePass *p, bool *modified) override; private: std::vector> operations_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc index e5cfb662681..8da4ee63921 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc @@ -70,14 +70,14 @@ Status RepeatNode::GetDatasetSize(const std::shared_ptr &size return Status::OK(); } -// Visitor accepting method for NodePass -Status RepeatNode::Accept(NodePass *p, bool *modified) { +// Visitor accepting method for IRNodePass +Status RepeatNode::Accept(IRNodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->Visit(shared_from_base(), modified); } -// Visitor accepting method for NodePass -Status RepeatNode::AcceptAfter(NodePass *p, bool *modified) { +// Visitor accepting method for IRNodePass +Status RepeatNode::AcceptAfter(IRNodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->VisitAfter(shared_from_base(), modified); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h index b212fbb834a..fcb9790cf48 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h @@ -66,17 +66,17 @@ class RepeatNode : public DatasetNode { Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, int64_t *dataset_size) override; - /// \brief Base-class override for accepting NodePass visitor + /// \brief Base-class override for accepting IRNodePass visitor /// \param[in] p The node to visit /// \param[out] modified Indicator if the node was modified /// \return Status of the node visit - Status Accept(NodePass *p, bool *modified) override; + Status Accept(IRNodePass *p, bool *modified) override; - /// \brief Base-class override for accepting NodePass visitor + /// \brief Base-class override for accepting IRNodePass visitor /// \param[in] p The node to visit /// \param[out] modified Indicator if the node was modified /// \return Status of the node visit - Status AcceptAfter(NodePass *p, bool *modified) override; + Status AcceptAfter(IRNodePass *p, bool *modified) override; private: int32_t repeat_count_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.cc index 2d90d93cf85..55ed794e10a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.cc @@ -72,14 +72,14 @@ Status RootNode::ValidateParams() { return Status::OK(); } -// Visitor accepting method for NodePass -Status RootNode::Accept(NodePass *p, bool *modified) { +// Visitor accepting method for IRNodePass +Status RootNode::Accept(IRNodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->Visit(shared_from_base(), modified); } -// Visitor accepting method for NodePass -Status RootNode::AcceptAfter(NodePass *p, bool *modified) { +// Visitor accepting method for IRNodePass +Status RootNode::AcceptAfter(IRNodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->VisitAfter(shared_from_base(), modified); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.h index b6554e88585..9dbdfecf820 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.h @@ -58,17 +58,17 @@ class RootNode : public DatasetNode { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; - /// \brief Base-class override for accepting NodePass visitor + /// \brief Base-class override for accepting IRNodePass visitor /// \param[in] p The node to visit /// \param[out] modified Indicator if the node was modified /// \return Status of the node visit - Status Accept(NodePass *p, bool *modified) override; + Status Accept(IRNodePass *p, bool *modified) override; - /// \brief Base-class override for accepting NodePass visitor + /// \brief Base-class override for accepting IRNodePass visitor /// \param[in] p The node to visit /// \param[out] modified Indicator if the node was modified /// \return Status of the node visit - Status AcceptAfter(NodePass *p, bool *modified) override; + Status AcceptAfter(IRNodePass *p, bool *modified) override; private: int32_t num_epochs_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc index 432737f8145..32100c3ca1f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc @@ -21,6 +21,7 @@ #include #include "minddata/dataset/engine/datasetops/skip_op.h" +#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/util/status.h" namespace mindspore { @@ -70,5 +71,16 @@ Status SkipNode::GetDatasetSize(const std::shared_ptr &size_g return Status::OK(); } +// Visitor accepting method for IRNodePass +Status SkipNode::Accept(IRNodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->Visit(shared_from_base(), modified); +} + +// Visitor accepting method for IRNodePass +Status SkipNode::AcceptAfter(IRNodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->VisitAfter(shared_from_base(), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h index d7cd434e129..d9203fe800b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h @@ -64,6 +64,18 @@ class SkipNode : public DatasetNode { Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, int64_t *dataset_size) override; + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(IRNodePass *p, bool *modified) override; + + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status AcceptAfter(IRNodePass *p, bool *modified) override; + private: int32_t skip_count_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc index e2efafc5298..be8ee972861 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc @@ -40,7 +40,7 @@ AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_sch sampler_(sampler) {} std::shared_ptr AlbumNode::Copy() { - std::shared_ptr sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); auto node = std::make_shared(dataset_dir_, schema_path_, column_names_, decode_, sampler, cache_); return node; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc index 8cda9a8b002..8a099de335d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc @@ -40,7 +40,7 @@ CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, extensions_(extensions) {} std::shared_ptr CelebANode::Copy() { - std::shared_ptr sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); auto node = std::make_shared(dataset_dir_, usage_, sampler, decode_, extensions_, cache_); return node; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc index 6dd41713218..c901e3a8320 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc @@ -33,7 +33,7 @@ Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &us : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} std::shared_ptr Cifar100Node::Copy() { - std::shared_ptr sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); auto node = std::make_shared(dataset_dir_, usage_, sampler, cache_); return node; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc index 9498e26e685..1e4f45dd42a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc @@ -33,7 +33,7 @@ Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usag : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} std::shared_ptr Cifar10Node::Copy() { - std::shared_ptr sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); auto node = std::make_shared(dataset_dir_, usage_, sampler, cache_); return node; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc index 9ef925609b1..6f956fe960e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc @@ -208,7 +208,7 @@ Status CLUENode::Build(std::vector> *node_ops) { RETURN_IF_NOT_OK(clue_op->Init()); - if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { + if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) { // Inject ShuffleOp std::shared_ptr shuffle_op = nullptr; int64_t num_rows = 0; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc index 1c6443ab5b2..49b9c6fea76 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc @@ -38,7 +38,7 @@ CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation sampler_(sampler) {} std::shared_ptr CocoNode::Copy() { - std::shared_ptr sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); auto node = std::make_shared(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_); return node; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc index 5e3d8aa0116..3826356c63c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc @@ -119,7 +119,7 @@ Status CSVNode::Build(std::vector> *node_ops) { RETURN_IF_NOT_OK(csv_op->Init()); - if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { + if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) { // Inject ShuffleOp std::shared_ptr shuffle_op = nullptr; int64_t num_rows = 0; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc index 7172a126a23..9c8a7ad86aa 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc @@ -33,8 +33,16 @@ GeneratorNode::GeneratorNode(py::function generator_function, const std::vector< column_names_(column_names), column_types_(column_types) {} +GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr &schema) + : generator_function_(generator_function), schema_(schema) {} + std::shared_ptr GeneratorNode::Copy() { - auto node = std::make_shared(generator_function_, column_names_, column_types_); + std::shared_ptr node; + if (schema_ == nullptr) { + node = std::make_shared(generator_function_, column_names_, column_types_); + } else { + node = std::make_shared(generator_function_, schema_); + } return node; } @@ -42,9 +50,6 @@ void GeneratorNode::Print(std::ostream &out) const { out << Name() + "(:" + ",columns:" + PrintColumns(column_names_) + ",)"; } -GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr &schema) - : generator_function_(generator_function), schema_(schema) {} - Status GeneratorNode::Build(std::vector> *node_ops) { std::unique_ptr data_schema = std::make_unique(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc index c1f8700f0a9..fa6e8287d3b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc @@ -42,7 +42,7 @@ ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shar exts_(extensions) {} std::shared_ptr ImageFolderNode::Copy() { - std::shared_ptr sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); auto node = std::make_shared(dataset_dir_, decode_, sampler, recursive_, exts_, class_indexing_, cache_); return node; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc index d296023bb7e..c63052a71fb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc @@ -40,7 +40,7 @@ ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &u sampler_(sampler) {} std::shared_ptr ManifestNode::Copy() { - std::shared_ptr sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); auto node = std::make_shared(dataset_file_, usage_, sampler, class_index_, decode_, cache_); return node; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc index 6c47d3af6d5..96ffac63893 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc @@ -54,12 +54,13 @@ MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector MindDataNode::Copy() { std::shared_ptr node; - std::shared_ptr sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); if (dataset_files_.empty()) { node = std::make_shared(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_); } else { node = std::make_shared(dataset_files_, columns_list_, sampler, padded_sample_, num_padded_); } + node->SetSampleBytes(&sample_bytes_); return node; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc index 343102827f8..7766561df22 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc @@ -32,7 +32,7 @@ MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} std::shared_ptr MnistNode::Copy() { - std::shared_ptr sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); auto node = std::make_shared(dataset_dir_, usage_, sampler, cache_); return node; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc index a97b26380d7..4c83c4880cb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc @@ -86,7 +86,7 @@ Status TextFileNode::Build(std::vector> *node_ops) { connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build())); RETURN_IF_NOT_OK(text_file_op->Init()); - if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { + if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) { // Inject ShuffleOp std::shared_ptr shuffle_op = nullptr; int64_t num_rows = 0; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc index 094b54cb28d..66558bb2850 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc @@ -134,7 +134,7 @@ Status TFRecordNode::Build(std::vector> *node_ops) { RETURN_IF_NOT_OK(tf_reader_op->Init()); - if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { + if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) { // Inject ShuffleOp std::shared_ptr shuffle_op = nullptr; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc index 2bbacd243a3..7a8fd8f2bf4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc @@ -41,7 +41,7 @@ VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const sampler_(sampler) {} std::shared_ptr VOCNode::Copy() { - std::shared_ptr sampler = sampler_ == nullptr ? nullptr : sampler_->Copy(); + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); auto node = std::make_shared(dataset_dir_, task_, usage_, class_index_, decode_, sampler, cache_); return node; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc index 7196ec87e3e..5d6e12bd181 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc @@ -22,6 +22,7 @@ #include #include "minddata/dataset/engine/datasetops/take_op.h" +#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/util/status.h" namespace mindspore { @@ -68,5 +69,16 @@ Status TakeNode::GetDatasetSize(const std::shared_ptr &size_g return Status::OK(); } +// Visitor accepting method for IRNodePass +Status TakeNode::Accept(IRNodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->Visit(shared_from_base(), modified); +} + +// Visitor accepting method for IRNodePass +Status TakeNode::AcceptAfter(IRNodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->VisitAfter(shared_from_base(), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h index 5136185c0a1..0fffe514ca4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h @@ -64,6 +64,18 @@ class TakeNode : public DatasetNode { Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, int64_t *dataset_size) override; + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(IRNodePass *p, bool *modified) override; + + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status AcceptAfter(IRNodePass *p, bool *modified) override; + private: int32_t take_count_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc index e66202b4232..1003f7de138 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc @@ -104,14 +104,14 @@ Status TransferNode::Build(std::vector> *node_ops) { return Status::OK(); } -// Visitor accepting method for NodePass -Status TransferNode::Accept(NodePass *p, bool *modified) { +// Visitor accepting method for IRNodePass +Status TransferNode::Accept(IRNodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->Visit(shared_from_base(), modified); } -// Visitor accepting method for NodePass -Status TransferNode::AcceptAfter(NodePass *p, bool *modified) { +// Visitor accepting method for IRNodePass +Status TransferNode::AcceptAfter(IRNodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->VisitAfter(shared_from_base(), modified); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h index f43e33abda3..0fac8bcf1a5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h @@ -58,17 +58,17 @@ class TransferNode : public DatasetNode { static Status get_distribution(std::shared_ptr ds, int32_t *device_id); - /// \brief Base-class override for accepting NodePass visitor + /// \brief Base-class override for accepting IRNodePass visitor /// \param[in] p The node to visit /// \param[out] modified Indicator if the node was modified /// \return Status of the node visit - Status Accept(NodePass *p, bool *modified) override; + Status Accept(IRNodePass *p, bool *modified) override; - /// \brief Base-class override for accepting NodePass visitor + /// \brief Base-class override for accepting IRNodePass visitor /// \param[in] p The node to visit /// \param[out] modified Indicator if the node was modified /// \return Status of the node visit - Status AcceptAfter(NodePass *p, bool *modified) override; + Status AcceptAfter(IRNodePass *p, bool *modified) override; private: std::string queue_name_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc index 082116dca48..03f89a4b446 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.cc @@ -79,14 +79,14 @@ Status ZipNode::GetDatasetSize(const std::shared_ptr &size_ge return Status::OK(); } -// Visitor accepting method for NodePass -Status ZipNode::Accept(NodePass *p, bool *modified) { +// Visitor accepting method for IRNodePass +Status ZipNode::Accept(IRNodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->Visit(shared_from_base(), modified); } -// Visitor accepting method for NodePass -Status ZipNode::AcceptAfter(NodePass *p, bool *modified) { +// Visitor accepting method for IRNodePass +Status ZipNode::AcceptAfter(IRNodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->VisitAfter(shared_from_base(), modified); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.h index c346945aa72..a3e64a9401d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/zip_node.h @@ -64,19 +64,20 @@ class ZipNode : public DatasetNode { Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, int64_t *dataset_size) override; + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(IRNodePass *p, bool *modified) override; + + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status AcceptAfter(IRNodePass *p, bool *modified) override; + private: std::vector> datasets_; - /// \brief Base-class override for accepting NodePass visitor - /// \param[in] p The node to visit - /// \param[out] modified Indicator if the node was modified - /// \return Status of the node visit - Status Accept(NodePass *p, bool *modified) override; - - /// \brief Base-class override for accepting NodePass visitor - /// \param[in] p The node to visit - /// \param[out] modified Indicator if the node was modified - /// \return Status of the node visit - Status AcceptAfter(NodePass *p, bool *modified) override; }; } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt index 7ad4da248f8..44847ade49b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt @@ -6,9 +6,12 @@ add_library(engine-opt OBJECT post/repeat_pass.cc pre/cache_error_pass.cc pre/cache_transform_pass.cc + pre/cache_validation_pass.cc + pre/epoch_ctrl_pass.cc pre/epoch_injection_pass.cc pre/getter_pass.cc pre/input_validation_pass.cc + pre/node_removal_pass.cc pre/removal_pass.cc util/printer_pass.cc ) diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc index 3401bb64192..97f21f1358f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc @@ -87,7 +87,7 @@ namespace mindspore { namespace dataset { // Driver method for TreePass -Status TreePass::Run(std::shared_ptr root_ir, bool *modified) { +Status IRTreePass::Run(std::shared_ptr root_ir, bool *modified) { if (root_ir == nullptr || modified == nullptr) { return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass"); } @@ -95,7 +95,7 @@ Status TreePass::Run(std::shared_ptr root_ir, bool *modified) { } // Driver method for NodePass -Status NodePass::Run(std::shared_ptr root_ir, bool *modified) { +Status IRNodePass::Run(std::shared_ptr root_ir, bool *modified) { if (root_ir == nullptr || modified == nullptr) { return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass"); } @@ -110,7 +110,7 @@ Status NodePass::Run(std::shared_ptr root_ir, bool *modified) { } // Helper function to perform DFS visit -Status NodePass::DFSNodeVisit(std::shared_ptr node_ir, bool *modified) { +Status IRNodePass::DFSNodeVisit(std::shared_ptr node_ir, bool *modified) { bool m = false; RETURN_IF_NOT_OK(node_ir->Accept(this, &m)); @@ -125,7 +125,7 @@ Status NodePass::DFSNodeVisit(std::shared_ptr node_ir, bool *modifi } // Helper function to perform BFS visit -Status NodePass::BFSNodeVisit(std::shared_ptr node_ir, bool *modified) { +Status IRNodePass::BFSNodeVisit(std::shared_ptr node_ir, bool *modified) { bool m = false; // Initialize bfs queue with root @@ -151,121 +151,113 @@ Status NodePass::BFSNodeVisit(std::shared_ptr node_ir, bool *modifi } // For non-leaf IR node -Status NodePass::Visit(std::shared_ptr node, bool *modified) { +Status IRNodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } -Status NodePass::Visit(std::shared_ptr node, bool *modified) { +Status IRNodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } -Status NodePass::Visit(std::shared_ptr node, bool *modified) { +Status IRNodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } -Status NodePass::Visit(std::shared_ptr node, bool *modified) { +Status IRNodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } -Status NodePass::Visit(std::shared_ptr node, bool *modified) { +Status IRNodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } -Status NodePass::Visit(std::shared_ptr node, bool *modified) { +Status IRNodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } -Status NodePass::Visit(std::shared_ptr node, bool *modified) { +Status IRNodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } -Status NodePass::Visit(std::shared_ptr node, bool *modified) { +Status IRNodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } -Status NodePass::Visit(std::shared_ptr node, bool *modified) { +Status IRNodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } -Status NodePass::Visit(std::shared_ptr node, bool *modified) { +Status IRNodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } -Status NodePass::Visit(std::shared_ptr node, bool *modified) { +Status IRNodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } -Status NodePass::Visit(std::shared_ptr node, bool *modified) { +Status IRNodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } -Status NodePass::Visit(std::shared_ptr node, bool *modified) { +Status IRNodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } -Status NodePass::Visit(std::shared_ptr node, bool *modified) { +Status IRNodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } -Status NodePass::Visit(std::shared_ptr node, bool *modified) { +Status IRNodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } #ifdef ENABLE_PYTHON -Status NodePass::Visit(std::shared_ptr node, bool *modified) { +Status IRNodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } #endif #ifndef ENABLE_ANDROID -Status NodePass::Visit(std::shared_ptr node, bool *modified) { +Status IRNodePass::Visit(std::shared_ptr node, bool *modified) { return Visit(std::static_pointer_cast(node), modified); } -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *modified) { return VisitAfter(std::static_pointer_cast(node), modified); } #endif -// For leaf IR Node -Status NodePass::Visit(std::shared_ptr node, bool *modified) { - return Visit(std::static_pointer_cast(node), modified); -} -Status NodePass::VisitAfter(std::shared_ptr node, bool *modified) { - return VisitAfter(std::static_pointer_cast(node), modified); -} - ////////////////////////////////// // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. // Driver method for TreePass diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h index de1bea1cb14..b34c3f5736e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h @@ -113,26 +113,18 @@ class GeneratorOp; // The base class Pass is the basic unit of tree transformation. // The actual implementation of the passes will be derived from here. -class Pass : public std::enable_shared_from_this { +class IRPass : public std::enable_shared_from_this { public: // Run the transformation pass against the IR tree. // @param root_ir - Pointer to the IR tree to be transformed. // @param modified - Pointer to the modified flag, virtual Status Run(std::shared_ptr root_ir, bool *modified) = 0; - ////////////////////////////////// - // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. - // Run the transformation pass against the execution tree. - // @param tree - Pointer to the execution tree to be transformed. - // @param modified - Pointer to the modified flag, - virtual Status Run(ExecutionTree *tree, bool *modified) = 0; - ////////////////////////////////// - - virtual ~Pass() = default; + virtual ~IRPass() = default; }; -// TreePass is a basic Pass class which performs transformation on ExecutionTree directly. -class TreePass : public Pass { +// IRTreePass is a basic Pass class which performs transformation on IR tree directly. +class IRTreePass : public IRPass { public: /// \brief Run the transformation pass against the IR tree. /// \param[inout] root_ir Pointer to the IR tree to be transformed. @@ -145,44 +137,29 @@ class TreePass : public Pass { /// \param[inout] Indicate if the tree was modified. /// \return Status The error code return virtual Status RunOnTree(std::shared_ptr root_ir, bool *modified) { return Status::OK(); } - - ////////////////////////////////// - // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. - /// \brief Run the transformation pass against the execution tree. - /// \param[inout] tree Pointer to the execution tree to be transformed. - /// \param[inout] modified Indicate if the tree was modified - Status Run(ExecutionTree *tree, bool *modified) final; - - /// \brief Derived classes may implement the runOnTree function to implement tree transformation. - /// "modified" flag needs to be set to true if tree is modified during the pass execution. - /// \param[inout] tree The tree to operate on. - /// \param[inout] Indicate of the tree was modified. - /// \return Status The error code return - virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); } - ////////////////////////////////// }; -// NodePass is a base Pass class which performs transformation on node visiting. -// NodePass implements Visitor design pattern. +// IRNodePass is a base Pass class which performs transformation on node visiting. +// IRNodePass implements Visitor design pattern. // The visiting happens twice for each node in the DFS traversal, one on the way down of the traversal, // and the other when all the descending nodes are visited. -// Actual transformation is done by implementing a new derived class of NodePass. +// Actual transformation is done by implementing a new derived class of IRNodePass. // The derived class will implement the method Visit()/VisitAfter() passing specified node types -// it wants to action on them, overriding the ones defined in NodePass. +// it wants to action on them, overriding the ones defined in IRNodePass. // If the derived class wants to perform the same action on all node types, // it can simply implement the method Visit()/VisitAfter() passing the base class DatasetNode. // This is made possible by overloading the method Visit()/VisitAfter() on each node type to fall back -// to call the Visit()/VisitAfter() in this parent NodePass class. -class NodePass : public Pass { +// to call the Visit()/VisitAfter() in this parent IRNodePass class. +class IRNodePass : public IRPass { public: // Tree traversal order enum Order { DFS, BFS }; // Constructor // Default DFS traversal - explicit NodePass(Order order = Order::DFS) { traversalOrder_ = order; } + explicit IRNodePass(Order order = Order::DFS) { traversalOrder_ = order; } - ~NodePass() = default; + ~IRNodePass() = default; /// \brief Run the transformation pass against the IR tree /// \param[inout] root_ir Pointer to the IR tree to be transformed @@ -251,12 +228,70 @@ class NodePass : public Pass { virtual Status Visit(std::shared_ptr node, bool *modified); virtual Status VisitAfter(std::shared_ptr node, bool *modified); #endif - // Leaf IR node - virtual Status Visit(std::shared_ptr node, bool *modified); - virtual Status VisitAfter(std::shared_ptr node, bool *modified); - ////////////////////////////////// - // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. + private: + // Helper function to perform DFS visit + Status DFSNodeVisit(std::shared_ptr node_ir, bool *modified); + + // Helper function to perform BFS visit + Status BFSNodeVisit(std::shared_ptr node_ir, bool *modified); + + // Tree traversal order of the NodePass + Order traversalOrder_; +}; + +////////////////////////////////// +// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. +// The base class Pass is the basic unit of tree transformation. +// The actual implementation of the passes will be derived from here. +class Pass : public std::enable_shared_from_this { + public: + // Run the transformation pass against the execution tree. + // @param tree - Pointer to the execution tree to be transformed. + // @param modified - Pointer to the modified flag, + virtual Status Run(ExecutionTree *tree, bool *modified) = 0; + + virtual ~Pass() = default; +}; + +// TreePass is a basic Pass class which performs transformation on ExecutionTree directly. +class TreePass : public Pass { + public: + /// \brief Run the transformation pass against the execution tree. + /// \param[inout] tree Pointer to the execution tree to be transformed. + /// \param[inout] modified Indicate if the tree was modified + Status Run(ExecutionTree *tree, bool *modified) final; + + /// \brief Derived classes may implement the runOnTree function to implement tree transformation. + /// "modified" flag needs to be set to true if tree is modified during the pass execution. + /// \param[inout] tree The tree to operate on. + /// \param[inout] Indicate of the tree was modified. + /// \return Status The error code return + virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); } +}; + +// NodePass is a base Pass class which performs transformation on node visiting. +// NodePass implements Visitor design pattern. +// The visiting happens twice for each node in the DFS traversal, one on the way down of the traversal, +// and the other when all the descending nodes are visited. +// Actual transformation is done by implementing a new derived class of NodePass. +// The derived class will implement the method Visit()/VisitAfter() passing specified node types +// it wants to action on them, overriding the ones defined in NodePass. +// If the derived class wants to perform the same action on all node types, +// it can simply implement the method Visit()/VisitAfter() passing the base class DatasetNode. +// This is made possible by overloading the method Visit()/VisitAfter() on each node type to fall back +// to call the Visit()/VisitAfter() in this parent NodePass class. +class NodePass : public Pass { + public: + // Tree traversal order + enum Order { DFS, BFS }; + + // Constructor + // Default DFS traversal + explicit NodePass(Order order = Order::DFS) { traversalOrder_ = order; } + + ~NodePass() = default; + /// \brief Run the transformation pass against the execution tree /// \param[inout] tree Pointer to the execution tree to be transformed /// \param[inout] modified Indicator if the tree was changed @@ -326,27 +361,18 @@ class NodePass : public Pass { virtual Status RunOnNode(std::shared_ptr node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); #endif - ////////////////////////////////// private: - // Helper function to perform DFS visit - Status DFSNodeVisit(std::shared_ptr node_ir, bool *modified); - - // Helper function to perform BFS visit - Status BFSNodeVisit(std::shared_ptr node_ir, bool *modified); - - ////////////////////////////////// - // This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done. // Helper function to perform DFS visit Status DFSNodeVisit(std::shared_ptr node, bool *modified); // Helper function to perform BFS visit Status BFSNodeVisit(std::shared_ptr root, bool *modified); - ////////////////////////////////// // Tree traversal order of the NodePass Order traversalOrder_; }; +////////////////////////////////// } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_validation_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_validation_pass.cc new file mode 100644 index 00000000000..cd41ec918c4 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_validation_pass.cc @@ -0,0 +1,163 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "minddata/dataset/engine/opt/pre/cache_validation_pass.h" + +#include "minddata/dataset/engine/ir/datasetops/batch_node.h" +#include "minddata/dataset/engine/ir/datasetops/concat_node.h" +#include "minddata/dataset/engine/ir/datasetops/filter_node.h" +#include "minddata/dataset/engine/ir/datasetops/map_node.h" +#include "minddata/dataset/engine/ir/datasetops/repeat_node.h" +#include "minddata/dataset/engine/ir/datasetops/skip_node.h" +#include "minddata/dataset/engine/ir/datasetops/take_node.h" +#include "minddata/dataset/engine/ir/datasetops/zip_node.h" +#include "minddata/dataset/include/transforms.h" + +namespace mindspore { +namespace dataset { + +// Constructor +CacheValidationPass::CacheValidationPass() : is_cached_(false), is_mappable_(false) {} + +// Returns an error if BatchNode exists under a cache +Status CacheValidationPass::Visit(std::shared_ptr node, bool *modified) { + MS_LOG(DEBUG) << "CacheValidationPass::Visit(): visiting " << node->Name() << "."; + if (is_cached_) { + RETURN_STATUS_UNEXPECTED("BatchNode is not supported as a descendant operator under a cache."); + } + if (node->IsCached()) { + RETURN_STATUS_UNEXPECTED("BatchNode cannot be cached."); + } + return Status::OK(); +} + +// Returns an error if ConcatNode exists under a cache +Status CacheValidationPass::Visit(std::shared_ptr node, bool *modified) { + MS_LOG(DEBUG) << "CacheValidationPass::Visit(): visiting " << node->Name() << "."; + if (is_cached_) { + RETURN_STATUS_UNEXPECTED("ConcatNode is not supported as a descendant operator under a cache."); + } + if (node->IsCached()) { + RETURN_STATUS_UNEXPECTED("ConcatNode cannot be cached."); + } + return Status::OK(); +} + +// Returns an error if FilterNode exists under a cache +Status CacheValidationPass::Visit(std::shared_ptr node, bool *modified) { + MS_LOG(DEBUG) << "CacheValidationPass::Visit(): visiting " << node->Name() << "."; + if (is_cached_) { + RETURN_STATUS_UNEXPECTED("FilterNode is not supported as a descendant operator under a cache."); + } + if (node->IsCached()) { + RETURN_STATUS_UNEXPECTED("FilterNode cannot be cached."); + } + return Status::OK(); +} + +// Returns an error if SkipNode exists under a cache +Status CacheValidationPass::Visit(std::shared_ptr node, bool *modified) { + MS_LOG(DEBUG) << "CacheValidationPass::Visit(): visiting " << node->Name() << "."; + if (is_cached_) { + RETURN_STATUS_UNEXPECTED("SkipNode is not supported as a descendant operator under a cache."); + } + if (node->IsCached()) { + RETURN_STATUS_UNEXPECTED("SkipNode cannot be cached."); + } + return Status::OK(); +} + +// Returns an error if TakeNode exists under a cache +Status CacheValidationPass::Visit(std::shared_ptr node, bool *modified) { + MS_LOG(DEBUG) << "CacheValidationPass::Visit(): visiting " << node->Name() << "."; + if (is_cached_) { + RETURN_STATUS_UNEXPECTED("TakeNode (possibly from Split) is not supported as a descendant operator under a cache."); + } + if (node->IsCached()) { + RETURN_STATUS_UNEXPECTED("TakeNode cannot be cached."); + } + return Status::OK(); +} + +// Returns an error if ZipNode exists under a cache +Status CacheValidationPass::Visit(std::shared_ptr node, bool *modified) { + MS_LOG(DEBUG) << "CacheValidationPass::Visit(): visiting " << node->Name() << "."; + if (is_cached_) { + RETURN_STATUS_UNEXPECTED("ZipNode is not supported as a descendant operator under a cache."); + } + if (node->IsCached()) { + RETURN_STATUS_UNEXPECTED("ZipNode cannot be cached."); + } + return Status::OK(); +} + +// Returns an error if MapNode with non-deterministic tensor operations exists under a cache +Status CacheValidationPass::Visit(std::shared_ptr node, bool *modified) { + MS_LOG(DEBUG) << "CacheValidationPass::Visit(): visiting " << node->Name() << "."; + if (node->IsCached()) { + if (is_cached_) { + RETURN_STATUS_UNEXPECTED("Nested cache operations over MapNode is not supported."); + } + // If Map is created to be cached, set the flag indicating we found an operation with a cache. + is_cached_ = true; + auto tfuncs = node->TensorOperations(); + for (size_t i = 0; i < tfuncs.size(); i++) { + if (tfuncs[i]->IsRandomOp()) { + RETURN_STATUS_UNEXPECTED( + "MapNode with non-deterministic operations is not supported as a descendant of cache."); + } + } + } + return Status::OK(); +} + +// Flag an error if we have a cache over another cache +Status CacheValidationPass::Visit(std::shared_ptr node, bool *modified) { + MS_LOG(DEBUG) << "CacheValidationPass::Visit(): visiting " << node->Name() << "."; + if (node->IsCached()) { + if (is_cached_) { + RETURN_STATUS_UNEXPECTED("Nested cache operations over " + node->Name() + " is not supported."); + } + // If this node is created to be cached, set the flag. + is_cached_ = true; + } + if (node->IsLeaf() && node->IsMappable()) { + is_mappable_ = true; + } + return Status::OK(); +} + +// Returns an error if MappableSource <- Repeat <- Node with a cache +// Because there is no operator in the cache hit stream to consume EoEs, caching above repeat causes problem. +Status CacheValidationPass::VisitAfter(std::shared_ptr node, bool *modified) { + MS_LOG(DEBUG) << "CacheValidationPass::VisitAfter(): visiting " << node->Name() << "."; + if (is_cached_ && is_mappable_) { + RETURN_STATUS_UNEXPECTED("A cache over a RepeatNode of a mappable dataset is not supported."); + } + return Status::OK(); +} + +Status CacheValidationPass::VisitAfter(std::shared_ptr node, bool *modified) { + MS_LOG(DEBUG) << "CacheValidationPass::VisitAfter(): visiting " << node->Name() << "."; + // Reset the flag when all descendants are visited + if (node->IsCached()) { + is_cached_ = false; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_validation_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_validation_pass.h new file mode 100644 index 00000000000..d7f1f930fc4 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_validation_pass.h @@ -0,0 +1,105 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_VALIDATION_PASS_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_VALIDATION_PASS_ + +#include +#include +#include +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +/// \class CacheValidationPass cache_validation_pass.h +/// \brief This is a NodePass who's job is to catch invalid tree configurations related to cache and generate failures. +class CacheValidationPass : public IRNodePass { + public: + /// \brief Constructor + CacheValidationPass(); + + /// \brief Destructor + ~CacheValidationPass() = default; + + /// \brief Returns an error if BatchNode exists under a cache + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status Visit(std::shared_ptr node, bool *modified) override; + + /// \brief Returns an error if ConcatNode exists under a cache + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status Visit(std::shared_ptr node, bool *modified) override; + + /// \brief Returns an error if FilterNode exists under a cache + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status Visit(std::shared_ptr node, bool *modified) override; + + /// \brief Returns an error if SkipNode exists under a cache + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status Visit(std::shared_ptr node, bool *modified) override; + + /// \brief Returns an error if TakeNode exists under a cache + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status Visit(std::shared_ptr node, bool *modified) override; + + /// \brief Returns an error if ZipNode exists under a cache + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status Visit(std::shared_ptr node, bool *modified) override; + + /// \brief Returns an error if MapNode with non-deterministic tensor operations exists under a cache + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status Visit(std::shared_ptr node, bool *modified) override; + + /// \brief Returns an error if there is a cache over another cache + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status Visit(std::shared_ptr node, bool *modified) override; + + /// \brief Identifies and block repeat under cache scenarios + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status VisitAfter(std::shared_ptr node, bool *modified) override; + + /// \brief Identifies the subtree above this node as not being cached + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status VisitAfter(std::shared_ptr node, bool *modified) override; + + private: + bool is_cached_; + bool is_mappable_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_VALIDATION_PASS_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_ctrl_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_ctrl_pass.cc new file mode 100644 index 00000000000..cbfe087bddd --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_ctrl_pass.cc @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h" +#include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h" +#include "minddata/dataset/engine/ir/datasetops/root_node.h" +#include "minddata/dataset/engine/ir/datasetops/transfer_node.h" + +namespace mindspore { +namespace dataset { + +// constructor +EpochCtrlPass::InjectionFinder::InjectionFinder(std::shared_ptr node) + : injection_point_(nullptr), num_epochs_(-1) {} + +// Performs finder work for BuildVocabOp that has special rules about epoch control injection +Status EpochCtrlPass::InjectionFinder::Visit(std::shared_ptr node, bool *modified) { + // The injection is at the child of the root node + injection_point_ = node; + num_epochs_ = node->num_epochs(); + return Status::OK(); +} + +// Performs finder work for BuildVocabOp that has special rules about epoch control injection +Status EpochCtrlPass::InjectionFinder::Visit(std::shared_ptr node, bool *modified) { + injection_point_ = nullptr; + return Status::OK(); +} + +#ifndef ENABLE_ANDROID +// Performs finder work for BuildSentencePieceVocabNode that has special rules about epoch control injection +Status EpochCtrlPass::InjectionFinder::Visit(std::shared_ptr node, bool *modified) { + injection_point_ = nullptr; + return Status::OK(); +} +#endif + +Status EpochCtrlPass::InjectionFinder::VisitAfter(std::shared_ptr node, bool *modified) { + // Assumption: There is only one TransferNode in a pipeline. This assumption is not validated here. + // Move the injection point to the child of this node. + injection_point_ = node; + return Status::OK(); +} + +// constructor +EpochCtrlPass::EpochCtrlPass() {} + +// Runs an injection pass to inject in operators needed at the pre pass stage +Status EpochCtrlPass::RunOnTree(std::shared_ptr root_ir, bool *modified) { + MS_LOG(INFO) << "Pre pass: Injection pass started."; + + // First, run the finder to perform any injection info before we can go ahead to drive the op injection work. + // The finder can make updates to the EpochInjectionPass object. + EpochCtrlPass::InjectionFinder finder(root_ir); + RETURN_IF_NOT_OK(finder.Run(root_ir, modified)); + + // The first injection logic is to check if we should inject the epoch control op as the root node. + // Do not inject the op if the number of epochs is 1. + std::shared_ptr parent = finder.injection_point(); + int32_t num_epochs = finder.num_epochs(); + if (num_epochs != 1 && parent != nullptr) { + CHECK_FAIL_RETURN_UNEXPECTED(parent->Children().size() == 1, "EpochCtrl must be injected on only one child."); + auto epoch_ctrl_node = std::make_shared(num_epochs); + RETURN_IF_NOT_OK(parent->InsertBelow(epoch_ctrl_node)); + } + MS_LOG(INFO) << "Pre pass: Injection pass complete."; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h new file mode 100644 index 00000000000..fbc4a1d547b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ +#define DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ + +#include +#include +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +class DatasetOp; + +/// \class EpochInjectionPass epoch_ctrl_pass.h +/// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api +/// parsing. +class EpochCtrlPass : public IRTreePass { + /// \class InjectionFinder + /// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for + /// operators that need to be injected. It is run first by the main injection pass to find out what operators + /// it may need to inject. + class InjectionFinder : public IRNodePass { + public: + /// \brief Constructor + explicit InjectionFinder(std::shared_ptr node); + + /// \brief Destructor + ~InjectionFinder() = default; + + /// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status Visit(std::shared_ptr node, bool *modified) override; + + /// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status Visit(std::shared_ptr node, bool *modified) override; + +#ifndef ENABLE_ANDROID + /// \brief Performs finder work for BuildSentenceVocabNode that has special rules about epoch control injection. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status Visit(std::shared_ptr node, bool *modified) override; +#endif + + /// \brief Register the TransferNode for further action. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status VisitAfter(std::shared_ptr node, bool *modified) override; + + /// \brief Getter + std::shared_ptr injection_point() { return injection_point_; } + + /// \brief Getter + int32_t num_epochs() { return num_epochs_; } + + private: + std::shared_ptr injection_point_; + int32_t num_epochs_; + }; + + public: + /// \brief Constructor + EpochCtrlPass(); + + /// \brief Destructor + ~EpochCtrlPass() = default; + + /// \brief Runs an injection pass to inject in operators needed at the pre pass stage + /// \param[inout] tree The tree to operate on. + /// \param[inout] Indicate of the tree was modified. + /// \return Status The error code return + Status RunOnTree(std::shared_ptr root_ir, bool *modified) override; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/input_validation_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/input_validation_pass.h index b6a18aa7917..304d8269e7c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/input_validation_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/input_validation_pass.h @@ -26,7 +26,7 @@ namespace dataset { /// \class InputValidationPass /// \brief This is a parse pass that validates input parameters of the IR tree. -class InputValidationPass : public NodePass { +class InputValidationPass : public IRNodePass { /// \brief Runs a validatation pass to check input parameters /// \param[in] node The node being visited /// \param[inout] *modified indicates whether the node has been visited diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.cc new file mode 100644 index 00000000000..6aec472188d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "minddata/dataset/engine/opt/pre/node_removal_pass.h" +#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" + +namespace mindspore { +namespace dataset { + +NodeRemovalPass::RemovalNodes::RemovalNodes() : is_caching_(false) {} + +// Identifies the subtree below this node as a cached descendant tree. +Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr node, bool *modified) { + *modified = false; + MS_LOG(INFO) << "Node removal pass: Operation with cache found, identified descendant tree."; + if (node->IsCached()) { + is_caching_ = true; + } + return Status::OK(); +} + +// Resets the tracking of the cache within the tree +Status NodeRemovalPass::RemovalNodes::VisitAfter(std::shared_ptr node, bool *modified) { + *modified = false; + MS_LOG(INFO) << "Removal pass: Descendant walk is complete."; + if (is_caching_ && node->IsLeaf()) { + // Mark this leaf node to indicate it is a descendant of an operator with cache. + // This is currently used in non-mappable dataset (leaf) nodes to not add a ShuffleOp in DatasetNode::Build(). + node->HasCacheAbove(); + } + is_caching_ = false; + return Status::OK(); +} + +// Perform ShuffleOp removal check. +Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr node, bool *modified) { + *modified = false; +#if 0 + // If we are in a cache descendant tree, then this shuffle op needs to be removed + if (is_caching_) { + MS_LOG(INFO) << "Shuffle under an operation with cache is identified for removal."; + nodes_to_remove_.push_back(std::static_pointer_cast(node)); + } +#endif + return Status::OK(); +} + +// constructor +NodeRemovalPass::NodeRemovalPass() {} + +// Walk the tree to collect the nodes to remove, then removes them. +Status NodeRemovalPass::RunOnTree(std::shared_ptr root_ir, bool *modified) { + MS_LOG(INFO) << "Pre pass: node removal pass started."; + // Create the removal node pass which can identify which nodes need to be removed. + std::unique_ptr removal_nodes = std::make_unique(); + RETURN_IF_NOT_OK(removal_nodes->Run(root_ir, modified)); + + // Then, execute the removal of any nodes that were set up for removal + for (auto node : removal_nodes->nodes_to_remove()) { + RETURN_IF_NOT_OK(node->Remove()); + } + MS_LOG(INFO) << "Pre pass: node removal pass complete."; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.h new file mode 100644 index 00000000000..292d46e0b53 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.h @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_REMOVAL_PASS_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_REMOVAL_PASS_H_ + +#include +#include +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +class DatasetOp; + +/// \class RemovalPass removal_pass.h +/// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which +/// nodes should be removed, and then removes them. +class NodeRemovalPass : public IRTreePass { + /// \class RemovalNodes + /// \brief This is a NodePass who's job is to identify which nodes should be removed. + /// It works in conjunction with the removal_pass. + class RemovalNodes : public IRNodePass { + public: + /// \brief Constructor + /// \param[in] removal_pass Raw pointer back to controlling tree pass + RemovalNodes(); + + /// \brief Destructor + ~RemovalNodes() = default; + + /// \brief Identifies the subtree below this node as a cached descendant tree. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status Visit(std::shared_ptr node, bool *modified) override; + + /// \brief Resets the tracking of the cache within the tree + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status VisitAfter(std::shared_ptr node, bool *modified) override; + + /// \brief Perform ShuffleNode removal check + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status Visit(std::shared_ptr node, bool *modified) override; + + /// \brief Getter + /// \return All the nodes to be removed + std::vector> nodes_to_remove() { return nodes_to_remove_; } + + private: + bool is_caching_; + std::vector> nodes_to_remove_; + }; + + public: + /// \brief Constructor + NodeRemovalPass(); + + /// \brief Destructor + ~NodeRemovalPass() = default; + + /// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them. + /// \param[inout] tree The tree to operate on. + /// \param[inout] Indicate of the tree was modified. + /// \return Status The error code return + Status RunOnTree(std::shared_ptr root_ir, bool *modified) override; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_REMOVAL_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc index f1713134970..18940026e7d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc @@ -17,34 +17,25 @@ #include "minddata/dataset/engine/tree_adapter.h" #include "minddata/dataset/core/client.h" -#include "minddata/dataset/include/datasets.h" #include "minddata/dataset/engine/ir/datasetops/root_node.h" #include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/engine/opt/pre/cache_validation_pass.h" +#include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h" #include "minddata/dataset/engine/opt/pre/input_validation_pass.h" +#include "minddata/dataset/engine/opt/pre/node_removal_pass.h" namespace mindspore { namespace dataset { Status TreeAdapter::PrePass(std::shared_ptr ir) { - // Vector of actions in validation pass - std::vector> validations; + // Vector of actions in pre-pass phase + std::vector> actions; MS_LOG(INFO) << "Running pre pass loops."; - validations.push_back(std::make_unique()); - - // Vector of flags for each action - // Apply validation actions - for (auto i = 0; i < validations.size(); i++) { - auto modified = false; - // InputValidationPass does not change the IR tree. We don't need to capture the "modified" value. - RETURN_IF_NOT_OK(validations[i]->Run(ir, &modified)); - } - - // Vector of actions in pre-pass phase - std::vector> actions; - - // We will gradually move CacheErrorPass, EpochInjectionPass, CacheTransformPass - // from ExecutionTree::PrepareTreePreAction to here. + actions.push_back(std::make_unique()); + actions.push_back(std::make_unique()); + actions.push_back(std::make_unique()); + actions.push_back(std::make_unique()); // Vector of flags for each action std::vector modified(actions.size(), false); @@ -60,7 +51,7 @@ Status TreeAdapter::PrePass(std::shared_ptr ir) { Status TreeAdapter::Optimize(std::shared_ptr ir) { // Vector of optimizations - std::vector> optimizations; + std::vector> optimizations; MS_LOG(INFO) << "Running optimization pass loops"; // We will gradually move TensorOpFusionPass from ExecutionTree::Optimize to here. @@ -79,7 +70,7 @@ Status TreeAdapter::Optimize(std::shared_ptr ir) { Status TreeAdapter::PostPass(std::shared_ptr ir) { // Vector of actions in post-pass phase - std::vector> actions; + std::vector> actions; MS_LOG(INFO) << "Running post pass loops."; // We will gradually move RepeatPass from ExecutionTree::PrepareTreePostAction to here. @@ -96,10 +87,6 @@ Status TreeAdapter::PostPass(std::shared_ptr ir) { } Status TreeAdapter::BuildExecutionTree(std::shared_ptr ir, std::shared_ptr *op) { - // Check if pipeline is valid or not - CHECK_FAIL_RETURN_UNEXPECTED(ir->Parent().size() <= 1, - "The data pipeline is not a tree (i.e. one node has two consumers)"); - // Build the DatasetOp ExecutionTree from the optimized IR tree std::vector> ops; RETURN_IF_NOT_OK(ir->Build(&ops)); @@ -130,8 +117,12 @@ Status TreeAdapter::Compile(std::shared_ptr input_ir, int32_t num_e RETURN_UNEXPECTED_IF_NULL(input_ir); MS_LOG(INFO) << "Input plan:" << '\n' << *input_ir << '\n'; + // We will first walk the input tree to sanity check this is not a graph + // Flag an error when it is not a tree + CHECK_FAIL_RETURN_UNEXPECTED(input_ir->IsTree(), "The data pipeline is not a tree (i.e. one node has two consumers)"); + // Copy the input IR tree and insert under the root node - // Create a root node to host the input IR tree, the deepcopied tree will be passed to optimization pass + // Create a root node to host the new copy of the input IR tree to pass to the optimizer auto root_ir = std::make_shared(input_ir->DeepCopy(), num_epochs); MS_LOG(INFO) << "Plan before PrePass:" << '\n' << *root_ir << '\n'; @@ -151,11 +142,9 @@ Status TreeAdapter::Compile(std::shared_ptr input_ir, int32_t num_e // This will evolve in the long run tree_ = std::make_unique(); - // Build the Execution tree from the child of the root node + // Build the Execution tree from the child of the IR root node, which represent the root of the input IR tree std::shared_ptr root_op; - // input_ir is the ir node before the deepcopy. - // We will replace input_ir with root_ir->Children()[0] once IR optimizer is in - RETURN_IF_NOT_OK(BuildExecutionTree(input_ir, &root_op)); + RETURN_IF_NOT_OK(BuildExecutionTree(root_ir->Children()[0], &root_op)); RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_); @@ -163,7 +152,7 @@ Status TreeAdapter::Compile(std::shared_ptr input_ir, int32_t num_e // Note: We will gradually move the pre pass, optimizer pass, and post pass // on ExecutionTree to perform on IR tree. // Prepare the tree - RETURN_IF_NOT_OK(tree_->Prepare(num_epochs)); + RETURN_IF_NOT_OK(tree_->Prepare(num_epochs, true)); // After the tree is prepared, the col_name_id_map can safely be obtained column_name_map_ = tree_->root()->column_name_id_map(); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.cc index 7c98a9343fc..24ee3572ef2 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.cc @@ -44,7 +44,7 @@ Status NormalizeOp::Compute(const std::shared_ptr &input, std::shared_pt } void NormalizeOp::Print(std::ostream &out) const { - out << "NormalizeOp, mean: " << mean_ << std::endl << "std: " << std_ << std::endl; + out << "NormalizeOp, mean: " << *(mean_.get()) << std::endl << "std: " << *(std_.get()) << std::endl; } } // namespace dataset } // namespace mindspore diff --git a/tests/ut/cpp/dataset/optimization_pass_test.cc b/tests/ut/cpp/dataset/optimization_pass_test.cc index ed5802d89d3..ad4aa961ee4 100644 --- a/tests/ut/cpp/dataset/optimization_pass_test.cc +++ b/tests/ut/cpp/dataset/optimization_pass_test.cc @@ -83,7 +83,7 @@ TEST_F(MindDataTestOptimizationPass, MindDataTestOutputShapeAndTypePass) { }; exe_tree->SetPrePassOverride(pass); - ASSERT_OK(exe_tree->PrepareTreePreAction()); + ASSERT_OK(exe_tree->PreAction()); std::stringstream ss; // print the tree in std::string as a way to verify that nodes are indeed removed @@ -124,7 +124,7 @@ TEST_F(MindDataTestOptimizationPass, MindDataTestDatasetSizePass) { }; exe_tree->SetPrePassOverride(pass); - ASSERT_OK(exe_tree->PrepareTreePreAction()); + ASSERT_OK(exe_tree->PreAction()); std::stringstream ss; // print the tree in std::string as a way to verify that nodes are indeed removed exe_tree->Print(ss); diff --git a/tests/ut/python/dataset/test_cache_map.py b/tests/ut/python/dataset/test_cache_map.py index 87d99fc069d..e6d7d023bcd 100644 --- a/tests/ut/python/dataset/test_cache_map.py +++ b/tests/ut/python/dataset/test_cache_map.py @@ -237,7 +237,7 @@ def test_cache_map_failure1(): num_iter = 0 for _ in ds1.create_dict_iterator(num_epochs=1): num_iter += 1 - assert "Nested cache operations is not supported!" in str(e.value) + assert "Nested cache operations" in str(e.value) assert num_iter == 0 logger.info('test_cache_failure1 Ended.\n') @@ -279,7 +279,7 @@ def test_cache_map_failure2(): num_iter = 0 for _ in dsz.create_dict_iterator(): num_iter += 1 - assert "ZipOp is currently not supported as a descendant operator under a cache" in str(e.value) + assert "ZipNode is not supported as a descendant operator under a cache" in str(e.value) assert num_iter == 0 logger.info('test_cache_failure2 Ended.\n') @@ -319,7 +319,7 @@ def test_cache_map_failure3(): num_iter = 0 for _ in ds1.create_dict_iterator(): num_iter += 1 - assert "BatchOp is currently not supported as a descendant operator under a cache" in str(e.value) + assert "BatchNode is not supported as a descendant operator under a cache" in str(e.value) assert num_iter == 0 logger.info('test_cache_failure3 Ended.\n') @@ -361,7 +361,7 @@ def test_cache_map_failure4(): num_iter = 0 for _ in ds1.create_dict_iterator(): num_iter += 1 - assert "FilterOp is currently not supported as a descendant operator under a cache" in str(e.value) + assert "FilterNode is not supported as a descendant operator under a cache" in str(e.value) assert num_iter == 0 logger.info('test_cache_failure4 Ended.\n') @@ -402,7 +402,7 @@ def test_cache_map_failure5(): num_iter = 0 for _ in data.create_dict_iterator(): num_iter += 1 - assert "MapOp with non-deterministic TensorOps is currently not supported as a descendant of cache" in str(e.value) + assert "MapOp with non-deterministic TensorOps is currently not supported as a descendant" in str(e.value) assert num_iter == 0 logger.info('test_cache_failure5 Ended.\n') @@ -522,7 +522,7 @@ def test_cache_map_failure8(): num_iter = 0 for _ in ds1.create_dict_iterator(num_epochs=1): num_iter += 1 - assert "Repeat is not supported as a descendant operator under a mappable cache" in str(e.value) + assert "A cache over a RepeatNode of a mappable dataset is not supported" in str(e.value) assert num_iter == 0 logger.info('test_cache_failure8 Ended.\n') @@ -564,7 +564,7 @@ def test_cache_map_failure9(): num_iter = 0 for _ in ds1.create_dict_iterator(): num_iter += 1 - assert "TakeOp/SplitOp is currently not supported as a descendant operator under a cache" in str(e.value) + assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value) assert num_iter == 0 logger.info('test_cache_failure9 Ended.\n') @@ -606,7 +606,7 @@ def test_cache_map_failure10(): num_iter = 0 for _ in ds1.create_dict_iterator(): num_iter += 1 - assert "SkipOp is currently not supported as a descendant operator under a cache" in str(e.value) + assert "SkipNode is not supported as a descendant operator under a cache" in str(e.value) assert num_iter == 0 logger.info('test_cache_failure10 Ended.\n') @@ -655,13 +655,13 @@ def test_cache_map_split1(): num_iter = 0 for _ in ds1.create_dict_iterator(): num_iter += 1 - assert "TakeOp/SplitOp is currently not supported as a descendant operator under a cache" in str(e.value) + assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value) with pytest.raises(RuntimeError) as e: num_iter = 0 for _ in ds2.create_dict_iterator(): num_iter += 1 - assert "TakeOp/SplitOp is currently not supported as a descendant operator under a cache" in str(e.value) + assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value) logger.info('test_cache_split1 Ended.\n')