diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc index 25a8c34de76..ae73cc1bdf9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc @@ -237,8 +237,11 @@ Status CacheOp::Accept(NodePass *p, bool *const modified) { return p->RunOnNode(shared_from_base(), modified); } -// A public wrapper for creating the cache through the client -Status CacheOp::CreateCache(uint32_t cache_crc) { +Status CacheOp::PrepareNodePostAction() { + // Run any common code from super class first before adding our own + RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction()); + // Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache + uint32_t cache_crc = DatasetOp::GenerateCRC(shared_from_this()); // This is a non-mappable cache op so the id's need to be generated. // Construct the cache const bool generate_ids = true; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h index 26cce24496b..30d7ea6ab6d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h @@ -141,11 +141,7 @@ class CacheOp : public CacheBase, public RandomAccessOp { bool AllowCacheMiss() override { return false; } /// \brief Base-class override for the name of this operator std::string Name() const override { return kCacheOp; } - /// \brief A public wrapper for creating the cache through the client - /// \param[in] cache_crc The crc that identifies the cache - /// \see cache_pass.cc - /// \return Status return code - Status CreateCache(uint32_t cache_crc); + Status PrepareNodePostAction() override; private: WaitPost rows_cache_done_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc index b6e992c40d7..8287191d5ad 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc @@ -33,11 +33,7 @@ namespace mindspore { namespace dataset { ClueOp::Builder::Builder() - : builder_device_id_(0), - builder_num_devices_(1), - builder_num_samples_(0), - builder_shuffle_files_(false), - builder_sampler_(nullptr) { + : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { std::shared_ptr config_manager = GlobalContext::config_manager(); builder_num_workers_ = config_manager->num_parallel_workers(); builder_op_connector_size_ = config_manager->op_connector_size(); @@ -74,7 +70,7 @@ Status ClueOp::Builder::Build(std::shared_ptr *op) { std::shared_ptr clue_op = std::make_shared( builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map, builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, - builder_device_id_, std::move(builder_sampler_)); + builder_device_id_); RETURN_IF_NOT_OK(clue_op->Init()); *op = std::move(clue_op); @@ -94,8 +90,8 @@ std::vector ClueOp::Builder::split(const std::string &s, char delim ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, ColKeyMap cols_to_keyword, std::vector clue_files_list, int32_t op_connector_size, - bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr sampler) - : ParallelOp(num_workers, op_connector_size, std::move(sampler)), + bool shuffle_files, int32_t num_device, int32_t device_id) + : ParallelOp(num_workers, op_connector_size), rows_per_buffer_(rows_per_buffer), num_rows_per_shard_(0), all_num_rows_(0), @@ -552,16 +548,6 @@ Status ClueOp::ComputeColMap() { return Status::OK(); } -// Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing -// a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so -// that this clue op will produce the full set of data into the cache. -void ClueOp::MakeSimpleProducer() { - device_id_ = 0; - num_devices_ = 1; - shuffle_files_ = false; - num_samples_ = 0; -} - // Visitor accept method for NodePass Status ClueOp::Accept(NodePass *p, bool *const modified) { // Downcast shared pointer then call visitor diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h index 048db62ad1d..a95b8a16e56 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h @@ -122,14 +122,6 @@ class ClueOp : public ParallelOp { // @return - the a string vector std::vector split(const std::string &s, char delim); - // Setter method - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - private: int32_t builder_device_id_; int32_t builder_num_devices_; @@ -141,13 +133,12 @@ class ClueOp : public ParallelOp { std::vector builder_clue_files_list_; bool builder_shuffle_files_; std::map builder_cols_to_keyword_; - std::shared_ptr builder_sampler_; }; // Constructor of ClueOp ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, ColKeyMap cols_to_keyword, std::vector clue_files_list, int32_t op_connector_size, - bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr sampler); + bool shuffle_files, int32_t num_devices, int32_t device_id); // Default destructor ~ClueOp() = default; @@ -182,11 +173,6 @@ class ClueOp : public ParallelOp { // @return Vector of the input file names std::vector FileNames() { return clue_files_list_; } - /// \Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing - /// a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so - /// that this clue op will produce the full set of data into the cache. - void MakeSimpleProducer(); - // Op name getter // @return Name of the current Op std::string Name() const override { return "ClueOp"; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc index ea8aa5a68ce..9938811a5bd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc @@ -29,11 +29,7 @@ namespace mindspore { namespace dataset { CsvOp::Builder::Builder() - : builder_device_id_(0), - builder_num_devices_(1), - builder_num_samples_(0), - builder_shuffle_files_(false), - builder_sampler_(nullptr) { + : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { std::shared_ptr config_manager = GlobalContext::config_manager(); builder_num_workers_ = config_manager->num_parallel_workers(); builder_op_connector_size_ = config_manager->op_connector_size(); @@ -65,8 +61,7 @@ Status CsvOp::Builder::Build(std::shared_ptr *op) { std::shared_ptr csv_op = std::make_shared( builder_csv_files_list_, builder_field_delim_, builder_column_default_list_, builder_column_name_list_, builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, - builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_, - std::move(builder_sampler_)); + builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_); RETURN_IF_NOT_OK(csv_op->Init()); *op = std::move(csv_op); @@ -77,8 +72,8 @@ CsvOp::CsvOp(const std::vector &csv_files_list, char field_delim, const std::vector> &column_default, const std::vector &column_name, int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, - int32_t num_device, int32_t device_id, std::shared_ptr sampler) - : ParallelOp(num_workers, op_connector_size, std::move(sampler)), + int32_t num_device, int32_t device_id) + : ParallelOp(num_workers, op_connector_size), csv_files_list_(std::move(csv_files_list)), field_delim_(field_delim), column_default_list_(column_default), @@ -920,16 +915,6 @@ Status CsvOp::ComputeColMap() { return Status::OK(); } -// Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing -// a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so -// that this csv op will produce the full set of data into the cache. -void CsvOp::MakeSimpleProducer() { - device_id_ = 0; - num_devices_ = 1; - shuffle_files_ = false; - num_samples_ = 0; -} - // Visitor accept method for NodePass Status CsvOp::Accept(NodePass *p, bool *const modified) { // Downcast shared pointer then call visitor diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h index d8e7473ad82..891a0ca96fe 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h @@ -241,14 +241,6 @@ class CsvOp : public ParallelOp { return *this; } - // Setter method - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - private: int32_t builder_device_id_; int32_t builder_num_devices_; @@ -262,7 +254,6 @@ class CsvOp : public ParallelOp { char builder_field_delim_; std::vector> builder_column_default_list_; std::vector builder_column_name_list_; - std::shared_ptr builder_sampler_; }; // Constructor of CsvOp @@ -271,8 +262,7 @@ class CsvOp : public ParallelOp { CsvOp(const std::vector &csv_files_list, char field_delim, const std::vector> &column_default, const std::vector &column_name, int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, - int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id, - std::shared_ptr sampler); + int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id); // Default destructor ~CsvOp() = default; @@ -308,11 +298,6 @@ class CsvOp : public ParallelOp { // @return Vector of the input file names std::vector FileNames() { return csv_files_list_; } - /// \Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing - /// a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so - /// that this csv op will produce the full set of data into the cache. - void MakeSimpleProducer(); - // Base-class override for NodePass visitor acceptor. // @param p - Pointer to the NodePass to be accepted. // @param modified - Whether this node visit modified the pipeline. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc index 0ec94b5c893..bfccc97ff38 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc @@ -34,8 +34,7 @@ RandomDataOp::Builder::Builder() builder_num_workers_(0), builder_op_connector_size_(0), builder_rows_per_buffer_(0), - builder_total_rows_(0), - builder_sampler_(nullptr) { + builder_total_rows_(0) { // Some arguments to the RandomDataOp have a default argument that is taken from the config. // The user may override these defaults by using the builder set methods. std::shared_ptr cfg = GlobalContext::config_manager(); @@ -48,9 +47,8 @@ RandomDataOp::Builder::Builder() Status RandomDataOp::Builder::Build(std::shared_ptr *out_op) { RETURN_IF_NOT_OK(SanityCheck()); - *out_op = - std::make_shared(builder_num_workers_, builder_op_connector_size_, builder_rows_per_buffer_, - builder_total_rows_, std::move(builder_data_schema_), std::move(builder_sampler_)); + *out_op = std::make_shared(builder_num_workers_, builder_op_connector_size_, builder_rows_per_buffer_, + builder_total_rows_, std::move(builder_data_schema_)); return Status::OK(); } @@ -65,8 +63,8 @@ Status RandomDataOp::Builder::SanityCheck() const { // Constructor for RandomDataOp RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, - std::unique_ptr data_schema, std::shared_ptr sampler) - : ParallelOp(num_workers, op_connector_size, std::move(sampler)), + std::unique_ptr data_schema) + : ParallelOp(num_workers, op_connector_size), buffer_id_(0), rows_per_buffer_(rows_per_buffer), total_rows_(total_rows), @@ -80,8 +78,7 @@ RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64 if (total_rows_ == 0) { total_rows_ = GenRandomInt(1, kMaxTotalRows); } - // If the user did not provide a schema, then we will ask the op to generate a pseudo-random - // schema. + // If the user did not provide a schema, then we will ask the op to generate a pseudo-random schema. // See details of generateSchema function to learn what type of schema it will create. if (data_schema_ == nullptr) { GenerateSchema(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h index 39fbb701137..bec8cb2f540 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h @@ -117,14 +117,6 @@ class RandomDataOp : public ParallelOp { return *this; } - // Setter method - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - private: /** * Check if the required parameters are set by the builder. @@ -133,7 +125,6 @@ class RandomDataOp : public ParallelOp { Status SanityCheck() const; std::unique_ptr builder_data_schema_; - std::shared_ptr builder_sampler_; int32_t builder_num_workers_; int32_t builder_op_connector_size_; int64_t builder_rows_per_buffer_; @@ -148,11 +139,10 @@ class RandomDataOp : public ParallelOp { * @param rows_per_buffer - The number of rows in each DataBuffer * @param data_schema - A user-provided schema * @param total_rows - The total number of rows in the dataset - * @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes * @return Builder - The modified builder by reference */ RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, - std::unique_ptr data_schema, std::shared_ptr sampler); + std::unique_ptr data_schema); /** * Destructor diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc index 9c7443149dd..5ddc4732ec2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc @@ -34,11 +34,7 @@ namespace mindspore { namespace dataset { TextFileOp::Builder::Builder() - : builder_device_id_(0), - builder_num_devices_(1), - builder_total_rows_(0), - builder_shuffle_files_(false), - builder_sampler_(nullptr) { + : builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_shuffle_files_(false) { std::shared_ptr config_manager = GlobalContext::config_manager(); builder_num_workers_ = config_manager->num_parallel_workers(); builder_op_connector_size_ = config_manager->op_connector_size(); @@ -74,7 +70,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr *op) { std::shared_ptr text_file_op = std::make_shared( builder_num_workers_, builder_rows_per_buffer_, builder_total_rows_, builder_worker_connector_size_, std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_, - builder_num_devices_, builder_device_id_, std::move(builder_sampler_)); + builder_num_devices_, builder_device_id_); RETURN_IF_NOT_OK(text_file_op->Init()); *op = std::move(text_file_op); @@ -83,9 +79,8 @@ Status TextFileOp::Builder::Build(std::shared_ptr *op) { TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, std::unique_ptr schema, std::vector text_files_list, - int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id, - std::shared_ptr sampler) - : ParallelOp(num_workers, op_connector_size, std::move(sampler)), + int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id) + : ParallelOp(num_workers, op_connector_size), device_id_(device_id), num_devices_(num_device), rows_per_buffer_(rows_per_buffer), @@ -504,16 +499,6 @@ Status TextFileOp::ComputeColMap() { return Status::OK(); } -// Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing -// a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so -// that this text file op will produce the full set of data into the cache. -void TextFileOp::MakeSimpleProducer() { - device_id_ = 0; - num_devices_ = 1; - shuffle_files_ = false; - total_rows_ = 0; -} - // Visitor accept method for NodePass Status TextFileOp::Accept(NodePass *p, bool *const modified) { // Downcast shared pointer then call visitor diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h index b60ac409009..f253a4e3acd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h @@ -112,14 +112,6 @@ class TextFileOp : public ParallelOp { return *this; } - // Setter method - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - private: int32_t builder_device_id_; int32_t builder_num_devices_; @@ -131,7 +123,6 @@ class TextFileOp : public ParallelOp { std::vector builder_text_files_list_; bool builder_shuffle_files_; std::unique_ptr builder_schema_; - std::shared_ptr builder_sampler_; }; // Constructor of TextFileOp @@ -145,10 +136,9 @@ class TextFileOp : public ParallelOp { // @param columns_to_load - the names of the columns to load data from. // @param shuffle_files - whether or not to shuffle the files before reading data. // @param equal_rows_per_shard - whether or not to get equal rows for each process. - // @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, std::unique_ptr, std::vector text_files_list, int32_t op_connector_size, - bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr sampler); + bool shuffle_files, int32_t num_devices, int32_t device_id); // Default destructor ~TextFileOp() = default; @@ -187,11 +177,6 @@ class TextFileOp : public ParallelOp { // @return Vector of the input file names std::vector FileNames() { return text_files_list_; } - /// \Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing - /// a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so - /// that this text file op will produce the full set of data into the cache. - void MakeSimpleProducer(); - // Base-class override for NodePass visitor acceptor. // @param p - Pointer to the NodePass to be accepted. // @param modified - Whether this node visit modified the pipeline. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc index d551a2591a4..4b9db81fedf 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc @@ -44,11 +44,7 @@ namespace mindspore { namespace dataset { TFReaderOp::Builder::Builder() - : builder_device_id_(0), - builder_num_devices_(1), - builder_total_rows_(0), - builder_equal_rows_per_shard_(false), - builder_sampler_(nullptr) { + : builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_equal_rows_per_shard_(false) { std::shared_ptr config_manager = GlobalContext::config_manager(); builder_num_workers_ = config_manager->num_parallel_workers(); builder_worker_connector_size_ = config_manager->worker_connector_size(); @@ -122,8 +118,7 @@ Status TFReaderOp::Builder::Build(std::shared_ptr *out_tf_reader_op) std::shared_ptr new_tf_reader_op = std::make_shared( builder_num_workers_, builder_worker_connector_size_, builder_rows_per_buffer_, builder_total_rows_, builder_dataset_files_list_, std::move(builder_data_schema_), builder_op_connector_size_, builder_columns_to_load_, - builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_, - std::move(builder_sampler_)); + builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_); RETURN_IF_NOT_OK(new_tf_reader_op->Init()); *out_tf_reader_op = std::move(new_tf_reader_op); @@ -134,8 +129,8 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64 int64_t total_num_rows, std::vector dataset_files_list, std::unique_ptr data_schema, int32_t op_connector_size, std::vector columns_to_load, bool shuffle_files, int32_t num_device, - int32_t device_id, bool equal_rows_per_shard, std::shared_ptr sampler) - : ParallelOp(num_workers, op_connector_size, std::move(sampler)), + int32_t device_id, bool equal_rows_per_shard) + : ParallelOp(num_workers, op_connector_size), device_id_(device_id), num_devices_(num_device), rows_per_buffer_(rows_per_buffer), @@ -1043,17 +1038,6 @@ Status TFReaderOp::ComputeColMap() { return Status::OK(); } -// Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing -// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so -// that this tf reader will produce the full set of data into the cache. -void TFReaderOp::MakeSimpleProducer() { - device_id_ = 0; - num_devices_ = 1; - total_rows_ = 0; - shuffle_files_ = false; - equal_rows_per_shard_ = false; -} - // During tree prepare phase, operators may have specific post-operations to perform depending on // their role. Status TFReaderOp::PrepareNodePostAction() { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h index 6fdabf069e3..cf038a392a7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h @@ -153,17 +153,8 @@ class TFReaderOp : public ParallelOp { return *this; } - // Setter method - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - private: std::unique_ptr builder_data_schema_; - std::shared_ptr builder_sampler_; int32_t builder_device_id_; int32_t builder_num_devices_; int32_t builder_num_workers_; @@ -189,11 +180,10 @@ class TFReaderOp : public ParallelOp { // @param columns_to_load - the names of the columns to load data from. // @param shuffle_files - whether or not to shuffle the files before reading data. // @param equal_rows_per_shard - whether or not to get equal rows for each process. - // @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows, std::vector dataset_files_list, std::unique_ptr data_schema, int32_t op_connector_size, std::vector columns_to_load, bool shuffle_files, - int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr sampler); + int32_t num_devices, int32_t device_id, bool equal_rows_per_shard); // Default destructor ~TFReaderOp() = default; @@ -246,11 +236,6 @@ class TFReaderOp : public ParallelOp { // @return Vector of the input file names std::vector FileNames() { return dataset_files_list_; } - /// \Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing - /// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so - /// that this tf reader will produce the full set of data into the cache. - void MakeSimpleProducer(); - // During tree prepare phase, operators may have specific post-operations to perform depending on // their role. // @notes Derived versions of this function should always call it's superclass version first @@ -387,7 +372,7 @@ class TFReaderOp : public ParallelOp { bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, const int64_t &pre_count); - // Caculate number of rows in each shard. + // Calculate number of rows in each shard. // @return Status - the error code returned. Status CalculateNumRowsPerShard(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc index 2aff6dcec47..d10d70a2246 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc @@ -320,7 +320,6 @@ Status ExecutionTree::PostAction() { // The IR version cannot detect an invalid case of a cache on Map with random tensor operation from Python API. // This is because Python API binding to TensorOperation is still in progress. post_actions.push_back(std::make_unique()); - post_actions.push_back(std::make_unique()); post_actions.push_back(std::make_unique()); #endif diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h index c78560b05c2..f57673a9b27 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h @@ -19,6 +19,7 @@ #include #include "minddata/dataset/engine/datasetops/dataset_op.h" +#include "minddata/dataset/include/samplers.h" #include "minddata/dataset/util/status.h" namespace mindspore::dataset { @@ -29,6 +30,9 @@ class DatasetCache { virtual Status ValidateParams() = 0; virtual Status CreateCacheOp(int num_workers, std::shared_ptr *ds_op) = 0; virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); } + virtual Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr *ds, + std::shared_ptr sampler) = 0; + virtual Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr *ds) = 0; }; } // namespace mindspore::dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc index 648744eb5f6..ac9dc4c0707 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc @@ -16,6 +16,8 @@ #include #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" +#include "minddata/dataset/engine/datasetops/cache_lookup_op.h" +#include "minddata/dataset/engine/datasetops/cache_merge_op.h" #include "minddata/dataset/engine/datasetops/cache_op.h" namespace mindspore { @@ -44,5 +46,28 @@ Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr *ds, + std::shared_ptr sampler) { + CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); + std::shared_ptr lookup_op = nullptr; + RETURN_IF_NOT_OK(CacheLookupOp::Builder() + .SetNumWorkers(num_workers) + .SetClient(cache_client_) + .SetSampler(sampler->SamplerBuild()) + .Build(&lookup_op)); + *ds = lookup_op; + + return Status::OK(); +} + +Status DatasetCacheImpl::CreateCacheMergeOp(int32_t num_workers, std::shared_ptr *ds) { + CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); + std::shared_ptr merge_op = nullptr; + RETURN_IF_NOT_OK(CacheMergeOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&merge_op)); + *ds = merge_op; + + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h index 12c9435e392..f1a0e0b768f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h @@ -56,6 +56,11 @@ class DatasetCacheImpl : public DatasetCache { Status CreateCacheOp(int32_t num_workers, std::shared_ptr *ds) override; + Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr *ds, + std::shared_ptr sampler) override; + + Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr *ds) override; + Status ValidateParams() override { return Status::OK(); } ~DatasetCacheImpl() = default; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.cc index aee0c0b1ab8..4ab5f19ad7c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.cc @@ -16,6 +16,8 @@ #include #include "minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h" +#include "minddata/dataset/engine/datasetops/cache_lookup_op.h" +#include "minddata/dataset/engine/datasetops/cache_merge_op.h" #include "minddata/dataset/engine/datasetops/cache_op.h" namespace mindspore { @@ -46,5 +48,29 @@ Status PreBuiltDatasetCache::to_json(nlohmann::json *out_json) { *out_json = args; return Status::OK(); } + +Status PreBuiltDatasetCache::CreateCacheLookupOp(int32_t num_workers, std::shared_ptr *ds, + std::shared_ptr sampler) { + CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); + std::shared_ptr lookup_op = nullptr; + RETURN_IF_NOT_OK(CacheLookupOp::Builder() + .SetNumWorkers(num_workers) + .SetClient(cache_client_) + .SetSampler(sampler->SamplerBuild()) + .Build(&lookup_op)); + *ds = lookup_op; + + return Status::OK(); +} + +Status PreBuiltDatasetCache::CreateCacheMergeOp(int32_t num_workers, std::shared_ptr *ds) { + CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); + std::shared_ptr merge_op = nullptr; + RETURN_IF_NOT_OK(CacheMergeOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&merge_op)); + *ds = merge_op; + + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h index fa11d9ee87a..d1588ce19c7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h @@ -40,6 +40,11 @@ class PreBuiltDatasetCache : public DatasetCache { Status CreateCacheOp(int32_t num_workers, std::shared_ptr *const ds) override; + Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr *ds, + std::shared_ptr sampler) override; + + Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr *ds) override; + Status ValidateParams() override { return Status::OK(); } Status to_json(nlohmann::json *out_json) override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt index 3974ab2a1b6..29018607b56 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt @@ -8,6 +8,9 @@ set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES bucket_batch_by_length_node.cc build_sentence_piece_vocab_node.cc build_vocab_node.cc + cache_lookup_node.cc + cache_merge_node.cc + cache_node.cc concat_node.cc epoch_ctrl_node.cc filter_node.cc diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.cc new file mode 100644 index 00000000000..63217c381dc --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h" + +#include +#include +#include +#include + +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +CacheLookupNode::CacheLookupNode(std::shared_ptr child, std::shared_ptr sampler, + std::shared_ptr cache) + : DatasetNode(std::move(cache)), sampler_(sampler), lookup_op_(nullptr), lookup_node_copy_(nullptr) { + this->AddChild(child); +} + +void CacheLookupNode::Print(std::ostream &out) const { out << Name(); } + +std::shared_ptr CacheLookupNode::Copy() { + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); + auto node = std::make_shared(nullptr, sampler, cache_); + lookup_node_copy_ = node; + return node; +} + +Status CacheLookupNode::ValidateParams() { + RETURN_IF_NOT_OK(ValidateDatasetSampler("CacheNode", sampler_)); + return Status::OK(); +} + +Status CacheLookupNode::Build(std::vector> *node_ops) { + CHECK_FAIL_RETURN_UNEXPECTED(cache_ != nullptr, + "Internal error. Attempt to create a cache lookup node without cache client."); + RETURN_IF_NOT_OK(cache_->Build()); + RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, &lookup_op_, sampler_)); + node_ops->push_back(lookup_op_); + return Status::OK(); +} + +std::shared_ptr CacheLookupNode::SamplerCopy() { + // CacheLookupNode should already been copied, so we just return it here + return std::static_pointer_cast(lookup_node_copy_); +} + +std::shared_ptr CacheLookupNode::SamplerBuild() { + // Runtime cache lookup op should already been built, so we just return it here + auto lookup_op = std::dynamic_pointer_cast(lookup_op_); + return std::shared_ptr(lookup_op); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.h new file mode 100644 index 00000000000..04a510d34d5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.h @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_LOOKUP_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_LOOKUP_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/cache_lookup_op.h" +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" + +namespace mindspore { +namespace dataset { + +class CacheLookupNode : public DatasetNode, public SamplerObj { + public: + /// \brief Constructor + CacheLookupNode(std::shared_ptr child, std::shared_ptr sampler, + std::shared_ptr cache); + + /// \brief Destructor + ~CacheLookupNode() = default; + + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kCacheLookupNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + + /// \brief a base class override function to convert a SamplerObj class into a runtime sampler object + /// \return Shared pointers to the newly created Sampler + std::shared_ptr SamplerBuild() override; + + /// \brief a base class override function to copy a SamplerObj class + /// \return Shared pointers to the newly copied SamplerObj + std::shared_ptr SamplerCopy() override; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create + /// \return Status Status::OK() if build successfully + Status Build(std::vector> *node_ops) override; + + /// \brief Parameters validation + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; + + private: + std::shared_ptr sampler_; + std::shared_ptr lookup_op_; + std::shared_ptr lookup_node_copy_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_LOOKUP_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_merge_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_merge_node.cc new file mode 100644 index 00000000000..66faa998907 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_merge_node.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h" + +#include +#include +#include +#include + +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/engine/datasetops/cache_merge_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +CacheMergeNode::CacheMergeNode(std::shared_ptr child, std::shared_ptr cache) + : DatasetNode(std::move(cache)) { + nary_op_ = true; + this->AddChild(child); +} + +void CacheMergeNode::Print(std::ostream &out) const { out << Name(); } + +std::shared_ptr CacheMergeNode::Copy() { + auto node = std::make_shared(nullptr, cache_); + return node; +} + +Status CacheMergeNode::ValidateParams() { return Status::OK(); } + +Status CacheMergeNode::Build(std::vector> *node_ops) { + CHECK_FAIL_RETURN_UNEXPECTED(cache_ != nullptr, + "Internal error. Attempt to create a cache merge node without cache client."); + RETURN_IF_NOT_OK(cache_->Build()); + std::shared_ptr merge_op = nullptr; + RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, &merge_op)); + node_ops->push_back(merge_op); + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_merge_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_merge_node.h new file mode 100644 index 00000000000..0afcbc19228 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_merge_node.h @@ -0,0 +1,60 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_MERGE_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_MERGE_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" + +namespace mindspore { +namespace dataset { + +class CacheMergeNode : public DatasetNode { + public: + /// \brief Constructor + CacheMergeNode(std::shared_ptr child, std::shared_ptr cache); + + /// \brief Destructor + ~CacheMergeNode() = default; + + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kCacheMergeNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create + /// \return Status Status::OK() if build successfully + Status Build(std::vector> *node_ops) override; + + /// \brief Parameters validation + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_MERGE_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_node.cc new file mode 100644 index 00000000000..38edbeb4680 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_node.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/ir/datasetops/cache_node.h" + +#include +#include +#include +#include + +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/engine/datasetops/cache_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +CacheNode::CacheNode(std::shared_ptr child, std::shared_ptr sampler, + std::shared_ptr cache) + : DatasetNode(std::move(cache)), sampler_(sampler) { + this->AddChild(child); +} + +void CacheNode::Print(std::ostream &out) const { out << Name(); } + +std::shared_ptr CacheNode::Copy() { + std::shared_ptr sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); + auto node = std::make_shared(nullptr, sampler, cache_); + return node; +} + +Status CacheNode::ValidateParams() { + RETURN_IF_NOT_OK(ValidateDatasetSampler("CacheNode", sampler_)); + return Status::OK(); +} + +Status CacheNode::Build(std::vector> *node_ops) { + CHECK_FAIL_RETURN_UNEXPECTED(cache_ != nullptr, + "Internal error. Attempt to create a cache node without cache client."); + RETURN_IF_NOT_OK(cache_->Build()); + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op)); + cache_op->SetSampler(sampler_->SamplerBuild()); + node_ops->push_back(cache_op); + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_node.h new file mode 100644 index 00000000000..25d969a23ff --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_node.h @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" + +namespace mindspore { +namespace dataset { + +class CacheNode : public DatasetNode { + public: + /// \brief Constructor + CacheNode(std::shared_ptr child, std::shared_ptr sampler, + std::shared_ptr cache); + + /// \brief Destructor + ~CacheNode() = default; + + /// \brief Node name getter + /// \return Name of the current node + std::string Name() const override { return kCacheNode; } + + /// \brief Print the description + /// \param out - The output stream to write output to + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object + /// \return A shared pointer to the new copy + std::shared_ptr Copy() override; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create + /// \return Status Status::OK() if build successfully + Status Build(std::vector> *node_ops) override; + + /// \brief Parameters validation + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; + + private: + std::shared_ptr sampler_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc index 2315ca2efc0..174a9958592 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc @@ -204,15 +204,6 @@ std::shared_ptr SelectSampler(int64_t num_samples, bool shuffle, int return SequentialSampler(0, num_samples); } -Status DatasetNode::AddCacheOp(std::vector> *node_ops) { - if (cache_ != nullptr) { - RETURN_IF_NOT_OK(cache_->Build()); - std::shared_ptr cache_op; - RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op)); - node_ops->push_back(cache_op); - } - return Status::OK(); -} // Constructor to initialize the cache DatasetNode::DatasetNode(const std::shared_ptr &dataset_cache) : DatasetNode() { cache_ = dataset_cache; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h index 1ac53cd87fb..b96f89523a3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h @@ -53,6 +53,9 @@ constexpr char kBatchNode[] = "Batch"; constexpr char kBucketBatchByLengthNode[] = "BucketBatchByLength"; constexpr char kBuildSentencePieceVocabNode[] = "BuildSentencePieceVocab"; constexpr char kBuildVocabNode[] = "BuildVocab"; +constexpr char kCacheLookupNode[] = "CacheLookup"; +constexpr char kCacheMergeNode[] = "CacheMerge"; +constexpr char kCacheNode[] = "Cache"; constexpr char kConcatNode[] = "Concat"; constexpr char kEpochCtrlNode[] = "EpochCtrl"; constexpr char kFilterNode[] = "Filter"; @@ -248,6 +251,9 @@ class DatasetNode : public std::enable_shared_from_this { /// \brief Getter of the number of workers int32_t num_workers() { return num_workers_; } + /// \brief Getter of dataset cache + std::shared_ptr GetDatasetCache() { return cache_; } + /// \brief Setter function for runtime number of workers /// \param[in] num_workers The number of threads in this operator /// \return Shared pointer to the original object @@ -299,7 +305,6 @@ class DatasetNode : public std::enable_shared_from_this { // Used only in the constructor of the class and its derived classes. void AddChild(std::shared_ptr child); std::string PrintColumns(const std::vector &columns) const; - Status AddCacheOp(std::vector> *node_ops); void PrintNode(std::ostream &out, int *level) const; enum DataSource { kNotADataSource = 0, kNonMappableSource = 1, kMappableSource = 2 }; enum DataSource mappable_; @@ -360,6 +365,20 @@ class NonMappableSourceNode : public DatasetNode { /// \brief Node name getter /// \return Name of the current node virtual std::string Name() const = 0; + + /// \brief By default non-mappable dataset does not support sampling. However, if a cache operator + /// is injected at some other place higher in the tree, that cache can inherit this sampler + /// from the leaf, providing sampling support from the caching layer. + /// This function sets up the sampler for a leaf node that does not use sampling. + /// \param[in] sampler The sampler to setup + /// \return Status of the function + virtual Status SetupSamplerForCache(std::shared_ptr *sampler) = 0; + + /// \brief If a cache has been added into the ascendant tree over this non-mappable source node, then the cache will + /// be executing a sampler for fetching the data. As such, any options in the source node need to be reset to its + /// defaults so that this source node will produce the full set of data into the cache. + /// \return Status of the function + virtual Status MakeSimpleProducer() = 0; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc index 90acb95463e..d325f9d0998 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc @@ -76,7 +76,6 @@ Status MapNode::Build(std::vector> *const node_ops) { auto project_op = std::make_shared(project_columns_); node_ops->push_back(project_op); } - RETURN_IF_NOT_OK(AddCacheOp(node_ops)); node_ops->push_back(map_op); return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc index 8cfbb9ba598..1e2d20428de 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc @@ -72,8 +72,6 @@ Status AlbumNode::Build(std::vector> *const node_ops) // Argument that is not exposed to user in the API. std::set extensions = {}; - RETURN_IF_NOT_OK(AddCacheOp(node_ops)); - node_ops->push_back(std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, decode_, extensions, std::move(schema), std::move(sampler_->SamplerBuild()))); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc index 87668c7490a..386afe33ee9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc @@ -67,8 +67,6 @@ Status CelebANode::Build(std::vector> *const node_ops // label is like this:0 1 0 0 1...... RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(AddCacheOp(node_ops)); - node_ops->push_back(std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, decode_, usage_, extensions_, std::move(schema), std::move(sampler_->SamplerBuild()))); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc index 5eced16efa9..4e78fe5dcd9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc @@ -64,8 +64,6 @@ Status Cifar100Node::Build(std::vector> *const node_o RETURN_IF_NOT_OK( schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); - RETURN_IF_NOT_OK(AddCacheOp(node_ops)); - node_ops->push_back(std::make_shared(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, std::move(schema), std::move(sampler_->SamplerBuild()))); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc index 6d99c6c79f3..7d188331c83 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc @@ -62,8 +62,6 @@ Status Cifar10Node::Build(std::vector> *const node_op RETURN_IF_NOT_OK( schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); - RETURN_IF_NOT_OK(AddCacheOp(node_ops)); - node_ops->push_back(std::make_shared(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, std::move(schema), std::move(sampler_->SamplerBuild()))); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc index ea3a2f25649..f37b62e6565 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc @@ -83,84 +83,66 @@ std::vector CLUENode::split(const std::string &s, char delim) { return res; } -// Function to build CLUENode -Status CLUENode::Build(std::vector> *const node_ops) { +std::map CLUENode::CreateKeyMapForBuild() { std::map key_map; if (task_ == "AFQMC") { - if (usage_ == "train") { + if (usage_ == "train" || usage_ == "eval") { key_map["sentence1"] = "sentence1"; key_map["sentence2"] = "sentence2"; key_map["label"] = "label"; - } else if (usage_ == "test") { + } else { // usage_ == "test" key_map["id"] = "id"; key_map["sentence1"] = "sentence1"; key_map["sentence2"] = "sentence2"; - } else if (usage_ == "eval") { - key_map["sentence1"] = "sentence1"; - key_map["sentence2"] = "sentence2"; - key_map["label"] = "label"; } - } else if (task_ == "CMNLI") { - if (usage_ == "train") { + } + if (task_ == "CMNLI") { + if (usage_ == "train" || usage_ == "eval") { key_map["sentence1"] = "sentence1"; key_map["sentence2"] = "sentence2"; key_map["label"] = "label"; - } else if (usage_ == "test") { + } else { // usage_ == "test" key_map["id"] = "id"; key_map["sentence1"] = "sentence1"; key_map["sentence2"] = "sentence2"; - } else if (usage_ == "eval") { - key_map["sentence1"] = "sentence1"; - key_map["sentence2"] = "sentence2"; - key_map["label"] = "label"; } - } else if (task_ == "CSL") { - if (usage_ == "train") { + } + if (task_ == "CSL") { + if (usage_ == "train" || usage_ == "eval") { key_map["id"] = "id"; key_map["abst"] = "abst"; key_map["keyword"] = "keyword"; key_map["label"] = "label"; - } else if (usage_ == "test") { + } else { // usage_ == "test" key_map["id"] = "id"; key_map["abst"] = "abst"; key_map["keyword"] = "keyword"; - } else if (usage_ == "eval") { - key_map["id"] = "id"; - key_map["abst"] = "abst"; - key_map["keyword"] = "keyword"; - key_map["label"] = "label"; } - } else if (task_ == "IFLYTEK") { - if (usage_ == "train") { + } + if (task_ == "IFLYTEK") { + if (usage_ == "train" || usage_ == "eval") { key_map["label"] = "label"; key_map["label_des"] = "label_des"; key_map["sentence"] = "sentence"; - } else if (usage_ == "test") { + } else { // usage_ == "test" key_map["id"] = "id"; key_map["sentence"] = "sentence"; - } else if (usage_ == "eval") { - key_map["label"] = "label"; - key_map["label_des"] = "label_des"; - key_map["sentence"] = "sentence"; } - } else if (task_ == "TNEWS") { - if (usage_ == "train") { + } + if (task_ == "TNEWS") { + if (usage_ == "train" || usage_ == "eval") { key_map["label"] = "label"; key_map["label_desc"] = "label_desc"; key_map["sentence"] = "sentence"; key_map["keywords"] = "keywords"; - } else if (usage_ == "test") { + } else { // usage_ == "test" key_map["id"] = "id"; key_map["sentence"] = "sentence"; key_map["keywords"] = "keywords"; - } else if (usage_ == "eval") { - key_map["label"] = "label"; - key_map["label_desc"] = "label_desc"; - key_map["sentence"] = "sentence"; - key_map["keywords"] = "keywords"; } - } else if (task_ == "WSC") { - if (usage_ == "train") { + } + if (task_ == "WSC") { + if (usage_ == "train" || usage_ == "eval") { key_map["span1_index"] = "target/span1_index"; key_map["span2_index"] = "target/span2_index"; key_map["span1_text"] = "target/span1_text"; @@ -168,24 +150,21 @@ Status CLUENode::Build(std::vector> *const node_ops) key_map["idx"] = "idx"; key_map["label"] = "label"; key_map["text"] = "text"; - } else if (usage_ == "test") { + } else { // usage_ == "test" key_map["span1_index"] = "target/span1_index"; key_map["span2_index"] = "target/span2_index"; key_map["span1_text"] = "target/span1_text"; key_map["span2_text"] = "target/span2_text"; key_map["idx"] = "idx"; key_map["text"] = "text"; - } else if (usage_ == "eval") { - key_map["span1_index"] = "target/span1_index"; - key_map["span2_index"] = "target/span2_index"; - key_map["span1_text"] = "target/span1_text"; - key_map["span2_text"] = "target/span2_text"; - key_map["idx"] = "idx"; - key_map["label"] = "label"; - key_map["text"] = "text"; } } + return key_map; +} +// Function to build CLUENode +Status CLUENode::Build(std::vector> *const node_ops) { + auto key_map = CreateKeyMapForBuild(); ColKeyMap ck_map; for (auto &p : key_map) { ck_map.insert({p.first, split(p.second, '/')}); @@ -193,19 +172,13 @@ Status CLUENode::Build(std::vector> *const node_ops) bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); - // ClueOp by itself is a non-mappable dataset that does not support sampling. - // However, if a cache operator is injected at some other place higher in the tree, that cache can - // inherit this sampler from the leaf, providing sampling support from the caching layer. - // That is why we save the sampler here in a leaf node that does not use sampling. - std::shared_ptr sampler_ = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); - // Sort the dataset files in a lexicographical order std::vector sorted_dataset_files = dataset_files_; std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end()); - std::shared_ptr clue_op = std::make_shared( - num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, sorted_dataset_files, - connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->SamplerBuild())); + std::shared_ptr clue_op = + std::make_shared(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, + sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_); RETURN_IF_NOT_OK(clue_op->Init()); @@ -222,7 +195,6 @@ Status CLUENode::Build(std::vector> *const node_ops) rows_per_buffer_, &shuffle_op)); node_ops->push_back(shuffle_op); } - RETURN_IF_NOT_OK(AddCacheOp(node_ops)); node_ops->push_back(clue_op); @@ -270,5 +242,27 @@ Status CLUENode::to_json(nlohmann::json *out_json) { *out_json = args; return Status::OK(); } + +// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class. +// CLUE by itself is a non-mappable dataset that does not support sampling. +// However, if a cache operator is injected at some other place higher in the tree, that cache can +// inherit this sampler from the leaf, providing sampling support from the caching layer. +// That is why we setup the sampler for a leaf node that does not use sampling. +Status CLUENode::SetupSamplerForCache(std::shared_ptr *sampler) { + bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); + *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); + return Status::OK(); +} + +// If a cache has been added into the ascendant tree over this clue node, then the cache will be executing +// a sampler for fetching the data. As such, any options in the clue node need to be reset to its defaults so +// that this clue node will produce the full set of data into the cache. +Status CLUENode::MakeSimpleProducer() { + shard_id_ = 0; + num_shards_ = 1; + shuffle_ = ShuffleMode::kFalse; + num_samples_ = 0; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h index cffe8d026ce..033e2518136 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CLUE_NODE_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CLUE_NODE_H_ +#include #include #include #include @@ -49,6 +50,10 @@ class CLUENode : public NonMappableSourceNode { /// \return A shared pointer to the new copy std::shared_ptr Copy() override; + /// \brief Generate a key map to be used in Build() according to usage and task + /// \return The generated key map + std::map CreateKeyMapForBuild(); + /// \brief a base class override function to create the required runtime dataset op objects for this class /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create /// \return Status Status::OK() if build successfully @@ -85,6 +90,22 @@ class CLUENode : public NonMappableSourceNode { /// \return Status of the function Status to_json(nlohmann::json *out_json) override; + /// \brief CLUE by itself is a non-mappable dataset that does not support sampling. + /// However, if a cache operator is injected at some other place higher in the tree, that cache can + /// inherit this sampler from the leaf, providing sampling support from the caching layer. + /// That is why we setup the sampler for a leaf node that does not use sampling. + /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. + /// \param[in] sampler The sampler to setup + /// \return Status of the function + Status SetupSamplerForCache(std::shared_ptr *sampler) override; + + /// \brief If a cache has been added into the ascendant tree over this clue node, then the cache will be executing + /// a sampler for fetching the data. As such, any options in the clue node need to be reset to its defaults so + /// that this clue node will produce the full set of data into the cache. + /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. + /// \return Status of the function + Status MakeSimpleProducer() override; + private: /// \brief Split string based on a character delimiter /// \return A string vector diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc index bde444f1f2f..309619737d4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc @@ -122,7 +122,6 @@ Status CocoNode::Build(std::vector> *const node_ops) std::shared_ptr op = std::make_shared(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_, connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild())); - RETURN_IF_NOT_OK(AddCacheOp(node_ops)); node_ops->push_back(op); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc index d2ac94b958b..b9cc83e2c2a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc @@ -95,12 +95,6 @@ Status CSVNode::ValidateParams() { Status CSVNode::Build(std::vector> *const node_ops) { bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); - // CSVOp by itself is a non-mappable dataset that does not support sampling. - // However, if a cache operator is injected at some other place higher in the tree, that cache can - // inherit this sampler from the leaf, providing sampling support from the caching layer. - // That is why we save the sampler here in a leaf node that does not use sampling. - std::shared_ptr sampler_ = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); - // Sort the dataset files in a lexicographical order std::vector sorted_dataset_files = dataset_files_; std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end()); @@ -119,10 +113,9 @@ Status CSVNode::Build(std::vector> *const node_ops) { } } - std::shared_ptr csv_op = - std::make_shared(sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, - rows_per_buffer_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, - num_shards_, shard_id_, std::move(sampler_->SamplerBuild())); + std::shared_ptr csv_op = std::make_shared( + sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, rows_per_buffer_, + num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_); RETURN_IF_NOT_OK(csv_op->Init()); @@ -140,7 +133,6 @@ Status CSVNode::Build(std::vector> *const node_ops) { node_ops->push_back(shuffle_op); } - RETURN_IF_NOT_OK(AddCacheOp(node_ops)); node_ops->push_back(csv_op); @@ -188,5 +180,27 @@ Status CSVNode::to_json(nlohmann::json *out_json) { *out_json = args; return Status::OK(); } + +// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class. +// CSV by itself is a non-mappable dataset that does not support sampling. +// However, if a cache operator is injected at some other place higher in the tree, that cache can +// inherit this sampler from the leaf, providing sampling support from the caching layer. +// That is why we setup the sampler for a leaf node that does not use sampling. +Status CSVNode::SetupSamplerForCache(std::shared_ptr *sampler) { + bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); + *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); + return Status::OK(); +} + +// If a cache has been added into the ascendant tree over this CSV node, then the cache will be executing +// a sampler for fetching the data. As such, any options in the CSV node need to be reset to its defaults so +// that this CSV node will produce the full set of data into the cache. +Status CSVNode::MakeSimpleProducer() { + shard_id_ = 0; + num_shards_ = 1; + shuffle_ = ShuffleMode::kFalse; + num_samples_ = 0; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h index 03d45393661..2c774991631 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h @@ -107,6 +107,22 @@ class CSVNode : public NonMappableSourceNode { /// \return Status of the function Status to_json(nlohmann::json *out_json) override; + /// \brief CSV by itself is a non-mappable dataset that does not support sampling. + /// However, if a cache operator is injected at some other place higher in the tree, that cache can + /// inherit this sampler from the leaf, providing sampling support from the caching layer. + /// That is why we setup the sampler for a leaf node that does not use sampling. + /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. + /// \param[in] sampler The sampler to setup + /// \return Status of the function + Status SetupSamplerForCache(std::shared_ptr *sampler) override; + + /// \brief If a cache has been added into the ascendant tree over this CSV node, then the cache will be executing + /// a sampler for fetching the data. As such, any options in the CSV node need to be reset to its defaults so + /// that this CSV node will produce the full set of data into the cache. + /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. + /// \return Status of the function + Status MakeSimpleProducer() override; + private: std::vector dataset_files_; char field_delim_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h index 7578c7302f0..0054cf972da 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h @@ -95,10 +95,10 @@ class GeneratorNode : public MappableSourceNode { /// \brief Sampler getter /// \return SamplerObj of the current node - std::shared_ptr Sampler() override { return nullptr; } + std::shared_ptr Sampler() override { return sampler_; } /// \brief Sampler setter - void SetSampler(std::shared_ptr sampler) override {} + void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } private: py::function generator_function_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc index 9bfbf852bfb..b2c852764fc 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc @@ -70,8 +70,6 @@ Status ImageFolderNode::Build(std::vector> *const nod RETURN_IF_NOT_OK( schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); - RETURN_IF_NOT_OK(AddCacheOp(node_ops)); - node_ops->push_back(std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, recursive_, decode_, exts_, class_indexing_, std::move(schema), std::move(sampler_->SamplerBuild()))); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc index bd8f0622a75..16077144b37 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc @@ -94,7 +94,6 @@ Status ManifestNode::Build(std::vector> *const node_o manifest_op = std::make_shared(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_, class_index_, std::move(schema), std::move(sampler_->SamplerBuild()), usage_); - RETURN_IF_NOT_OK(AddCacheOp(node_ops)); node_ops->push_back(manifest_op); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc index b76155ec1ee..804abb07630 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc @@ -23,8 +23,9 @@ #include #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" - +#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/util/status.h" + namespace mindspore { namespace dataset { @@ -203,5 +204,16 @@ Status MindDataNode::GetDatasetSize(const std::shared_ptr &si return Status::OK(); } +// Visitor accepting method for IRNodePass +Status MindDataNode::Accept(IRNodePass *const p, bool *const modified) { + // Downcast shared pointer then call visitor + return p->Visit(shared_from_base(), modified); +} + +// Visitor accepting method for IRNodePass +Status MindDataNode::AcceptAfter(IRNodePass *p, bool *const modified) { + // Downcast shared pointer then call visitor + return p->VisitAfter(shared_from_base(), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h index abec37ed34a..1434137a25a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h @@ -92,6 +92,18 @@ class MindDataNode : public MappableSourceNode { /// \brief Sampler setter void SetSampler(std::shared_ptr sampler) override { sampler_ = sampler; } + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(IRNodePass *p, bool *const modified) override; + + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status AcceptAfter(IRNodePass *p, bool *const modified) override; + private: std::string dataset_file_; // search_for_pattern_ will be true in this mode std::vector dataset_files_; // search_for_pattern_ will be false in this mode diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc index 2c32a7fe555..5a3efcfba14 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc @@ -57,7 +57,6 @@ Status MnistNode::Build(std::vector> *const node_ops) TensorShape scalar = TensorShape::CreateScalar(); RETURN_IF_NOT_OK( schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); - RETURN_IF_NOT_OK(AddCacheOp(node_ops)); node_ops->push_back(std::make_shared(usage_, num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, std::move(schema), diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc index af24433064f..4c6cf420c7c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc @@ -22,6 +22,7 @@ #include #include "minddata/dataset/engine/datasetops/source/random_data_op.h" +#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/util/random.h" #include "minddata/dataset/util/status.h" @@ -105,17 +106,9 @@ Status RandomNode::Build(std::vector> *const node_ops } } - // RandomOp by itself is a non-mappable dataset that does not support sampling. - // However, if a cache operator is injected at some other place higher in the tree, that cache can - // inherit this sampler from the leaf, providing sampling support from the caching layer. - // That is why we save the sampler here in a leaf node that does not use sampling. - // RandomOp doesn't support sampler, should not support sharding, select sampler should just be sequential. - std::shared_ptr sampler_ = SelectSampler(total_rows_, false, 1, 0); - std::shared_ptr op; op = std::make_shared(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_, - std::move(data_schema_), std::move(sampler_->SamplerBuild())); - RETURN_IF_NOT_OK(AddCacheOp(node_ops)); + std::move(data_schema_)); node_ops->push_back(op); @@ -142,5 +135,27 @@ Status RandomNode::GetDatasetSize(const std::shared_ptr &size dataset_size_ = *dataset_size; return Status::OK(); } + +// RandomDataset by itself is a non-mappable dataset that does not support sampling. +// However, if a cache operator is injected at some other place higher in the tree, that cache can +// inherit this sampler from the leaf, providing sampling support from the caching layer. +// That is why we setup the sampler for a leaf node that does not use sampling. +Status RandomNode::SetupSamplerForCache(std::shared_ptr *sampler) { + // RandomOp doesn't support sampler, should not support sharding, select sampler should just be sequential. + *sampler = SelectSampler(total_rows_, false, 1, 0); + return Status::OK(); +} + +// Visitor accepting method for IRNodePass +Status RandomNode::Accept(IRNodePass *p, bool *const modified) { + // Downcast shared pointer then call visitor + return p->Visit(shared_from_base(), modified); +} + +// Visitor accepting method for IRNodePass +Status RandomNode::AcceptAfter(IRNodePass *p, bool *const modified) { + // Downcast shared pointer then call visitor + return p->VisitAfter(shared_from_base(), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h index 3bfbc4ac156..eeae3cd739a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h @@ -99,6 +99,30 @@ class RandomNode : public NonMappableSourceNode { const std::mt19937 &RandGen() const { return rand_gen_; } const std::unique_ptr &GetDataSchema() const { return data_schema_; } + /// \brief RandomDataset by itself is a non-mappable dataset that does not support sampling. + /// However, if a cache operator is injected at some other place higher in the tree, that cache can + /// inherit this sampler from the leaf, providing sampling support from the caching layer. + /// That is why we setup the sampler for a leaf node that does not use sampling. + /// \param[in] sampler The sampler to setup + /// \return Status of the function + Status SetupSamplerForCache(std::shared_ptr *sampler) override; + + /// \brief Random node will always produce the full set of data into the cache + /// \return Status of the function + Status MakeSimpleProducer() override { return Status::OK(); } + + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(IRNodePass *p, bool *const modified) override; + + /// \brief Base-class override for accepting IRNodePass visitor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status AcceptAfter(IRNodePass *p, bool *const modified) override; + private: /// \brief A quick inline for producing a random number between (and including) min/max /// \param[in] min minimum number that can be generated. diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc index e89adac546b..e836f88c7c5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc @@ -73,12 +73,6 @@ Status TextFileNode::ValidateParams() { Status TextFileNode::Build(std::vector> *const node_ops) { bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); - // TextFileOp by itself is a non-mappable dataset that does not support sampling. - // However, if a cache operator is injected at some other place higher in the tree, that cache can - // inherit this sampler from the leaf, providing sampling support from the caching layer. - // That is why we save the sampler here in a leaf node that does not use sampling. - std::shared_ptr sampler_ = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); - // Sort the dataset files in a lexicographical order std::vector sorted_dataset_files = dataset_files_; std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end()); @@ -87,10 +81,10 @@ Status TextFileNode::Build(std::vector> *const node_o auto schema = std::make_unique(); RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - // Create and initalize TextFileOp + // Create and initialize TextFileOp std::shared_ptr text_file_op = std::make_shared( num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files, - connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->SamplerBuild())); + connector_que_size_, shuffle_files, num_shards_, shard_id_); RETURN_IF_NOT_OK(text_file_op->Init()); if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) { @@ -106,7 +100,6 @@ Status TextFileNode::Build(std::vector> *const node_o rows_per_buffer_, &shuffle_op)); node_ops->push_back(shuffle_op); } - RETURN_IF_NOT_OK(AddCacheOp(node_ops)); // Add TextFileOp node_ops->push_back(text_file_op); @@ -152,5 +145,27 @@ Status TextFileNode::to_json(nlohmann::json *out_json) { *out_json = args; return Status::OK(); } + +// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class. +// TextFile by itself is a non-mappable dataset that does not support sampling. +// However, if a cache operator is injected at some other place higher in the tree, that cache can +// inherit this sampler from the leaf, providing sampling support from the caching layer. +// That is why we setup the sampler for a leaf node that does not use sampling. +Status TextFileNode::SetupSamplerForCache(std::shared_ptr *sampler) { + bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); + *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); + return Status::OK(); +} + +// If a cache has been added into the ascendant tree over this TextFile node, then the cache will be executing +// a sampler for fetching the data. As such, any options in the TextFile node need to be reset to its defaults so +// that this TextFile node will produce the full set of data into the cache. +Status TextFileNode::MakeSimpleProducer() { + shard_id_ = 0; + num_shards_ = 1; + shuffle_ = ShuffleMode::kFalse; + num_samples_ = 0; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h index 300251c3f92..9cea20f09aa 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h @@ -83,6 +83,22 @@ class TextFileNode : public NonMappableSourceNode { /// \return Status of the function Status to_json(nlohmann::json *out_json) override; + /// \brief TextFile by itself is a non-mappable dataset that does not support sampling. + /// However, if a cache operator is injected at some other place higher in the tree, that cache can + /// inherit this sampler from the leaf, providing sampling support from the caching layer. + /// That is why we setup the sampler for a leaf node that does not use sampling. + /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. + /// \param[in] sampler The sampler to setup + /// \return Status of the function + Status SetupSamplerForCache(std::shared_ptr *sampler) override; + + /// \brief If a cache has been added into the ascendant tree over this TextFile node, then the cache will be executing + /// a sampler for fetching the data. As such, any options in the TextFile node need to be reset to its defaults + /// so that this TextFile node will produce the full set of data into the cache. + /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. + /// \return Status of the function + Status MakeSimpleProducer() override; + private: std::vector dataset_files_; int32_t num_samples_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc index 901a13ee2bb..b9fbc77e1c4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc @@ -121,17 +121,10 @@ Status TFRecordNode::Build(std::vector> *const node_o bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); - // TFReaderOp by itself is a non-mappable dataset that does not support sampling. - // However, if a cache operator is injected at some other place higher in the tree, that cache can - // inherit this sampler from the leaf, providing sampling support from the caching layer. - // That is why we save the sampler here in a leaf node that does not use sampling. - std::shared_ptr sampler_ = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); - // Create and initialize TFReaderOp - std::shared_ptr tf_reader_op = - std::make_shared(num_workers_, worker_connector_size_, rows_per_buffer_, num_samples_, sorted_dir_files, - std::move(data_schema), connector_que_size_, columns_list_, shuffle_files, num_shards_, - shard_id_, shard_equal_rows_, std::move(sampler_->SamplerBuild())); + std::shared_ptr tf_reader_op = std::make_shared( + num_workers_, worker_connector_size_, rows_per_buffer_, num_samples_, sorted_dir_files, std::move(data_schema), + connector_que_size_, columns_list_, shuffle_files, num_shards_, shard_id_, shard_equal_rows_); RETURN_IF_NOT_OK(tf_reader_op->Init()); @@ -149,7 +142,6 @@ Status TFRecordNode::Build(std::vector> *const node_o rows_per_buffer_, &shuffle_op)); node_ops->push_back(shuffle_op); } - RETURN_IF_NOT_OK(AddCacheOp(node_ops)); // Add TFReaderOp node_ops->push_back(tf_reader_op); @@ -227,5 +219,29 @@ Status TFRecordNode::to_json(nlohmann::json *out_json) { *out_json = args; return Status::OK(); } + +// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class. +// TFRecord by itself is a non-mappable dataset that does not support sampling. +// However, if a cache operator is injected at some other place higher in the tree, that cache can +// inherit this sampler from the leaf, providing sampling support from the caching layer. +// That is why we setup the sampler for a leaf node that does not use sampling. +Status TFRecordNode::SetupSamplerForCache(std::shared_ptr *sampler) { + bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); + *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); + return Status::OK(); +} + +// If a cache has been added into the ascendant tree over this TFRecord node, then the cache will be executing +// a sampler for fetching the data. As such, any options in the TFRecord node need to be reset to its defaults so +// that this TFRecord node will produce the full set of data into the cache. +Status TFRecordNode::MakeSimpleProducer() { + shard_id_ = 0; + num_shards_ = 1; + shuffle_ = ShuffleMode::kFalse; + num_samples_ = 0; + shard_equal_rows_ = false; + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h index 74b5bc50c84..ad5d259e57d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h @@ -124,6 +124,22 @@ class TFRecordNode : public NonMappableSourceNode { /// \return Status of the function Status to_json(nlohmann::json *out_json) override; + /// \brief TFRecord by itself is a non-mappable dataset that does not support sampling. + /// However, if a cache operator is injected at some other place higher in the tree, that cache can + /// inherit this sampler from the leaf, providing sampling support from the caching layer. + /// That is why we setup the sampler for a leaf node that does not use sampling. + /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. + /// \param[in] sampler The sampler to setup + /// \return Status of the function + Status SetupSamplerForCache(std::shared_ptr *sampler) override; + + /// \brief If a cache has been added into the ascendant tree over this TFRecord node, then the cache will be executing + /// a sampler for fetching the data. As such, any options in the TFRecord node need to be reset to its defaults + /// so that this TFRecord node will produce the full set of data into the cache. + /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. + /// \return Status of the function + Status MakeSimpleProducer() override; + private: std::vector dataset_files_; std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc index b2e799747ac..f08d40e4934 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc @@ -113,7 +113,6 @@ Status VOCNode::Build(std::vector> *const node_ops) { voc_op = std::make_shared(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_, connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild())); - RETURN_IF_NOT_OK(AddCacheOp(node_ops)); node_ops->push_back(voc_op); return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc index 34b42a0392e..e0cbbc69309 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc @@ -31,8 +31,14 @@ #include "minddata/dataset/engine/ir/datasetops/root_node.h" #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" #include "minddata/dataset/engine/ir/datasetops/skip_node.h" +#ifndef ENABLE_ANDROID +#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" +#endif #ifdef ENABLE_PYTHON #include "minddata/dataset/engine/ir/datasetops/source/generator_node.h" +#endif +#include "minddata/dataset/engine/ir/datasetops/source/random_node.h" +#ifdef ENABLE_PYTHON #include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h" #endif #include "minddata/dataset/engine/ir/datasetops/take_node.h" @@ -195,10 +201,10 @@ Status IRNodePass::VisitAfter(std::shared_ptr node, bool *const modi } #ifdef ENABLE_PYTHON Status IRNodePass::Visit(std::shared_ptr node, bool *const modified) { - return Visit(std::static_pointer_cast(node), modified); + return Visit(std::static_pointer_cast(node), modified); } Status IRNodePass::VisitAfter(std::shared_ptr node, bool *const modified) { - return VisitAfter(std::static_pointer_cast(node), modified); + return VisitAfter(std::static_pointer_cast(node), modified); } #endif Status IRNodePass::Visit(std::shared_ptr node, bool *const modified) { @@ -207,12 +213,26 @@ Status IRNodePass::Visit(std::shared_ptr node, bool *const modified) { Status IRNodePass::VisitAfter(std::shared_ptr node, bool *const modified) { return VisitAfter(std::static_pointer_cast(node), modified); } +#ifndef ENABLE_ANDROID +Status IRNodePass::Visit(std::shared_ptr node, bool *const modified) { + return Visit(std::static_pointer_cast(node), modified); +} +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *const modified) { + return VisitAfter(std::static_pointer_cast(node), modified); +} +#endif Status IRNodePass::Visit(std::shared_ptr node, bool *const modified) { return Visit(std::static_pointer_cast(node), modified); } Status IRNodePass::VisitAfter(std::shared_ptr node, bool *const modified) { return VisitAfter(std::static_pointer_cast(node), modified); } +Status IRNodePass::Visit(std::shared_ptr node, bool *const modified) { + return Visit(std::static_pointer_cast(node), modified); +} +Status IRNodePass::VisitAfter(std::shared_ptr node, bool *const modified) { + return VisitAfter(std::static_pointer_cast(node), modified); +} Status IRNodePass::Visit(std::shared_ptr node, bool *const modified) { return Visit(std::static_pointer_cast(node), modified); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h index 9130a36a9eb..2e3de519472 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h @@ -44,7 +44,6 @@ class TakeNode; class TransferNode; class ZipNode; #ifdef ENABLE_PYTHON -class GeneratorNode; class SyncWaitNode; #endif #ifndef ENABLE_ANDROID @@ -129,14 +128,14 @@ class IRPass : public std::enable_shared_from_this { class IRTreePass : public IRPass { public: /// \brief Run the transformation pass against the IR tree. - /// \param[inout] root_ir Pointer to the IR tree to be transformed. - /// \param[inout] modified Indicate if the tree was modified + /// \param[in,out] root_ir Pointer to the IR tree to be transformed. + /// \param[in,out] modified Indicate if the tree was modified Status Run(std::shared_ptr root_ir, bool *const modified) final; /// \brief Derived classes may implement the runOnTree function to implement tree transformation. /// "modified" flag needs to be set to true if tree is modified during the pass execution. - /// \param[inout] tree The tree to operate on. - /// \param[inout] Indicate if the tree was modified. + /// \param[in,out] tree The tree to operate on. + /// \param[in,out] Indicate if the tree was modified. /// \return Status The status code returned virtual Status RunOnTree(std::shared_ptr root_ir, bool *const modified) { return Status::OK(); } }; @@ -164,8 +163,8 @@ class IRNodePass : public IRPass { ~IRNodePass() = default; /// \brief Run the transformation pass against the IR tree - /// \param[inout] root_ir Pointer to the IR tree to be transformed - /// \param[inout] modified Indicator if the tree was changed + /// \param[in,out] root_ir Pointer to the IR tree to be transformed + /// \param[in,out] modified Indicator if the tree was changed Status Run(std::shared_ptr root_ir, bool *const modified) final; /// \brief Derived classes may implement the Visit function to implement any initial visit work on the way down @@ -210,8 +209,14 @@ class IRNodePass : public IRPass { #endif virtual Status Visit(std::shared_ptr node, bool *const modified); virtual Status VisitAfter(std::shared_ptr node, bool *const modified); +#ifndef ENABLE_ANDROID + virtual Status Visit(std::shared_ptr node, bool *const modified); + virtual Status VisitAfter(std::shared_ptr node, bool *const modified); +#endif virtual Status Visit(std::shared_ptr node, bool *const modified); virtual Status VisitAfter(std::shared_ptr node, bool *const modified); + virtual Status Visit(std::shared_ptr node, bool *const modified); + virtual Status VisitAfter(std::shared_ptr node, bool *const modified); virtual Status Visit(std::shared_ptr node, bool *const modified); virtual Status VisitAfter(std::shared_ptr node, bool *const modified); virtual Status Visit(std::shared_ptr node, bool *const modified); @@ -270,14 +275,14 @@ class Pass : public std::enable_shared_from_this { class TreePass : public Pass { public: /// \brief Run the transformation pass against the execution tree. - /// \param[inout] tree Pointer to the execution tree to be transformed. - /// \param[inout] modified Indicate if the tree was modified + /// \param[in,out] tree Pointer to the execution tree to be transformed. + /// \param[in,out] modified Indicate if the tree was modified Status Run(ExecutionTree *tree, bool *const modified) final; /// \brief Derived classes may implement the runOnTree function to implement tree transformation. /// "modified" flag needs to be set to true if tree is modified during the pass execution. - /// \param[inout] tree The tree to operate on. - /// \param[inout] Indicate of the tree was modified. + /// \param[in,out] tree The tree to operate on. + /// \param[in,out] Indicate of the tree was modified. /// \return Status The status code returned virtual Status RunOnTree(ExecutionTree *tree, bool *const modified) { return Status::OK(); } }; @@ -305,8 +310,8 @@ class NodePass : public Pass { ~NodePass() = default; /// \brief Run the transformation pass against the execution tree - /// \param[inout] tree Pointer to the execution tree to be transformed - /// \param[inout] modified Indicator if the tree was changed + /// \param[in,out] tree Pointer to the execution tree to be transformed + /// \param[in,out] modified Indicator if the tree was changed Status Run(ExecutionTree *tree, bool *const modified) final; /// \brief Derived classes may implement the PreRunOnNode function to implement any initial visit work on the way down diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc index d805642945a..feaaec15149 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc @@ -16,207 +16,130 @@ #include #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" -#include "minddata/dataset/engine/execution_tree.h" -#include "minddata/dataset/engine/cache/cache_client.h" -#include "minddata/dataset/engine/datasetops/cache_lookup_op.h" -#include "minddata/dataset/engine/datasetops/cache_merge_op.h" -#include "minddata/dataset/engine/datasetops/cache_op.h" -#include "minddata/dataset/engine/datasetops/source/album_op.h" -#include "minddata/dataset/engine/datasetops/source/celeba_op.h" -#include "minddata/dataset/engine/datasetops/source/cifar_op.h" -#include "minddata/dataset/engine/datasetops/source/coco_op.h" -#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" - -#ifndef ENABLE_ANDROID -#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" -#endif - -#include "minddata/dataset/engine/datasetops/source/mnist_op.h" -#include "minddata/dataset/engine/datasetops/source/random_data_op.h" - -#ifndef ENABLE_ANDROID -#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" -#include "minddata/dataset/engine/datasetops/source/clue_op.h" -#include "minddata/dataset/engine/datasetops/source/csv_op.h" -#include "minddata/dataset/engine/datasetops/source/text_file_op.h" -#endif - +#include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h" +#include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h" +#include "minddata/dataset/engine/ir/datasetops/cache_node.h" #ifdef ENABLE_PYTHON -#include "minddata/dataset/engine/datasetops/source/generator_op.h" -#include "minddata/dataset/engine/datasetops/source/manifest_op.h" -#include "minddata/dataset/engine/datasetops/source/voc_op.h" +#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h" #endif +#ifndef ENABLE_ANDROID +#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" +#endif +#include "minddata/dataset/engine/ir/datasetops/source/random_node.h" namespace mindspore { namespace dataset { // Constructor -CacheTransformPass::CachePass::CachePass() : is_caching_(false), leaf_op_(nullptr) {} +CacheTransformPass::CachePass::CachePass() : is_caching_(false), leaf_node_(nullptr), sampler_(nullptr) {} // Identifies the subtree below this node as a cached descendant tree. -Status CacheTransformPass::CachePass::PreRunOnNode(std::shared_ptr node, bool *const modified) { +// Note that this function will only get called on non-leaf nodes. +// For leaf nodes, the other Visit with NonMappableSourceNode or MappableSourceNode argument will be called instead. +Status CacheTransformPass::CachePass::Visit(std::shared_ptr node, bool *const modified) { *modified = false; - MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; - if (is_caching_) { - return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Nested cache operations is not supported!"); + if (node->IsCached()) { + MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; + is_caching_ = true; } - is_caching_ = true; return Status::OK(); } -// Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache +// Resets the tracking of the cache within the tree and assigns the nodes that will be involved in a cache // transformation -Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *const modified) { +Status CacheTransformPass::CachePass::VisitAfter(std::shared_ptr node, bool *const modified) { *modified = false; - is_caching_ = false; // We a no longer in a cache subtree. clear the flag. - if (leaf_op_) { - MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache."; - // Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op, - // using base class pointers. - AddMappableCacheOperators(std::move(leaf_op_), node); - } else { - // If there was no leaf_op set, then this is a non-mappable scenario. - - if (sampler_) { - // Grab the sampler that was saved from the leaf and plug it into the cache op - node->SetSampler(std::move(sampler_)); - MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf."; + if (node->IsCached()) { + is_caching_ = false; // We a no longer in a cache subtree. clear the flag. + if (leaf_node_) { + MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache."; + // Assign the leaf node into the transform pass, using move to null our copy of it, + // and also assign the cached node, using base class pointers. + // In the cases where cache is directly injected after the leaf node, these two nodes might be the same. + cache_pairs_.push_back(std::make_pair(std::move(leaf_node_), node)); } else { - // We're a cache op but no sampler was saved from leaf, so create a default sampler - const int64_t num_samples = 0; - const int64_t start_index = 0; - sampler_ = std::make_shared(num_samples, start_index); - node->SetSampler(std::move(sampler_)); - MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op."; + // If there was no leaf_node_ set, then this is a non-mappable scenario. + // We only assign the cached node in this case. + cached_nodes_.push_back(node); } - - // Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache - uint32_t cache_crc = DatasetOp::GenerateCRC(node); - RETURN_IF_NOT_OK(node->CreateCache(cache_crc)); } return Status::OK(); } -// Common code for mappable leaf setup. -Status CacheTransformPass::CachePass::MappableCacheLeafSetup(std::shared_ptr leaf_op) { - // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. - if (is_caching_ && leaf_op_) { - return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, - "There is currently no support for multiple leaf nodes under cache."); - } - - // If we are a leaf in the caching path, then save this leaf. - if (is_caching_) { - MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected"; - leaf_op_ = std::move(leaf_op); - } - return Status::OK(); -} - -// Common code for non mappable leaf setup. -Status CacheTransformPass::CachePass::NonMappableCacheLeafSetup(std::shared_ptr leaf_op) { - // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. - if (is_caching_ && leaf_op_) { - return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, - "There is currently no support for multiple leaf nodes under cache."); - } - - // Sampler for non mappable dataset only works if there is a downstream cache. Remove it from the leaf - // as save it for use by cache op in ascendant tree. - if (is_caching_) { - RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_)); - MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected"; - } else { - // If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can - // remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based) - std::shared_ptr sampler_from_leaf; - RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf)); - } - return Status::OK(); -} - #ifndef ENABLE_ANDROID // Perform leaf node cache transform identification -Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *const modified) { - if (is_caching_) { - // If we are a TF Reader in a caching tree, then change our config so that it becomes a basic - // TF reader that parses all files. Selection of data will come from the sampler on the cache instead. - node->MakeSimpleProducer(); +Status CacheTransformPass::CachePass::Visit(std::shared_ptr node, bool *const modified) { + if (node->IsCached()) { + MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; + is_caching_ = true; } - return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache transform identification -Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *const modified) { + // Cache might also be injected to the non-leaf node upper in the tree, so is_caching_ might also be set to true + // by the other Visit() with DatasetNode argument if (is_caching_) { - // If we are a ClueOp in a caching tree, then change our config so that it becomes a basic - // ClueOp that parses all files. Selection of data will come from the sampler on the cache instead. - node->MakeSimpleProducer(); + MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected"; + // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. + if (leaf_node_) { + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, + "There is currently no support for multiple leaf nodes under cache."); + } + // Set up a sampler here to be used by cache if we are a non-mappable leaf in a caching tree. + // Node that sampler for non mappable dataset only works if there is a downstream cache. + RETURN_IF_NOT_OK(node->SetupSamplerForCache(&sampler_)); + // If we are a non-mappable source node in a caching tree, then change our config so that it becomes a basic + // source node that parses all files. Selection of data will come from the sampler on the cache instead. + RETURN_IF_NOT_OK(node->MakeSimpleProducer()); } - return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache transform identification -Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *const modified) { - if (is_caching_) { - // If we are a CsvOp in a caching tree, then change our config so that it becomes a basic - // CsvOp that parses all files. Selection of data will come from the sampler on the cache instead. - node->MakeSimpleProducer(); - } - return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache transform identification -Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *const modified) { - if (is_caching_) { - // If we are a TextFileOp in a caching tree, then change our config so that it becomes a basic - // TextFileOp that parses all files. Selection of data will come from the sampler on the cache instead. - node->MakeSimpleProducer(); - } - return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); + return Status::OK(); } #endif -// Perform leaf node cache transform identification -Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *const modified) { - return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); +Status CacheTransformPass::CachePass::Visit(std::shared_ptr node, bool *const modified) { + if (node->IsCached()) { + MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; + is_caching_ = true; + } + // Cache might also be injected to the non-leaf node upper in the tree, so is_caching_ might also be set to true + // by the other Visit() with DatasetNode argument + if (is_caching_) { + MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected"; + // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. + if (leaf_node_) { + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, + "There is currently no support for multiple leaf nodes under cache."); + } + // Set up a sampler here to be used by cache if we are a non-mappable leaf in a caching tree. + // Node that sampler for non mappable dataset only works if there is a downstream cache. + RETURN_IF_NOT_OK(node->SetupSamplerForCache(&sampler_)); + } + return Status::OK(); } // Perform leaf node cache transform identification -Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *const modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache transform identification -Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *const modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache transform identification -Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *const modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache transform identification -Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *const modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache transform identification -Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *const modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache transform identification -Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *const modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); +Status CacheTransformPass::CachePass::Visit(std::shared_ptr node, bool *const modified) { + if (node->IsCached()) { + MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; + is_caching_ = true; + } + // Cache might also be injected to the non-leaf node upper in the tree, so is_caching_ might also be set to true + // by the other Visit() with DatasetNode argument + if (is_caching_) { + MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected"; + // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. + if (leaf_node_) { + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, + "There is currently no support for multiple leaf nodes under cache."); + } + // If we are a leaf in the caching path, then save this leaf + leaf_node_ = node; + } + return Status::OK(); } #ifndef ENABLE_ANDROID // Perform leaf node cache transform identification -Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *const modified) { - if (is_caching_) { +Status CacheTransformPass::CachePass::Visit(std::shared_ptr node, bool *const modified) { + if (node->IsCached() || is_caching_) { return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "There is currently no support for MindRecordOp under cache."); } @@ -226,102 +149,85 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr no #ifdef ENABLE_PYTHON // Perform leaf node cache transform identification -Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *const modified) { - if (is_caching_) { +Status CacheTransformPass::CachePass::Visit(std::shared_ptr node, bool *const modified) { + if (node->IsCached() || is_caching_) { return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "There is currently no support for GeneratorOp under cache."); } return Status::OK(); } - -// Perform leaf node cache transform identification -Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *const modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache transform identification -Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *const modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} #endif -// Assigns the leaf and cache operators that are involved in a cache transformation -void CacheTransformPass::CachePass::AddMappableCacheOperators(std::shared_ptr leaf_op, - std::shared_ptr cache_op) { - cache_pairs_.push_back(std::make_pair(leaf_op, cache_op)); -} - // constructor CacheTransformPass::CacheTransformPass() {} // Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations -Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *const modified) { +Status CacheTransformPass::RunOnTree(std::shared_ptr root_ir, bool *const modified) { MS_LOG(INFO) << "Pre pass: Cache transform pass started."; // Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will // use to execute a transform. CachePass cache_pass = CachePass(); - RETURN_IF_NOT_OK(cache_pass.Run(tree, modified)); + RETURN_IF_NOT_OK(cache_pass.Run(root_ir, modified)); - // Then, execute the transform for each pair + // Execute the transform for non-mappable cache + for (auto cached_node : cache_pass.cached_nodes()) { + MS_LOG(DEBUG) << "Cache transform pass: Injecting a non-mappable cache node."; + RETURN_IF_NOT_OK(InjectNonMappableCacheNode(cached_node, cache_pass.sampler())); + } + + // Execute the transform for mappable cache for (auto cache_pair : cache_pass.cache_pairs()) { - MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform."; - RETURN_IF_NOT_OK( - ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client())); + MS_LOG(DEBUG) << "Cache transform pass: Injecting a mappable cache node."; + RETURN_IF_NOT_OK(InjectMappableCacheNode(cache_pair.first, cache_pair.second)); } MS_LOG(INFO) << "Pre pass: Cache transform pass complete."; return Status::OK(); } -// Helper function to execute the cache transformation. -Status CacheTransformPass::ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr leaf_op, - std::shared_ptr cache_op, - std::shared_ptr cache_client) { - // Get local pointers the child/parent of the cache op. It's possible that the parent is null if the cache was - // the root node. It is also possible that cache_child == leaf_op - std::shared_ptr cache_child = cache_op->child(0); - DatasetOp *cache_parent = nullptr; - cache_op->Parent(&cache_parent, 0); // fetch the cache op's parent +// Helper function to execute mappable cache transformation. +// Input: +// Sampler +// | +// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache) +// +// Transformed: +// Sampler --> CacheLookupNode -------------------------> +// | | +// | CacheMergeNode +// | | +// LeafNode --> OtherNodes --> CachedNode +Status CacheTransformPass::InjectMappableCacheNode(std::shared_ptr leaf_node, + std::shared_ptr cached_node) { + // Create a cache merge node with defaults + auto cache_merge_node = std::make_shared(nullptr, cached_node->GetDatasetCache()); + // Insert the cache merge node to become the cached_node's parent + RETURN_IF_NOT_OK(cached_node->InsertAbove(cache_merge_node)); // Extract the sampler from the leaf. We will overwrite this sampler with the lookup op later. - std::shared_ptr leaf_sampler = leaf_op->sampler(); - - // Construct the merge op with defaults - std::shared_ptr merge_op; - CacheMergeOp::Builder merge_builder; - RETURN_IF_NOT_OK(merge_builder.SetClient(cache_client).Build(&merge_op)); - RETURN_IF_NOT_OK(tree->AssociateNode(merge_op)); - - // Construct the cache lookup op with defaults - std::shared_ptr cache_lookup_op; - CacheLookupOp::Builder lookup_builder; - RETURN_IF_NOT_OK(lookup_builder.SetClient(cache_client).SetSampler(std::move(leaf_sampler)).Build(&cache_lookup_op)); - RETURN_IF_NOT_OK(tree->AssociateNode(cache_lookup_op)); - - // Overwrite the old sampler in this leaf op to become the lookup op - leaf_op->SetSampler(cache_lookup_op); - - // If the cache had a parent, then go into that parent to remove the cache from it's child list and then - // replace it with the merge op. - if (cache_parent != nullptr) { - RETURN_IF_NOT_OK(cache_parent->RemoveChild(cache_op)); - RETURN_IF_NOT_OK(cache_parent->AddChild(merge_op)); - } else { - // If we didn't have a parent, then the merge op is the root node - RETURN_IF_NOT_OK(tree->AssignRoot(merge_op)); - } - - // Set the cache op to no longer be a parent over it's child. This will fully disconnect the old cache op. - // We maintain a local pointer to the old child though. - RETURN_IF_NOT_OK(cache_op->RemoveChild(cache_child)); - - // Connect the merge op - RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_lookup_op))); - RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_child))); - - // At this point, the cache op has already had it's children and parents taken away. Calling remove - // on it at this point will not do any node hookups, and instead set internal fields to invalid. - RETURN_IF_NOT_OK(cache_op->Remove()); + std::shared_ptr leaf_sampler = leaf_node->Sampler(); + // Create a cache lookup node with leaf_node's sampler + auto cache_lookup_node = std::make_shared(nullptr, leaf_sampler, cached_node->GetDatasetCache()); + // Insert the cache lookup node as the first child of cache merge node + RETURN_IF_NOT_OK(cache_merge_node->InsertChildAt(0, cache_lookup_node)); + // Overwrite the old sampler in this leaf node to become the cache lookup node + leaf_node->SetSampler(std::static_pointer_cast(cache_lookup_node)); + return Status::OK(); +} +// Helper function to execute non-mappable cache transformation. +// Input: +// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache) +// +// Transformed: +// Sampler +// | +// LeafNode --> OtherNodes --> CachedNode --> CacheNode +Status CacheTransformPass::InjectNonMappableCacheNode(std::shared_ptr cached_node, + std::shared_ptr sampler) { + // Create a cache node using the sampler we saved from the leaf + auto cache_node = std::make_shared(nullptr, sampler, cached_node->GetDatasetCache()); + // Insert the cache node to become the cached_node's parent + RETURN_IF_NOT_OK(cached_node->InsertAbove(cache_node)); return Status::OK(); } } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h index 647f80751c5..a785042fd0e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h @@ -20,6 +20,8 @@ #include #include #include + +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" #include "minddata/dataset/engine/opt/pass.h" namespace mindspore { @@ -32,11 +34,11 @@ class CacheClient; /// \class CacheTransformPass cache_transform_pass.h /// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching /// operations -class CacheTransformPass : public TreePass { +class CacheTransformPass : public IRTreePass { /// \class CachePass /// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache /// transformation. It works in conjunction with the CacheTransformPass - class CachePass : public NodePass { + class CachePass : public IRNodePass { public: /// \brief Constructor /// \param[in] transform_pass Raw pointer back to controlling tree pass @@ -47,138 +49,72 @@ class CacheTransformPass : public TreePass { /// \brief Identifies the subtree below this node as a cached descendant tree. /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all + /// \param[in,out] modified Indicator if the node was changed at all /// \return Status The status code returned - Status PreRunOnNode(std::shared_ptr node, bool *const modified) override; + Status Visit(std::shared_ptr node, bool *const modified) override; /// \brief Resets the tracking of the cache within the tree and assigns the operators that /// will be involved in a cache transformation /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all + /// \param[in,out] modified Indicator if the node was changed at all /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; + Status VisitAfter(std::shared_ptr node, bool *const modified) override; #ifndef ENABLE_ANDROID - /// \brief Perform leaf node cache transform identifications + /// \brief Perform non-mappable leaf node cache transform identifications /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all + /// \param[in,out] modified Indicator if the node was changed at all /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Perform leaf node cache transform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Perform leaf node cache transform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Perform leaf node cache transform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; + Status Visit(std::shared_ptr node, bool *const modified) override; #endif - /// \brief Perform leaf node cache transform identifications + /// \brief Perform non-mappable leaf node cache transform identifications /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all + /// \param[in,out] modified Indicator if the node was changed at all /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; + Status Visit(std::shared_ptr node, bool *const modified) override; - /// \brief Perform leaf node cache transform identifications + /// \brief Perform mappable leaf node cache transform identifications /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all + /// \param[in,out] modified Indicator if the node was changed at all /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Perform leaf node cache transform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Perform leaf node cache transform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; + Status Visit(std::shared_ptr node, bool *const modified) override; #ifdef ENABLE_PYTHON /// \brief Perform leaf node cache transform identifications /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all + /// \param[in,out] modified Indicator if the node was changed at all /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Perform leaf node cache transform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Perform leaf node cache transform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; + Status Visit(std::shared_ptr node, bool *const modified) override; #endif - /// \brief Perform leaf node cache transform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Perform leaf node cache transform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - - /// \brief Perform leaf node cache transform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; - #ifndef ENABLE_ANDROID /// \brief Perform leaf node cache transform identifications /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all + /// \param[in,out] modified Indicator if the node was changed at all /// \return Status The status code returned - Status RunOnNode(std::shared_ptr node, bool *const modified) override; + Status Visit(std::shared_ptr node, bool *const modified) override; #endif /// \brief Getter - std::vector, std::shared_ptr>> cache_pairs() { return cache_pairs_; } + std::vector, std::shared_ptr>> cache_pairs() { + return cache_pairs_; + } + + /// \brief Getter + std::vector> cached_nodes() { return cached_nodes_; } + + /// \brief Getter + std::shared_ptr sampler() { return sampler_; } private: - /// \brief Common code for mappable leaf setup. - /// \param[in] node The leaf node performing setup work. - /// \return Status The status code returned - Status MappableCacheLeafSetup(std::shared_ptr leaf_op); - - /// \brief Common code for non-mappable leaf setup. - /// \param[in] node The leaf node performing setup work. - /// \return Status The status code returned - Status NonMappableCacheLeafSetup(std::shared_ptr leaf_op); - - /// \brief Assigns the leaf and cache operators that are involved in a cache transformation - /// \param[in] leaf_op The leaf operator involved in the cache transform - /// \param[in] cache_op The cache operator involved in the cache transform - void AddMappableCacheOperators(std::shared_ptr leaf_op, std::shared_ptr cache_op); - bool is_caching_; - std::shared_ptr leaf_op_; - std::shared_ptr sampler_; - // The two operators that work together to establish the cache transform - std::vector, std::shared_ptr>> cache_pairs_; + std::shared_ptr leaf_node_; + std::shared_ptr sampler_; + // The two nodes that work together to establish the cache transform + std::vector> cached_nodes_; + std::vector, std::shared_ptr>> cache_pairs_; }; public: @@ -189,32 +125,46 @@ class CacheTransformPass : public TreePass { ~CacheTransformPass() = default; /// \brief Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations - /// \param[inout] tree The tree to operate on. - /// \param[inout] Indicate of the tree was modified. + /// \param[in,out] tree The tree to operate on. + /// \param[in,out] Indicate of the tree was modified. /// \return Status The status code returned - Status RunOnTree(ExecutionTree *tree, bool *const modified) override; + Status RunOnTree(std::shared_ptr root_ir, bool *const modified) override; private: - /// \brief Helper function to execute the cache transformation. + /// \brief Helper function to execute mappable cache transformation. /// /// Input: /// Sampler /// | - /// LeafOp --> OtherOps --> CacheOp + /// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache) /// /// Transformed: - /// Sampler --> CacheLookupOp ----------------> - /// | | - /// | MergeOp - /// | | - /// LeafOp --> OtherOps --> + /// Sampler --> CacheLookupNode -------------------------> + /// | | + /// | CacheMergeNode + /// | | + /// LeafNode --> OtherNodes --> CachedNode /// - /// \param[in] leaf_op The leaf node in the transform - /// \param[in] cache_op The cache op in the transform (will get removed) - /// \param[in] cache_client The cache client + /// \param[in] leaf_node The leaf node in the transform + /// \param[in] cached_node The node with cache attribute which is involved in the cache transform /// \return Status The status code returned - Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr leaf_op, - std::shared_ptr cache_op, std::shared_ptr cache_client); + Status InjectMappableCacheNode(std::shared_ptr leaf_node, + std::shared_ptr cached_node); + + /// \brief Helper function to execute non-mappable cache transformation. + /// + /// Input: + /// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache) + /// + /// Transformed: + /// Sampler + /// | + /// LeafNode --> OtherNodes --> CachedNode --> CacheNode + /// + /// \param[in] cached_node The node with cache attribute which is involved in the cache transform + /// \param[in] sampler The sampler saved for non-mappable leaf nodes during the CachePass + /// \return Status The status code returned + Status InjectNonMappableCacheNode(std::shared_ptr cached_node, std::shared_ptr sampler); }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc index ed4b3ebdc51..097fa08c6c9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc @@ -24,6 +24,7 @@ #ifdef ENABLE_PYTHON #include "minddata/dataset/engine/opt/post/generator_node_pass.h" #endif +#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" #include "minddata/dataset/engine/opt/pre/cache_validation_pass.h" #include "minddata/dataset/engine/opt/pre/deep_copy_pass.h" #include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h" @@ -53,6 +54,7 @@ Status TreeAdapter::PrePass(std::shared_ptr ir) { actions.emplace_back(std::make_unique()); actions.emplace_back(std::make_unique()); if (usage_ == kDeGetter) actions.emplace_back(std::make_unique()); + actions.emplace_back(std::make_unique()); // Vector of flags for each action std::vector modified(actions.size(), false); // Apply pre-pass actions diff --git a/mindspore/ccsrc/minddata/dataset/include/samplers.h b/mindspore/ccsrc/minddata/dataset/include/samplers.h index 90f975e44f4..6d8887ea300 100644 --- a/mindspore/ccsrc/minddata/dataset/include/samplers.h +++ b/mindspore/ccsrc/minddata/dataset/include/samplers.h @@ -35,7 +35,7 @@ namespace dataset { // Internal Sampler class forward declaration class SamplerRT; -class SamplerObj : public std::enable_shared_from_this { +class SamplerObj { public: /// \brief Constructor SamplerObj(); @@ -122,7 +122,7 @@ std::shared_ptr RandomSampler(bool replacement = false, int64_ /// Function to create a Sequential Sampler. /// \notes Samples the dataset elements sequentially, same as not having a sampler. -/// \param[in] start_index - Index to start sampling at (dafault to start at first id). +/// \param[in] start_index - Index to start sampling at (default to start at first id). /// \param[in] num_samples - The number of samples to draw (default to all elements). /// \return Shared pointer to the current Sampler. std::shared_ptr SequentialSampler(int64_t start_index = 0, int64_t num_samples = 0); diff --git a/tests/ut/cpp/dataset/cache_op_test.cc b/tests/ut/cpp/dataset/cache_op_test.cc index 7926d116074..5f01261ad98 100644 --- a/tests/ut/cpp/dataset/cache_op_test.cc +++ b/tests/ut/cpp/dataset/cache_op_test.cc @@ -465,24 +465,21 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { rc = ccbuilder.Build(&myClient); ASSERT_TRUE(rc.IsOk()); - // In a mappable dataset, it uses a complex interactions of cache lookup op and cache merge op. - // Rather than manually build this, the way to do it is to choose the position of the cache in the tree by - // adding a CacheOp. Then, the tree prepare code will drive a transform that will remove the CacheOp and - // replace it with the required tree structures for cache lookup op and cache merge op. - - std::shared_ptr myCacheOp; - rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp); + std::shared_ptr myLookupOp; + rc = CacheLookupOp::Builder().SetNumWorkers(4).SetClient(myClient).SetSampler(seq_sampler).Build(&myLookupOp); + std::shared_ptr myMergeOp; + rc = CacheMergeOp::Builder().SetNumWorkers(4).SetClient(myClient).Build(&myMergeOp); std::shared_ptr so; ImageFolderOp::Builder builder; - builder.SetSampler(std::move(seq_sampler)) - .SetOpConnectorSize(3) + builder.SetOpConnectorSize(3) .SetNumWorkers(3) .SetRowsPerBuffer(2) .SetExtensions({".jpg", ".JPEG"}) .SetRecursive(true) .SetImageFolderDir(datasets_root_path_ + "/testPK/data"); rc = builder.Build(&so); + so->SetSampler(myLookupOp); ASSERT_TRUE(rc.IsOk()); // RepeatOp @@ -495,7 +492,9 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { rc = myTree->AssociateNode(so); ASSERT_TRUE(rc.IsOk()); - rc = myTree->AssociateNode(myCacheOp); + rc = myTree->AssociateNode(myLookupOp); + ASSERT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myMergeOp); ASSERT_TRUE(rc.IsOk()); rc = myTree->AssociateNode(myRepeatOp); @@ -503,9 +502,11 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { rc = myTree->AssignRoot(myRepeatOp); ASSERT_TRUE(rc.IsOk()); - rc = myRepeatOp->AddChild(myCacheOp); + rc = myRepeatOp->AddChild(myMergeOp); ASSERT_TRUE(rc.IsOk()); - rc = myCacheOp->AddChild(so); + rc = myMergeOp->AddChild(myLookupOp); + ASSERT_TRUE(rc.IsOk()); + rc = myMergeOp->AddChild(so); ASSERT_TRUE(rc.IsOk()); rc = myTree->Prepare(1); @@ -532,119 +533,3 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { rc = myClient->DestroyCache(); ASSERT_TRUE(rc.IsOk()); } - -//// Simple test with a repeated cache op over random data producer. -//// The difference in this one is that you do not add the sampler to the cache op directly. -//// Instead, the sampler is added as part of the leaf op construction. Then, the prepare -//// phase will pull this up from the leaf and into the cache. -//// It removes the sampler from the leaf op, which doesn't make sense there anyway for -//// the RandomDataOp which doesn't support sampling without a cache. -//// -//// RepeatOp -//// | -//// CacheOp -//// | -//// RandomDataOp -//// -TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) { - // Clear the rc of the master thread if any - (void)TaskManager::GetMasterThreadRc(); - Status rc; - int32_t rank = 0; // not used - MS_LOG(INFO) << "UT test TestCacheInheritSampler"; - - session_id_type env_session; - rc = GetSessionFromEnv(&env_session); - ASSERT_TRUE(rc.IsOk()); - - int64_t num_samples = 0; - int64_t start_index = 0; - auto seq_sampler = std::make_shared(num_samples, start_index); - - // Start with an empty execution tree - auto myTree = std::make_shared(); - - // Create a schema using the C api's - std::unique_ptr testSchema = std::make_unique(); - - // 2 columns. First column is an "image" 640,480,3 - TensorShape c1Shape({640, 480, 3}); - ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible, - rank, // not used - &c1Shape); - - // Column 2 will just be a scalar label number - TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor - ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape); - - testSchema->AddColumn(c1); - testSchema->AddColumn(c2); - - // RandomDataOp - std::shared_ptr myRandomDataOp; - rc = RandomDataOp::Builder() - .SetRowsPerBuffer(2) - .SetNumWorkers(4) - .SetDataSchema(std::move(testSchema)) - .SetTotalRows(10) - .SetSampler(std::move(seq_sampler)) - .Build(&myRandomDataOp); - ASSERT_TRUE(rc.IsOk()); - rc = myTree->AssociateNode(myRandomDataOp); - ASSERT_TRUE(rc.IsOk()); - - // CacheOp - CacheClient::Builder ccbuilder; - // use arbitrary session of 1, size of 0, spilling// is true - ccbuilder.SetSessionId(env_session).SetCacheMemSz(4).SetSpill(true); - std::shared_ptr myClient; - rc = ccbuilder.Build(&myClient); - ASSERT_TRUE(rc.IsOk()); - std::shared_ptr myCacheOp; - rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp); - ASSERT_TRUE(rc.IsOk()); - rc = myTree->AssociateNode(myCacheOp); - ASSERT_TRUE(rc.IsOk()); - - // RepeatOp - uint32_t numRepeats = 4; - std::shared_ptr myRepeatOp; - rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); - ASSERT_TRUE(rc.IsOk()); - rc = myTree->AssociateNode(myRepeatOp); - ASSERT_TRUE(rc.IsOk()); - - // Assign tree relations and root - rc = myRepeatOp->AddChild(myCacheOp); - ASSERT_TRUE(rc.IsOk()); - rc = myCacheOp->AddChild(myRandomDataOp); - ASSERT_TRUE(rc.IsOk()); - rc = myTree->AssignRoot(myRepeatOp); - ASSERT_TRUE(rc.IsOk()); - - MS_LOG(INFO) << "Launching tree and begin iteration"; - rc = myTree->Prepare(1); - ASSERT_TRUE(rc.IsOk()); - - std::cout << *myClient << std::endl; - - rc = myTree->Launch(); - ASSERT_TRUE(rc.IsOk()); - - // Start the loop of reading tensors from our pipeline - DatasetIterator dI(myTree); - TensorRow tensorList; - rc = dI.FetchNextTensorRow(&tensorList); - ASSERT_TRUE(rc.IsOk()); - int rowCount = 0; - while (!tensorList.empty()) { - // Don't display these rows, just count them - MS_LOG(INFO) << "Row fetched #: " << rowCount; - rc = dI.FetchNextTensorRow(&tensorList); - ASSERT_TRUE(rc.IsOk()); - rowCount++; - } - ASSERT_EQ(rowCount, 40); - rc = myClient->DestroyCache(); - ASSERT_TRUE(rc.IsOk()); -} diff --git a/tests/ut/python/cachetests/cachetest_py.sh b/tests/ut/python/cachetests/cachetest_py.sh index b7867fa5270..79da31b551d 100755 --- a/tests/ut/python/cachetests/cachetest_py.sh +++ b/tests/ut/python/cachetests/cachetest_py.sh @@ -315,6 +315,9 @@ HandleRcExit $? 0 0 PytestCmd "test_cache_nomap.py" "test_cache_nomap_long_file_list" HandleRcExit $? 0 0 +PytestCmd "test_cache_nomap.py" "test_cache_nomap_failure" 1 +HandleRcExit $? 0 0 + for i in $(seq 1 3) do test_name="test_cache_nomap_multiple_cache${i}" diff --git a/tests/ut/python/dataset/test_cache_map.py b/tests/ut/python/dataset/test_cache_map.py index 6f5f11fd89d..3f03d6a065e 100644 --- a/tests/ut/python/dataset/test_cache_map.py +++ b/tests/ut/python/dataset/test_cache_map.py @@ -216,7 +216,7 @@ def test_cache_map_failure1(): | Cache | - ImageFolder + Coco """ logger.info("Test cache failure 1") @@ -227,8 +227,9 @@ def test_cache_map_failure1(): some_cache = ds.DatasetCache(session_id=session_id, size=0) - # This DATA_DIR only has 2 images in it - ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) + # This DATA_DIR has 6 images in it + ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True, + cache=some_cache) decode_op = c_vision.Decode() ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) ds1 = ds1.repeat(4) @@ -302,7 +303,7 @@ def test_cache_map_failure3(): | Batch | - ImageFolder + Mnist """ logger.info("Test cache failure 3") if "SESSION_ID" in os.environ: @@ -312,8 +313,7 @@ def test_cache_map_failure3(): some_cache = ds.DatasetCache(session_id=session_id, size=0) - # This DATA_DIR only has 2 images in it - ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) + ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10) ds1 = ds1.batch(2) resize_op = c_vision.Resize((224, 224)) ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) @@ -342,7 +342,7 @@ def test_cache_map_failure4(): | Filter | - ImageFolder + CelebA """ logger.info("Test cache failure 4") @@ -353,8 +353,8 @@ def test_cache_map_failure4(): some_cache = ds.DatasetCache(session_id=session_id, size=0) - # This DATA_DIR only has 2 images in it - ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) + # This dataset has 4 records + ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True) ds1 = ds1.filter(predicate=lambda data: data < 11, input_columns=["label"]) decode_op = c_vision.Decode() @@ -382,7 +382,7 @@ def test_cache_map_failure5(): | Map(decode, randomCrop) | - ImageFolder + Manifest """ logger.info("Test cache failure 5") @@ -393,8 +393,8 @@ def test_cache_map_failure5(): some_cache = ds.DatasetCache(session_id=session_id, size=0) - # This DATA_DIR only has 2 images in it - data = ds.ImageFolderDataset(dataset_dir=DATA_DIR) + # This dataset has 4 records + data = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True) random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200]) decode_op = c_vision.Decode() @@ -505,7 +505,7 @@ def test_cache_map_failure8(): | Repeat | - ImageFolder + Cifar10 """ logger.info("Test cache failure 8") @@ -516,8 +516,7 @@ def test_cache_map_failure8(): some_cache = ds.DatasetCache(session_id=session_id, size=0) - # This DATA_DIR only has 2 images in it - ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) + ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10) decode_op = c_vision.Decode() ds1 = ds1.repeat(4) ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) @@ -545,7 +544,7 @@ def test_cache_map_failure9(): | Take | - ImageFolder + Cifar100 """ logger.info("Test cache failure 9") @@ -556,8 +555,7 @@ def test_cache_map_failure9(): some_cache = ds.DatasetCache(session_id=session_id, size=0) - # This DATA_DIR only has 2 images in it - ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) + ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10) ds1 = ds1.take(2) decode_op = c_vision.Decode() @@ -587,7 +585,7 @@ def test_cache_map_failure10(): | Skip | - ImageFolder + VOC """ logger.info("Test cache failure 10") @@ -598,8 +596,8 @@ def test_cache_map_failure10(): some_cache = ds.DatasetCache(session_id=session_id, size=0) - # This DATA_DIR only has 2 images in it - ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) + # This dataset has 9 records + ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) ds1 = ds1.skip(1) decode_op = c_vision.Decode() diff --git a/tests/ut/python/dataset/test_cache_nomap.py b/tests/ut/python/dataset/test_cache_nomap.py index 48b17f6b6cd..5316b37aff6 100644 --- a/tests/ut/python/dataset/test_cache_nomap.py +++ b/tests/ut/python/dataset/test_cache_nomap.py @@ -1913,6 +1913,217 @@ def test_cache_nomap_long_file_list(): logger.info("test_cache_nomap_long_file_list Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_failure1(): + """ + Test nested cache (failure) + + Repeat + | + Cache + | + Map(decode) + | + Cache + | + TFRecord + + """ + logger.info("Test cache nomap failure 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) + ds1 = ds1.repeat(4) + + with pytest.raises(RuntimeError) as e: + ds1.get_batch_size() + assert "Nested cache operations" in str(e.value) + + with pytest.raises(RuntimeError) as e: + num_iter = 0 + for _ in ds1.create_dict_iterator(num_epochs=1): + num_iter += 1 + assert "Nested cache operations" in str(e.value) + + assert num_iter == 0 + logger.info('test_cache_nomap_failure1 Ended.\n') + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_failure2(): + """ + Test zip under cache (failure) + + repeat + | + Cache + | + Map(decode) + | + Zip + | | + Random Random + + """ + logger.info("Test cache nomap failure 2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0) + + schema = ds.Schema() + schema.add_column('image', de_type=mstype.uint8, + shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) + schema.add_column('label', de_type=mstype.uint8, shape=[1]) + + ds1 = ds.RandomDataset(schema=schema) + ds2 = ds.RandomDataset(schema=schema) + dsz = ds.zip((ds1, ds2)) + decode_op = c_vision.Decode() + dsz = dsz.map(input_columns=["image"], operations=decode_op, cache=some_cache) + dsz = dsz.repeat(4) + + with pytest.raises(RuntimeError) as e: + num_iter = 0 + for _ in dsz.create_dict_iterator(): + num_iter += 1 + assert "ZipNode is not supported as a descendant operator under a cache" in str(e.value) + + assert num_iter == 0 + logger.info('test_cache_nomap_failure2 Ended.\n') + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_failure3(): + """ + Test batch under cache (failure) + + repeat + | + Cache + | + Map(resize) + | + Batch + | + Clue + """ + logger.info("Test cache nomap failure 3") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0) + + ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train') + ds1 = ds1.batch(2) + resize_op = c_vision.Resize((224, 224)) + ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) + ds1 = ds1.repeat(4) + + with pytest.raises(RuntimeError) as e: + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + assert "BatchNode is not supported as a descendant operator under a cache" in str(e.value) + + assert num_iter == 0 + logger.info('test_cache_nomap_failure3 Ended.\n') + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_failure4(): + """ + Test filter under cache (failure) + + repeat + | + Cache + | + Map(decode) + | + Filter + | + CSV + + """ + logger.info("Test cache nomap failure 4") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0) + + ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"], + column_names=['col1', 'col2', 'col3', 'col4']) + ds1 = ds1.filter(predicate=lambda data: data < 11, input_columns=["label"]) + + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + with pytest.raises(RuntimeError) as e: + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + assert "FilterNode is not supported as a descendant operator under a cache" in str(e.value) + + assert num_iter == 0 + logger.info('test_cache_nomap_failure4 Ended.\n') + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_failure5(): + """ + Test Map with non-deterministic TensorOps under cache (failure) + + repeat + | + Cache + | + Map(decode, randomCrop) + | + TextFile + + """ + logger.info("Test cache nomap failure 5") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0) + + data = ds.TextFileDataset(TEXT_FILE_DATA_DIR) + random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200]) + decode_op = c_vision.Decode() + + data = data.map(input_columns=["image"], operations=decode_op) + data = data.map(input_columns=["image"], operations=random_crop_op, cache=some_cache) + data = data.repeat(4) + + with pytest.raises(RuntimeError) as e: + num_iter = 0 + for _ in data.create_dict_iterator(): + num_iter += 1 + assert "MapNode with non-deterministic operations is not supported as a descendant of cache" in str(e.value) + + assert num_iter == 0 + logger.info('test_cache_nomap_failure5 Ended.\n') + + if __name__ == '__main__': test_cache_nomap_basic1() test_cache_nomap_basic2()