From 57f3732ac324774868b9c682aedd65e0a80bb31d Mon Sep 17 00:00:00 2001 From: Junhan Hu Date: Wed, 17 Jun 2020 23:34:50 -0400 Subject: [PATCH] Move python tree transformation into cpp pass. --- mindspore/ccsrc/dataset/api/de_pipeline.cc | 8 ++ .../dataset/engine/datasetops/dataset_op.cc | 40 +++++++- .../dataset/engine/datasetops/dataset_op.h | 18 +++- .../ccsrc/dataset/engine/datasetops/map_op.cc | 9 +- .../ccsrc/dataset/engine/datasetops/map_op.h | 20 +++- .../engine/datasetops/source/clue_op.cc | 13 ++- .../engine/datasetops/source/clue_op.h | 20 +++- .../engine/datasetops/source/text_file_op.cc | 12 ++- .../engine/datasetops/source/text_file_op.h | 20 +++- .../engine/datasetops/source/tf_reader_op.cc | 13 ++- .../engine/datasetops/source/tf_reader_op.h | 20 +++- .../ccsrc/dataset/engine/execution_tree.cc | 23 ++++- .../ccsrc/dataset/engine/opt/CMakeLists.txt | 2 + mindspore/ccsrc/dataset/engine/opt/pass.cc | 10 +- mindspore/ccsrc/dataset/engine/opt/pass.h | 4 +- .../dataset/engine/opt/pre/global_shuffle.cc | 98 +++++++++++++++++++ .../dataset/engine/opt/pre/global_shuffle.h | 35 +++++++ .../engine/opt/pre/map_column_reorder.cc | 51 ++++++++++ .../engine/opt/pre/map_column_reorder.h | 35 +++++++ mindspore/dataset/engine/datasets.py | 4 + mindspore/dataset/engine/iterators.py | 24 +---- tests/ut/cpp/dataset/repeat_op_test.cc | 1 + tests/ut/python/dataset/test_opt_pass.py | 90 +++++++++++++++++ 23 files changed, 517 insertions(+), 53 deletions(-) create mode 100644 mindspore/ccsrc/dataset/engine/opt/pre/global_shuffle.cc create mode 100644 mindspore/ccsrc/dataset/engine/opt/pre/global_shuffle.h create mode 100644 mindspore/ccsrc/dataset/engine/opt/pre/map_column_reorder.cc create mode 100644 mindspore/ccsrc/dataset/engine/opt/pre/map_column_reorder.h create mode 100644 tests/ut/python/dataset/test_opt_pass.py diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index 5ff8151c0ec..e14d8d1d504 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -487,6 +487,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr * (void)builder->SetInColNames(in_col_names); } else if (key == "output_columns") { (void)builder->SetOutColNames(ToStringVector(value)); + } else if (key == "columns_order") { + (void)builder->SetColOrder(ToStringVector(value)); } else if (key == "num_parallel_workers") { (void)builder->SetNumWorkers(ToInt(value)); } else if (key == "prefetch_size") { @@ -835,6 +837,8 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptrSetColumnsToLoad(columns_to_load); } else if (key == "shuffle_files") { (void)builder->SetShuffleFiles(ToBool(value)); + } else if (key == "shuffle_global") { + (void)builder->SetShuffleGlobal(ToBool(value)); } else if (key == "schema_file_path" || key == "schema_json_string") { schema_exists = true; } else if (key == "num_samples") { @@ -1225,6 +1229,8 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptrSetNumWorkers(ToInt(value)); } else if (key == "shuffle_files") { (void)builder->SetShuffleFiles(ToBool(value)); + } else if (key == "shuffle_global") { + (void)builder->SetShuffleGlobal(ToBool(value)); } else if (key == "num_samples") { (void)builder->SetTotalRows(ToInt(value)); } else if (key == "num_shards") { @@ -1314,6 +1320,8 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr (void)builder->SetNumWorkers(ToInt(value)); } else if (key == "shuffle_files") { (void)builder->SetShuffleFiles(ToBool(value)); + } else if (key == "shuffle_global") { + (void)builder->SetShuffleGlobal(ToBool(value)); } else if (key == "num_samples") { (void)builder->SetNumSamples(ToInt(value)); } else if (key == "num_shards") { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc index e105c554f2d..bf991ea7d9d 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include "dataset/engine/execution_tree.h" #include "dataset/engine/datasetops/device_queue_op.h" @@ -68,8 +69,45 @@ Status DatasetOp::AddChild(std::shared_ptr child) { return Status::OK(); } +Status DatasetOp::RemoveChild(std::shared_ptr child) { + if (operator_id_ == kInvalidOperatorId) { + std::string err_msg( + "Cannot remove child node. Tree node connections can only" + "be made if the node belongs to a tree."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // disallow relationships with other trees + if (tree_ != child->tree_) { + std::string err_msg( + "Cannot remove child node. Tree node connections can only be made if both nodes belong to the same tree."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + child_.erase(std::remove(child_.begin(), child_.end(), child), child_.end()); + child->RemoveParent(this); + return Status::OK(); +} + +Status DatasetOp::InsertAsParent(std::shared_ptr to_add) { + for (auto &prev_parent : this->parent_) { + RETURN_IF_NOT_OK(prev_parent->RemoveChild(shared_from_this())); + RETURN_IF_NOT_OK(prev_parent->AddChild(to_add)); + } + RETURN_IF_NOT_OK(to_add->AddChild(shared_from_this())); + if (tree_->root()->id() == this->id()) { + tree_->AssignRoot(to_add); + } + return Status::OK(); +} + // Adds a parent operator to this operator -void DatasetOp::AddParent(const DatasetOp *parent) { parent_.push_back(parent); } +void DatasetOp::AddParent(DatasetOp *parent) { parent_.push_back(parent); } + +// Removes a parent operator from this operator +void DatasetOp::RemoveParent(DatasetOp *parent) { + parent_.erase(std::remove(parent_.begin(), parent_.end(), parent), parent_.end()); +} // Getter function to get a shared pointer to our childAdds a operator to become our child. std::shared_ptr DatasetOp::child(int32_t child_index) const { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h index 21b5cbf7ce3..955c2c486d8 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h @@ -64,10 +64,19 @@ class DatasetOp : public std::enable_shared_from_this { // @param child - shared pointer to the child to add. Status AddChild(std::shared_ptr child); + // Remove a operator from our children. + // @param child - shared pointer to the child to remove. + Status RemoveChild(std::shared_ptr child); + // Getter function to get a shared pointer to our child // @param child_index - An operator can have n children. Indicates choose which child to return. std::shared_ptr child(int32_t child_index) const; + // Inserts a operator as the parent current op. + // Inserted op will become the sole parent of the current op. + // The existing parent of the current op will be transferred to the inserted op. + Status InsertAsParent(std::shared_ptr to_add); + // Creates the connector within this operator // @param num_producers - number of threads that write into this connector // @param num_consumers - number of threads that read from this connector @@ -261,7 +270,12 @@ class DatasetOp : public std::enable_shared_from_this { // Adds a parent operator to this operator // @notes External callers do not have access to this function. // @param parent - The parent node to add - void AddParent(const DatasetOp *parent); + void AddParent(DatasetOp *parent); + + // Removes a parent operator from this operator + // @notes External callers do not have access to this function. + // @param parent - The parent node to remove + void RemoveParent(DatasetOp *parent); // A helper function for providing an assignment of the column name map. // This grabs the map from child 0 and assigns it into this op. @@ -270,7 +284,7 @@ class DatasetOp : public std::enable_shared_from_this { Status AssignColMapFromChild(); std::vector> child_; // Child nodes - std::vector parent_; // Parent nodes. No ownership and read-only + std::vector parent_; // Parent nodes. No ownership int32_t oc_queue_size_; // Capacity for each out_connector_ int32_t operator_id_; // Generated id for the node ExecutionTree *tree_; // Back pointer to our tree. diff --git a/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc index 008ff09c99a..99182602017 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc @@ -54,19 +54,20 @@ Status MapOp::Builder::sanityCheck() const { Status MapOp::Builder::Build(std::shared_ptr *ptr) { RETURN_IF_NOT_OK(sanityCheck()); *ptr = std::make_shared(std::move(build_in_col_names_), std::move(build_out_col_names_), - std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_, - build_perf_mode_); + std::move(build_tensor_funcs_), std::move(build_col_order_), build_num_workers_, + build_op_connector_size_, build_perf_mode_); return Status::OK(); } // Constructor of MapOp MapOp::MapOp(const std::vector &in_col_names, const std::vector &out_col_names, - std::vector> tensor_funcs, int32_t num_workers, int32_t op_connector_size, - bool perf_mode) + std::vector> tensor_funcs, const std::vector &columns_order, + int32_t num_workers, int32_t op_connector_size, bool perf_mode) : ParallelOp(num_workers, op_connector_size), tfuncs_(std::move(tensor_funcs)), in_columns_(in_col_names), out_columns_(out_col_names), + columns_order_(columns_order), perf_mode_(perf_mode) { // If caller didn't specify the out_col_names, assume they are same as the in_columns. if (out_columns_.empty() || out_columns_[0].empty()) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/map_op.h b/mindspore/ccsrc/dataset/engine/datasetops/map_op.h index 8bec6179e30..4d7ffd12047 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/map_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/map_op.h @@ -93,6 +93,13 @@ class MapOp : public ParallelOp { return *this; } + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetColOrder(const std::vector &col_order_) { + build_col_order_ = col_order_; + return *this; + } + // Setter method. // @return Builder setter method returns reference to the builder. Builder &SetNumWorkers(int32_t num_workers) { @@ -123,6 +130,7 @@ class MapOp : public ParallelOp { std::vector build_in_col_names_; std::vector build_out_col_names_; std::vector> build_tensor_funcs_; + std::vector build_col_order_; int32_t build_num_workers_; int32_t build_op_connector_size_; bool build_perf_mode_; // Default true. @@ -137,11 +145,12 @@ class MapOp : public ParallelOp { // @param in_col_names A list of input column names (should match the input/output \p tensorFuncs). // @param out_col_names A list of output column names (should match the input/output \p tensorFuncs). // @param tensor_funcs A list of TensorOp pointers for MapOp to apply to each data. + // @param columns_order names A full list of column names (should match the whole dataset view post \p tensorFuncs). // @param num_workers The number of worker threads. // @param op_connector_size The size of each queue in the connector. MapOp(const std::vector &in_col_names, const std::vector &out_col_names, - std::vector> tensor_funcs, int32_t num_workers, int32_t op_connector_size, - bool perf_mode); + std::vector> tensor_funcs, const std::vector &columns_order, + int32_t num_workers, int32_t op_connector_size, bool perf_mode); // Destructor ~MapOp() = default; @@ -181,6 +190,10 @@ class MapOp : public ParallelOp { // @return Name of the current Op std::string Name() const override { return "MapOp"; } + // Columns order getter + // @return The post map columns order + std::vector const &ColumnsOrder() const { return columns_order_; } + private: // Local queues where worker threads can pop from. // Popping directly from the Connector can block if the previous designated threads haven't pop. @@ -202,6 +215,9 @@ class MapOp : public ParallelOp { // Indices of the columns to process. std::vector to_process_indices_; + // Variable to store the column_order of all columns post tensorOps + std::vector columns_order_; + // Performance mode is when the main thread creates local queues, pulls databuffers from the previous // op's Connector and distributes them to the local queues. Workers pull from the local queues. // If this flag is false, each worker pulls directly from the Connector. This use less resources diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc index efe2d8d0613..c047bbc164c 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc @@ -32,7 +32,11 @@ namespace mindspore { namespace dataset { ClueOp::Builder::Builder() - : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { + : builder_device_id_(0), + builder_num_devices_(1), + builder_num_samples_(0), + builder_shuffle_files_(false), + builder_shuffle_global_(false) { std::shared_ptr config_manager = GlobalContext::config_manager(); builder_num_workers_ = config_manager->num_parallel_workers(); builder_op_connector_size_ = config_manager->op_connector_size(); @@ -63,8 +67,8 @@ Status ClueOp::Builder::Build(std::shared_ptr *op) { std::shared_ptr clue_op = std::make_shared( builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map, - builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, - builder_device_id_); + builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_shuffle_global_, + builder_num_devices_, builder_device_id_); RETURN_IF_NOT_OK(clue_op->Init()); *op = std::move(clue_op); @@ -84,7 +88,7 @@ std::vector ClueOp::Builder::split(const std::string &s, char delim ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, ColKeyMap cols_to_keyword, std::vector clue_files_list, int32_t op_connector_size, - bool shuffle_files, int32_t num_device, int32_t device_id) + bool shuffle_files, bool shuffle_global, int32_t num_device, int32_t device_id) : ParallelOp(num_workers, op_connector_size), rows_per_buffer_(rows_per_buffer), num_rows_per_shard_(0), @@ -95,6 +99,7 @@ ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples load_jagged_connector_(true), cols_to_keyword_(cols_to_keyword), shuffle_files_(shuffle_files), + shuffle_global_(shuffle_global), finished_reading_dataset_(false), num_devices_(num_device), device_id_(device_id), diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h index 1b8f23c97be..15961ffd622 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h @@ -104,6 +104,13 @@ class ClueOp : public ParallelOp { return *this; } + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetShuffleGlobal(bool shuffle_global) { + builder_shuffle_global_ = shuffle_global; + return *this; + } + // Setter method. // @return Builder - setter method returns reference to the builder. Builder &SetNumSamples(int64_t num_samples) { @@ -132,13 +139,15 @@ class ClueOp : public ParallelOp { int32_t builder_worker_connector_size_; std::vector builder_clue_files_list_; bool builder_shuffle_files_; + bool builder_shuffle_global_; std::map builder_cols_to_keyword_; }; // Constructor of ClueOp + // @param shuffle_global - whether or not to shuffle the entire dataset. ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, ColKeyMap cols_to_keyword, std::vector clue_files_list, int32_t op_connector_size, - bool shuffle_files, int32_t num_devices, int32_t device_id); + bool shuffle_files, bool shuffle_global, int32_t num_devices, int32_t device_id); // Default destructor ~ClueOp() = default; @@ -169,6 +178,14 @@ class ClueOp : public ParallelOp { // @return Status - the error coed returned. static Status CountAllFileRows(const std::vector &files, int64_t *count); + // File names getter + // @return Vector of the input file names + std::vector FileNames() { return clue_files_list_; } + + // Global shuffle flag getter + // @return Bool - whether this Op requires global shuffle + bool RequireGlobalShuffle() { return shuffle_global_; } + private: // The entry point for when workers are launched. // @param worker_id - the id of the worker that is executing this function. @@ -248,6 +265,7 @@ class ClueOp : public ParallelOp { int32_t device_id_; bool shuffle_files_; + bool shuffle_global_; bool finished_reading_dataset_; int32_t num_devices_; int64_t rows_per_buffer_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc index 8e22c102dd9..26058cc8b8b 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc @@ -33,7 +33,11 @@ namespace mindspore { namespace dataset { TextFileOp::Builder::Builder() - : builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_shuffle_files_(false) { + : builder_device_id_(0), + builder_num_devices_(1), + builder_total_rows_(0), + builder_shuffle_files_(false), + builder_shuffle_global_(false) { std::shared_ptr config_manager = GlobalContext::config_manager(); builder_num_workers_ = config_manager->num_parallel_workers(); builder_op_connector_size_ = config_manager->op_connector_size(); @@ -64,7 +68,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr *op) { std::shared_ptr text_file_op = std::make_shared( builder_num_workers_, builder_rows_per_buffer_, builder_total_rows_, builder_worker_connector_size_, std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_, - builder_num_devices_, builder_device_id_); + builder_shuffle_global_, builder_num_devices_, builder_device_id_); RETURN_IF_NOT_OK(text_file_op->Init()); *op = std::move(text_file_op); @@ -73,7 +77,8 @@ Status TextFileOp::Builder::Build(std::shared_ptr *op) { TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, std::unique_ptr schema, std::vector text_files_list, - int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id) + int32_t op_connector_size, bool shuffle_files, bool shuffle_global, int32_t num_device, + int32_t device_id) : ParallelOp(num_workers, op_connector_size), device_id_(device_id), num_devices_(num_device), @@ -81,6 +86,7 @@ TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t tot total_rows_(total_rows), text_files_list_(std::move(text_files_list)), shuffle_files_(shuffle_files), + shuffle_global_(shuffle_global), data_schema_(std::move(schema)), all_num_rows_(0), num_rows_per_shard_(0), diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h index 2445c5c6be2..dd258d914e4 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h @@ -105,6 +105,13 @@ class TextFileOp : public ParallelOp { return *this; } + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetShuffleGlobal(bool shuffle_global) { + builder_shuffle_global_ = shuffle_global; + return *this; + } + // Setter method. // @return Builder - setter method returns reference to the builder. Builder &SetTotalRows(int64_t total_rows) { @@ -122,6 +129,7 @@ class TextFileOp : public ParallelOp { int32_t builder_worker_connector_size_; std::vector builder_text_files_list_; bool builder_shuffle_files_; + bool builder_shuffle_global_; std::unique_ptr builder_schema_; }; @@ -135,10 +143,11 @@ class TextFileOp : public ParallelOp { // @param op_connector_size - size of each queue in the connector that the child operator pulls from. // @param columns_to_load - the names of the columns to load data from. // @param shuffle_files - whether or not to shuffle the files before reading data. + // @param shuffle_global - whether or not to shuffle the entire dataset. // @param equal_rows_per_shard - whether or not to get equal rows for each process. TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, std::unique_ptr, std::vector text_files_list, int32_t op_connector_size, - bool shuffle_files, int32_t num_devices, int32_t device_id); + bool shuffle_files, bool shuffle_global, int32_t num_devices, int32_t device_id); // Default destructor ~TextFileOp() = default; @@ -173,6 +182,14 @@ class TextFileOp : public ParallelOp { // @return Name of the current Op std::string Name() const override { return "TextFileOp"; } + // File names getter + // @return Vector of the input file names + std::vector FileNames() { return text_files_list_; } + + // Global shuffle flag getter + // @return Bool - whether this Op requires global shuffle + bool RequireGlobalShuffle() { return shuffle_global_; } + private: // The entry point for when workers are launched. // @param worker_id - the id of the worker that is executing this function. @@ -253,6 +270,7 @@ class TextFileOp : public ParallelOp { int64_t total_rows_; std::vector text_files_list_; bool shuffle_files_; + bool shuffle_global_; std::unique_ptr data_schema_; int64_t all_num_rows_; int64_t num_rows_per_shard_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc index 4d3851488a7..b56eecc5110 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc @@ -56,6 +56,7 @@ TFReaderOp::Builder::Builder() builder_op_connector_size_ = config_manager->op_connector_size(); builder_rows_per_buffer_ = config_manager->rows_per_buffer(); builder_shuffle_files_ = false; + builder_shuffle_global_ = false; builder_data_schema_ = std::make_unique(); } @@ -126,7 +127,8 @@ Status TFReaderOp::Builder::Build(std::shared_ptr *out_tf_reader_op) std::shared_ptr new_tf_reader_op = std::make_shared( builder_num_workers_, builder_worker_connector_size_, builder_rows_per_buffer_, builder_total_rows_, builder_dataset_files_list_, std::move(builder_data_schema_), builder_op_connector_size_, builder_columns_to_load_, - builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_); + builder_shuffle_files_, builder_shuffle_global_, builder_num_devices_, builder_device_id_, + builder_equal_rows_per_shard_); RETURN_IF_NOT_OK(new_tf_reader_op->Init()); *out_tf_reader_op = std::move(new_tf_reader_op); @@ -136,8 +138,8 @@ Status TFReaderOp::Builder::Build(std::shared_ptr *out_tf_reader_op) TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows, std::vector dataset_files_list, std::unique_ptr data_schema, int32_t op_connector_size, - std::vector columns_to_load, bool shuffle_files, int32_t num_device, - int32_t device_id, bool equal_rows_per_shard) + std::vector columns_to_load, bool shuffle_files, bool shuffle_global, + int32_t num_device, int32_t device_id, bool equal_rows_per_shard) : ParallelOp(num_workers, op_connector_size), device_id_(device_id), num_devices_(num_device), @@ -147,6 +149,7 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64 columns_to_load_(std::move(columns_to_load)), finished_reading_dataset_(false), shuffle_files_(shuffle_files), + shuffle_global_(shuffle_global), data_schema_(std::move(data_schema)), filename_index_(std::make_unique()), load_io_block_queue_(true), @@ -172,7 +175,8 @@ void TFReaderOp::Print(std::ostream &out, bool show_all) const { // Then show any custom derived-internal stuff out << "\nRows per buffer: " << rows_per_buffer_ << "\nTotal rows: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") - << "\nDataset files list:\n"; + << "\nShuffle global: " << ((shuffle_global_) ? "yes" : "no") + << "\nDataset files list: Size: " << dataset_files_list_.size() << "\n"; for (int i = 0; i < dataset_files_list_.size(); ++i) { out << " " << dataset_files_list_[i]; } @@ -217,7 +221,6 @@ Status TFReaderOp::Init() { // temporary: make size large enough to hold all files + EOE to avoid hangs int32_t safe_queue_size = static_cast(std::ceil(dataset_files_list_.size() / num_workers_)) + 1; io_block_queues_.Init(num_workers_, safe_queue_size); - dataset_files_list_.clear(); // no longer need the original list of files return Status::OK(); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h index 79131c026b2..9c92d6d4be6 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h @@ -146,6 +146,13 @@ class TFReaderOp : public ParallelOp { return *this; } + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetShuffleGlobal(bool shuffle_global) { + builder_shuffle_global_ = shuffle_global; + return *this; + } + // Setter method. // @return Builder - setter method returns reference to the builder. Builder &SetShardEqualRows(bool shard_equal_rows) { @@ -165,6 +172,7 @@ class TFReaderOp : public ParallelOp { std::vector builder_dataset_files_list_; std::vector builder_columns_to_load_; bool builder_shuffle_files_; + bool builder_shuffle_global_; bool builder_equal_rows_per_shard_; }; @@ -179,11 +187,12 @@ class TFReaderOp : public ParallelOp { // @param op_connector_size - size of each queue in the connector that the child operator pulls from. // @param columns_to_load - the names of the columns to load data from. // @param shuffle_files - whether or not to shuffle the files before reading data. + // @param shuffle_global - whether or not to shuffle the entire dataset. // @param equal_rows_per_shard - whether or not to get equal rows for each process. TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows, std::vector dataset_files_list, std::unique_ptr data_schema, int32_t op_connector_size, std::vector columns_to_load, bool shuffle_files, - int32_t num_devices, int32_t device_id, bool equal_rows_per_shard); + bool shuffle_global, int32_t num_devices, int32_t device_id, bool equal_rows_per_shard); // Default destructor ~TFReaderOp() = default; @@ -232,6 +241,14 @@ class TFReaderOp : public ParallelOp { // @return Name of the current Op std::string Name() const override { return "TFReaderOp"; } + // File names getter + // @return Vector of the input file names + std::vector FileNames() { return dataset_files_list_; } + + // Global shuffle flag getter + // @return Bool - whether this Op requires global shuffle + bool RequireGlobalShuffle() { return shuffle_global_; } + private: // The entry point for when workers are launched. // @param worker_id - the id of the worker that is executing this function. @@ -372,6 +389,7 @@ class TFReaderOp : public ParallelOp { std::vector columns_to_load_; bool finished_reading_dataset_; bool shuffle_files_; + bool shuffle_global_; std::unique_ptr data_schema_; std::unique_ptr filename_index_; bool load_io_block_queue_; diff --git a/mindspore/ccsrc/dataset/engine/execution_tree.cc b/mindspore/ccsrc/dataset/engine/execution_tree.cc index 345f9758890..5c921bba84c 100644 --- a/mindspore/ccsrc/dataset/engine/execution_tree.cc +++ b/mindspore/ccsrc/dataset/engine/execution_tree.cc @@ -19,6 +19,8 @@ #include "dataset/engine/datasetops/dataset_op.h" #include "dataset/engine/datasetops/shuffle_op.h" #include "dataset/util/task_manager.h" +#include "dataset/engine/opt/pre/map_column_reorder.h" +#include "dataset/engine/opt/pre/global_shuffle.h" #include "dataset/engine/perf/profiling.h" #include "dataset/engine/perf/monitor.h" @@ -79,8 +81,6 @@ Status ExecutionTree::AssignRoot(const std::shared_ptr &op) { // Then add it as the root. root_ = op; - // The tree has an assigned root now and it's ready to be prepared. - tree_state_ = kDeTStatePrepare; return Status::OK(); } @@ -207,9 +207,24 @@ Status ExecutionTree::Prepare() { return Status::OK(); } -Status ExecutionTree::PrepareTreePreAction() { return Status::OK(); } +Status ExecutionTree::PrepareTreePreAction() { + bool modified = false; + std::vector pre_actions; + // Construct pre actions + pre_actions.push_back(new MapColumnReorder()); + pre_actions.push_back(new GlobalShufflePass()); + // Apply pre action passes + for (auto &pass : pre_actions) { + RETURN_IF_NOT_OK(pass->Run(this, &modified)); + } + return Status::OK(); +} -Status ExecutionTree::PrepareTreePostAction() { return Status::OK(); } +Status ExecutionTree::PrepareTreePostAction() { + // The tree is ready to be prepared. + tree_state_ = kDeTStatePrepare; + return Status::OK(); +} Status ExecutionTree::Optimize() { // auto pp = new PrinterPass(); diff --git a/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt index 9804b85d3ad..170cbb55e53 100644 --- a/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt @@ -2,5 +2,7 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) add_library(engine-opt OBJECT pass.cc + pre/map_column_reorder.cc + pre/global_shuffle.cc util/printer_pass.cc ) \ No newline at end of file diff --git a/mindspore/ccsrc/dataset/engine/opt/pass.cc b/mindspore/ccsrc/dataset/engine/opt/pass.cc index e6bd9fe247f..4f0cfe4a42d 100644 --- a/mindspore/ccsrc/dataset/engine/opt/pass.cc +++ b/mindspore/ccsrc/dataset/engine/opt/pass.cc @@ -37,10 +37,18 @@ namespace mindspore { namespace dataset { // Driver method for TreePass -Status TreePass::Run(ExecutionTree *tree, bool *modified) { return this->RunOnTree(tree, modified); } +Status TreePass::Run(ExecutionTree *tree, bool *modified) { + if (!tree || !modified) { + return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass"); + } + return this->RunOnTree(tree, modified); +} // Driver method for NodePass Status NodePass::Run(ExecutionTree *tree, bool *modified) { + if (!tree || !modified) { + return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass"); + } std::shared_ptr root = tree->root(); if (traversalOrder_ == Order::DFS) { // DFS diff --git a/mindspore/ccsrc/dataset/engine/opt/pass.h b/mindspore/ccsrc/dataset/engine/opt/pass.h index bac464f4016..39682b22f70 100644 --- a/mindspore/ccsrc/dataset/engine/opt/pass.h +++ b/mindspore/ccsrc/dataset/engine/opt/pass.h @@ -57,10 +57,10 @@ class ImageFolderOp; // The actual implementation of the passes will be derived from here. class Pass : public std::enable_shared_from_this { public: - // Run the transformation pass again the execution tree. + // Run the transformation pass against the execution tree. // @param tree - Pointer to the execution tree to be transformed. // @param modified - Pointer to the modified flag, - virtual Status Run(ExecutionTree *tree, bool *modified) { return Status::OK(); } + virtual Status Run(ExecutionTree *tree, bool *modified) = 0; }; // TreePass is a basic Pass class which performs transformation on ExecutionTree directly. diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/global_shuffle.cc b/mindspore/ccsrc/dataset/engine/opt/pre/global_shuffle.cc new file mode 100644 index 00000000000..2adf734a6c4 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/opt/pre/global_shuffle.cc @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "dataset/engine/opt/pre/global_shuffle.h" +#include "dataset/engine/execution_tree.h" +#include "dataset/engine/datasetops/shuffle_op.h" +#include "dataset/engine/datasetops/source/tf_reader_op.h" +#include "dataset/engine/datasetops/source/text_file_op.h" +#include "dataset/engine/datasetops/source/clue_op.h" + +namespace mindspore { +namespace dataset { + +Status GlobalShufflePass::RunOnTree(ExecutionTree *tree, bool *modified) { + std::vector> tf_readers; + std::vector> text_files; + std::vector> clues; + + // Pass 1, search for all sources which requires global shuffle + for (auto &op : *tree) { + if (auto ptr = std::dynamic_pointer_cast(op.shared_from_this())) { + if (ptr->RequireGlobalShuffle()) { + tf_readers.push_back(ptr); + continue; + } + } + if (auto ptr = std::dynamic_pointer_cast(op.shared_from_this())) { + if (ptr->RequireGlobalShuffle()) { + text_files.push_back(ptr); + continue; + } + } + if (auto ptr = std::dynamic_pointer_cast(op.shared_from_this())) { + if (ptr->RequireGlobalShuffle()) { + clues.push_back(ptr); + continue; + } + } + } + + // Pass 2, insert shuffle nodes + // The following blocks can be implemented with template if we unify the CountTotalRows across all source nodes . + for (auto node : tf_readers) { + std::shared_ptr builder = std::make_shared(); + int64_t total_rows = 0; + TFReaderOp::CountTotalRows(&total_rows, node->FileNames(), 8, true); + int32_t avg_file_size = total_rows / (node->FileNames().size()); + builder->SetShuffleSize(std::max(avg_file_size * 4, 10000)); + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + RETURN_IF_NOT_OK(tree->AssociateNode(op)); + RETURN_IF_NOT_OK(node->InsertAsParent(op)); + } + + for (auto node : text_files) { + std::shared_ptr builder = std::make_shared(); + int64_t total_rows = 0; + TextFileOp::CountAllFileRows(node->FileNames(), &total_rows); + int32_t avg_file_size = total_rows / (node->FileNames().size()); + builder->SetShuffleSize(std::max(avg_file_size * 4, 10000)); + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + RETURN_IF_NOT_OK(tree->AssociateNode(op)); + RETURN_IF_NOT_OK(node->InsertAsParent(op)); + } + + for (auto node : clues) { + std::shared_ptr builder = std::make_shared(); + int64_t total_rows = 0; + ClueOp::CountAllFileRows(node->FileNames(), &total_rows); + int32_t avg_file_size = total_rows / (node->FileNames().size()); + builder->SetShuffleSize(std::max(avg_file_size * 4, 10000)); + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + RETURN_IF_NOT_OK(tree->AssociateNode(op)); + RETURN_IF_NOT_OK(node->InsertAsParent(op)); + } + + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/global_shuffle.h b/mindspore/ccsrc/dataset/engine/opt/pre/global_shuffle.h new file mode 100644 index 00000000000..6865ac93911 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/opt/pre/global_shuffle.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_PRE_GLOBALSHUFFLE_H +#define DATASET_ENGINE_OPT_PASS_PRE_GLOBALSHUFFLE_H + +#include +#include "dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { +// Global Shuffle Pass will insert ShuffleOp when the leaf nodes requires global shuffle. +// Example: +// Input Tree: TFReader(GLOBAL_SHUFFLE) -> Batch +// Output Tree: TFReader -> Shuffle -> Batch +class GlobalShufflePass : public TreePass { + Status RunOnTree(ExecutionTree *tree, bool *modified) override; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_PRE_GLOBALSHUFFLE_H diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/map_column_reorder.cc b/mindspore/ccsrc/dataset/engine/opt/pre/map_column_reorder.cc new file mode 100644 index 00000000000..a3dbbfcc545 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/opt/pre/map_column_reorder.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "dataset/engine/opt/pre/map_column_reorder.h" +#include "dataset/engine/execution_tree.h" +#include "dataset/engine/datasetops/map_op.h" +#include "dataset/engine/datasetops/project_op.h" + +namespace mindspore { +namespace dataset { + +Status MapColumnReorder::RunOnTree(ExecutionTree *tree, bool *modified) { + std::vector> to_process; + + // Pass 1, search for all MapOp with column orders + for (auto &op : *tree) { + if (auto mapOp = std::dynamic_pointer_cast(op.shared_from_this())) { + if (mapOp->ColumnsOrder().size() != 0) { + to_process.push_back(mapOp); + } + } + } + + // Pass 2, insert nodes for all MapOp + for (auto node : to_process) { + std::shared_ptr builder = std::make_shared(node->ColumnsOrder()); + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + RETURN_IF_NOT_OK(tree->AssociateNode(op)); + RETURN_IF_NOT_OK(node->InsertAsParent(op)); + } + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/map_column_reorder.h b/mindspore/ccsrc/dataset/engine/opt/pre/map_column_reorder.h new file mode 100644 index 00000000000..84274db3d51 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/opt/pre/map_column_reorder.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_PRE_MAPCOLREORDER_H +#define DATASET_ENGINE_OPT_PASS_PRE_MAPCOLREORDER_H + +#include +#include "dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { +// Map Column Recorder Pass will insert ProjectOp when MapOp requires a full output columns reorder. +// Example: +// Input Tree: TFReader -> MapOp(with col_order) -> Batch +// Output Tree: TFReader -> MapOp -> ProjectOp(col_order) -> Batch +class MapColumnReorder : public TreePass { + Status RunOnTree(ExecutionTree *tree, bool *modified) override; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_PRE_MAPCOLREORDER_H diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 91785d15c17..ed842c1c4e4 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1910,6 +1910,7 @@ class MapDataset(DatasetOp): args["input_columns"] = self.input_columns args["operations"] = self.operations args["output_columns"] = self.output_columns + args["columns_order"] = self.columns_order return args def get_dataset_size(self): @@ -3299,6 +3300,7 @@ class TFRecordDataset(SourceDataset): args["num_samples"] = self.num_samples if self.shuffle_files is not None: args["shuffle_files"] = self.shuffle_files + args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL) args["shuffle"] = self.shuffle_level args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id @@ -4607,6 +4609,7 @@ class CLUEDataset(SourceDataset): args["num_samples"] = self.num_samples if self.shuffle_files is not None: args["shuffle_files"] = self.shuffle_files + args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL) args["shuffle"] = self.shuffle_level args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id @@ -4697,6 +4700,7 @@ class TextFileDataset(SourceDataset): args["num_samples"] = self.num_samples if self.shuffle_files is not None: args["shuffle_files"] = self.shuffle_files + args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL) args["shuffle"] = self.shuffle_level args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index d8cd53982d1..89d8165b1e8 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -49,33 +49,13 @@ def alter_tree(node): def _alter_node(node): - """Performing some alteration to a dataset node. A common alteration is to insert a node.""" - if isinstance(node, (de.TFRecordDataset, de.TextFileDataset, de.CLUEDataset)) \ - and node.shuffle_level == de.Shuffle.GLOBAL: - # Remove the connection between the parent's node to the current node because we are inserting a node. - if node.output: - node.output.pop() - # Perform a fast scan for average rows per file - if isinstance(node, de.TFRecordDataset): - avg_rows_per_file = node.get_dataset_size(True) // len(node.dataset_files) - else: - avg_rows_per_file = node.get_dataset_size() // len(node.dataset_files) - - # Shuffle between 4 files with a minimum size of 10000 rows - new_shuffle = node.shuffle(max(avg_rows_per_file * 4, 10000)) - return new_shuffle - + """DEPRECATED""" + # Please check ccsrc/dataset/engine/opt for tree transformation. if isinstance(node, de.MapDataset): if node.python_multiprocessing: # Bootstrap can only be performed on a copy of the original dataset node. # Bootstrap on original dataset node will make all iterators share the same process pool node.iterator_bootstrap() - if node.columns_order is not None: - # Remove the connection between the parent's node to the current node because we are inserting a node. - if node.output: - node.output.pop() - - return node.project(node.columns_order) return node diff --git a/tests/ut/cpp/dataset/repeat_op_test.cc b/tests/ut/cpp/dataset/repeat_op_test.cc index e32e98cbd79..42549546ba3 100644 --- a/tests/ut/cpp/dataset/repeat_op_test.cc +++ b/tests/ut/cpp/dataset/repeat_op_test.cc @@ -51,6 +51,7 @@ TEST_F(MindDataTestrepeat_op, Testrepeat_opFuntions) { ASSERT_NE(my_tfreader_op, nullptr); parent_op->AddChild(std::move(my_tfreader_op)); MS_LOG(INFO) << parent_op; + my_tree->AssignRoot(parent_op); my_tree->Prepare(); RepeatOp RepeatOpOp(); diff --git a/tests/ut/python/dataset/test_opt_pass.py b/tests/ut/python/dataset/test_opt_pass.py new file mode 100644 index 00000000000..bab881e283b --- /dev/null +++ b/tests/ut/python/dataset/test_opt_pass.py @@ -0,0 +1,90 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np + +import mindspore.dataset as ds + + +def test_map_reorder_pass_0(): + def generator_mc(maxid=1): + for _ in range(maxid): + yield (np.array([0]), np.array([1])) + + # Generator -> Map + data0 = ds.GeneratorDataset(generator_mc, ["col0", "col1"]) + + data0 = data0.map(input_columns="col0", output_columns="out", columns_order=["col1", "out"], + operations=(lambda x: x)) + + for item in data0.create_tuple_iterator(): # each data is a dictionary + assert item == [np.array(1), np.array(0)] + + +def test_map_reorder_pass_1(): + def generator_mc(maxid=1): + for _ in range(maxid): + yield (np.array([0]), np.array([1]), np.array([2])) + + # Three map and zip + data0 = ds.GeneratorDataset(generator_mc, ["a0", "a1", "a2"]) + data0 = data0.map(input_columns="a0", columns_order=["a2", "a1", "a0"], operations=(lambda x: x)) + data1 = ds.GeneratorDataset(generator_mc, ["b0", "b1", "b2"]) + data1 = data1.map(input_columns="b0", columns_order=["b1", "b2", "b0"], operations=(lambda x: x)) + data2 = ds.zip((data0, data1)) + data2 = data2.map(input_columns="a0", columns_order=["b2", "a2", "b1", "a1", "b0", "a0"], operations=(lambda x: x)) + + for item in data2.create_tuple_iterator(): + assert item == [np.array(2), np.array(2), np.array(1), np.array(1), np.array(0), np.array(0)] + + +def test_global_shuffle_pass(): + + FILES = ["../data/dataset/testTFTestAllTypes/test.data"] + SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json" + + ds.config.set_seed(1) + data1 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.GLOBAL) + data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES) + data2 = data2.shuffle(10000) + + for d1, d2 in zip(data1, data2): + for t1, t2 in zip(d1, d2): + assert np.array_equal(t1, t2) + + ds.config.set_seed(1) + DATA_ALL_FILE = "../data/dataset/testTextFileDataset/*" + data1 = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.GLOBAL) + data2 = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.FILES) + data2 = data2.shuffle(10000) + + for d1, d2 in zip(data1, data2): + for t1, t2 in zip(d1, d2): + assert np.array_equal(t1, t2) + + ds.config.set_seed(1) + TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' + data1 = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=ds.Shuffle.GLOBAL) + data2 = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=ds.Shuffle.FILES) + data2 = data2.shuffle(10000) + + for d1, d2 in zip(data1, data2): + for t1, t2 in zip(d1, d2): + assert np.array_equal(t1, t2) + + +if __name__ == "__main__": + test_map_reorder_pass_0() + test_map_reorder_pass_1() + test_global_shuffle_pass()