diff --git a/mindspore/ccsrc/dataset/CMakeLists.txt b/mindspore/ccsrc/dataset/CMakeLists.txt index 5e4e06b6382..8b8ade52e38 100644 --- a/mindspore/ccsrc/dataset/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/CMakeLists.txt @@ -66,6 +66,7 @@ set(submodules $ $ $ + $ $ ) diff --git a/mindspore/ccsrc/dataset/engine/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/CMakeLists.txt index 99a07e8699a..9d01fca9143 100644 --- a/mindspore/ccsrc/dataset/engine/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(datasetops) +add_subdirectory(opt) if (ENABLE_TDTQUE) add_subdirectory(tdt) endif () @@ -14,7 +15,7 @@ add_library(engine OBJECT target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) if (ENABLE_TDTQUE) - add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt) + add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt) else() - add_dependencies(engine engine-datasetops engine-datasetops-source) + add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt) endif () diff --git a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc index 13f3d4b2ba7..bfe9079501a 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc @@ -22,6 +22,7 @@ #include "dataset/core/pybind_support.h" #include "dataset/engine/data_buffer.h" #include "dataset/engine/db_connector.h" +#include "dataset/engine/opt/pass.h" using float16 = Eigen::half; @@ -462,5 +463,11 @@ Status BatchOp::PadHelper(std::shared_ptr src, std::shared_ptr d return Status::OK(); } +// Visitor accept method for NodePass +Status BatchOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h b/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h index f8faa9562ee..1a862acd0bd 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h @@ -192,6 +192,12 @@ class BatchOp : public ParallelOp { Status PadTensor(std::shared_ptr src, std::shared_ptr *dst, const std::vector &pad_shape, float pad_val); + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + private: // recursive helper function. This function could be very expensive if called on a multi-dimensional tensor // it is only meant to be called by PadTensor. diff --git a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc index 62a9ede5872..6787b7bbaa6 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc @@ -25,6 +25,7 @@ #include "dataset/engine/datasetops/device_queue_op.h" #include "dataset/engine/data_buffer.h" #include "dataset/engine/db_connector.h" +#include "dataset/engine/opt/pass.h" #include "utils/log_adapter.h" @@ -249,5 +250,11 @@ Status DatasetOp::AssignColMapFromChild() { } return Status::OK(); } + +Status DatasetOp::Accept(NodePass *p, bool *modified) { + // DatasetOp is the base class of visitor target. + // This method will only be called if its derived class does not implement one. + return p->RunOnNode(shared_from_this(), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h index cbd3115074b..315dc272194 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h @@ -32,6 +32,8 @@ class ExecutionTree; class DataBuffer; +class NodePass; + // The base class DatasetOp is the main tree node. It is an abstract class, so // the actual implementation of the operators will be derived from here. class DatasetOp : public std::enable_shared_from_this { @@ -209,6 +211,16 @@ class DatasetOp : public std::enable_shared_from_this { // @return - the column name map as a string std::string ColumnNameMapAsString() const; + // Children Getter + // @return Vector or Children + std::vector> Children() const { return child_; } + + // Base method for NodePass visit. + // Subclass needs to override this if it requires special node visit access. + // Check "dataset/engine/opt/pass.h" for more details. + // @return Statue of the node visit + virtual Status Accept(NodePass *p, bool *modified); + protected: // Adds a parent operator to this operator // @notes External callers do not have access to this function. diff --git a/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.cc index 0815088fa57..f3ab287babc 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.cc @@ -24,6 +24,7 @@ #include "dataset/engine/dataset_iterator.h" #include "dataset/util/status.h" #include "dataset/util/task_manager.h" +#include "dataset/engine/opt/pass.h" #ifdef ENABLE_TDTQUE #include "tdt/tsd_client.h" @@ -265,5 +266,12 @@ void DeviceQueueOp::Print(std::ostream &out, bool show_all) const { out << "\nChannel name: " << channel_name_ << "\nPrefetch size: " << prefetch_size_ << "\n\n"; } } + +// Visitor accept method for NodePass +Status DeviceQueueOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.h b/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.h index 8856cc4460a..ebbcd16cc3a 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.h @@ -134,6 +134,12 @@ class DeviceQueueOp : public PipelineOp { Status operator()() override; + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + private: // Name: checkExceptions(DataBuffer); // Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp diff --git a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc index 8fe005383fd..26b99080c8d 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc @@ -27,6 +27,7 @@ #include "dataset/engine/data_buffer.h" #include "dataset/engine/db_connector.h" #include "dataset/engine/execution_tree.h" +#include "dataset/engine/opt/pass.h" #include "dataset/kernels/tensor_op.h" #include "utils/log_adapter.h" #include "dataset/util/task_manager.h" @@ -259,5 +260,11 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate } return Status(StatusCode::kOK, "FilterOp predicate func call succeed"); } + +// Visitor accept method for NodePass +Status FilterOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h index a2b5bfa541a..cd6c01da90f 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h @@ -121,6 +121,12 @@ class FilterOp : public ParallelOp { // @param show_all A bool to control if you want to show all info or just a summary. void Print(std::ostream &out, bool show_all) const override; + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + private: // predicate_func python callable which returns a boolean value. py::function predicate_func_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc index 9136da7a9e5..ec3a09154bd 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc @@ -27,6 +27,7 @@ #include "dataset/engine/data_buffer.h" #include "dataset/engine/db_connector.h" #include "dataset/engine/execution_tree.h" +#include "dataset/engine/opt/pass.h" #include "dataset/kernels/tensor_op.h" #include "utils/log_adapter.h" #include "dataset/util/task_manager.h" @@ -370,5 +371,11 @@ void MapOp::CreateFinalColMap(std::unordered_map *col_name column_name_id_map_ = final_col_name_id_map; } } + +// Visitor accept method for NodePass +Status MapOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/map_op.h b/mindspore/ccsrc/dataset/engine/datasetops/map_op.h index a6a573146e5..f903881ca24 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/map_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/map_op.h @@ -171,6 +171,12 @@ class MapOp : public ParallelOp { // @return the number of threads consuming data from previous op's output Connector. int32_t num_consumers() const override; + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + private: // Local queues where worker threads can pop from. // Popping directly from the Connector can block if the previous designated threads haven't pop. diff --git a/mindspore/ccsrc/dataset/engine/datasetops/project_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/project_op.cc index f99319396ff..73d58d785eb 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/project_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/project_op.cc @@ -25,6 +25,7 @@ #include "dataset/engine/data_buffer.h" #include "dataset/engine/db_connector.h" #include "dataset/engine/execution_tree.h" +#include "dataset/engine/opt/pass.h" #include "utils/log_adapter.h" namespace mindspore { @@ -144,5 +145,11 @@ Status ProjectOp::EoeReceived(int32_t worker_id) { } Status ProjectOp::EofReceived(int32_t worker_id) { return Status::OK(); } + +// Visitor accept method for NodePass +Status ProjectOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/project_op.h b/mindspore/ccsrc/dataset/engine/datasetops/project_op.h index 25b6cc691e1..3940b9adc71 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/project_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/project_op.h @@ -101,6 +101,12 @@ class ProjectOp : public PipelineOp { // @return Status - The error code returned. Status EofReceived(int32_t worker_id) override; + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + private: std::vector columns_to_project_; std::vector projected_column_indices_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/rename_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/rename_op.cc index a940bef0b84..e4715a8fac7 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/rename_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/rename_op.cc @@ -24,6 +24,7 @@ #include "dataset/core/global_context.h" #include "dataset/engine/data_buffer.h" #include "dataset/engine/db_connector.h" +#include "dataset/engine/opt/pass.h" #include "utils/log_adapter.h" namespace mindspore { @@ -170,5 +171,11 @@ Status RenameOp::EoeReceived(int32_t) { state_ = OpState::kDeOpIdle; return Status::OK(); } + +// Visitor accept method for NodePass +Status RenameOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/rename_op.h b/mindspore/ccsrc/dataset/engine/datasetops/rename_op.h index f91b6af9310..2bd4875fda7 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/rename_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/rename_op.h @@ -110,6 +110,12 @@ class RenameOp : public PipelineOp { // @return Status - The error code return Status operator()() override; + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + protected: // Rename core functionality Status RenameColumns(); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc index bcda2dab459..5b1221d72c6 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc @@ -21,6 +21,7 @@ #include "dataset/engine/datasetops/repeat_op.h" #include "dataset/engine/data_buffer.h" #include "dataset/engine/db_connector.h" +#include "dataset/engine/opt/pass.h" #include "utils/log_adapter.h" @@ -187,5 +188,11 @@ int32_t RepeatOp::num_producers() const { return child_[0]->num_producers(); } } + +// Visitor accept method for NodePass +Status RepeatOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h b/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h index 8497b4cf3ca..718bc1922bd 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h @@ -118,6 +118,12 @@ class RepeatOp : public PipelineOp { // @param workerId - The worker id int32_t num_producers() const override; + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + private: int32_t max_repeats_; // The number of repeats that the user requested int32_t repeat_count_; // A counter for the current number of executed repeats diff --git a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc index 0f10d3106ae..ebc69e87fe0 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc @@ -30,6 +30,7 @@ #include "dataset/engine/dataset_iterator.h" #include "dataset/engine/data_buffer.h" #include "dataset/engine/db_connector.h" +#include "dataset/engine/opt/pass.h" #include "dataset/util/random.h" #include "dataset/util/status.h" @@ -296,5 +297,11 @@ Status ShuffleOp::EoeReceived(int32_t worker_id) { state_ = OpState::kDeOpIdle; return Status::OK(); } + +// Visitor accept method for NodePass +Status ShuffleOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.h b/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.h index 23f6f15c410..baabad758c6 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.h @@ -155,6 +155,12 @@ class ShuffleOp : public PipelineOp { // @return Status - The error code return Status EoeReceived(int32_t worker_id) override; + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + private: // Private function to add a new row to the shuffle buffer. // @return Status - The error code return diff --git a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc index d3edd98909d..35fae8d0916 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc @@ -22,6 +22,7 @@ #include "dataset/engine/datasetops/skip_op.h" #include "dataset/engine/db_connector.h" #include "dataset/engine/execution_tree.h" +#include "dataset/engine/opt/pass.h" #include "utils/log_adapter.h" @@ -128,5 +129,11 @@ Status SkipOp::EofReceived(int32_t worker_id) { MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now."; return Status::OK(); } + +// Visitor accept method for NodePass +Status SkipOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.h b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.h index a16b82ed21a..40db770642c 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.h @@ -74,6 +74,12 @@ class SkipOp : public PipelineOp { // @param worker_id - The worker id Status EofReceived(int32_t worker_id) override; + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + private: int32_t max_skips_; // The number of skips that the user requested int32_t skip_count_; // A counter for the current number of executed skips diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.cc index bff7a7580ef..fcdb2d987f6 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.cc @@ -20,6 +20,7 @@ #include "dataset/engine/data_buffer.h" #include "dataset/engine/execution_tree.h" #include "dataset/util/task_manager.h" +#include "dataset/engine/opt/pass.h" namespace mindspore { namespace dataset { @@ -250,5 +251,11 @@ Status GeneratorOp::Reset() { wp_.Set(); return Status(StatusCode::kOK, "GeneratorOp Reset Succeed"); } + +// Visitor accept method for NodePass +Status GeneratorOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.h index f8227dafa5f..afeff29b869 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.h @@ -121,6 +121,12 @@ class GeneratorOp : public PipelineOp { // @return Status - The error code return Status Reset() override; + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + private: py::function generator_function_; std::vector column_names_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc index 77b94fb6ceb..16a9135ade8 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc @@ -22,6 +22,7 @@ #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" #include "dataset/engine/db_connector.h" #include "dataset/engine/execution_tree.h" +#include "dataset/engine/opt/pass.h" namespace mindspore { namespace dataset { @@ -451,5 +452,11 @@ Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const int64_t (*num_rows) = (row_cnt / num_dev) + (row_cnt % num_dev == 0 ? 0 : 1); return Status::OK(); } + +// Visitor accept method for NodePass +Status ImageFolderOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h index d27d220cb57..72d47224fb8 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h @@ -225,6 +225,12 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { const std::set &exts, int64_t *num_rows, int64_t *num_classes, int64_t dev_id = 0, int64_t num_dev = 1); + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + private: // Initialize Sampler, calls sampler->Init() within // @return Status - The error code return diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc index 1644ce1cfc1..991869ac081 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc @@ -29,6 +29,7 @@ #include "dataset/engine/datasetops/dataset_op.h" #include "dataset/engine/db_connector.h" #include "dataset/engine/execution_tree.h" +#include "dataset/engine/opt/pass.h" #include "utils/log_adapter.h" namespace mindspore { @@ -684,5 +685,11 @@ Status MindRecordOp::CountTotalRows(const std::vector dataset_path, } return Status::OK(); } + +// Visitor accept method for NodePass +Status MindRecordOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h index a3a066bfc25..c8e333d3736 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h @@ -195,6 +195,12 @@ class MindRecordOp : public ParallelOp { Status SetColumnsBlob(); + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + private: Status GetBufferFromReader(std::unique_ptr *fetched_buffer, int64_t buffer_id, int32_t worker_id); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc index 12b10e69973..181588c470c 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc @@ -37,6 +37,7 @@ #include "dataset/engine/db_connector.h" #include "dataset/engine/execution_tree.h" #include "dataset/engine/jagged_connector.h" +#include "dataset/engine/opt/pass.h" #include "dataset/util/path.h" #include "dataset/util/queue.h" #include "dataset/util/random.h" @@ -1037,5 +1038,11 @@ int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector &file return rows_read; } + +// Visitor accept method for NodePass +Status TFReaderOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h index f0f08c7971e..3dc5ee932ed 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h @@ -222,6 +222,12 @@ class TFReaderOp : public ParallelOp { static Status CountTotalRows(int64_t *out_total_rows, const std::vector &filenames, int64_t threads = 1, bool estimate = false); + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + private: // The entry point for when workers are launched. // @param worker_id - the id of the worker that is executing this function. diff --git a/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc index 23bd1b08d3d..f9ff49a5b78 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc @@ -22,6 +22,7 @@ #include "dataset/engine/datasetops/take_op.h" #include "dataset/engine/db_connector.h" #include "dataset/engine/execution_tree.h" +#include "dataset/engine/opt/pass.h" namespace mindspore { namespace dataset { @@ -132,5 +133,11 @@ Status TakeOp::PrepareNodePostAction() { tree_->AddToRepeatStack(shared_from_this()); return Status::OK(); } + +// Visitor accept method for NodePass +Status TakeOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/take_op.h b/mindspore/ccsrc/dataset/engine/datasetops/take_op.h index f70a1e91a37..64ba8e69e00 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/take_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/take_op.h @@ -84,6 +84,12 @@ class TakeOp : public PipelineOp { // before providing their own implementations. Status PrepareNodePostAction() override; + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + private: int32_t max_takes_; // The number of takes that the user requested int32_t take_count_; // A counter for the current number of executed takes diff --git a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/zip_op.cc index fc99be8d88d..1a96e601b78 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/zip_op.cc @@ -19,6 +19,7 @@ #include "dataset/core/constants.h" #include "dataset/engine/data_buffer.h" #include "dataset/engine/db_connector.h" +#include "dataset/engine/opt/pass.h" #include "dataset/core/config_manager.h" #include "dataset/core/global_context.h" #include "utils/log_adapter.h" @@ -250,5 +251,11 @@ Status ZipOp::EoeReceived(int32_t) { state_ = OpState::kDeOpIdle; return Status::OK(); } + +// Visitor accept method for NodePass +Status ZipOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(std::static_pointer_cast(shared_from_this()), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h b/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h index fa7f97f3872..1140a98dd73 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h @@ -104,6 +104,12 @@ class ZipOp : public PipelineOp { // @return Status - The error code return Status operator()() override; + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + private: // Handles preprocessing of the main loop, used when starting new epoch Status prepare(TensorQTable *const table); diff --git a/mindspore/ccsrc/dataset/engine/execution_tree.cc b/mindspore/ccsrc/dataset/engine/execution_tree.cc index dbcc201d484..95a6b6f0f53 100644 --- a/mindspore/ccsrc/dataset/engine/execution_tree.cc +++ b/mindspore/ccsrc/dataset/engine/execution_tree.cc @@ -20,6 +20,8 @@ #include "dataset/engine/datasetops/shuffle_op.h" #include "dataset/util/task_manager.h" +#include "dataset/engine/opt/util/printer_pass.h" + namespace mindspore { namespace dataset { // Constructor @@ -161,10 +163,54 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::functionPrepareTreePreAction()); + + // Optimization transformation + RETURN_IF_NOT_OK(this->Optimize()); + + // Post optimization compulsory transformation + RETURN_IF_NOT_OK(this->PrepareTreePostAction()); + + // Existing transformation implementation, will be removed later + RETURN_IF_NOT_OK(this->PrepareDeprecated()); + return Status::OK(); +} + +Status ExecutionTree::PrepareTreePreAction() { return Status::OK(); } + +Status ExecutionTree::PrepareTreePostAction() { return Status::OK(); } + +Status ExecutionTree::Optimize() { + // auto pp = new PrinterPass(); + // bool modified = false; + // pp->Run(this, &modified); + return Status::OK(); +} + // The driver of the prepare phase of the execution tree. The prepare phase will recursively // walk the tree to perform modifications to the tree or specific nodes within the tree to get // it ready for execution. -Status ExecutionTree::Prepare() { +// +// This driver is deprecated. +Status ExecutionTree::PrepareDeprecated() { // Tree must be in pending prepare state before we can assign root to it if (tree_state_ != kDeTStatePrepare) { std::string err_msg = diff --git a/mindspore/ccsrc/dataset/engine/execution_tree.h b/mindspore/ccsrc/dataset/engine/execution_tree.h index 838eb3a014c..f0c894f05bf 100644 --- a/mindspore/ccsrc/dataset/engine/execution_tree.h +++ b/mindspore/ccsrc/dataset/engine/execution_tree.h @@ -152,11 +152,41 @@ class ExecutionTree { // @return the prepare flags uint32_t PrepareFlags() const { return prepare_flags_; } - // The driver of the prepare phase of the execution tree. The prepare phase will recursively + // The driver of the prepare phase of the execution tree. + // Prepare phase consists of three sub phases + // + // 1. PrepareTreePreAction() + // Compulsory transformation/action pre optimization. + // For example, CacheOp Insertion + // + // 2. Optimize() + // Optimization transformation/action, optional + // For example, MapOp Fusion + // + // 3. PrepareTreePostAction() + // Compulsory transformation/action post optimization. + // For example, repeatOp inlining + // + // @return Status - The error code return + Status Prepare(); + + // Compulsory transformation/action pre optimization. + // @return Status - The error code return + Status PrepareTreePreAction(); + + // Compulsory transformation/action post optimization. + // @return Status - The error code return + Status PrepareTreePostAction(); + + // Optimization transformation/action, optional. + // @return Status - The error code return + Status Optimize(); + + // The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively // walk the tree to perform modifications to the tree or specific nodes within the tree to get // it ready for execution. // @return Status - The error code return - Status Prepare(); + Status PrepareDeprecated(); // Recursive function used during prepare phase to visit a node and drive any pre- and post- // node actions during a tree walk. diff --git a/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt new file mode 100644 index 00000000000..9804b85d3ad --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) +add_library(engine-opt OBJECT + pass.cc + util/printer_pass.cc + ) \ No newline at end of file diff --git a/mindspore/ccsrc/dataset/engine/opt/pass.cc b/mindspore/ccsrc/dataset/engine/opt/pass.cc new file mode 100644 index 00000000000..e6bd9fe247f --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/opt/pass.cc @@ -0,0 +1,157 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "dataset/engine/opt/pass.h" +#include "dataset/engine/datasetops/batch_op.h" +#include "dataset/engine/datasetops/dataset_op.h" +#include "dataset/engine/datasetops/device_queue_op.h" +#include "dataset/engine/datasetops/map_op.h" +#include "dataset/engine/datasetops/project_op.h" +#include "dataset/engine/datasetops/rename_op.h" +#include "dataset/engine/datasetops/filter_op.h" +#include "dataset/engine/datasetops/repeat_op.h" +#include "dataset/engine/datasetops/skip_op.h" +#include "dataset/engine/datasetops/shuffle_op.h" +#include "dataset/engine/datasetops/source/generator_op.h" +#include "dataset/engine/datasetops/source/mindrecord_op.h" +#include "dataset/engine/datasetops/source/storage_op.h" +#include "dataset/engine/datasetops/source/tf_reader_op.h" +#include "dataset/engine/datasetops/source/image_folder_op.h" +#include "dataset/engine/datasetops/take_op.h" +#include "dataset/engine/datasetops/zip_op.h" + +namespace mindspore { +namespace dataset { + +// Driver method for TreePass +Status TreePass::Run(ExecutionTree *tree, bool *modified) { return this->RunOnTree(tree, modified); } + +// Driver method for NodePass +Status NodePass::Run(ExecutionTree *tree, bool *modified) { + std::shared_ptr root = tree->root(); + if (traversalOrder_ == Order::DFS) { + // DFS + return DFSNodeVisit(root, modified); + } else if (traversalOrder_ == Order::BFS) { + // BFS + return BFSNodeVisit(root, modified); + } + return Status::OK(); +} + +// Helper function to perform DFS visit +Status NodePass::DFSNodeVisit(std::shared_ptr node, bool *modified) { + for (const auto &c : node->Children()) { + RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified)); + } + return node->Accept(this, modified); +} + +// Helper function to perform BFS visit +Status NodePass::BFSNodeVisit(std::shared_ptr root, bool *modified) { + // Initialize bfs queue with root + std::queue> bfsQueue; + bfsQueue.push(root); + + // BFS loop + while (!bfsQueue.empty()) { + // Pop the front of the bfs queue + auto curNode = bfsQueue.front(); + bfsQueue.pop(); + + // Run node pass + RETURN_IF_NOT_OK(curNode->Accept(this, modified)); + + // Push children into bfs queue + for (const auto &c : curNode->Children()) { + bfsQueue.push(c); + } + } + return Status::OK(); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/pass.h b/mindspore/ccsrc/dataset/engine/opt/pass.h new file mode 100644 index 00000000000..58dfc787f47 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/opt/pass.h @@ -0,0 +1,146 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_H_ +#define DATASET_ENGINE_OPT_PASS_H_ + +#include +#include + +#include "dataset/engine/execution_tree.h" +#include "dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class BatchOp; + +class MapOp; + +class ProjectOp; + +class RenameOp; + +class FilterOp; + +class SkipOp; + +class ShuffleOp; + +class GeneratorOp; + +class MindRecordOp; + +class TFReaderOp; + +class TakeOp; + +class ZipOp; + +class DeviceQueueOp; + +class ImageFolderOp; + +// The base class Pass is the basic unit of tree transformation. +// The actual implementation of the passes will be derived from here. +class Pass : public std::enable_shared_from_this { + public: + // Run the transformation pass again the execution tree. + // @param tree - Pointer to the execution tree to be transformed. + // @param modified - Pointer to the modified flag, + virtual Status Run(ExecutionTree *tree, bool *modified) { return Status::OK(); } +}; + +// TreePass is a basic Pass class which performs transformation on ExecutionTree directly. +class TreePass : public Pass { + public: + // Run the transformation pass against the execution tree. + // @param tree - Pointer to the execution tree to be transformed. + // @param modified - Pointer to the modified flag, + Status Run(ExecutionTree *tree, bool *modified) final; + + // Derived classes may implement the runOnTree function to implement tree transformation. + // "modified" flag needs to be set to true if tree is modified during the pass execution. + // @return Status - The error code return + virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); } +}; + +// NodePass is a basic Pass class which performs transformation on Node visiting. +// NodePass implements Visitor design pattern. +class NodePass : public Pass { + public: + // Tree traversal order + enum Order { DFS, BFS }; + + // Constructor + // Default DFS traversal + explicit NodePass(Order order = Order::DFS) { traversalOrder_ = order; } + + // Run the transformation pass against the execution tree. + // @param tree - Pointer to the execution tree to be transformed. + // @param modified - Pointer to the modified flag, + Status Run(ExecutionTree *tree, bool *modified) final; + + // Derived classes may implement the runOnNode function to implement node level tree transformation. + // "modified" flag needs to be set to true if tree is modified during the pass execution. + // @return Status - The error code return + virtual Status RunOnNode(std::shared_ptr node, bool *modified) { return Status::OK(); } + + // Visit methods to be overridden. + // Note that member template can not be virtual, any op which wants to work with NodePass should declare RunOnNode + // of its own type and override "Accept" from DatasetOp. + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + private: + // Helper function to perform DFS visit + Status DFSNodeVisit(std::shared_ptr node, bool *modified); + + // Helper function to perform BFS visit + Status BFSNodeVisit(std::shared_ptr root, bool *modified); + + // Tree traversal order of the NodePass + Order traversalOrder_; +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_H_ diff --git a/mindspore/ccsrc/dataset/engine/opt/util/printer_pass.cc b/mindspore/ccsrc/dataset/engine/opt/util/printer_pass.cc new file mode 100644 index 00000000000..852bc018b20 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/opt/util/printer_pass.cc @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "dataset/engine/opt/util/printer_pass.h" + +namespace mindspore { +namespace dataset { + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting DatasetOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting BatchOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting MapOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting ProjectOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting RenameOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting FilterOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting SkipOp" << '\n'; + return Status::OK(); +} +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting ShuffleOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting GeneratorOp" << '\n'; + return Status::OK(); +} +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting MindRecordOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting TFReaderOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting TakeOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting ZipOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting DeviceQueueOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting ImageFolderOp" << '\n'; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/util/printer_pass.h b/mindspore/ccsrc/dataset/engine/opt/util/printer_pass.h new file mode 100644 index 00000000000..fa04a88277e --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/opt/util/printer_pass.h @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H +#define DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H + +#include +#include "dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +class PrinterPass : public NodePass { + public: + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H diff --git a/tests/ut/python/dataset/test_opt.py b/tests/ut/python/dataset/test_opt.py new file mode 100644 index 00000000000..939a885156e --- /dev/null +++ b/tests/ut/python/dataset/test_opt.py @@ -0,0 +1,46 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np +import pytest + +import mindspore.dataset as ds + +# Generate 1d int numpy array from 0 - 63 +def generator_1d(): + for i in range(64): + yield (np.array([i]),) + + +def test_case_0(): + """ + Test 1D Generator + """ + + # apply dataset operations + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + + data1 = data1.shuffle(2) + + data1 = data1.map(["data"], operations=(lambda x : x)) + + data1 = data1.batch(2) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + pass + + +if __name__ == "__main__": + test_case_0() \ No newline at end of file