Migrate CacheTransformPass

This commit is contained in:
Lixia Chen 2021-01-13 13:47:34 -05:00
parent 3708624a25
commit 922e1f4f36
60 changed files with 1275 additions and 809 deletions

View File

@ -237,8 +237,11 @@ Status CacheOp::Accept(NodePass *p, bool *const modified) {
return p->RunOnNode(shared_from_base<CacheOp>(), modified);
}
// A public wrapper for creating the cache through the client
Status CacheOp::CreateCache(uint32_t cache_crc) {
Status CacheOp::PrepareNodePostAction() {
// Run any common code from super class first before adding our own
RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction());
// Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache
uint32_t cache_crc = DatasetOp::GenerateCRC(shared_from_this());
// This is a non-mappable cache op so the id's need to be generated.
// Construct the cache
const bool generate_ids = true;

View File

@ -141,11 +141,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
bool AllowCacheMiss() override { return false; }
/// \brief Base-class override for the name of this operator
std::string Name() const override { return kCacheOp; }
/// \brief A public wrapper for creating the cache through the client
/// \param[in] cache_crc The crc that identifies the cache
/// \see cache_pass.cc
/// \return Status return code
Status CreateCache(uint32_t cache_crc);
Status PrepareNodePostAction() override;
private:
WaitPost rows_cache_done_;

View File

@ -33,11 +33,7 @@
namespace mindspore {
namespace dataset {
ClueOp::Builder::Builder()
: builder_device_id_(0),
builder_num_devices_(1),
builder_num_samples_(0),
builder_shuffle_files_(false),
builder_sampler_(nullptr) {
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers();
builder_op_connector_size_ = config_manager->op_connector_size();
@ -74,7 +70,7 @@ Status ClueOp::Builder::Build(std::shared_ptr<ClueOp> *op) {
std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>(
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map,
builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_,
builder_device_id_, std::move(builder_sampler_));
builder_device_id_);
RETURN_IF_NOT_OK(clue_op->Init());
*op = std::move(clue_op);
@ -94,8 +90,8 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim
ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
bool shuffle_files, int32_t num_device, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
rows_per_buffer_(rows_per_buffer),
num_rows_per_shard_(0),
all_num_rows_(0),
@ -552,16 +548,6 @@ Status ClueOp::ComputeColMap() {
return Status::OK();
}
// Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing
// a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so
// that this clue op will produce the full set of data into the cache.
void ClueOp::MakeSimpleProducer() {
device_id_ = 0;
num_devices_ = 1;
shuffle_files_ = false;
num_samples_ = 0;
}
// Visitor accept method for NodePass
Status ClueOp::Accept(NodePass *p, bool *const modified) {
// Downcast shared pointer then call visitor

View File

@ -122,14 +122,6 @@ class ClueOp : public ParallelOp {
// @return - the a string vector
std::vector<std::string> split(const std::string &s, char delim);
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
private:
int32_t builder_device_id_;
int32_t builder_num_devices_;
@ -141,13 +133,12 @@ class ClueOp : public ParallelOp {
std::vector<std::string> builder_clue_files_list_;
bool builder_shuffle_files_;
std::map<std::string, std::string> builder_cols_to_keyword_;
std::shared_ptr<SamplerRT> builder_sampler_;
};
// Constructor of ClueOp
ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<SamplerRT> sampler);
bool shuffle_files, int32_t num_devices, int32_t device_id);
// Default destructor
~ClueOp() = default;
@ -182,11 +173,6 @@ class ClueOp : public ParallelOp {
// @return Vector of the input file names
std::vector<std::string> FileNames() { return clue_files_list_; }
/// \Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so
/// that this clue op will produce the full set of data into the cache.
void MakeSimpleProducer();
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "ClueOp"; }

View File

@ -29,11 +29,7 @@
namespace mindspore {
namespace dataset {
CsvOp::Builder::Builder()
: builder_device_id_(0),
builder_num_devices_(1),
builder_num_samples_(0),
builder_shuffle_files_(false),
builder_sampler_(nullptr) {
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers();
builder_op_connector_size_ = config_manager->op_connector_size();
@ -65,8 +61,7 @@ Status CsvOp::Builder::Build(std::shared_ptr<CsvOp> *op) {
std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>(
builder_csv_files_list_, builder_field_delim_, builder_column_default_list_, builder_column_name_list_,
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_,
builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_,
std::move(builder_sampler_));
builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_);
RETURN_IF_NOT_OK(csv_op->Init());
*op = std::move(csv_op);
@ -77,8 +72,8 @@ CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default,
const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer,
int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files,
int32_t num_device, int32_t device_id, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
int32_t num_device, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
csv_files_list_(std::move(csv_files_list)),
field_delim_(field_delim),
column_default_list_(column_default),
@ -920,16 +915,6 @@ Status CsvOp::ComputeColMap() {
return Status::OK();
}
// Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing
// a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so
// that this csv op will produce the full set of data into the cache.
void CsvOp::MakeSimpleProducer() {
device_id_ = 0;
num_devices_ = 1;
shuffle_files_ = false;
num_samples_ = 0;
}
// Visitor accept method for NodePass
Status CsvOp::Accept(NodePass *p, bool *const modified) {
// Downcast shared pointer then call visitor

View File

@ -241,14 +241,6 @@ class CsvOp : public ParallelOp {
return *this;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
private:
int32_t builder_device_id_;
int32_t builder_num_devices_;
@ -262,7 +254,6 @@ class CsvOp : public ParallelOp {
char builder_field_delim_;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_;
std::vector<std::string> builder_column_name_list_;
std::shared_ptr<SamplerRT> builder_sampler_;
};
// Constructor of CsvOp
@ -271,8 +262,7 @@ class CsvOp : public ParallelOp {
CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name,
int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id,
std::shared_ptr<SamplerRT> sampler);
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id);
// Default destructor
~CsvOp() = default;
@ -308,11 +298,6 @@ class CsvOp : public ParallelOp {
// @return Vector of the input file names
std::vector<std::string> FileNames() { return csv_files_list_; }
/// \Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so
/// that this csv op will produce the full set of data into the cache.
void MakeSimpleProducer();
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.

View File

@ -34,8 +34,7 @@ RandomDataOp::Builder::Builder()
builder_num_workers_(0),
builder_op_connector_size_(0),
builder_rows_per_buffer_(0),
builder_total_rows_(0),
builder_sampler_(nullptr) {
builder_total_rows_(0) {
// Some arguments to the RandomDataOp have a default argument that is taken from the config.
// The user may override these defaults by using the builder set methods.
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
@ -48,9 +47,8 @@ RandomDataOp::Builder::Builder()
Status RandomDataOp::Builder::Build(std::shared_ptr<RandomDataOp> *out_op) {
RETURN_IF_NOT_OK(SanityCheck());
*out_op =
std::make_shared<RandomDataOp>(builder_num_workers_, builder_op_connector_size_, builder_rows_per_buffer_,
builder_total_rows_, std::move(builder_data_schema_), std::move(builder_sampler_));
*out_op = std::make_shared<RandomDataOp>(builder_num_workers_, builder_op_connector_size_, builder_rows_per_buffer_,
builder_total_rows_, std::move(builder_data_schema_));
return Status::OK();
}
@ -65,8 +63,8 @@ Status RandomDataOp::Builder::SanityCheck() const {
// Constructor for RandomDataOp
RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
std::unique_ptr<DataSchema> data_schema)
: ParallelOp(num_workers, op_connector_size),
buffer_id_(0),
rows_per_buffer_(rows_per_buffer),
total_rows_(total_rows),
@ -80,8 +78,7 @@ RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64
if (total_rows_ == 0) {
total_rows_ = GenRandomInt(1, kMaxTotalRows);
}
// If the user did not provide a schema, then we will ask the op to generate a pseudo-random
// schema.
// If the user did not provide a schema, then we will ask the op to generate a pseudo-random schema.
// See details of generateSchema function to learn what type of schema it will create.
if (data_schema_ == nullptr) {
GenerateSchema();

View File

@ -117,14 +117,6 @@ class RandomDataOp : public ParallelOp {
return *this;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
private:
/**
* Check if the required parameters are set by the builder.
@ -133,7 +125,6 @@ class RandomDataOp : public ParallelOp {
Status SanityCheck() const;
std::unique_ptr<DataSchema> builder_data_schema_;
std::shared_ptr<SamplerRT> builder_sampler_;
int32_t builder_num_workers_;
int32_t builder_op_connector_size_;
int64_t builder_rows_per_buffer_;
@ -148,11 +139,10 @@ class RandomDataOp : public ParallelOp {
* @param rows_per_buffer - The number of rows in each DataBuffer
* @param data_schema - A user-provided schema
* @param total_rows - The total number of rows in the dataset
* @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes
* @return Builder - The modified builder by reference
*/
RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
std::unique_ptr<DataSchema> data_schema);
/**
* Destructor

View File

@ -34,11 +34,7 @@
namespace mindspore {
namespace dataset {
TextFileOp::Builder::Builder()
: builder_device_id_(0),
builder_num_devices_(1),
builder_total_rows_(0),
builder_shuffle_files_(false),
builder_sampler_(nullptr) {
: builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_shuffle_files_(false) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers();
builder_op_connector_size_ = config_manager->op_connector_size();
@ -74,7 +70,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
builder_num_workers_, builder_rows_per_buffer_, builder_total_rows_, builder_worker_connector_size_,
std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_,
builder_num_devices_, builder_device_id_, std::move(builder_sampler_));
builder_num_devices_, builder_device_id_);
RETURN_IF_NOT_OK(text_file_op->Init());
*op = std::move(text_file_op);
@ -83,9 +79,8 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size,
std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list,
int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id,
std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
device_id_(device_id),
num_devices_(num_device),
rows_per_buffer_(rows_per_buffer),
@ -504,16 +499,6 @@ Status TextFileOp::ComputeColMap() {
return Status::OK();
}
// Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing
// a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so
// that this text file op will produce the full set of data into the cache.
void TextFileOp::MakeSimpleProducer() {
device_id_ = 0;
num_devices_ = 1;
shuffle_files_ = false;
total_rows_ = 0;
}
// Visitor accept method for NodePass
Status TextFileOp::Accept(NodePass *p, bool *const modified) {
// Downcast shared pointer then call visitor

View File

@ -112,14 +112,6 @@ class TextFileOp : public ParallelOp {
return *this;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
private:
int32_t builder_device_id_;
int32_t builder_num_devices_;
@ -131,7 +123,6 @@ class TextFileOp : public ParallelOp {
std::vector<std::string> builder_text_files_list_;
bool builder_shuffle_files_;
std::unique_ptr<DataSchema> builder_schema_;
std::shared_ptr<SamplerRT> builder_sampler_;
};
// Constructor of TextFileOp
@ -145,10 +136,9 @@ class TextFileOp : public ParallelOp {
// @param columns_to_load - the names of the columns to load data from.
// @param shuffle_files - whether or not to shuffle the files before reading data.
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
// @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes
TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size,
std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<SamplerRT> sampler);
bool shuffle_files, int32_t num_devices, int32_t device_id);
// Default destructor
~TextFileOp() = default;
@ -187,11 +177,6 @@ class TextFileOp : public ParallelOp {
// @return Vector of the input file names
std::vector<std::string> FileNames() { return text_files_list_; }
/// \Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so
/// that this text file op will produce the full set of data into the cache.
void MakeSimpleProducer();
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.

View File

@ -44,11 +44,7 @@
namespace mindspore {
namespace dataset {
TFReaderOp::Builder::Builder()
: builder_device_id_(0),
builder_num_devices_(1),
builder_total_rows_(0),
builder_equal_rows_per_shard_(false),
builder_sampler_(nullptr) {
: builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_equal_rows_per_shard_(false) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers();
builder_worker_connector_size_ = config_manager->worker_connector_size();
@ -122,8 +118,7 @@ Status TFReaderOp::Builder::Build(std::shared_ptr<TFReaderOp> *out_tf_reader_op)
std::shared_ptr<TFReaderOp> new_tf_reader_op = std::make_shared<TFReaderOp>(
builder_num_workers_, builder_worker_connector_size_, builder_rows_per_buffer_, builder_total_rows_,
builder_dataset_files_list_, std::move(builder_data_schema_), builder_op_connector_size_, builder_columns_to_load_,
builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_,
std::move(builder_sampler_));
builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_);
RETURN_IF_NOT_OK(new_tf_reader_op->Init());
*out_tf_reader_op = std::move(new_tf_reader_op);
@ -134,8 +129,8 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64
int64_t total_num_rows, std::vector<std::string> dataset_files_list,
std::unique_ptr<DataSchema> data_schema, int32_t op_connector_size,
std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_device,
int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
int32_t device_id, bool equal_rows_per_shard)
: ParallelOp(num_workers, op_connector_size),
device_id_(device_id),
num_devices_(num_device),
rows_per_buffer_(rows_per_buffer),
@ -1043,17 +1038,6 @@ Status TFReaderOp::ComputeColMap() {
return Status::OK();
}
// Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing
// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so
// that this tf reader will produce the full set of data into the cache.
void TFReaderOp::MakeSimpleProducer() {
device_id_ = 0;
num_devices_ = 1;
total_rows_ = 0;
shuffle_files_ = false;
equal_rows_per_shard_ = false;
}
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
Status TFReaderOp::PrepareNodePostAction() {

View File

@ -153,17 +153,8 @@ class TFReaderOp : public ParallelOp {
return *this;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
private:
std::unique_ptr<DataSchema> builder_data_schema_;
std::shared_ptr<SamplerRT> builder_sampler_;
int32_t builder_device_id_;
int32_t builder_num_devices_;
int32_t builder_num_workers_;
@ -189,11 +180,10 @@ class TFReaderOp : public ParallelOp {
// @param columns_to_load - the names of the columns to load data from.
// @param shuffle_files - whether or not to shuffle the files before reading data.
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
// @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes
TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows,
std::vector<std::string> dataset_files_list, std::unique_ptr<DataSchema> data_schema,
int32_t op_connector_size, std::vector<std::string> columns_to_load, bool shuffle_files,
int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<SamplerRT> sampler);
int32_t num_devices, int32_t device_id, bool equal_rows_per_shard);
// Default destructor
~TFReaderOp() = default;
@ -246,11 +236,6 @@ class TFReaderOp : public ParallelOp {
// @return Vector of the input file names
std::vector<std::string> FileNames() { return dataset_files_list_; }
/// \Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so
/// that this tf reader will produce the full set of data into the cache.
void MakeSimpleProducer();
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
@ -387,7 +372,7 @@ class TFReaderOp : public ParallelOp {
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);
// Caculate number of rows in each shard.
// Calculate number of rows in each shard.
// @return Status - the error code returned.
Status CalculateNumRowsPerShard();

View File

@ -320,7 +320,6 @@ Status ExecutionTree::PostAction() {
// The IR version cannot detect an invalid case of a cache on Map with random tensor operation from Python API.
// This is because Python API binding to TensorOperation is still in progress.
post_actions.push_back(std::make_unique<CacheErrorPass>());
post_actions.push_back(std::make_unique<CacheTransformPass>());
post_actions.push_back(std::make_unique<RepeatPass>());
#endif

View File

@ -19,6 +19,7 @@
#include <memory>
#include "minddata/dataset/engine/datasetops/dataset_op.h"
#include "minddata/dataset/include/samplers.h"
#include "minddata/dataset/util/status.h"
namespace mindspore::dataset {
@ -29,6 +30,9 @@ class DatasetCache {
virtual Status ValidateParams() = 0;
virtual Status CreateCacheOp(int num_workers, std::shared_ptr<DatasetOp> *ds_op) = 0;
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
virtual Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
std::shared_ptr<SamplerObj> sampler) = 0;
virtual Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) = 0;
};
} // namespace mindspore::dataset

View File

@ -16,6 +16,8 @@
#include <memory>
#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
namespace mindspore {
@ -44,5 +46,28 @@ Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr<Data
return Status::OK();
}
Status DatasetCacheImpl::CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
std::shared_ptr<SamplerObj> sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
std::shared_ptr<CacheLookupOp> lookup_op = nullptr;
RETURN_IF_NOT_OK(CacheLookupOp::Builder()
.SetNumWorkers(num_workers)
.SetClient(cache_client_)
.SetSampler(sampler->SamplerBuild())
.Build(&lookup_op));
*ds = lookup_op;
return Status::OK();
}
Status DatasetCacheImpl::CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
std::shared_ptr<CacheMergeOp> merge_op = nullptr;
RETURN_IF_NOT_OK(CacheMergeOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&merge_op));
*ds = merge_op;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -56,6 +56,11 @@ class DatasetCacheImpl : public DatasetCache {
Status CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override;
Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
std::shared_ptr<SamplerObj> sampler) override;
Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override;
Status ValidateParams() override { return Status::OK(); }
~DatasetCacheImpl() = default;

View File

@ -16,6 +16,8 @@
#include <memory>
#include "minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h"
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
namespace mindspore {
@ -46,5 +48,29 @@ Status PreBuiltDatasetCache::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
Status PreBuiltDatasetCache::CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
std::shared_ptr<SamplerObj> sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
std::shared_ptr<CacheLookupOp> lookup_op = nullptr;
RETURN_IF_NOT_OK(CacheLookupOp::Builder()
.SetNumWorkers(num_workers)
.SetClient(cache_client_)
.SetSampler(sampler->SamplerBuild())
.Build(&lookup_op));
*ds = lookup_op;
return Status::OK();
}
Status PreBuiltDatasetCache::CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
std::shared_ptr<CacheMergeOp> merge_op = nullptr;
RETURN_IF_NOT_OK(CacheMergeOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&merge_op));
*ds = merge_op;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -40,6 +40,11 @@ class PreBuiltDatasetCache : public DatasetCache {
Status CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *const ds) override;
Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
std::shared_ptr<SamplerObj> sampler) override;
Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override;
Status ValidateParams() override { return Status::OK(); }
Status to_json(nlohmann::json *out_json) override;

View File

@ -8,6 +8,9 @@ set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
bucket_batch_by_length_node.cc
build_sentence_piece_vocab_node.cc
build_vocab_node.cc
cache_lookup_node.cc
cache_merge_node.cc
cache_node.cc
concat_node.cc
epoch_ctrl_node.cc
filter_node.cc

View File

@ -0,0 +1,70 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
CacheLookupNode::CacheLookupNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)), sampler_(sampler), lookup_op_(nullptr), lookup_node_copy_(nullptr) {
this->AddChild(child);
}
void CacheLookupNode::Print(std::ostream &out) const { out << Name(); }
std::shared_ptr<DatasetNode> CacheLookupNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<CacheLookupNode>(nullptr, sampler, cache_);
lookup_node_copy_ = node;
return node;
}
Status CacheLookupNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetSampler("CacheNode", sampler_));
return Status::OK();
}
Status CacheLookupNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_ != nullptr,
"Internal error. Attempt to create a cache lookup node without cache client.");
RETURN_IF_NOT_OK(cache_->Build());
RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, &lookup_op_, sampler_));
node_ops->push_back(lookup_op_);
return Status::OK();
}
std::shared_ptr<SamplerObj> CacheLookupNode::SamplerCopy() {
// CacheLookupNode should already been copied, so we just return it here
return std::static_pointer_cast<SamplerObj>(lookup_node_copy_);
}
std::shared_ptr<SamplerRT> CacheLookupNode::SamplerBuild() {
// Runtime cache lookup op should already been built, so we just return it here
auto lookup_op = std::dynamic_pointer_cast<CacheLookupOp>(lookup_op_);
return std::shared_ptr<SamplerRT>(lookup_op);
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,75 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_LOOKUP_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_LOOKUP_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
class CacheLookupNode : public DatasetNode, public SamplerObj {
public:
/// \brief Constructor
CacheLookupNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache);
/// \brief Destructor
~CacheLookupNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kCacheLookupNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to convert a SamplerObj class into a runtime sampler object
/// \return Shared pointers to the newly created Sampler
std::shared_ptr<SamplerRT> SamplerBuild() override;
/// \brief a base class override function to copy a SamplerObj class
/// \return Shared pointers to the newly copied SamplerObj
std::shared_ptr<SamplerObj> SamplerCopy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create
/// \return Status Status::OK() if build successfully
Status Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
std::shared_ptr<SamplerObj> sampler_;
std::shared_ptr<DatasetOp> lookup_op_;
std::shared_ptr<CacheLookupNode> lookup_node_copy_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_LOOKUP_NODE_H_

View File

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
CacheMergeNode::CacheMergeNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)) {
nary_op_ = true;
this->AddChild(child);
}
void CacheMergeNode::Print(std::ostream &out) const { out << Name(); }
std::shared_ptr<DatasetNode> CacheMergeNode::Copy() {
auto node = std::make_shared<CacheMergeNode>(nullptr, cache_);
return node;
}
Status CacheMergeNode::ValidateParams() { return Status::OK(); }
Status CacheMergeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_ != nullptr,
"Internal error. Attempt to create a cache merge node without cache client.");
RETURN_IF_NOT_OK(cache_->Build());
std::shared_ptr<DatasetOp> merge_op = nullptr;
RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, &merge_op));
node_ops->push_back(merge_op);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,60 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_MERGE_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_MERGE_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
class CacheMergeNode : public DatasetNode {
public:
/// \brief Constructor
CacheMergeNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<DatasetCache> cache);
/// \brief Destructor
~CacheMergeNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kCacheMergeNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create
/// \return Status Status::OK() if build successfully
Status Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_MERGE_NODE_H_

View File

@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/cache_node.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
CacheNode::CacheNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)), sampler_(sampler) {
this->AddChild(child);
}
void CacheNode::Print(std::ostream &out) const { out << Name(); }
std::shared_ptr<DatasetNode> CacheNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<CacheNode>(nullptr, sampler, cache_);
return node;
}
Status CacheNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetSampler("CacheNode", sampler_));
return Status::OK();
}
Status CacheNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_ != nullptr,
"Internal error. Attempt to create a cache node without cache client.");
RETURN_IF_NOT_OK(cache_->Build());
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op));
cache_op->SetSampler(sampler_->SamplerBuild());
node_ops->push_back(cache_op);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,64 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
class CacheNode : public DatasetNode {
public:
/// \brief Constructor
CacheNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache);
/// \brief Destructor
~CacheNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kCacheNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create
/// \return Status Status::OK() if build successfully
Status Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
std::shared_ptr<SamplerObj> sampler_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_NODE_H_

View File

@ -204,15 +204,6 @@ std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int
return SequentialSampler(0, num_samples);
}
Status DatasetNode::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
if (cache_ != nullptr) {
RETURN_IF_NOT_OK(cache_->Build());
std::shared_ptr<DatasetOp> cache_op;
RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op));
node_ops->push_back(cache_op);
}
return Status::OK();
}
// Constructor to initialize the cache
DatasetNode::DatasetNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode() { cache_ = dataset_cache; }

View File

@ -53,6 +53,9 @@ constexpr char kBatchNode[] = "Batch";
constexpr char kBucketBatchByLengthNode[] = "BucketBatchByLength";
constexpr char kBuildSentencePieceVocabNode[] = "BuildSentencePieceVocab";
constexpr char kBuildVocabNode[] = "BuildVocab";
constexpr char kCacheLookupNode[] = "CacheLookup";
constexpr char kCacheMergeNode[] = "CacheMerge";
constexpr char kCacheNode[] = "Cache";
constexpr char kConcatNode[] = "Concat";
constexpr char kEpochCtrlNode[] = "EpochCtrl";
constexpr char kFilterNode[] = "Filter";
@ -248,6 +251,9 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \brief Getter of the number of workers
int32_t num_workers() { return num_workers_; }
/// \brief Getter of dataset cache
std::shared_ptr<DatasetCache> GetDatasetCache() { return cache_; }
/// \brief Setter function for runtime number of workers
/// \param[in] num_workers The number of threads in this operator
/// \return Shared pointer to the original object
@ -299,7 +305,6 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
// Used only in the constructor of the class and its derived classes.
void AddChild(std::shared_ptr<DatasetNode> child);
std::string PrintColumns(const std::vector<std::string> &columns) const;
Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops);
void PrintNode(std::ostream &out, int *level) const;
enum DataSource { kNotADataSource = 0, kNonMappableSource = 1, kMappableSource = 2 };
enum DataSource mappable_;
@ -360,6 +365,20 @@ class NonMappableSourceNode : public DatasetNode {
/// \brief Node name getter
/// \return Name of the current node
virtual std::string Name() const = 0;
/// \brief By default non-mappable dataset does not support sampling. However, if a cache operator
/// is injected at some other place higher in the tree, that cache can inherit this sampler
/// from the leaf, providing sampling support from the caching layer.
/// This function sets up the sampler for a leaf node that does not use sampling.
/// \param[in] sampler The sampler to setup
/// \return Status of the function
virtual Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) = 0;
/// \brief If a cache has been added into the ascendant tree over this non-mappable source node, then the cache will
/// be executing a sampler for fetching the data. As such, any options in the source node need to be reset to its
/// defaults so that this source node will produce the full set of data into the cache.
/// \return Status of the function
virtual Status MakeSimpleProducer() = 0;
};
} // namespace dataset
} // namespace mindspore

View File

@ -76,7 +76,6 @@ Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto project_op = std::make_shared<ProjectOp>(project_columns_);
node_ops->push_back(project_op);
}
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
node_ops->push_back(map_op);
return Status::OK();

View File

@ -72,8 +72,6 @@ Status AlbumNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
// Argument that is not exposed to user in the API.
std::set<std::string> extensions = {};
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
node_ops->push_back(std::make_shared<AlbumOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
decode_, extensions, std::move(schema),
std::move(sampler_->SamplerBuild())));

View File

@ -67,8 +67,6 @@ Status CelebANode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
// label is like this:0 1 0 0 1......
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
node_ops->push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
decode_, usage_, extensions_, std::move(schema),
std::move(sampler_->SamplerBuild())));

View File

@ -64,8 +64,6 @@ Status Cifar100Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_,
dataset_dir_, connector_que_size_, std::move(schema),
std::move(sampler_->SamplerBuild())));

View File

@ -62,8 +62,6 @@ Status Cifar10Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_op
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_,
dataset_dir_, connector_que_size_, std::move(schema),
std::move(sampler_->SamplerBuild())));

View File

@ -83,84 +83,66 @@ std::vector<std::string> CLUENode::split(const std::string &s, char delim) {
return res;
}
// Function to build CLUENode
Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
std::map<std::string, std::string> CLUENode::CreateKeyMapForBuild() {
std::map<std::string, std::string> key_map;
if (task_ == "AFQMC") {
if (usage_ == "train") {
if (usage_ == "train" || usage_ == "eval") {
key_map["sentence1"] = "sentence1";
key_map["sentence2"] = "sentence2";
key_map["label"] = "label";
} else if (usage_ == "test") {
} else { // usage_ == "test"
key_map["id"] = "id";
key_map["sentence1"] = "sentence1";
key_map["sentence2"] = "sentence2";
} else if (usage_ == "eval") {
key_map["sentence1"] = "sentence1";
key_map["sentence2"] = "sentence2";
key_map["label"] = "label";
}
} else if (task_ == "CMNLI") {
if (usage_ == "train") {
}
if (task_ == "CMNLI") {
if (usage_ == "train" || usage_ == "eval") {
key_map["sentence1"] = "sentence1";
key_map["sentence2"] = "sentence2";
key_map["label"] = "label";
} else if (usage_ == "test") {
} else { // usage_ == "test"
key_map["id"] = "id";
key_map["sentence1"] = "sentence1";
key_map["sentence2"] = "sentence2";
} else if (usage_ == "eval") {
key_map["sentence1"] = "sentence1";
key_map["sentence2"] = "sentence2";
key_map["label"] = "label";
}
} else if (task_ == "CSL") {
if (usage_ == "train") {
}
if (task_ == "CSL") {
if (usage_ == "train" || usage_ == "eval") {
key_map["id"] = "id";
key_map["abst"] = "abst";
key_map["keyword"] = "keyword";
key_map["label"] = "label";
} else if (usage_ == "test") {
} else { // usage_ == "test"
key_map["id"] = "id";
key_map["abst"] = "abst";
key_map["keyword"] = "keyword";
} else if (usage_ == "eval") {
key_map["id"] = "id";
key_map["abst"] = "abst";
key_map["keyword"] = "keyword";
key_map["label"] = "label";
}
} else if (task_ == "IFLYTEK") {
if (usage_ == "train") {
}
if (task_ == "IFLYTEK") {
if (usage_ == "train" || usage_ == "eval") {
key_map["label"] = "label";
key_map["label_des"] = "label_des";
key_map["sentence"] = "sentence";
} else if (usage_ == "test") {
} else { // usage_ == "test"
key_map["id"] = "id";
key_map["sentence"] = "sentence";
} else if (usage_ == "eval") {
key_map["label"] = "label";
key_map["label_des"] = "label_des";
key_map["sentence"] = "sentence";
}
} else if (task_ == "TNEWS") {
if (usage_ == "train") {
}
if (task_ == "TNEWS") {
if (usage_ == "train" || usage_ == "eval") {
key_map["label"] = "label";
key_map["label_desc"] = "label_desc";
key_map["sentence"] = "sentence";
key_map["keywords"] = "keywords";
} else if (usage_ == "test") {
} else { // usage_ == "test"
key_map["id"] = "id";
key_map["sentence"] = "sentence";
key_map["keywords"] = "keywords";
} else if (usage_ == "eval") {
key_map["label"] = "label";
key_map["label_desc"] = "label_desc";
key_map["sentence"] = "sentence";
key_map["keywords"] = "keywords";
}
} else if (task_ == "WSC") {
if (usage_ == "train") {
}
if (task_ == "WSC") {
if (usage_ == "train" || usage_ == "eval") {
key_map["span1_index"] = "target/span1_index";
key_map["span2_index"] = "target/span2_index";
key_map["span1_text"] = "target/span1_text";
@ -168,24 +150,21 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
key_map["idx"] = "idx";
key_map["label"] = "label";
key_map["text"] = "text";
} else if (usage_ == "test") {
} else { // usage_ == "test"
key_map["span1_index"] = "target/span1_index";
key_map["span2_index"] = "target/span2_index";
key_map["span1_text"] = "target/span1_text";
key_map["span2_text"] = "target/span2_text";
key_map["idx"] = "idx";
key_map["text"] = "text";
} else if (usage_ == "eval") {
key_map["span1_index"] = "target/span1_index";
key_map["span2_index"] = "target/span2_index";
key_map["span1_text"] = "target/span1_text";
key_map["span2_text"] = "target/span2_text";
key_map["idx"] = "idx";
key_map["label"] = "label";
key_map["text"] = "text";
}
}
return key_map;
}
// Function to build CLUENode
Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto key_map = CreateKeyMapForBuild();
ColKeyMap ck_map;
for (auto &p : key_map) {
ck_map.insert({p.first, split(p.second, '/')});
@ -193,19 +172,13 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
// ClueOp by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we save the sampler here in a leaf node that does not use sampling.
std::shared_ptr<SamplerObj> sampler_ = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
// Sort the dataset files in a lexicographical order
std::vector<std::string> sorted_dataset_files = dataset_files_;
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>(
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, sorted_dataset_files,
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->SamplerBuild()));
std::shared_ptr<ClueOp> clue_op =
std::make_shared<ClueOp>(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map,
sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_);
RETURN_IF_NOT_OK(clue_op->Init());
@ -222,7 +195,6 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
rows_per_buffer_, &shuffle_op));
node_ops->push_back(shuffle_op);
}
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
node_ops->push_back(clue_op);
@ -270,5 +242,27 @@ Status CLUENode::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
// CLUE by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we setup the sampler for a leaf node that does not use sampling.
Status CLUENode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
return Status::OK();
}
// If a cache has been added into the ascendant tree over this clue node, then the cache will be executing
// a sampler for fetching the data. As such, any options in the clue node need to be reset to its defaults so
// that this clue node will produce the full set of data into the cache.
Status CLUENode::MakeSimpleProducer() {
shard_id_ = 0;
num_shards_ = 1;
shuffle_ = ShuffleMode::kFalse;
num_samples_ = 0;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CLUE_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CLUE_NODE_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
@ -49,6 +50,10 @@ class CLUENode : public NonMappableSourceNode {
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief Generate a key map to be used in Build() according to usage and task
/// \return The generated key map
std::map<std::string, std::string> CreateKeyMapForBuild();
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create
/// \return Status Status::OK() if build successfully
@ -85,6 +90,22 @@ class CLUENode : public NonMappableSourceNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief CLUE by itself is a non-mappable dataset that does not support sampling.
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
/// That is why we setup the sampler for a leaf node that does not use sampling.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
/// \param[in] sampler The sampler to setup
/// \return Status of the function
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
/// \brief If a cache has been added into the ascendant tree over this clue node, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the clue node need to be reset to its defaults so
/// that this clue node will produce the full set of data into the cache.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
/// \return Status of the function
Status MakeSimpleProducer() override;
private:
/// \brief Split string based on a character delimiter
/// \return A string vector

View File

@ -122,7 +122,6 @@ Status CocoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
std::shared_ptr<CocoOp> op =
std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_,
connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild()));
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
node_ops->push_back(op);

View File

@ -95,12 +95,6 @@ Status CSVNode::ValidateParams() {
Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
// CSVOp by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we save the sampler here in a leaf node that does not use sampling.
std::shared_ptr<SamplerObj> sampler_ = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
// Sort the dataset files in a lexicographical order
std::vector<std::string> sorted_dataset_files = dataset_files_;
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
@ -119,10 +113,9 @@ Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
}
}
std::shared_ptr<CsvOp> csv_op =
std::make_shared<CsvOp>(sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_,
rows_per_buffer_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files,
num_shards_, shard_id_, std::move(sampler_->SamplerBuild()));
std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>(
sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, rows_per_buffer_,
num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_);
RETURN_IF_NOT_OK(csv_op->Init());
@ -140,7 +133,6 @@ Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
node_ops->push_back(shuffle_op);
}
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
node_ops->push_back(csv_op);
@ -188,5 +180,27 @@ Status CSVNode::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
// CSV by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we setup the sampler for a leaf node that does not use sampling.
Status CSVNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
return Status::OK();
}
// If a cache has been added into the ascendant tree over this CSV node, then the cache will be executing
// a sampler for fetching the data. As such, any options in the CSV node need to be reset to its defaults so
// that this CSV node will produce the full set of data into the cache.
Status CSVNode::MakeSimpleProducer() {
shard_id_ = 0;
num_shards_ = 1;
shuffle_ = ShuffleMode::kFalse;
num_samples_ = 0;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -107,6 +107,22 @@ class CSVNode : public NonMappableSourceNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief CSV by itself is a non-mappable dataset that does not support sampling.
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
/// That is why we setup the sampler for a leaf node that does not use sampling.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
/// \param[in] sampler The sampler to setup
/// \return Status of the function
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
/// \brief If a cache has been added into the ascendant tree over this CSV node, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the CSV node need to be reset to its defaults so
/// that this CSV node will produce the full set of data into the cache.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
/// \return Status of the function
Status MakeSimpleProducer() override;
private:
std::vector<std::string> dataset_files_;
char field_delim_;

View File

@ -95,10 +95,10 @@ class GeneratorNode : public MappableSourceNode {
/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return nullptr; }
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
/// \brief Sampler setter
void SetSampler(std::shared_ptr<SamplerObj> sampler) override {}
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
private:
py::function generator_function_;

View File

@ -70,8 +70,6 @@ Status ImageFolderNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const nod
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar)));
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
node_ops->push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
recursive_, decode_, exts_, class_indexing_, std::move(schema),
std::move(sampler_->SamplerBuild())));

View File

@ -94,7 +94,6 @@ Status ManifestNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
manifest_op =
std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_,
class_index_, std::move(schema), std::move(sampler_->SamplerBuild()), usage_);
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
node_ops->push_back(manifest_op);

View File

@ -23,8 +23,9 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
@ -203,5 +204,16 @@ Status MindDataNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
return Status::OK();
}
// Visitor accepting method for IRNodePass
Status MindDataNode::Accept(IRNodePass *const p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<MindDataNode>(), modified);
}
// Visitor accepting method for IRNodePass
Status MindDataNode::AcceptAfter(IRNodePass *p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<MindDataNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -92,6 +92,18 @@ class MindDataNode : public MappableSourceNode {
/// \brief Sampler setter
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(IRNodePass *p, bool *const modified) override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(IRNodePass *p, bool *const modified) override;
private:
std::string dataset_file_; // search_for_pattern_ will be true in this mode
std::vector<std::string> dataset_files_; // search_for_pattern_ will be false in this mode

View File

@ -57,7 +57,6 @@ Status MnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
node_ops->push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_,
connector_que_size_, std::move(schema),

View File

@ -22,6 +22,7 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/status.h"
@ -105,17 +106,9 @@ Status RandomNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
}
}
// RandomOp by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we save the sampler here in a leaf node that does not use sampling.
// RandomOp doesn't support sampler, should not support sharding, select sampler should just be sequential.
std::shared_ptr<SamplerObj> sampler_ = SelectSampler(total_rows_, false, 1, 0);
std::shared_ptr<RandomDataOp> op;
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_,
std::move(data_schema_), std::move(sampler_->SamplerBuild()));
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
std::move(data_schema_));
node_ops->push_back(op);
@ -142,5 +135,27 @@ Status RandomNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
dataset_size_ = *dataset_size;
return Status::OK();
}
// RandomDataset by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we setup the sampler for a leaf node that does not use sampling.
Status RandomNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
// RandomOp doesn't support sampler, should not support sharding, select sampler should just be sequential.
*sampler = SelectSampler(total_rows_, false, 1, 0);
return Status::OK();
}
// Visitor accepting method for IRNodePass
Status RandomNode::Accept(IRNodePass *p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<RandomNode>(), modified);
}
// Visitor accepting method for IRNodePass
Status RandomNode::AcceptAfter(IRNodePass *p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<RandomNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -99,6 +99,30 @@ class RandomNode : public NonMappableSourceNode {
const std::mt19937 &RandGen() const { return rand_gen_; }
const std::unique_ptr<DataSchema> &GetDataSchema() const { return data_schema_; }
/// \brief RandomDataset by itself is a non-mappable dataset that does not support sampling.
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
/// That is why we setup the sampler for a leaf node that does not use sampling.
/// \param[in] sampler The sampler to setup
/// \return Status of the function
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
/// \brief Random node will always produce the full set of data into the cache
/// \return Status of the function
Status MakeSimpleProducer() override { return Status::OK(); }
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(IRNodePass *p, bool *const modified) override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(IRNodePass *p, bool *const modified) override;
private:
/// \brief A quick inline for producing a random number between (and including) min/max
/// \param[in] min minimum number that can be generated.

View File

@ -73,12 +73,6 @@ Status TextFileNode::ValidateParams() {
Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
// TextFileOp by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we save the sampler here in a leaf node that does not use sampling.
std::shared_ptr<SamplerObj> sampler_ = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
// Sort the dataset files in a lexicographical order
std::vector<std::string> sorted_dataset_files = dataset_files_;
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
@ -87,10 +81,10 @@ Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
// Create and initalize TextFileOp
// Create and initialize TextFileOp
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files,
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->SamplerBuild()));
connector_que_size_, shuffle_files, num_shards_, shard_id_);
RETURN_IF_NOT_OK(text_file_op->Init());
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) {
@ -106,7 +100,6 @@ Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
rows_per_buffer_, &shuffle_op));
node_ops->push_back(shuffle_op);
}
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
// Add TextFileOp
node_ops->push_back(text_file_op);
@ -152,5 +145,27 @@ Status TextFileNode::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
// TextFile by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we setup the sampler for a leaf node that does not use sampling.
Status TextFileNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
return Status::OK();
}
// If a cache has been added into the ascendant tree over this TextFile node, then the cache will be executing
// a sampler for fetching the data. As such, any options in the TextFile node need to be reset to its defaults so
// that this TextFile node will produce the full set of data into the cache.
Status TextFileNode::MakeSimpleProducer() {
shard_id_ = 0;
num_shards_ = 1;
shuffle_ = ShuffleMode::kFalse;
num_samples_ = 0;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -83,6 +83,22 @@ class TextFileNode : public NonMappableSourceNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief TextFile by itself is a non-mappable dataset that does not support sampling.
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
/// That is why we setup the sampler for a leaf node that does not use sampling.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
/// \param[in] sampler The sampler to setup
/// \return Status of the function
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
/// \brief If a cache has been added into the ascendant tree over this TextFile node, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the TextFile node need to be reset to its defaults
/// so that this TextFile node will produce the full set of data into the cache.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
/// \return Status of the function
Status MakeSimpleProducer() override;
private:
std::vector<std::string> dataset_files_;
int32_t num_samples_;

View File

@ -121,17 +121,10 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
// TFReaderOp by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we save the sampler here in a leaf node that does not use sampling.
std::shared_ptr<SamplerObj> sampler_ = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
// Create and initialize TFReaderOp
std::shared_ptr<TFReaderOp> tf_reader_op =
std::make_shared<TFReaderOp>(num_workers_, worker_connector_size_, rows_per_buffer_, num_samples_, sorted_dir_files,
std::move(data_schema), connector_que_size_, columns_list_, shuffle_files, num_shards_,
shard_id_, shard_equal_rows_, std::move(sampler_->SamplerBuild()));
std::shared_ptr<TFReaderOp> tf_reader_op = std::make_shared<TFReaderOp>(
num_workers_, worker_connector_size_, rows_per_buffer_, num_samples_, sorted_dir_files, std::move(data_schema),
connector_que_size_, columns_list_, shuffle_files, num_shards_, shard_id_, shard_equal_rows_);
RETURN_IF_NOT_OK(tf_reader_op->Init());
@ -149,7 +142,6 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
rows_per_buffer_, &shuffle_op));
node_ops->push_back(shuffle_op);
}
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
// Add TFReaderOp
node_ops->push_back(tf_reader_op);
@ -227,5 +219,29 @@ Status TFRecordNode::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
// TFRecord by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we setup the sampler for a leaf node that does not use sampling.
Status TFRecordNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
return Status::OK();
}
// If a cache has been added into the ascendant tree over this TFRecord node, then the cache will be executing
// a sampler for fetching the data. As such, any options in the TFRecord node need to be reset to its defaults so
// that this TFRecord node will produce the full set of data into the cache.
Status TFRecordNode::MakeSimpleProducer() {
shard_id_ = 0;
num_shards_ = 1;
shuffle_ = ShuffleMode::kFalse;
num_samples_ = 0;
shard_equal_rows_ = false;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -124,6 +124,22 @@ class TFRecordNode : public NonMappableSourceNode {
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
/// \brief TFRecord by itself is a non-mappable dataset that does not support sampling.
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
/// That is why we setup the sampler for a leaf node that does not use sampling.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
/// \param[in] sampler The sampler to setup
/// \return Status of the function
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
/// \brief If a cache has been added into the ascendant tree over this TFRecord node, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the TFRecord node need to be reset to its defaults
/// so that this TFRecord node will produce the full set of data into the cache.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
/// \return Status of the function
Status MakeSimpleProducer() override;
private:
std::vector<std::string> dataset_files_;
std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string

View File

@ -113,7 +113,6 @@ Status VOCNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
voc_op =
std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild()));
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
node_ops->push_back(voc_op);
return Status::OK();

View File

@ -31,8 +31,14 @@
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#endif
#ifdef ENABLE_PYTHON
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
#ifdef ENABLE_PYTHON
#include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
@ -195,10 +201,10 @@ Status IRNodePass::VisitAfter(std::shared_ptr<FilterNode> node, bool *const modi
}
#ifdef ENABLE_PYTHON
Status IRNodePass::Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
return Visit(std::static_pointer_cast<MappableSourceNode>(node), modified);
}
Status IRNodePass::VisitAfter(std::shared_ptr<GeneratorNode> node, bool *const modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
return VisitAfter(std::static_pointer_cast<MappableSourceNode>(node), modified);
}
#endif
Status IRNodePass::Visit(std::shared_ptr<MapNode> node, bool *const modified) {
@ -207,12 +213,26 @@ Status IRNodePass::Visit(std::shared_ptr<MapNode> node, bool *const modified) {
Status IRNodePass::VisitAfter(std::shared_ptr<MapNode> node, bool *const modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#ifndef ENABLE_ANDROID
Status IRNodePass::Visit(std::shared_ptr<MindDataNode> node, bool *const modified) {
return Visit(std::static_pointer_cast<MappableSourceNode>(node), modified);
}
Status IRNodePass::VisitAfter(std::shared_ptr<MindDataNode> node, bool *const modified) {
return VisitAfter(std::static_pointer_cast<MappableSourceNode>(node), modified);
}
#endif
Status IRNodePass::Visit(std::shared_ptr<ProjectNode> node, bool *const modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status IRNodePass::VisitAfter(std::shared_ptr<ProjectNode> node, bool *const modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status IRNodePass::Visit(std::shared_ptr<RandomNode> node, bool *const modified) {
return Visit(std::static_pointer_cast<NonMappableSourceNode>(node), modified);
}
Status IRNodePass::VisitAfter(std::shared_ptr<RandomNode> node, bool *const modified) {
return VisitAfter(std::static_pointer_cast<NonMappableSourceNode>(node), modified);
}
Status IRNodePass::Visit(std::shared_ptr<RenameNode> node, bool *const modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

View File

@ -44,7 +44,6 @@ class TakeNode;
class TransferNode;
class ZipNode;
#ifdef ENABLE_PYTHON
class GeneratorNode;
class SyncWaitNode;
#endif
#ifndef ENABLE_ANDROID
@ -129,14 +128,14 @@ class IRPass : public std::enable_shared_from_this<IRPass> {
class IRTreePass : public IRPass {
public:
/// \brief Run the transformation pass against the IR tree.
/// \param[inout] root_ir Pointer to the IR tree to be transformed.
/// \param[inout] modified Indicate if the tree was modified
/// \param[in,out] root_ir Pointer to the IR tree to be transformed.
/// \param[in,out] modified Indicate if the tree was modified
Status Run(std::shared_ptr<DatasetNode> root_ir, bool *const modified) final;
/// \brief Derived classes may implement the runOnTree function to implement tree transformation.
/// "modified" flag needs to be set to true if tree is modified during the pass execution.
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate if the tree was modified.
/// \param[in,out] tree The tree to operate on.
/// \param[in,out] Indicate if the tree was modified.
/// \return Status The status code returned
virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) { return Status::OK(); }
};
@ -164,8 +163,8 @@ class IRNodePass : public IRPass {
~IRNodePass() = default;
/// \brief Run the transformation pass against the IR tree
/// \param[inout] root_ir Pointer to the IR tree to be transformed
/// \param[inout] modified Indicator if the tree was changed
/// \param[in,out] root_ir Pointer to the IR tree to be transformed
/// \param[in,out] modified Indicator if the tree was changed
Status Run(std::shared_ptr<DatasetNode> root_ir, bool *const modified) final;
/// \brief Derived classes may implement the Visit function to implement any initial visit work on the way down
@ -210,8 +209,14 @@ class IRNodePass : public IRPass {
#endif
virtual Status Visit(std::shared_ptr<MapNode> node, bool *const modified);
virtual Status VisitAfter(std::shared_ptr<MapNode> node, bool *const modified);
#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<MindDataNode> node, bool *const modified);
virtual Status VisitAfter(std::shared_ptr<MindDataNode> node, bool *const modified);
#endif
virtual Status Visit(std::shared_ptr<ProjectNode> node, bool *const modified);
virtual Status VisitAfter(std::shared_ptr<ProjectNode> node, bool *const modified);
virtual Status Visit(std::shared_ptr<RandomNode> node, bool *const modified);
virtual Status VisitAfter(std::shared_ptr<RandomNode> node, bool *const modified);
virtual Status Visit(std::shared_ptr<RenameNode> node, bool *const modified);
virtual Status VisitAfter(std::shared_ptr<RenameNode> node, bool *const modified);
virtual Status Visit(std::shared_ptr<RepeatNode> node, bool *const modified);
@ -270,14 +275,14 @@ class Pass : public std::enable_shared_from_this<Pass> {
class TreePass : public Pass {
public:
/// \brief Run the transformation pass against the execution tree.
/// \param[inout] tree Pointer to the execution tree to be transformed.
/// \param[inout] modified Indicate if the tree was modified
/// \param[in,out] tree Pointer to the execution tree to be transformed.
/// \param[in,out] modified Indicate if the tree was modified
Status Run(ExecutionTree *tree, bool *const modified) final;
/// \brief Derived classes may implement the runOnTree function to implement tree transformation.
/// "modified" flag needs to be set to true if tree is modified during the pass execution.
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \param[in,out] tree The tree to operate on.
/// \param[in,out] Indicate of the tree was modified.
/// \return Status The status code returned
virtual Status RunOnTree(ExecutionTree *tree, bool *const modified) { return Status::OK(); }
};
@ -305,8 +310,8 @@ class NodePass : public Pass {
~NodePass() = default;
/// \brief Run the transformation pass against the execution tree
/// \param[inout] tree Pointer to the execution tree to be transformed
/// \param[inout] modified Indicator if the tree was changed
/// \param[in,out] tree Pointer to the execution tree to be transformed
/// \param[in,out] modified Indicator if the tree was changed
Status Run(ExecutionTree *tree, bool *const modified) final;
/// \brief Derived classes may implement the PreRunOnNode function to implement any initial visit work on the way down

View File

@ -16,207 +16,130 @@
#include <vector>
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/source/album_op.h"
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
#endif
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h"
#include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h"
#include "minddata/dataset/engine/ir/datasetops/cache_node.h"
#ifdef ENABLE_PYTHON
#include "minddata/dataset/engine/datasetops/source/generator_op.h"
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
#endif
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
namespace mindspore {
namespace dataset {
// Constructor
CacheTransformPass::CachePass::CachePass() : is_caching_(false), leaf_op_(nullptr) {}
CacheTransformPass::CachePass::CachePass() : is_caching_(false), leaf_node_(nullptr), sampler_(nullptr) {}
// Identifies the subtree below this node as a cached descendant tree.
Status CacheTransformPass::CachePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) {
// Note that this function will only get called on non-leaf nodes.
// For leaf nodes, the other Visit with NonMappableSourceNode or MappableSourceNode argument will be called instead.
Status CacheTransformPass::CachePass::Visit(std::shared_ptr<DatasetNode> node, bool *const modified) {
*modified = false;
MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
if (is_caching_) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Nested cache operations is not supported!");
if (node->IsCached()) {
MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
is_caching_ = true;
}
is_caching_ = true;
return Status::OK();
}
// Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache
// Resets the tracking of the cache within the tree and assigns the nodes that will be involved in a cache
// transformation
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) {
Status CacheTransformPass::CachePass::VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) {
*modified = false;
is_caching_ = false; // We a no longer in a cache subtree. clear the flag.
if (leaf_op_) {
MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache.";
// Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op,
// using base class pointers.
AddMappableCacheOperators(std::move(leaf_op_), node);
} else {
// If there was no leaf_op set, then this is a non-mappable scenario.
if (sampler_) {
// Grab the sampler that was saved from the leaf and plug it into the cache op
node->SetSampler(std::move(sampler_));
MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf.";
if (node->IsCached()) {
is_caching_ = false; // We a no longer in a cache subtree. clear the flag.
if (leaf_node_) {
MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache.";
// Assign the leaf node into the transform pass, using move to null our copy of it,
// and also assign the cached node, using base class pointers.
// In the cases where cache is directly injected after the leaf node, these two nodes might be the same.
cache_pairs_.push_back(std::make_pair(std::move(leaf_node_), node));
} else {
// We're a cache op but no sampler was saved from leaf, so create a default sampler
const int64_t num_samples = 0;
const int64_t start_index = 0;
sampler_ = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
node->SetSampler(std::move(sampler_));
MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op.";
// If there was no leaf_node_ set, then this is a non-mappable scenario.
// We only assign the cached node in this case.
cached_nodes_.push_back(node);
}
// Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache
uint32_t cache_crc = DatasetOp::GenerateCRC(node);
RETURN_IF_NOT_OK(node->CreateCache(cache_crc));
}
return Status::OK();
}
// Common code for mappable leaf setup.
Status CacheTransformPass::CachePass::MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) {
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
if (is_caching_ && leaf_op_) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
"There is currently no support for multiple leaf nodes under cache.");
}
// If we are a leaf in the caching path, then save this leaf.
if (is_caching_) {
MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected";
leaf_op_ = std::move(leaf_op);
}
return Status::OK();
}
// Common code for non mappable leaf setup.
Status CacheTransformPass::CachePass::NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) {
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
if (is_caching_ && leaf_op_) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
"There is currently no support for multiple leaf nodes under cache.");
}
// Sampler for non mappable dataset only works if there is a downstream cache. Remove it from the leaf
// as save it for use by cache op in ascendant tree.
if (is_caching_) {
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_));
MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected";
} else {
// If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can
// remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based)
std::shared_ptr<SamplerRT> sampler_from_leaf;
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf));
}
return Status::OK();
}
#ifndef ENABLE_ANDROID
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *const modified) {
if (is_caching_) {
// If we are a TF Reader in a caching tree, then change our config so that it becomes a basic
// TF reader that parses all files. Selection of data will come from the sampler on the cache instead.
node->MakeSimpleProducer();
Status CacheTransformPass::CachePass::Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) {
if (node->IsCached()) {
MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
is_caching_ = true;
}
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ClueOp> node, bool *const modified) {
// Cache might also be injected to the non-leaf node upper in the tree, so is_caching_ might also be set to true
// by the other Visit() with DatasetNode argument
if (is_caching_) {
// If we are a ClueOp in a caching tree, then change our config so that it becomes a basic
// ClueOp that parses all files. Selection of data will come from the sampler on the cache instead.
node->MakeSimpleProducer();
MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected";
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
if (leaf_node_) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
"There is currently no support for multiple leaf nodes under cache.");
}
// Set up a sampler here to be used by cache if we are a non-mappable leaf in a caching tree.
// Node that sampler for non mappable dataset only works if there is a downstream cache.
RETURN_IF_NOT_OK(node->SetupSamplerForCache(&sampler_));
// If we are a non-mappable source node in a caching tree, then change our config so that it becomes a basic
// source node that parses all files. Selection of data will come from the sampler on the cache instead.
RETURN_IF_NOT_OK(node->MakeSimpleProducer());
}
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CsvOp> node, bool *const modified) {
if (is_caching_) {
// If we are a CsvOp in a caching tree, then change our config so that it becomes a basic
// CsvOp that parses all files. Selection of data will come from the sampler on the cache instead.
node->MakeSimpleProducer();
}
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TextFileOp> node, bool *const modified) {
if (is_caching_) {
// If we are a TextFileOp in a caching tree, then change our config so that it becomes a basic
// TextFileOp that parses all files. Selection of data will come from the sampler on the cache instead.
node->MakeSimpleProducer();
}
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
return Status::OK();
}
#endif
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *const modified) {
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
Status CacheTransformPass::CachePass::Visit(std::shared_ptr<RandomNode> node, bool *const modified) {
if (node->IsCached()) {
MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
is_caching_ = true;
}
// Cache might also be injected to the non-leaf node upper in the tree, so is_caching_ might also be set to true
// by the other Visit() with DatasetNode argument
if (is_caching_) {
MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected";
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
if (leaf_node_) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
"There is currently no support for multiple leaf nodes under cache.");
}
// Set up a sampler here to be used by cache if we are a non-mappable leaf in a caching tree.
// Node that sampler for non mappable dataset only works if there is a downstream cache.
RETURN_IF_NOT_OK(node->SetupSamplerForCache(&sampler_));
}
return Status::OK();
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *const modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *const modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *const modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *const modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *const modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *const modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
Status CacheTransformPass::CachePass::Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified) {
if (node->IsCached()) {
MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
is_caching_ = true;
}
// Cache might also be injected to the non-leaf node upper in the tree, so is_caching_ might also be set to true
// by the other Visit() with DatasetNode argument
if (is_caching_) {
MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected";
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
if (leaf_node_) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
"There is currently no support for multiple leaf nodes under cache.");
}
// If we are a leaf in the caching path, then save this leaf
leaf_node_ = node;
}
return Status::OK();
}
#ifndef ENABLE_ANDROID
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *const modified) {
if (is_caching_) {
Status CacheTransformPass::CachePass::Visit(std::shared_ptr<MindDataNode> node, bool *const modified) {
if (node->IsCached() || is_caching_) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
"There is currently no support for MindRecordOp under cache.");
}
@ -226,102 +149,85 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MindRecordOp> no
#ifdef ENABLE_PYTHON
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) {
if (is_caching_) {
Status CacheTransformPass::CachePass::Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) {
if (node->IsCached() || is_caching_) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
"There is currently no support for GeneratorOp under cache.");
}
return Status::OK();
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *const modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *const modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
#endif
// Assigns the leaf and cache operators that are involved in a cache transformation
void CacheTransformPass::CachePass::AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op,
std::shared_ptr<CacheOp> cache_op) {
cache_pairs_.push_back(std::make_pair(leaf_op, cache_op));
}
// constructor
CacheTransformPass::CacheTransformPass() {}
// Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations
Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *const modified) {
Status CacheTransformPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) {
MS_LOG(INFO) << "Pre pass: Cache transform pass started.";
// Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will
// use to execute a transform.
CachePass cache_pass = CachePass();
RETURN_IF_NOT_OK(cache_pass.Run(tree, modified));
RETURN_IF_NOT_OK(cache_pass.Run(root_ir, modified));
// Then, execute the transform for each pair
// Execute the transform for non-mappable cache
for (auto cached_node : cache_pass.cached_nodes()) {
MS_LOG(DEBUG) << "Cache transform pass: Injecting a non-mappable cache node.";
RETURN_IF_NOT_OK(InjectNonMappableCacheNode(cached_node, cache_pass.sampler()));
}
// Execute the transform for mappable cache
for (auto cache_pair : cache_pass.cache_pairs()) {
MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform.";
RETURN_IF_NOT_OK(
ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client()));
MS_LOG(DEBUG) << "Cache transform pass: Injecting a mappable cache node.";
RETURN_IF_NOT_OK(InjectMappableCacheNode(cache_pair.first, cache_pair.second));
}
MS_LOG(INFO) << "Pre pass: Cache transform pass complete.";
return Status::OK();
}
// Helper function to execute the cache transformation.
Status CacheTransformPass::ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op,
std::shared_ptr<DatasetOp> cache_op,
std::shared_ptr<CacheClient> cache_client) {
// Get local pointers the child/parent of the cache op. It's possible that the parent is null if the cache was
// the root node. It is also possible that cache_child == leaf_op
std::shared_ptr<DatasetOp> cache_child = cache_op->child(0);
DatasetOp *cache_parent = nullptr;
cache_op->Parent(&cache_parent, 0); // fetch the cache op's parent
// Helper function to execute mappable cache transformation.
// Input:
// Sampler
// |
// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache)
//
// Transformed:
// Sampler --> CacheLookupNode ------------------------->
// | |
// | CacheMergeNode
// | |
// LeafNode --> OtherNodes --> CachedNode
Status CacheTransformPass::InjectMappableCacheNode(std::shared_ptr<MappableSourceNode> leaf_node,
std::shared_ptr<DatasetNode> cached_node) {
// Create a cache merge node with defaults
auto cache_merge_node = std::make_shared<CacheMergeNode>(nullptr, cached_node->GetDatasetCache());
// Insert the cache merge node to become the cached_node's parent
RETURN_IF_NOT_OK(cached_node->InsertAbove(cache_merge_node));
// Extract the sampler from the leaf. We will overwrite this sampler with the lookup op later.
std::shared_ptr<SamplerRT> leaf_sampler = leaf_op->sampler();
// Construct the merge op with defaults
std::shared_ptr<CacheMergeOp> merge_op;
CacheMergeOp::Builder merge_builder;
RETURN_IF_NOT_OK(merge_builder.SetClient(cache_client).Build(&merge_op));
RETURN_IF_NOT_OK(tree->AssociateNode(merge_op));
// Construct the cache lookup op with defaults
std::shared_ptr<CacheLookupOp> cache_lookup_op;
CacheLookupOp::Builder lookup_builder;
RETURN_IF_NOT_OK(lookup_builder.SetClient(cache_client).SetSampler(std::move(leaf_sampler)).Build(&cache_lookup_op));
RETURN_IF_NOT_OK(tree->AssociateNode(cache_lookup_op));
// Overwrite the old sampler in this leaf op to become the lookup op
leaf_op->SetSampler(cache_lookup_op);
// If the cache had a parent, then go into that parent to remove the cache from it's child list and then
// replace it with the merge op.
if (cache_parent != nullptr) {
RETURN_IF_NOT_OK(cache_parent->RemoveChild(cache_op));
RETURN_IF_NOT_OK(cache_parent->AddChild(merge_op));
} else {
// If we didn't have a parent, then the merge op is the root node
RETURN_IF_NOT_OK(tree->AssignRoot(merge_op));
}
// Set the cache op to no longer be a parent over it's child. This will fully disconnect the old cache op.
// We maintain a local pointer to the old child though.
RETURN_IF_NOT_OK(cache_op->RemoveChild(cache_child));
// Connect the merge op
RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_lookup_op)));
RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_child)));
// At this point, the cache op has already had it's children and parents taken away. Calling remove
// on it at this point will not do any node hookups, and instead set internal fields to invalid.
RETURN_IF_NOT_OK(cache_op->Remove());
std::shared_ptr<SamplerObj> leaf_sampler = leaf_node->Sampler();
// Create a cache lookup node with leaf_node's sampler
auto cache_lookup_node = std::make_shared<CacheLookupNode>(nullptr, leaf_sampler, cached_node->GetDatasetCache());
// Insert the cache lookup node as the first child of cache merge node
RETURN_IF_NOT_OK(cache_merge_node->InsertChildAt(0, cache_lookup_node));
// Overwrite the old sampler in this leaf node to become the cache lookup node
leaf_node->SetSampler(std::static_pointer_cast<SamplerObj>(cache_lookup_node));
return Status::OK();
}
// Helper function to execute non-mappable cache transformation.
// Input:
// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache)
//
// Transformed:
// Sampler
// |
// LeafNode --> OtherNodes --> CachedNode --> CacheNode
Status CacheTransformPass::InjectNonMappableCacheNode(std::shared_ptr<DatasetNode> cached_node,
std::shared_ptr<SamplerObj> sampler) {
// Create a cache node using the sampler we saved from the leaf
auto cache_node = std::make_shared<CacheNode>(nullptr, sampler, cached_node->GetDatasetCache());
// Insert the cache node to become the cached_node's parent
RETURN_IF_NOT_OK(cached_node->InsertAbove(cache_node));
return Status::OK();
}
} // namespace dataset

View File

@ -20,6 +20,8 @@
#include <memory>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
@ -32,11 +34,11 @@ class CacheClient;
/// \class CacheTransformPass cache_transform_pass.h
/// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching
/// operations
class CacheTransformPass : public TreePass {
class CacheTransformPass : public IRTreePass {
/// \class CachePass
/// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache
/// transformation. It works in conjunction with the CacheTransformPass
class CachePass : public NodePass {
class CachePass : public IRNodePass {
public:
/// \brief Constructor
/// \param[in] transform_pass Raw pointer back to controlling tree pass
@ -47,138 +49,72 @@ class CacheTransformPass : public TreePass {
/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \param[in,out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override;
Status Visit(std::shared_ptr<DatasetNode> node, bool *const modified) override;
/// \brief Resets the tracking of the cache within the tree and assigns the operators that
/// will be involved in a cache transformation
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \param[in,out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override;
Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) override;
#ifndef ENABLE_ANDROID
/// \brief Perform leaf node cache transform identifications
/// \brief Perform non-mappable leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \param[in,out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *const modified) override;
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<ClueOp> node, bool *const modified) override;
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CsvOp> node, bool *const modified) override;
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *const modified) override;
Status Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) override;
#endif
/// \brief Perform leaf node cache transform identifications
/// \brief Perform non-mappable leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \param[in,out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *const modified) override;
Status Visit(std::shared_ptr<RandomNode> node, bool *const modified) override;
/// \brief Perform leaf node cache transform identifications
/// \brief Perform mappable leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \param[in,out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *const modified) override;
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *const modified) override;
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<MnistOp> node, bool *const modified) override;
Status Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified) override;
#ifdef ENABLE_PYTHON
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \param[in,out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) override;
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *const modified) override;
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<VOCOp> node, bool *const modified) override;
Status Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) override;
#endif
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CifarOp> node, bool *const modified) override;
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CocoOp> node, bool *const modified) override;
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *const modified) override;
#ifndef ENABLE_ANDROID
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \param[in,out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *const modified) override;
Status Visit(std::shared_ptr<MindDataNode> node, bool *const modified) override;
#endif
/// \brief Getter
std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs() { return cache_pairs_; }
std::vector<std::pair<std::shared_ptr<MappableSourceNode>, std::shared_ptr<DatasetNode>>> cache_pairs() {
return cache_pairs_;
}
/// \brief Getter
std::vector<std::shared_ptr<DatasetNode>> cached_nodes() { return cached_nodes_; }
/// \brief Getter
std::shared_ptr<SamplerObj> sampler() { return sampler_; }
private:
/// \brief Common code for mappable leaf setup.
/// \param[in] node The leaf node performing setup work.
/// \return Status The status code returned
Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);
/// \brief Common code for non-mappable leaf setup.
/// \param[in] node The leaf node performing setup work.
/// \return Status The status code returned
Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);
/// \brief Assigns the leaf and cache operators that are involved in a cache transformation
/// \param[in] leaf_op The leaf operator involved in the cache transform
/// \param[in] cache_op The cache operator involved in the cache transform
void AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, std::shared_ptr<CacheOp> cache_op);
bool is_caching_;
std::shared_ptr<DatasetOp> leaf_op_;
std::shared_ptr<SamplerRT> sampler_;
// The two operators that work together to establish the cache transform
std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs_;
std::shared_ptr<MappableSourceNode> leaf_node_;
std::shared_ptr<SamplerObj> sampler_;
// The two nodes that work together to establish the cache transform
std::vector<std::shared_ptr<DatasetNode>> cached_nodes_;
std::vector<std::pair<std::shared_ptr<MappableSourceNode>, std::shared_ptr<DatasetNode>>> cache_pairs_;
};
public:
@ -189,32 +125,46 @@ class CacheTransformPass : public TreePass {
~CacheTransformPass() = default;
/// \brief Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \param[in,out] tree The tree to operate on.
/// \param[in,out] Indicate of the tree was modified.
/// \return Status The status code returned
Status RunOnTree(ExecutionTree *tree, bool *const modified) override;
Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) override;
private:
/// \brief Helper function to execute the cache transformation.
/// \brief Helper function to execute mappable cache transformation.
///
/// Input:
/// Sampler
/// |
/// LeafOp --> OtherOps --> CacheOp
/// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache)
///
/// Transformed:
/// Sampler --> CacheLookupOp ---------------->
/// | |
/// | MergeOp
/// | |
/// LeafOp --> OtherOps -->
/// Sampler --> CacheLookupNode ------------------------->
/// | |
/// | CacheMergeNode
/// | |
/// LeafNode --> OtherNodes --> CachedNode
///
/// \param[in] leaf_op The leaf node in the transform
/// \param[in] cache_op The cache op in the transform (will get removed)
/// \param[in] cache_client The cache client
/// \param[in] leaf_node The leaf node in the transform
/// \param[in] cached_node The node with cache attribute which is involved in the cache transform
/// \return Status The status code returned
Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op,
std::shared_ptr<DatasetOp> cache_op, std::shared_ptr<CacheClient> cache_client);
Status InjectMappableCacheNode(std::shared_ptr<MappableSourceNode> leaf_node,
std::shared_ptr<DatasetNode> cached_node);
/// \brief Helper function to execute non-mappable cache transformation.
///
/// Input:
/// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache)
///
/// Transformed:
/// Sampler
/// |
/// LeafNode --> OtherNodes --> CachedNode --> CacheNode
///
/// \param[in] cached_node The node with cache attribute which is involved in the cache transform
/// \param[in] sampler The sampler saved for non-mappable leaf nodes during the CachePass
/// \return Status The status code returned
Status InjectNonMappableCacheNode(std::shared_ptr<DatasetNode> cached_node, std::shared_ptr<SamplerObj> sampler);
};
} // namespace dataset
} // namespace mindspore

View File

@ -24,6 +24,7 @@
#ifdef ENABLE_PYTHON
#include "minddata/dataset/engine/opt/post/generator_node_pass.h"
#endif
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
#include "minddata/dataset/engine/opt/pre/cache_validation_pass.h"
#include "minddata/dataset/engine/opt/pre/deep_copy_pass.h"
#include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h"
@ -53,6 +54,7 @@ Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) {
actions.emplace_back(std::make_unique<NodeRemovalPass>());
actions.emplace_back(std::make_unique<EpochCtrlPass>());
if (usage_ == kDeGetter) actions.emplace_back(std::make_unique<GetterPass>());
actions.emplace_back(std::make_unique<CacheTransformPass>());
// Vector of flags for each action
std::vector<bool> modified(actions.size(), false);
// Apply pre-pass actions

View File

@ -35,7 +35,7 @@ namespace dataset {
// Internal Sampler class forward declaration
class SamplerRT;
class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
class SamplerObj {
public:
/// \brief Constructor
SamplerObj();
@ -122,7 +122,7 @@ std::shared_ptr<RandomSamplerObj> RandomSampler(bool replacement = false, int64_
/// Function to create a Sequential Sampler.
/// \notes Samples the dataset elements sequentially, same as not having a sampler.
/// \param[in] start_index - Index to start sampling at (dafault to start at first id).
/// \param[in] start_index - Index to start sampling at (default to start at first id).
/// \param[in] num_samples - The number of samples to draw (default to all elements).
/// \return Shared pointer to the current Sampler.
std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index = 0, int64_t num_samples = 0);

View File

@ -465,24 +465,21 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
rc = ccbuilder.Build(&myClient);
ASSERT_TRUE(rc.IsOk());
// In a mappable dataset, it uses a complex interactions of cache lookup op and cache merge op.
// Rather than manually build this, the way to do it is to choose the position of the cache in the tree by
// adding a CacheOp. Then, the tree prepare code will drive a transform that will remove the CacheOp and
// replace it with the required tree structures for cache lookup op and cache merge op.
std::shared_ptr<CacheOp> myCacheOp;
rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp);
std::shared_ptr<CacheLookupOp> myLookupOp;
rc = CacheLookupOp::Builder().SetNumWorkers(4).SetClient(myClient).SetSampler(seq_sampler).Build(&myLookupOp);
std::shared_ptr<CacheMergeOp> myMergeOp;
rc = CacheMergeOp::Builder().SetNumWorkers(4).SetClient(myClient).Build(&myMergeOp);
std::shared_ptr<ImageFolderOp> so;
ImageFolderOp::Builder builder;
builder.SetSampler(std::move(seq_sampler))
.SetOpConnectorSize(3)
builder.SetOpConnectorSize(3)
.SetNumWorkers(3)
.SetRowsPerBuffer(2)
.SetExtensions({".jpg", ".JPEG"})
.SetRecursive(true)
.SetImageFolderDir(datasets_root_path_ + "/testPK/data");
rc = builder.Build(&so);
so->SetSampler(myLookupOp);
ASSERT_TRUE(rc.IsOk());
// RepeatOp
@ -495,7 +492,9 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
rc = myTree->AssociateNode(so);
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myCacheOp);
rc = myTree->AssociateNode(myLookupOp);
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myMergeOp);
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRepeatOp);
@ -503,9 +502,11 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
rc = myTree->AssignRoot(myRepeatOp);
ASSERT_TRUE(rc.IsOk());
rc = myRepeatOp->AddChild(myCacheOp);
rc = myRepeatOp->AddChild(myMergeOp);
ASSERT_TRUE(rc.IsOk());
rc = myCacheOp->AddChild(so);
rc = myMergeOp->AddChild(myLookupOp);
ASSERT_TRUE(rc.IsOk());
rc = myMergeOp->AddChild(so);
ASSERT_TRUE(rc.IsOk());
rc = myTree->Prepare(1);
@ -532,119 +533,3 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
rc = myClient->DestroyCache();
ASSERT_TRUE(rc.IsOk());
}
//// Simple test with a repeated cache op over random data producer.
//// The difference in this one is that you do not add the sampler to the cache op directly.
//// Instead, the sampler is added as part of the leaf op construction. Then, the prepare
//// phase will pull this up from the leaf and into the cache.
//// It removes the sampler from the leaf op, which doesn't make sense there anyway for
//// the RandomDataOp which doesn't support sampling without a cache.
////
//// RepeatOp
//// |
//// CacheOp
//// |
//// RandomDataOp
////
TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) {
// Clear the rc of the master thread if any
(void)TaskManager::GetMasterThreadRc();
Status rc;
int32_t rank = 0; // not used
MS_LOG(INFO) << "UT test TestCacheInheritSampler";
session_id_type env_session;
rc = GetSessionFromEnv(&env_session);
ASSERT_TRUE(rc.IsOk());
int64_t num_samples = 0;
int64_t start_index = 0;
auto seq_sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
// Start with an empty execution tree
auto myTree = std::make_shared<ExecutionTree>();
// Create a schema using the C api's
std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
// 2 columns. First column is an "image" 640,480,3
TensorShape c1Shape({640, 480, 3});
ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
rank, // not used
&c1Shape);
// Column 2 will just be a scalar label number
TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
testSchema->AddColumn(c1);
testSchema->AddColumn(c2);
// RandomDataOp
std::shared_ptr<RandomDataOp> myRandomDataOp;
rc = RandomDataOp::Builder()
.SetRowsPerBuffer(2)
.SetNumWorkers(4)
.SetDataSchema(std::move(testSchema))
.SetTotalRows(10)
.SetSampler(std::move(seq_sampler))
.Build(&myRandomDataOp);
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRandomDataOp);
ASSERT_TRUE(rc.IsOk());
// CacheOp
CacheClient::Builder ccbuilder;
// use arbitrary session of 1, size of 0, spilling// is true
ccbuilder.SetSessionId(env_session).SetCacheMemSz(4).SetSpill(true);
std::shared_ptr<CacheClient> myClient;
rc = ccbuilder.Build(&myClient);
ASSERT_TRUE(rc.IsOk());
std::shared_ptr<CacheOp> myCacheOp;
rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp);
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myCacheOp);
ASSERT_TRUE(rc.IsOk());
// RepeatOp
uint32_t numRepeats = 4;
std::shared_ptr<RepeatOp> myRepeatOp;
rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRepeatOp);
ASSERT_TRUE(rc.IsOk());
// Assign tree relations and root
rc = myRepeatOp->AddChild(myCacheOp);
ASSERT_TRUE(rc.IsOk());
rc = myCacheOp->AddChild(myRandomDataOp);
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssignRoot(myRepeatOp);
ASSERT_TRUE(rc.IsOk());
MS_LOG(INFO) << "Launching tree and begin iteration";
rc = myTree->Prepare(1);
ASSERT_TRUE(rc.IsOk());
std::cout << *myClient << std::endl;
rc = myTree->Launch();
ASSERT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator dI(myTree);
TensorRow tensorList;
rc = dI.FetchNextTensorRow(&tensorList);
ASSERT_TRUE(rc.IsOk());
int rowCount = 0;
while (!tensorList.empty()) {
// Don't display these rows, just count them
MS_LOG(INFO) << "Row fetched #: " << rowCount;
rc = dI.FetchNextTensorRow(&tensorList);
ASSERT_TRUE(rc.IsOk());
rowCount++;
}
ASSERT_EQ(rowCount, 40);
rc = myClient->DestroyCache();
ASSERT_TRUE(rc.IsOk());
}

View File

@ -315,6 +315,9 @@ HandleRcExit $? 0 0
PytestCmd "test_cache_nomap.py" "test_cache_nomap_long_file_list"
HandleRcExit $? 0 0
PytestCmd "test_cache_nomap.py" "test_cache_nomap_failure" 1
HandleRcExit $? 0 0
for i in $(seq 1 3)
do
test_name="test_cache_nomap_multiple_cache${i}"

View File

@ -216,7 +216,7 @@ def test_cache_map_failure1():
|
Cache
|
ImageFolder
Coco
"""
logger.info("Test cache failure 1")
@ -227,8 +227,9 @@ def test_cache_map_failure1():
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
# This DATA_DIR has 6 images in it
ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True,
cache=some_cache)
decode_op = c_vision.Decode()
ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
ds1 = ds1.repeat(4)
@ -302,7 +303,7 @@ def test_cache_map_failure3():
|
Batch
|
ImageFolder
Mnist
"""
logger.info("Test cache failure 3")
if "SESSION_ID" in os.environ:
@ -312,8 +313,7 @@ def test_cache_map_failure3():
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10)
ds1 = ds1.batch(2)
resize_op = c_vision.Resize((224, 224))
ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
@ -342,7 +342,7 @@ def test_cache_map_failure4():
|
Filter
|
ImageFolder
CelebA
"""
logger.info("Test cache failure 4")
@ -353,8 +353,8 @@ def test_cache_map_failure4():
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
# This dataset has 4 records
ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True)
ds1 = ds1.filter(predicate=lambda data: data < 11, input_columns=["label"])
decode_op = c_vision.Decode()
@ -382,7 +382,7 @@ def test_cache_map_failure5():
|
Map(decode, randomCrop)
|
ImageFolder
Manifest
"""
logger.info("Test cache failure 5")
@ -393,8 +393,8 @@ def test_cache_map_failure5():
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
# This dataset has 4 records
data = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True)
random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
decode_op = c_vision.Decode()
@ -505,7 +505,7 @@ def test_cache_map_failure8():
|
Repeat
|
ImageFolder
Cifar10
"""
logger.info("Test cache failure 8")
@ -516,8 +516,7 @@ def test_cache_map_failure8():
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10)
decode_op = c_vision.Decode()
ds1 = ds1.repeat(4)
ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
@ -545,7 +544,7 @@ def test_cache_map_failure9():
|
Take
|
ImageFolder
Cifar100
"""
logger.info("Test cache failure 9")
@ -556,8 +555,7 @@ def test_cache_map_failure9():
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10)
ds1 = ds1.take(2)
decode_op = c_vision.Decode()
@ -587,7 +585,7 @@ def test_cache_map_failure10():
|
Skip
|
ImageFolder
VOC
"""
logger.info("Test cache failure 10")
@ -598,8 +596,8 @@ def test_cache_map_failure10():
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
# This dataset has 9 records
ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
ds1 = ds1.skip(1)
decode_op = c_vision.Decode()

View File

@ -1913,6 +1913,217 @@ def test_cache_nomap_long_file_list():
logger.info("test_cache_nomap_long_file_list Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_failure1():
"""
Test nested cache (failure)
Repeat
|
Cache
|
Map(decode)
|
Cache
|
TFRecord
"""
logger.info("Test cache nomap failure 1")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
decode_op = c_vision.Decode()
ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
ds1 = ds1.repeat(4)
with pytest.raises(RuntimeError) as e:
ds1.get_batch_size()
assert "Nested cache operations" in str(e.value)
with pytest.raises(RuntimeError) as e:
num_iter = 0
for _ in ds1.create_dict_iterator(num_epochs=1):
num_iter += 1
assert "Nested cache operations" in str(e.value)
assert num_iter == 0
logger.info('test_cache_nomap_failure1 Ended.\n')
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_failure2():
"""
Test zip under cache (failure)
repeat
|
Cache
|
Map(decode)
|
Zip
| |
Random Random
"""
logger.info("Test cache nomap failure 2")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0)
schema = ds.Schema()
schema.add_column('image', de_type=mstype.uint8,
shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
schema.add_column('label', de_type=mstype.uint8, shape=[1])
ds1 = ds.RandomDataset(schema=schema)
ds2 = ds.RandomDataset(schema=schema)
dsz = ds.zip((ds1, ds2))
decode_op = c_vision.Decode()
dsz = dsz.map(input_columns=["image"], operations=decode_op, cache=some_cache)
dsz = dsz.repeat(4)
with pytest.raises(RuntimeError) as e:
num_iter = 0
for _ in dsz.create_dict_iterator():
num_iter += 1
assert "ZipNode is not supported as a descendant operator under a cache" in str(e.value)
assert num_iter == 0
logger.info('test_cache_nomap_failure2 Ended.\n')
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_failure3():
"""
Test batch under cache (failure)
repeat
|
Cache
|
Map(resize)
|
Batch
|
Clue
"""
logger.info("Test cache nomap failure 3")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0)
ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train')
ds1 = ds1.batch(2)
resize_op = c_vision.Resize((224, 224))
ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
ds1 = ds1.repeat(4)
with pytest.raises(RuntimeError) as e:
num_iter = 0
for _ in ds1.create_dict_iterator():
num_iter += 1
assert "BatchNode is not supported as a descendant operator under a cache" in str(e.value)
assert num_iter == 0
logger.info('test_cache_nomap_failure3 Ended.\n')
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_failure4():
"""
Test filter under cache (failure)
repeat
|
Cache
|
Map(decode)
|
Filter
|
CSV
"""
logger.info("Test cache nomap failure 4")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0)
ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"],
column_names=['col1', 'col2', 'col3', 'col4'])
ds1 = ds1.filter(predicate=lambda data: data < 11, input_columns=["label"])
decode_op = c_vision.Decode()
ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
ds1 = ds1.repeat(4)
with pytest.raises(RuntimeError) as e:
num_iter = 0
for _ in ds1.create_dict_iterator():
num_iter += 1
assert "FilterNode is not supported as a descendant operator under a cache" in str(e.value)
assert num_iter == 0
logger.info('test_cache_nomap_failure4 Ended.\n')
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_failure5():
"""
Test Map with non-deterministic TensorOps under cache (failure)
repeat
|
Cache
|
Map(decode, randomCrop)
|
TextFile
"""
logger.info("Test cache nomap failure 5")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0)
data = ds.TextFileDataset(TEXT_FILE_DATA_DIR)
random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
decode_op = c_vision.Decode()
data = data.map(input_columns=["image"], operations=decode_op)
data = data.map(input_columns=["image"], operations=random_crop_op, cache=some_cache)
data = data.repeat(4)
with pytest.raises(RuntimeError) as e:
num_iter = 0
for _ in data.create_dict_iterator():
num_iter += 1
assert "MapNode with non-deterministic operations is not supported as a descendant of cache" in str(e.value)
assert num_iter == 0
logger.info('test_cache_nomap_failure5 Ended.\n')
if __name__ == '__main__':
test_cache_nomap_basic1()
test_cache_nomap_basic2()