forked from mindspore-Ecosystem/mindspore
Migrate CacheTransformPass
This commit is contained in:
parent
3708624a25
commit
922e1f4f36
|
@ -237,8 +237,11 @@ Status CacheOp::Accept(NodePass *p, bool *const modified) {
|
|||
return p->RunOnNode(shared_from_base<CacheOp>(), modified);
|
||||
}
|
||||
|
||||
// A public wrapper for creating the cache through the client
|
||||
Status CacheOp::CreateCache(uint32_t cache_crc) {
|
||||
Status CacheOp::PrepareNodePostAction() {
|
||||
// Run any common code from super class first before adding our own
|
||||
RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction());
|
||||
// Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache
|
||||
uint32_t cache_crc = DatasetOp::GenerateCRC(shared_from_this());
|
||||
// This is a non-mappable cache op so the id's need to be generated.
|
||||
// Construct the cache
|
||||
const bool generate_ids = true;
|
||||
|
|
|
@ -141,11 +141,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
|
|||
bool AllowCacheMiss() override { return false; }
|
||||
/// \brief Base-class override for the name of this operator
|
||||
std::string Name() const override { return kCacheOp; }
|
||||
/// \brief A public wrapper for creating the cache through the client
|
||||
/// \param[in] cache_crc The crc that identifies the cache
|
||||
/// \see cache_pass.cc
|
||||
/// \return Status return code
|
||||
Status CreateCache(uint32_t cache_crc);
|
||||
Status PrepareNodePostAction() override;
|
||||
|
||||
private:
|
||||
WaitPost rows_cache_done_;
|
||||
|
|
|
@ -33,11 +33,7 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
ClueOp::Builder::Builder()
|
||||
: builder_device_id_(0),
|
||||
builder_num_devices_(1),
|
||||
builder_num_samples_(0),
|
||||
builder_shuffle_files_(false),
|
||||
builder_sampler_(nullptr) {
|
||||
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
|
||||
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
|
||||
builder_num_workers_ = config_manager->num_parallel_workers();
|
||||
builder_op_connector_size_ = config_manager->op_connector_size();
|
||||
|
@ -74,7 +70,7 @@ Status ClueOp::Builder::Build(std::shared_ptr<ClueOp> *op) {
|
|||
std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>(
|
||||
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map,
|
||||
builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_,
|
||||
builder_device_id_, std::move(builder_sampler_));
|
||||
builder_device_id_);
|
||||
RETURN_IF_NOT_OK(clue_op->Init());
|
||||
*op = std::move(clue_op);
|
||||
|
||||
|
@ -94,8 +90,8 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim
|
|||
|
||||
ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
|
||||
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
|
||||
bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr<SamplerRT> sampler)
|
||||
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
|
||||
bool shuffle_files, int32_t num_device, int32_t device_id)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
num_rows_per_shard_(0),
|
||||
all_num_rows_(0),
|
||||
|
@ -552,16 +548,6 @@ Status ClueOp::ComputeColMap() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing
|
||||
// a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so
|
||||
// that this clue op will produce the full set of data into the cache.
|
||||
void ClueOp::MakeSimpleProducer() {
|
||||
device_id_ = 0;
|
||||
num_devices_ = 1;
|
||||
shuffle_files_ = false;
|
||||
num_samples_ = 0;
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status ClueOp::Accept(NodePass *p, bool *const modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
|
|
|
@ -122,14 +122,6 @@ class ClueOp : public ParallelOp {
|
|||
// @return - the a string vector
|
||||
std::vector<std::string> split(const std::string &s, char delim);
|
||||
|
||||
// Setter method
|
||||
// @param std::shared_ptr<Sampler> sampler
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
|
||||
builder_sampler_ = std::move(sampler);
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
int32_t builder_device_id_;
|
||||
int32_t builder_num_devices_;
|
||||
|
@ -141,13 +133,12 @@ class ClueOp : public ParallelOp {
|
|||
std::vector<std::string> builder_clue_files_list_;
|
||||
bool builder_shuffle_files_;
|
||||
std::map<std::string, std::string> builder_cols_to_keyword_;
|
||||
std::shared_ptr<SamplerRT> builder_sampler_;
|
||||
};
|
||||
|
||||
// Constructor of ClueOp
|
||||
ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
|
||||
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
|
||||
bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<SamplerRT> sampler);
|
||||
bool shuffle_files, int32_t num_devices, int32_t device_id);
|
||||
|
||||
// Default destructor
|
||||
~ClueOp() = default;
|
||||
|
@ -182,11 +173,6 @@ class ClueOp : public ParallelOp {
|
|||
// @return Vector of the input file names
|
||||
std::vector<std::string> FileNames() { return clue_files_list_; }
|
||||
|
||||
/// \Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing
|
||||
/// a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so
|
||||
/// that this clue op will produce the full set of data into the cache.
|
||||
void MakeSimpleProducer();
|
||||
|
||||
// Op name getter
|
||||
// @return Name of the current Op
|
||||
std::string Name() const override { return "ClueOp"; }
|
||||
|
|
|
@ -29,11 +29,7 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
CsvOp::Builder::Builder()
|
||||
: builder_device_id_(0),
|
||||
builder_num_devices_(1),
|
||||
builder_num_samples_(0),
|
||||
builder_shuffle_files_(false),
|
||||
builder_sampler_(nullptr) {
|
||||
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
|
||||
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
|
||||
builder_num_workers_ = config_manager->num_parallel_workers();
|
||||
builder_op_connector_size_ = config_manager->op_connector_size();
|
||||
|
@ -65,8 +61,7 @@ Status CsvOp::Builder::Build(std::shared_ptr<CsvOp> *op) {
|
|||
std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>(
|
||||
builder_csv_files_list_, builder_field_delim_, builder_column_default_list_, builder_column_name_list_,
|
||||
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_,
|
||||
builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_,
|
||||
std::move(builder_sampler_));
|
||||
builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_);
|
||||
RETURN_IF_NOT_OK(csv_op->Init());
|
||||
*op = std::move(csv_op);
|
||||
|
||||
|
@ -77,8 +72,8 @@ CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
|
|||
const std::vector<std::shared_ptr<BaseRecord>> &column_default,
|
||||
const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer,
|
||||
int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files,
|
||||
int32_t num_device, int32_t device_id, std::shared_ptr<SamplerRT> sampler)
|
||||
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
|
||||
int32_t num_device, int32_t device_id)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
csv_files_list_(std::move(csv_files_list)),
|
||||
field_delim_(field_delim),
|
||||
column_default_list_(column_default),
|
||||
|
@ -920,16 +915,6 @@ Status CsvOp::ComputeColMap() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing
|
||||
// a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so
|
||||
// that this csv op will produce the full set of data into the cache.
|
||||
void CsvOp::MakeSimpleProducer() {
|
||||
device_id_ = 0;
|
||||
num_devices_ = 1;
|
||||
shuffle_files_ = false;
|
||||
num_samples_ = 0;
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status CsvOp::Accept(NodePass *p, bool *const modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
|
|
|
@ -241,14 +241,6 @@ class CsvOp : public ParallelOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param std::shared_ptr<Sampler> sampler
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
|
||||
builder_sampler_ = std::move(sampler);
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
int32_t builder_device_id_;
|
||||
int32_t builder_num_devices_;
|
||||
|
@ -262,7 +254,6 @@ class CsvOp : public ParallelOp {
|
|||
char builder_field_delim_;
|
||||
std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_;
|
||||
std::vector<std::string> builder_column_name_list_;
|
||||
std::shared_ptr<SamplerRT> builder_sampler_;
|
||||
};
|
||||
|
||||
// Constructor of CsvOp
|
||||
|
@ -271,8 +262,7 @@ class CsvOp : public ParallelOp {
|
|||
CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
|
||||
const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name,
|
||||
int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
|
||||
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id,
|
||||
std::shared_ptr<SamplerRT> sampler);
|
||||
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id);
|
||||
|
||||
// Default destructor
|
||||
~CsvOp() = default;
|
||||
|
@ -308,11 +298,6 @@ class CsvOp : public ParallelOp {
|
|||
// @return Vector of the input file names
|
||||
std::vector<std::string> FileNames() { return csv_files_list_; }
|
||||
|
||||
/// \Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing
|
||||
/// a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so
|
||||
/// that this csv op will produce the full set of data into the cache.
|
||||
void MakeSimpleProducer();
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
|
|
|
@ -34,8 +34,7 @@ RandomDataOp::Builder::Builder()
|
|||
builder_num_workers_(0),
|
||||
builder_op_connector_size_(0),
|
||||
builder_rows_per_buffer_(0),
|
||||
builder_total_rows_(0),
|
||||
builder_sampler_(nullptr) {
|
||||
builder_total_rows_(0) {
|
||||
// Some arguments to the RandomDataOp have a default argument that is taken from the config.
|
||||
// The user may override these defaults by using the builder set methods.
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
|
@ -48,9 +47,8 @@ RandomDataOp::Builder::Builder()
|
|||
Status RandomDataOp::Builder::Build(std::shared_ptr<RandomDataOp> *out_op) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
|
||||
*out_op =
|
||||
std::make_shared<RandomDataOp>(builder_num_workers_, builder_op_connector_size_, builder_rows_per_buffer_,
|
||||
builder_total_rows_, std::move(builder_data_schema_), std::move(builder_sampler_));
|
||||
*out_op = std::make_shared<RandomDataOp>(builder_num_workers_, builder_op_connector_size_, builder_rows_per_buffer_,
|
||||
builder_total_rows_, std::move(builder_data_schema_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -65,8 +63,8 @@ Status RandomDataOp::Builder::SanityCheck() const {
|
|||
|
||||
// Constructor for RandomDataOp
|
||||
RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
|
||||
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
|
||||
std::unique_ptr<DataSchema> data_schema)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
buffer_id_(0),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
total_rows_(total_rows),
|
||||
|
@ -80,8 +78,7 @@ RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64
|
|||
if (total_rows_ == 0) {
|
||||
total_rows_ = GenRandomInt(1, kMaxTotalRows);
|
||||
}
|
||||
// If the user did not provide a schema, then we will ask the op to generate a pseudo-random
|
||||
// schema.
|
||||
// If the user did not provide a schema, then we will ask the op to generate a pseudo-random schema.
|
||||
// See details of generateSchema function to learn what type of schema it will create.
|
||||
if (data_schema_ == nullptr) {
|
||||
GenerateSchema();
|
||||
|
|
|
@ -117,14 +117,6 @@ class RandomDataOp : public ParallelOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param std::shared_ptr<Sampler> sampler
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
|
||||
builder_sampler_ = std::move(sampler);
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
/**
|
||||
* Check if the required parameters are set by the builder.
|
||||
|
@ -133,7 +125,6 @@ class RandomDataOp : public ParallelOp {
|
|||
Status SanityCheck() const;
|
||||
|
||||
std::unique_ptr<DataSchema> builder_data_schema_;
|
||||
std::shared_ptr<SamplerRT> builder_sampler_;
|
||||
int32_t builder_num_workers_;
|
||||
int32_t builder_op_connector_size_;
|
||||
int64_t builder_rows_per_buffer_;
|
||||
|
@ -148,11 +139,10 @@ class RandomDataOp : public ParallelOp {
|
|||
* @param rows_per_buffer - The number of rows in each DataBuffer
|
||||
* @param data_schema - A user-provided schema
|
||||
* @param total_rows - The total number of rows in the dataset
|
||||
* @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes
|
||||
* @return Builder - The modified builder by reference
|
||||
*/
|
||||
RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
|
||||
std::unique_ptr<DataSchema> data_schema);
|
||||
|
||||
/**
|
||||
* Destructor
|
||||
|
|
|
@ -34,11 +34,7 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
TextFileOp::Builder::Builder()
|
||||
: builder_device_id_(0),
|
||||
builder_num_devices_(1),
|
||||
builder_total_rows_(0),
|
||||
builder_shuffle_files_(false),
|
||||
builder_sampler_(nullptr) {
|
||||
: builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_shuffle_files_(false) {
|
||||
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
|
||||
builder_num_workers_ = config_manager->num_parallel_workers();
|
||||
builder_op_connector_size_ = config_manager->op_connector_size();
|
||||
|
@ -74,7 +70,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
|
|||
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
|
||||
builder_num_workers_, builder_rows_per_buffer_, builder_total_rows_, builder_worker_connector_size_,
|
||||
std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_,
|
||||
builder_num_devices_, builder_device_id_, std::move(builder_sampler_));
|
||||
builder_num_devices_, builder_device_id_);
|
||||
RETURN_IF_NOT_OK(text_file_op->Init());
|
||||
*op = std::move(text_file_op);
|
||||
|
||||
|
@ -83,9 +79,8 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
|
|||
|
||||
TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size,
|
||||
std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list,
|
||||
int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id,
|
||||
std::shared_ptr<SamplerRT> sampler)
|
||||
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
|
||||
int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
device_id_(device_id),
|
||||
num_devices_(num_device),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
|
@ -504,16 +499,6 @@ Status TextFileOp::ComputeColMap() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing
|
||||
// a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so
|
||||
// that this text file op will produce the full set of data into the cache.
|
||||
void TextFileOp::MakeSimpleProducer() {
|
||||
device_id_ = 0;
|
||||
num_devices_ = 1;
|
||||
shuffle_files_ = false;
|
||||
total_rows_ = 0;
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status TextFileOp::Accept(NodePass *p, bool *const modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
|
|
|
@ -112,14 +112,6 @@ class TextFileOp : public ParallelOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param std::shared_ptr<Sampler> sampler
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
|
||||
builder_sampler_ = std::move(sampler);
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
int32_t builder_device_id_;
|
||||
int32_t builder_num_devices_;
|
||||
|
@ -131,7 +123,6 @@ class TextFileOp : public ParallelOp {
|
|||
std::vector<std::string> builder_text_files_list_;
|
||||
bool builder_shuffle_files_;
|
||||
std::unique_ptr<DataSchema> builder_schema_;
|
||||
std::shared_ptr<SamplerRT> builder_sampler_;
|
||||
};
|
||||
|
||||
// Constructor of TextFileOp
|
||||
|
@ -145,10 +136,9 @@ class TextFileOp : public ParallelOp {
|
|||
// @param columns_to_load - the names of the columns to load data from.
|
||||
// @param shuffle_files - whether or not to shuffle the files before reading data.
|
||||
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
|
||||
// @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes
|
||||
TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size,
|
||||
std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size,
|
||||
bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<SamplerRT> sampler);
|
||||
bool shuffle_files, int32_t num_devices, int32_t device_id);
|
||||
|
||||
// Default destructor
|
||||
~TextFileOp() = default;
|
||||
|
@ -187,11 +177,6 @@ class TextFileOp : public ParallelOp {
|
|||
// @return Vector of the input file names
|
||||
std::vector<std::string> FileNames() { return text_files_list_; }
|
||||
|
||||
/// \Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing
|
||||
/// a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so
|
||||
/// that this text file op will produce the full set of data into the cache.
|
||||
void MakeSimpleProducer();
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
|
|
|
@ -44,11 +44,7 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
TFReaderOp::Builder::Builder()
|
||||
: builder_device_id_(0),
|
||||
builder_num_devices_(1),
|
||||
builder_total_rows_(0),
|
||||
builder_equal_rows_per_shard_(false),
|
||||
builder_sampler_(nullptr) {
|
||||
: builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_equal_rows_per_shard_(false) {
|
||||
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
|
||||
builder_num_workers_ = config_manager->num_parallel_workers();
|
||||
builder_worker_connector_size_ = config_manager->worker_connector_size();
|
||||
|
@ -122,8 +118,7 @@ Status TFReaderOp::Builder::Build(std::shared_ptr<TFReaderOp> *out_tf_reader_op)
|
|||
std::shared_ptr<TFReaderOp> new_tf_reader_op = std::make_shared<TFReaderOp>(
|
||||
builder_num_workers_, builder_worker_connector_size_, builder_rows_per_buffer_, builder_total_rows_,
|
||||
builder_dataset_files_list_, std::move(builder_data_schema_), builder_op_connector_size_, builder_columns_to_load_,
|
||||
builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_,
|
||||
std::move(builder_sampler_));
|
||||
builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_);
|
||||
|
||||
RETURN_IF_NOT_OK(new_tf_reader_op->Init());
|
||||
*out_tf_reader_op = std::move(new_tf_reader_op);
|
||||
|
@ -134,8 +129,8 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64
|
|||
int64_t total_num_rows, std::vector<std::string> dataset_files_list,
|
||||
std::unique_ptr<DataSchema> data_schema, int32_t op_connector_size,
|
||||
std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_device,
|
||||
int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<SamplerRT> sampler)
|
||||
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
|
||||
int32_t device_id, bool equal_rows_per_shard)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
device_id_(device_id),
|
||||
num_devices_(num_device),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
|
@ -1043,17 +1038,6 @@ Status TFReaderOp::ComputeColMap() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing
|
||||
// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so
|
||||
// that this tf reader will produce the full set of data into the cache.
|
||||
void TFReaderOp::MakeSimpleProducer() {
|
||||
device_id_ = 0;
|
||||
num_devices_ = 1;
|
||||
total_rows_ = 0;
|
||||
shuffle_files_ = false;
|
||||
equal_rows_per_shard_ = false;
|
||||
}
|
||||
|
||||
// During tree prepare phase, operators may have specific post-operations to perform depending on
|
||||
// their role.
|
||||
Status TFReaderOp::PrepareNodePostAction() {
|
||||
|
|
|
@ -153,17 +153,8 @@ class TFReaderOp : public ParallelOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param std::shared_ptr<Sampler> sampler
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
|
||||
builder_sampler_ = std::move(sampler);
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<DataSchema> builder_data_schema_;
|
||||
std::shared_ptr<SamplerRT> builder_sampler_;
|
||||
int32_t builder_device_id_;
|
||||
int32_t builder_num_devices_;
|
||||
int32_t builder_num_workers_;
|
||||
|
@ -189,11 +180,10 @@ class TFReaderOp : public ParallelOp {
|
|||
// @param columns_to_load - the names of the columns to load data from.
|
||||
// @param shuffle_files - whether or not to shuffle the files before reading data.
|
||||
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
|
||||
// @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes
|
||||
TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows,
|
||||
std::vector<std::string> dataset_files_list, std::unique_ptr<DataSchema> data_schema,
|
||||
int32_t op_connector_size, std::vector<std::string> columns_to_load, bool shuffle_files,
|
||||
int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<SamplerRT> sampler);
|
||||
int32_t num_devices, int32_t device_id, bool equal_rows_per_shard);
|
||||
|
||||
// Default destructor
|
||||
~TFReaderOp() = default;
|
||||
|
@ -246,11 +236,6 @@ class TFReaderOp : public ParallelOp {
|
|||
// @return Vector of the input file names
|
||||
std::vector<std::string> FileNames() { return dataset_files_list_; }
|
||||
|
||||
/// \Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing
|
||||
/// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so
|
||||
/// that this tf reader will produce the full set of data into the cache.
|
||||
void MakeSimpleProducer();
|
||||
|
||||
// During tree prepare phase, operators may have specific post-operations to perform depending on
|
||||
// their role.
|
||||
// @notes Derived versions of this function should always call it's superclass version first
|
||||
|
@ -387,7 +372,7 @@ class TFReaderOp : public ParallelOp {
|
|||
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
|
||||
const int64_t &pre_count);
|
||||
|
||||
// Caculate number of rows in each shard.
|
||||
// Calculate number of rows in each shard.
|
||||
// @return Status - the error code returned.
|
||||
Status CalculateNumRowsPerShard();
|
||||
|
||||
|
|
|
@ -320,7 +320,6 @@ Status ExecutionTree::PostAction() {
|
|||
// The IR version cannot detect an invalid case of a cache on Map with random tensor operation from Python API.
|
||||
// This is because Python API binding to TensorOperation is still in progress.
|
||||
post_actions.push_back(std::make_unique<CacheErrorPass>());
|
||||
post_actions.push_back(std::make_unique<CacheTransformPass>());
|
||||
post_actions.push_back(std::make_unique<RepeatPass>());
|
||||
#endif
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <memory>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/dataset_op.h"
|
||||
#include "minddata/dataset/include/samplers.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore::dataset {
|
||||
|
@ -29,6 +30,9 @@ class DatasetCache {
|
|||
virtual Status ValidateParams() = 0;
|
||||
virtual Status CreateCacheOp(int num_workers, std::shared_ptr<DatasetOp> *ds_op) = 0;
|
||||
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
|
||||
virtual Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
|
||||
std::shared_ptr<SamplerObj> sampler) = 0;
|
||||
virtual Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) = 0;
|
||||
};
|
||||
} // namespace mindspore::dataset
|
||||
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
#include <memory>
|
||||
|
||||
#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -44,5 +46,28 @@ Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr<Data
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DatasetCacheImpl::CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
|
||||
std::shared_ptr<SamplerObj> sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
|
||||
std::shared_ptr<CacheLookupOp> lookup_op = nullptr;
|
||||
RETURN_IF_NOT_OK(CacheLookupOp::Builder()
|
||||
.SetNumWorkers(num_workers)
|
||||
.SetClient(cache_client_)
|
||||
.SetSampler(sampler->SamplerBuild())
|
||||
.Build(&lookup_op));
|
||||
*ds = lookup_op;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DatasetCacheImpl::CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
|
||||
std::shared_ptr<CacheMergeOp> merge_op = nullptr;
|
||||
RETURN_IF_NOT_OK(CacheMergeOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&merge_op));
|
||||
*ds = merge_op;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -56,6 +56,11 @@ class DatasetCacheImpl : public DatasetCache {
|
|||
|
||||
Status CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override;
|
||||
|
||||
Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
|
||||
std::shared_ptr<SamplerObj> sampler) override;
|
||||
|
||||
Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override;
|
||||
|
||||
Status ValidateParams() override { return Status::OK(); }
|
||||
|
||||
~DatasetCacheImpl() = default;
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
#include <memory>
|
||||
|
||||
#include "minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -46,5 +48,29 @@ Status PreBuiltDatasetCache::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PreBuiltDatasetCache::CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
|
||||
std::shared_ptr<SamplerObj> sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
|
||||
std::shared_ptr<CacheLookupOp> lookup_op = nullptr;
|
||||
RETURN_IF_NOT_OK(CacheLookupOp::Builder()
|
||||
.SetNumWorkers(num_workers)
|
||||
.SetClient(cache_client_)
|
||||
.SetSampler(sampler->SamplerBuild())
|
||||
.Build(&lookup_op));
|
||||
*ds = lookup_op;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PreBuiltDatasetCache::CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
|
||||
std::shared_ptr<CacheMergeOp> merge_op = nullptr;
|
||||
RETURN_IF_NOT_OK(CacheMergeOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&merge_op));
|
||||
*ds = merge_op;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,6 +40,11 @@ class PreBuiltDatasetCache : public DatasetCache {
|
|||
|
||||
Status CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *const ds) override;
|
||||
|
||||
Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
|
||||
std::shared_ptr<SamplerObj> sampler) override;
|
||||
|
||||
Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override;
|
||||
|
||||
Status ValidateParams() override { return Status::OK(); }
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
|
|
@ -8,6 +8,9 @@ set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
|
|||
bucket_batch_by_length_node.cc
|
||||
build_sentence_piece_vocab_node.cc
|
||||
build_vocab_node.cc
|
||||
cache_lookup_node.cc
|
||||
cache_merge_node.cc
|
||||
cache_node.cc
|
||||
concat_node.cc
|
||||
epoch_ctrl_node.cc
|
||||
filter_node.cc
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
CacheLookupNode::CacheLookupNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache)
|
||||
: DatasetNode(std::move(cache)), sampler_(sampler), lookup_op_(nullptr), lookup_node_copy_(nullptr) {
|
||||
this->AddChild(child);
|
||||
}
|
||||
|
||||
void CacheLookupNode::Print(std::ostream &out) const { out << Name(); }
|
||||
|
||||
std::shared_ptr<DatasetNode> CacheLookupNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<CacheLookupNode>(nullptr, sampler, cache_);
|
||||
lookup_node_copy_ = node;
|
||||
return node;
|
||||
}
|
||||
|
||||
Status CacheLookupNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("CacheNode", sampler_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheLookupNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_ != nullptr,
|
||||
"Internal error. Attempt to create a cache lookup node without cache client.");
|
||||
RETURN_IF_NOT_OK(cache_->Build());
|
||||
RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, &lookup_op_, sampler_));
|
||||
node_ops->push_back(lookup_op_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerObj> CacheLookupNode::SamplerCopy() {
|
||||
// CacheLookupNode should already been copied, so we just return it here
|
||||
return std::static_pointer_cast<SamplerObj>(lookup_node_copy_);
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerRT> CacheLookupNode::SamplerBuild() {
|
||||
// Runtime cache lookup op should already been built, so we just return it here
|
||||
auto lookup_op = std::dynamic_pointer_cast<CacheLookupOp>(lookup_op_);
|
||||
return std::shared_ptr<SamplerRT>(lookup_op);
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,75 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_LOOKUP_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_LOOKUP_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class CacheLookupNode : public DatasetNode, public SamplerObj {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
CacheLookupNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~CacheLookupNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kCacheLookupNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to convert a SamplerObj class into a runtime sampler object
|
||||
/// \return Shared pointers to the newly created Sampler
|
||||
std::shared_ptr<SamplerRT> SamplerBuild() override;
|
||||
|
||||
/// \brief a base class override function to copy a SamplerObj class
|
||||
/// \return Shared pointers to the newly copied SamplerObj
|
||||
std::shared_ptr<SamplerObj> SamplerCopy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
/// \return Status Status::OK() if build successfully
|
||||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
std::shared_ptr<DatasetOp> lookup_op_;
|
||||
std::shared_ptr<CacheLookupNode> lookup_node_copy_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_LOOKUP_NODE_H_
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
CacheMergeNode::CacheMergeNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<DatasetCache> cache)
|
||||
: DatasetNode(std::move(cache)) {
|
||||
nary_op_ = true;
|
||||
this->AddChild(child);
|
||||
}
|
||||
|
||||
void CacheMergeNode::Print(std::ostream &out) const { out << Name(); }
|
||||
|
||||
std::shared_ptr<DatasetNode> CacheMergeNode::Copy() {
|
||||
auto node = std::make_shared<CacheMergeNode>(nullptr, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
Status CacheMergeNode::ValidateParams() { return Status::OK(); }
|
||||
|
||||
Status CacheMergeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_ != nullptr,
|
||||
"Internal error. Attempt to create a cache merge node without cache client.");
|
||||
RETURN_IF_NOT_OK(cache_->Build());
|
||||
std::shared_ptr<DatasetOp> merge_op = nullptr;
|
||||
RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, &merge_op));
|
||||
node_ops->push_back(merge_op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_MERGE_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_MERGE_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class CacheMergeNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
CacheMergeNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~CacheMergeNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kCacheMergeNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
/// \return Status Status::OK() if build successfully
|
||||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_MERGE_NODE_H_
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/cache_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
CacheNode::CacheNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache)
|
||||
: DatasetNode(std::move(cache)), sampler_(sampler) {
|
||||
this->AddChild(child);
|
||||
}
|
||||
|
||||
void CacheNode::Print(std::ostream &out) const { out << Name(); }
|
||||
|
||||
std::shared_ptr<DatasetNode> CacheNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<CacheNode>(nullptr, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
Status CacheNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("CacheNode", sampler_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_ != nullptr,
|
||||
"Internal error. Attempt to create a cache node without cache client.");
|
||||
RETURN_IF_NOT_OK(cache_->Build());
|
||||
std::shared_ptr<DatasetOp> cache_op = nullptr;
|
||||
RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op));
|
||||
cache_op->SetSampler(sampler_->SamplerBuild());
|
||||
node_ops->push_back(cache_op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class CacheNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
CacheNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~CacheNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kCacheNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
/// \return Status Status::OK() if build successfully
|
||||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_NODE_H_
|
|
@ -204,15 +204,6 @@ std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int
|
|||
return SequentialSampler(0, num_samples);
|
||||
}
|
||||
|
||||
Status DatasetNode::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
||||
if (cache_ != nullptr) {
|
||||
RETURN_IF_NOT_OK(cache_->Build());
|
||||
std::shared_ptr<DatasetOp> cache_op;
|
||||
RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op));
|
||||
node_ops->push_back(cache_op);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
// Constructor to initialize the cache
|
||||
DatasetNode::DatasetNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode() { cache_ = dataset_cache; }
|
||||
|
||||
|
|
|
@ -53,6 +53,9 @@ constexpr char kBatchNode[] = "Batch";
|
|||
constexpr char kBucketBatchByLengthNode[] = "BucketBatchByLength";
|
||||
constexpr char kBuildSentencePieceVocabNode[] = "BuildSentencePieceVocab";
|
||||
constexpr char kBuildVocabNode[] = "BuildVocab";
|
||||
constexpr char kCacheLookupNode[] = "CacheLookup";
|
||||
constexpr char kCacheMergeNode[] = "CacheMerge";
|
||||
constexpr char kCacheNode[] = "Cache";
|
||||
constexpr char kConcatNode[] = "Concat";
|
||||
constexpr char kEpochCtrlNode[] = "EpochCtrl";
|
||||
constexpr char kFilterNode[] = "Filter";
|
||||
|
@ -248,6 +251,9 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
/// \brief Getter of the number of workers
|
||||
int32_t num_workers() { return num_workers_; }
|
||||
|
||||
/// \brief Getter of dataset cache
|
||||
std::shared_ptr<DatasetCache> GetDatasetCache() { return cache_; }
|
||||
|
||||
/// \brief Setter function for runtime number of workers
|
||||
/// \param[in] num_workers The number of threads in this operator
|
||||
/// \return Shared pointer to the original object
|
||||
|
@ -299,7 +305,6 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
// Used only in the constructor of the class and its derived classes.
|
||||
void AddChild(std::shared_ptr<DatasetNode> child);
|
||||
std::string PrintColumns(const std::vector<std::string> &columns) const;
|
||||
Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops);
|
||||
void PrintNode(std::ostream &out, int *level) const;
|
||||
enum DataSource { kNotADataSource = 0, kNonMappableSource = 1, kMappableSource = 2 };
|
||||
enum DataSource mappable_;
|
||||
|
@ -360,6 +365,20 @@ class NonMappableSourceNode : public DatasetNode {
|
|||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
virtual std::string Name() const = 0;
|
||||
|
||||
/// \brief By default non-mappable dataset does not support sampling. However, if a cache operator
|
||||
/// is injected at some other place higher in the tree, that cache can inherit this sampler
|
||||
/// from the leaf, providing sampling support from the caching layer.
|
||||
/// This function sets up the sampler for a leaf node that does not use sampling.
|
||||
/// \param[in] sampler The sampler to setup
|
||||
/// \return Status of the function
|
||||
virtual Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) = 0;
|
||||
|
||||
/// \brief If a cache has been added into the ascendant tree over this non-mappable source node, then the cache will
|
||||
/// be executing a sampler for fetching the data. As such, any options in the source node need to be reset to its
|
||||
/// defaults so that this source node will produce the full set of data into the cache.
|
||||
/// \return Status of the function
|
||||
virtual Status MakeSimpleProducer() = 0;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -76,7 +76,6 @@ Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
|||
auto project_op = std::make_shared<ProjectOp>(project_columns_);
|
||||
node_ops->push_back(project_op);
|
||||
}
|
||||
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
|
||||
|
||||
node_ops->push_back(map_op);
|
||||
return Status::OK();
|
||||
|
|
|
@ -72,8 +72,6 @@ Status AlbumNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
// Argument that is not exposed to user in the API.
|
||||
std::set<std::string> extensions = {};
|
||||
|
||||
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
|
||||
|
||||
node_ops->push_back(std::make_shared<AlbumOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
decode_, extensions, std::move(schema),
|
||||
std::move(sampler_->SamplerBuild())));
|
||||
|
|
|
@ -67,8 +67,6 @@ Status CelebANode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
|
|||
// label is like this:0 1 0 0 1......
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
|
||||
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
|
||||
|
||||
node_ops->push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
decode_, usage_, extensions_, std::move(schema),
|
||||
std::move(sampler_->SamplerBuild())));
|
||||
|
|
|
@ -64,8 +64,6 @@ Status Cifar100Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
|
||||
|
||||
node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_,
|
||||
dataset_dir_, connector_que_size_, std::move(schema),
|
||||
std::move(sampler_->SamplerBuild())));
|
||||
|
|
|
@ -62,8 +62,6 @@ Status Cifar10Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_op
|
|||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
|
||||
|
||||
node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_,
|
||||
dataset_dir_, connector_que_size_, std::move(schema),
|
||||
std::move(sampler_->SamplerBuild())));
|
||||
|
|
|
@ -83,84 +83,66 @@ std::vector<std::string> CLUENode::split(const std::string &s, char delim) {
|
|||
return res;
|
||||
}
|
||||
|
||||
// Function to build CLUENode
|
||||
Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
std::map<std::string, std::string> CLUENode::CreateKeyMapForBuild() {
|
||||
std::map<std::string, std::string> key_map;
|
||||
if (task_ == "AFQMC") {
|
||||
if (usage_ == "train") {
|
||||
if (usage_ == "train" || usage_ == "eval") {
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
key_map["label"] = "label";
|
||||
} else if (usage_ == "test") {
|
||||
} else { // usage_ == "test"
|
||||
key_map["id"] = "id";
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
} else if (usage_ == "eval") {
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
key_map["label"] = "label";
|
||||
}
|
||||
} else if (task_ == "CMNLI") {
|
||||
if (usage_ == "train") {
|
||||
}
|
||||
if (task_ == "CMNLI") {
|
||||
if (usage_ == "train" || usage_ == "eval") {
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
key_map["label"] = "label";
|
||||
} else if (usage_ == "test") {
|
||||
} else { // usage_ == "test"
|
||||
key_map["id"] = "id";
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
} else if (usage_ == "eval") {
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
key_map["label"] = "label";
|
||||
}
|
||||
} else if (task_ == "CSL") {
|
||||
if (usage_ == "train") {
|
||||
}
|
||||
if (task_ == "CSL") {
|
||||
if (usage_ == "train" || usage_ == "eval") {
|
||||
key_map["id"] = "id";
|
||||
key_map["abst"] = "abst";
|
||||
key_map["keyword"] = "keyword";
|
||||
key_map["label"] = "label";
|
||||
} else if (usage_ == "test") {
|
||||
} else { // usage_ == "test"
|
||||
key_map["id"] = "id";
|
||||
key_map["abst"] = "abst";
|
||||
key_map["keyword"] = "keyword";
|
||||
} else if (usage_ == "eval") {
|
||||
key_map["id"] = "id";
|
||||
key_map["abst"] = "abst";
|
||||
key_map["keyword"] = "keyword";
|
||||
key_map["label"] = "label";
|
||||
}
|
||||
} else if (task_ == "IFLYTEK") {
|
||||
if (usage_ == "train") {
|
||||
}
|
||||
if (task_ == "IFLYTEK") {
|
||||
if (usage_ == "train" || usage_ == "eval") {
|
||||
key_map["label"] = "label";
|
||||
key_map["label_des"] = "label_des";
|
||||
key_map["sentence"] = "sentence";
|
||||
} else if (usage_ == "test") {
|
||||
} else { // usage_ == "test"
|
||||
key_map["id"] = "id";
|
||||
key_map["sentence"] = "sentence";
|
||||
} else if (usage_ == "eval") {
|
||||
key_map["label"] = "label";
|
||||
key_map["label_des"] = "label_des";
|
||||
key_map["sentence"] = "sentence";
|
||||
}
|
||||
} else if (task_ == "TNEWS") {
|
||||
if (usage_ == "train") {
|
||||
}
|
||||
if (task_ == "TNEWS") {
|
||||
if (usage_ == "train" || usage_ == "eval") {
|
||||
key_map["label"] = "label";
|
||||
key_map["label_desc"] = "label_desc";
|
||||
key_map["sentence"] = "sentence";
|
||||
key_map["keywords"] = "keywords";
|
||||
} else if (usage_ == "test") {
|
||||
} else { // usage_ == "test"
|
||||
key_map["id"] = "id";
|
||||
key_map["sentence"] = "sentence";
|
||||
key_map["keywords"] = "keywords";
|
||||
} else if (usage_ == "eval") {
|
||||
key_map["label"] = "label";
|
||||
key_map["label_desc"] = "label_desc";
|
||||
key_map["sentence"] = "sentence";
|
||||
key_map["keywords"] = "keywords";
|
||||
}
|
||||
} else if (task_ == "WSC") {
|
||||
if (usage_ == "train") {
|
||||
}
|
||||
if (task_ == "WSC") {
|
||||
if (usage_ == "train" || usage_ == "eval") {
|
||||
key_map["span1_index"] = "target/span1_index";
|
||||
key_map["span2_index"] = "target/span2_index";
|
||||
key_map["span1_text"] = "target/span1_text";
|
||||
|
@ -168,24 +150,21 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
key_map["idx"] = "idx";
|
||||
key_map["label"] = "label";
|
||||
key_map["text"] = "text";
|
||||
} else if (usage_ == "test") {
|
||||
} else { // usage_ == "test"
|
||||
key_map["span1_index"] = "target/span1_index";
|
||||
key_map["span2_index"] = "target/span2_index";
|
||||
key_map["span1_text"] = "target/span1_text";
|
||||
key_map["span2_text"] = "target/span2_text";
|
||||
key_map["idx"] = "idx";
|
||||
key_map["text"] = "text";
|
||||
} else if (usage_ == "eval") {
|
||||
key_map["span1_index"] = "target/span1_index";
|
||||
key_map["span2_index"] = "target/span2_index";
|
||||
key_map["span1_text"] = "target/span1_text";
|
||||
key_map["span2_text"] = "target/span2_text";
|
||||
key_map["idx"] = "idx";
|
||||
key_map["label"] = "label";
|
||||
key_map["text"] = "text";
|
||||
}
|
||||
}
|
||||
return key_map;
|
||||
}
|
||||
|
||||
// Function to build CLUENode
|
||||
Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto key_map = CreateKeyMapForBuild();
|
||||
ColKeyMap ck_map;
|
||||
for (auto &p : key_map) {
|
||||
ck_map.insert({p.first, split(p.second, '/')});
|
||||
|
@ -193,19 +172,13 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
|
||||
// ClueOp by itself is a non-mappable dataset that does not support sampling.
|
||||
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
// That is why we save the sampler here in a leaf node that does not use sampling.
|
||||
std::shared_ptr<SamplerObj> sampler_ = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
|
||||
|
||||
// Sort the dataset files in a lexicographical order
|
||||
std::vector<std::string> sorted_dataset_files = dataset_files_;
|
||||
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
|
||||
|
||||
std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>(
|
||||
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, sorted_dataset_files,
|
||||
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->SamplerBuild()));
|
||||
std::shared_ptr<ClueOp> clue_op =
|
||||
std::make_shared<ClueOp>(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map,
|
||||
sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_);
|
||||
|
||||
RETURN_IF_NOT_OK(clue_op->Init());
|
||||
|
||||
|
@ -222,7 +195,6 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
rows_per_buffer_, &shuffle_op));
|
||||
node_ops->push_back(shuffle_op);
|
||||
}
|
||||
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
|
||||
|
||||
node_ops->push_back(clue_op);
|
||||
|
||||
|
@ -270,5 +242,27 @@ Status CLUENode::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
// CLUE by itself is a non-mappable dataset that does not support sampling.
|
||||
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
// That is why we setup the sampler for a leaf node that does not use sampling.
|
||||
Status CLUENode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If a cache has been added into the ascendant tree over this clue node, then the cache will be executing
|
||||
// a sampler for fetching the data. As such, any options in the clue node need to be reset to its defaults so
|
||||
// that this clue node will produce the full set of data into the cache.
|
||||
Status CLUENode::MakeSimpleProducer() {
|
||||
shard_id_ = 0;
|
||||
num_shards_ = 1;
|
||||
shuffle_ = ShuffleMode::kFalse;
|
||||
num_samples_ = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CLUE_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CLUE_NODE_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -49,6 +50,10 @@ class CLUENode : public NonMappableSourceNode {
|
|||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief Generate a key map to be used in Build() according to usage and task
|
||||
/// \return The generated key map
|
||||
std::map<std::string, std::string> CreateKeyMapForBuild();
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
/// \return Status Status::OK() if build successfully
|
||||
|
@ -85,6 +90,22 @@ class CLUENode : public NonMappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief CLUE by itself is a non-mappable dataset that does not support sampling.
|
||||
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
/// That is why we setup the sampler for a leaf node that does not use sampling.
|
||||
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
/// \param[in] sampler The sampler to setup
|
||||
/// \return Status of the function
|
||||
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
|
||||
|
||||
/// \brief If a cache has been added into the ascendant tree over this clue node, then the cache will be executing
|
||||
/// a sampler for fetching the data. As such, any options in the clue node need to be reset to its defaults so
|
||||
/// that this clue node will produce the full set of data into the cache.
|
||||
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
/// \return Status of the function
|
||||
Status MakeSimpleProducer() override;
|
||||
|
||||
private:
|
||||
/// \brief Split string based on a character delimiter
|
||||
/// \return A string vector
|
||||
|
|
|
@ -122,7 +122,6 @@ Status CocoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
std::shared_ptr<CocoOp> op =
|
||||
std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_,
|
||||
connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild()));
|
||||
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
|
||||
|
||||
node_ops->push_back(op);
|
||||
|
||||
|
|
|
@ -95,12 +95,6 @@ Status CSVNode::ValidateParams() {
|
|||
Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
|
||||
// CSVOp by itself is a non-mappable dataset that does not support sampling.
|
||||
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
// That is why we save the sampler here in a leaf node that does not use sampling.
|
||||
std::shared_ptr<SamplerObj> sampler_ = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
|
||||
|
||||
// Sort the dataset files in a lexicographical order
|
||||
std::vector<std::string> sorted_dataset_files = dataset_files_;
|
||||
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
|
||||
|
@ -119,10 +113,9 @@ Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
|||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<CsvOp> csv_op =
|
||||
std::make_shared<CsvOp>(sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_,
|
||||
rows_per_buffer_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files,
|
||||
num_shards_, shard_id_, std::move(sampler_->SamplerBuild()));
|
||||
std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>(
|
||||
sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, rows_per_buffer_,
|
||||
num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_);
|
||||
|
||||
RETURN_IF_NOT_OK(csv_op->Init());
|
||||
|
||||
|
@ -140,7 +133,6 @@ Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
|||
|
||||
node_ops->push_back(shuffle_op);
|
||||
}
|
||||
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
|
||||
|
||||
node_ops->push_back(csv_op);
|
||||
|
||||
|
@ -188,5 +180,27 @@ Status CSVNode::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
// CSV by itself is a non-mappable dataset that does not support sampling.
|
||||
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
// That is why we setup the sampler for a leaf node that does not use sampling.
|
||||
Status CSVNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If a cache has been added into the ascendant tree over this CSV node, then the cache will be executing
|
||||
// a sampler for fetching the data. As such, any options in the CSV node need to be reset to its defaults so
|
||||
// that this CSV node will produce the full set of data into the cache.
|
||||
Status CSVNode::MakeSimpleProducer() {
|
||||
shard_id_ = 0;
|
||||
num_shards_ = 1;
|
||||
shuffle_ = ShuffleMode::kFalse;
|
||||
num_samples_ = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -107,6 +107,22 @@ class CSVNode : public NonMappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief CSV by itself is a non-mappable dataset that does not support sampling.
|
||||
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
/// That is why we setup the sampler for a leaf node that does not use sampling.
|
||||
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
/// \param[in] sampler The sampler to setup
|
||||
/// \return Status of the function
|
||||
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
|
||||
|
||||
/// \brief If a cache has been added into the ascendant tree over this CSV node, then the cache will be executing
|
||||
/// a sampler for fetching the data. As such, any options in the CSV node need to be reset to its defaults so
|
||||
/// that this CSV node will produce the full set of data into the cache.
|
||||
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
/// \return Status of the function
|
||||
Status MakeSimpleProducer() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> dataset_files_;
|
||||
char field_delim_;
|
||||
|
|
|
@ -95,10 +95,10 @@ class GeneratorNode : public MappableSourceNode {
|
|||
|
||||
/// \brief Sampler getter
|
||||
/// \return SamplerObj of the current node
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return nullptr; }
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override {}
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
py::function generator_function_;
|
||||
|
|
|
@ -70,8 +70,6 @@ Status ImageFolderNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const nod
|
|||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
|
||||
|
||||
node_ops->push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
recursive_, decode_, exts_, class_indexing_, std::move(schema),
|
||||
std::move(sampler_->SamplerBuild())));
|
||||
|
|
|
@ -94,7 +94,6 @@ Status ManifestNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
manifest_op =
|
||||
std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_,
|
||||
class_index_, std::move(schema), std::move(sampler_->SamplerBuild()), usage_);
|
||||
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
|
||||
|
||||
node_ops->push_back(manifest_op);
|
||||
|
||||
|
|
|
@ -23,8 +23,9 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
|
||||
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
|
@ -203,5 +204,16 @@ Status MindDataNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accepting method for IRNodePass
|
||||
Status MindDataNode::Accept(IRNodePass *const p, bool *const modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->Visit(shared_from_base<MindDataNode>(), modified);
|
||||
}
|
||||
|
||||
// Visitor accepting method for IRNodePass
|
||||
Status MindDataNode::AcceptAfter(IRNodePass *p, bool *const modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->VisitAfter(shared_from_base<MindDataNode>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -92,6 +92,18 @@ class MindDataNode : public MappableSourceNode {
|
|||
/// \brief Sampler setter
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
/// \brief Base-class override for accepting IRNodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(IRNodePass *p, bool *const modified) override;
|
||||
|
||||
/// \brief Base-class override for accepting IRNodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status AcceptAfter(IRNodePass *p, bool *const modified) override;
|
||||
|
||||
private:
|
||||
std::string dataset_file_; // search_for_pattern_ will be true in this mode
|
||||
std::vector<std::string> dataset_files_; // search_for_pattern_ will be false in this mode
|
||||
|
|
|
@ -57,7 +57,6 @@ Status MnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
|
||||
|
||||
node_ops->push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_,
|
||||
connector_que_size_, std::move(schema),
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
|
@ -105,17 +106,9 @@ Status RandomNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
|
|||
}
|
||||
}
|
||||
|
||||
// RandomOp by itself is a non-mappable dataset that does not support sampling.
|
||||
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
// That is why we save the sampler here in a leaf node that does not use sampling.
|
||||
// RandomOp doesn't support sampler, should not support sharding, select sampler should just be sequential.
|
||||
std::shared_ptr<SamplerObj> sampler_ = SelectSampler(total_rows_, false, 1, 0);
|
||||
|
||||
std::shared_ptr<RandomDataOp> op;
|
||||
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_,
|
||||
std::move(data_schema_), std::move(sampler_->SamplerBuild()));
|
||||
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
|
||||
std::move(data_schema_));
|
||||
|
||||
node_ops->push_back(op);
|
||||
|
||||
|
@ -142,5 +135,27 @@ Status RandomNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
|
|||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// RandomDataset by itself is a non-mappable dataset that does not support sampling.
|
||||
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
// That is why we setup the sampler for a leaf node that does not use sampling.
|
||||
Status RandomNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
|
||||
// RandomOp doesn't support sampler, should not support sharding, select sampler should just be sequential.
|
||||
*sampler = SelectSampler(total_rows_, false, 1, 0);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accepting method for IRNodePass
|
||||
Status RandomNode::Accept(IRNodePass *p, bool *const modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->Visit(shared_from_base<RandomNode>(), modified);
|
||||
}
|
||||
|
||||
// Visitor accepting method for IRNodePass
|
||||
Status RandomNode::AcceptAfter(IRNodePass *p, bool *const modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->VisitAfter(shared_from_base<RandomNode>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -99,6 +99,30 @@ class RandomNode : public NonMappableSourceNode {
|
|||
const std::mt19937 &RandGen() const { return rand_gen_; }
|
||||
const std::unique_ptr<DataSchema> &GetDataSchema() const { return data_schema_; }
|
||||
|
||||
/// \brief RandomDataset by itself is a non-mappable dataset that does not support sampling.
|
||||
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
/// That is why we setup the sampler for a leaf node that does not use sampling.
|
||||
/// \param[in] sampler The sampler to setup
|
||||
/// \return Status of the function
|
||||
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
|
||||
|
||||
/// \brief Random node will always produce the full set of data into the cache
|
||||
/// \return Status of the function
|
||||
Status MakeSimpleProducer() override { return Status::OK(); }
|
||||
|
||||
/// \brief Base-class override for accepting IRNodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(IRNodePass *p, bool *const modified) override;
|
||||
|
||||
/// \brief Base-class override for accepting IRNodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status AcceptAfter(IRNodePass *p, bool *const modified) override;
|
||||
|
||||
private:
|
||||
/// \brief A quick inline for producing a random number between (and including) min/max
|
||||
/// \param[in] min minimum number that can be generated.
|
||||
|
|
|
@ -73,12 +73,6 @@ Status TextFileNode::ValidateParams() {
|
|||
Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
|
||||
// TextFileOp by itself is a non-mappable dataset that does not support sampling.
|
||||
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
// That is why we save the sampler here in a leaf node that does not use sampling.
|
||||
std::shared_ptr<SamplerObj> sampler_ = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
|
||||
|
||||
// Sort the dataset files in a lexicographical order
|
||||
std::vector<std::string> sorted_dataset_files = dataset_files_;
|
||||
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
|
||||
|
@ -87,10 +81,10 @@ Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
|
||||
// Create and initalize TextFileOp
|
||||
// Create and initialize TextFileOp
|
||||
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
|
||||
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files,
|
||||
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->SamplerBuild()));
|
||||
connector_que_size_, shuffle_files, num_shards_, shard_id_);
|
||||
RETURN_IF_NOT_OK(text_file_op->Init());
|
||||
|
||||
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) {
|
||||
|
@ -106,7 +100,6 @@ Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
rows_per_buffer_, &shuffle_op));
|
||||
node_ops->push_back(shuffle_op);
|
||||
}
|
||||
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
|
||||
|
||||
// Add TextFileOp
|
||||
node_ops->push_back(text_file_op);
|
||||
|
@ -152,5 +145,27 @@ Status TextFileNode::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
// TextFile by itself is a non-mappable dataset that does not support sampling.
|
||||
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
// That is why we setup the sampler for a leaf node that does not use sampling.
|
||||
Status TextFileNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If a cache has been added into the ascendant tree over this TextFile node, then the cache will be executing
|
||||
// a sampler for fetching the data. As such, any options in the TextFile node need to be reset to its defaults so
|
||||
// that this TextFile node will produce the full set of data into the cache.
|
||||
Status TextFileNode::MakeSimpleProducer() {
|
||||
shard_id_ = 0;
|
||||
num_shards_ = 1;
|
||||
shuffle_ = ShuffleMode::kFalse;
|
||||
num_samples_ = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -83,6 +83,22 @@ class TextFileNode : public NonMappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief TextFile by itself is a non-mappable dataset that does not support sampling.
|
||||
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
/// That is why we setup the sampler for a leaf node that does not use sampling.
|
||||
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
/// \param[in] sampler The sampler to setup
|
||||
/// \return Status of the function
|
||||
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
|
||||
|
||||
/// \brief If a cache has been added into the ascendant tree over this TextFile node, then the cache will be executing
|
||||
/// a sampler for fetching the data. As such, any options in the TextFile node need to be reset to its defaults
|
||||
/// so that this TextFile node will produce the full set of data into the cache.
|
||||
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
/// \return Status of the function
|
||||
Status MakeSimpleProducer() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> dataset_files_;
|
||||
int32_t num_samples_;
|
||||
|
|
|
@ -121,17 +121,10 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
|
||||
// TFReaderOp by itself is a non-mappable dataset that does not support sampling.
|
||||
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
// That is why we save the sampler here in a leaf node that does not use sampling.
|
||||
std::shared_ptr<SamplerObj> sampler_ = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
|
||||
|
||||
// Create and initialize TFReaderOp
|
||||
std::shared_ptr<TFReaderOp> tf_reader_op =
|
||||
std::make_shared<TFReaderOp>(num_workers_, worker_connector_size_, rows_per_buffer_, num_samples_, sorted_dir_files,
|
||||
std::move(data_schema), connector_que_size_, columns_list_, shuffle_files, num_shards_,
|
||||
shard_id_, shard_equal_rows_, std::move(sampler_->SamplerBuild()));
|
||||
std::shared_ptr<TFReaderOp> tf_reader_op = std::make_shared<TFReaderOp>(
|
||||
num_workers_, worker_connector_size_, rows_per_buffer_, num_samples_, sorted_dir_files, std::move(data_schema),
|
||||
connector_que_size_, columns_list_, shuffle_files, num_shards_, shard_id_, shard_equal_rows_);
|
||||
|
||||
RETURN_IF_NOT_OK(tf_reader_op->Init());
|
||||
|
||||
|
@ -149,7 +142,6 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
rows_per_buffer_, &shuffle_op));
|
||||
node_ops->push_back(shuffle_op);
|
||||
}
|
||||
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
|
||||
|
||||
// Add TFReaderOp
|
||||
node_ops->push_back(tf_reader_op);
|
||||
|
@ -227,5 +219,29 @@ Status TFRecordNode::to_json(nlohmann::json *out_json) {
|
|||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
// TFRecord by itself is a non-mappable dataset that does not support sampling.
|
||||
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
// That is why we setup the sampler for a leaf node that does not use sampling.
|
||||
Status TFRecordNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If a cache has been added into the ascendant tree over this TFRecord node, then the cache will be executing
|
||||
// a sampler for fetching the data. As such, any options in the TFRecord node need to be reset to its defaults so
|
||||
// that this TFRecord node will produce the full set of data into the cache.
|
||||
Status TFRecordNode::MakeSimpleProducer() {
|
||||
shard_id_ = 0;
|
||||
num_shards_ = 1;
|
||||
shuffle_ = ShuffleMode::kFalse;
|
||||
num_samples_ = 0;
|
||||
shard_equal_rows_ = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -124,6 +124,22 @@ class TFRecordNode : public NonMappableSourceNode {
|
|||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief TFRecord by itself is a non-mappable dataset that does not support sampling.
|
||||
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
/// That is why we setup the sampler for a leaf node that does not use sampling.
|
||||
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
/// \param[in] sampler The sampler to setup
|
||||
/// \return Status of the function
|
||||
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
|
||||
|
||||
/// \brief If a cache has been added into the ascendant tree over this TFRecord node, then the cache will be executing
|
||||
/// a sampler for fetching the data. As such, any options in the TFRecord node need to be reset to its defaults
|
||||
/// so that this TFRecord node will produce the full set of data into the cache.
|
||||
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
/// \return Status of the function
|
||||
Status MakeSimpleProducer() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> dataset_files_;
|
||||
std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string
|
||||
|
|
|
@ -113,7 +113,6 @@ Status VOCNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
|||
voc_op =
|
||||
std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
|
||||
connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild()));
|
||||
RETURN_IF_NOT_OK(AddCacheOp(node_ops));
|
||||
|
||||
node_ops->push_back(voc_op);
|
||||
return Status::OK();
|
||||
|
|
|
@ -31,8 +31,14 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
#endif
|
||||
#ifdef ENABLE_PYTHON
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
|
||||
#ifdef ENABLE_PYTHON
|
||||
#include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
|
||||
|
@ -195,10 +201,10 @@ Status IRNodePass::VisitAfter(std::shared_ptr<FilterNode> node, bool *const modi
|
|||
}
|
||||
#ifdef ENABLE_PYTHON
|
||||
Status IRNodePass::Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) {
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
return Visit(std::static_pointer_cast<MappableSourceNode>(node), modified);
|
||||
}
|
||||
Status IRNodePass::VisitAfter(std::shared_ptr<GeneratorNode> node, bool *const modified) {
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
return VisitAfter(std::static_pointer_cast<MappableSourceNode>(node), modified);
|
||||
}
|
||||
#endif
|
||||
Status IRNodePass::Visit(std::shared_ptr<MapNode> node, bool *const modified) {
|
||||
|
@ -207,12 +213,26 @@ Status IRNodePass::Visit(std::shared_ptr<MapNode> node, bool *const modified) {
|
|||
Status IRNodePass::VisitAfter(std::shared_ptr<MapNode> node, bool *const modified) {
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status IRNodePass::Visit(std::shared_ptr<MindDataNode> node, bool *const modified) {
|
||||
return Visit(std::static_pointer_cast<MappableSourceNode>(node), modified);
|
||||
}
|
||||
Status IRNodePass::VisitAfter(std::shared_ptr<MindDataNode> node, bool *const modified) {
|
||||
return VisitAfter(std::static_pointer_cast<MappableSourceNode>(node), modified);
|
||||
}
|
||||
#endif
|
||||
Status IRNodePass::Visit(std::shared_ptr<ProjectNode> node, bool *const modified) {
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
Status IRNodePass::VisitAfter(std::shared_ptr<ProjectNode> node, bool *const modified) {
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
Status IRNodePass::Visit(std::shared_ptr<RandomNode> node, bool *const modified) {
|
||||
return Visit(std::static_pointer_cast<NonMappableSourceNode>(node), modified);
|
||||
}
|
||||
Status IRNodePass::VisitAfter(std::shared_ptr<RandomNode> node, bool *const modified) {
|
||||
return VisitAfter(std::static_pointer_cast<NonMappableSourceNode>(node), modified);
|
||||
}
|
||||
Status IRNodePass::Visit(std::shared_ptr<RenameNode> node, bool *const modified) {
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
|
|
@ -44,7 +44,6 @@ class TakeNode;
|
|||
class TransferNode;
|
||||
class ZipNode;
|
||||
#ifdef ENABLE_PYTHON
|
||||
class GeneratorNode;
|
||||
class SyncWaitNode;
|
||||
#endif
|
||||
#ifndef ENABLE_ANDROID
|
||||
|
@ -129,14 +128,14 @@ class IRPass : public std::enable_shared_from_this<IRPass> {
|
|||
class IRTreePass : public IRPass {
|
||||
public:
|
||||
/// \brief Run the transformation pass against the IR tree.
|
||||
/// \param[inout] root_ir Pointer to the IR tree to be transformed.
|
||||
/// \param[inout] modified Indicate if the tree was modified
|
||||
/// \param[in,out] root_ir Pointer to the IR tree to be transformed.
|
||||
/// \param[in,out] modified Indicate if the tree was modified
|
||||
Status Run(std::shared_ptr<DatasetNode> root_ir, bool *const modified) final;
|
||||
|
||||
/// \brief Derived classes may implement the runOnTree function to implement tree transformation.
|
||||
/// "modified" flag needs to be set to true if tree is modified during the pass execution.
|
||||
/// \param[inout] tree The tree to operate on.
|
||||
/// \param[inout] Indicate if the tree was modified.
|
||||
/// \param[in,out] tree The tree to operate on.
|
||||
/// \param[in,out] Indicate if the tree was modified.
|
||||
/// \return Status The status code returned
|
||||
virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) { return Status::OK(); }
|
||||
};
|
||||
|
@ -164,8 +163,8 @@ class IRNodePass : public IRPass {
|
|||
~IRNodePass() = default;
|
||||
|
||||
/// \brief Run the transformation pass against the IR tree
|
||||
/// \param[inout] root_ir Pointer to the IR tree to be transformed
|
||||
/// \param[inout] modified Indicator if the tree was changed
|
||||
/// \param[in,out] root_ir Pointer to the IR tree to be transformed
|
||||
/// \param[in,out] modified Indicator if the tree was changed
|
||||
Status Run(std::shared_ptr<DatasetNode> root_ir, bool *const modified) final;
|
||||
|
||||
/// \brief Derived classes may implement the Visit function to implement any initial visit work on the way down
|
||||
|
@ -210,8 +209,14 @@ class IRNodePass : public IRPass {
|
|||
#endif
|
||||
virtual Status Visit(std::shared_ptr<MapNode> node, bool *const modified);
|
||||
virtual Status VisitAfter(std::shared_ptr<MapNode> node, bool *const modified);
|
||||
#ifndef ENABLE_ANDROID
|
||||
virtual Status Visit(std::shared_ptr<MindDataNode> node, bool *const modified);
|
||||
virtual Status VisitAfter(std::shared_ptr<MindDataNode> node, bool *const modified);
|
||||
#endif
|
||||
virtual Status Visit(std::shared_ptr<ProjectNode> node, bool *const modified);
|
||||
virtual Status VisitAfter(std::shared_ptr<ProjectNode> node, bool *const modified);
|
||||
virtual Status Visit(std::shared_ptr<RandomNode> node, bool *const modified);
|
||||
virtual Status VisitAfter(std::shared_ptr<RandomNode> node, bool *const modified);
|
||||
virtual Status Visit(std::shared_ptr<RenameNode> node, bool *const modified);
|
||||
virtual Status VisitAfter(std::shared_ptr<RenameNode> node, bool *const modified);
|
||||
virtual Status Visit(std::shared_ptr<RepeatNode> node, bool *const modified);
|
||||
|
@ -270,14 +275,14 @@ class Pass : public std::enable_shared_from_this<Pass> {
|
|||
class TreePass : public Pass {
|
||||
public:
|
||||
/// \brief Run the transformation pass against the execution tree.
|
||||
/// \param[inout] tree Pointer to the execution tree to be transformed.
|
||||
/// \param[inout] modified Indicate if the tree was modified
|
||||
/// \param[in,out] tree Pointer to the execution tree to be transformed.
|
||||
/// \param[in,out] modified Indicate if the tree was modified
|
||||
Status Run(ExecutionTree *tree, bool *const modified) final;
|
||||
|
||||
/// \brief Derived classes may implement the runOnTree function to implement tree transformation.
|
||||
/// "modified" flag needs to be set to true if tree is modified during the pass execution.
|
||||
/// \param[inout] tree The tree to operate on.
|
||||
/// \param[inout] Indicate of the tree was modified.
|
||||
/// \param[in,out] tree The tree to operate on.
|
||||
/// \param[in,out] Indicate of the tree was modified.
|
||||
/// \return Status The status code returned
|
||||
virtual Status RunOnTree(ExecutionTree *tree, bool *const modified) { return Status::OK(); }
|
||||
};
|
||||
|
@ -305,8 +310,8 @@ class NodePass : public Pass {
|
|||
~NodePass() = default;
|
||||
|
||||
/// \brief Run the transformation pass against the execution tree
|
||||
/// \param[inout] tree Pointer to the execution tree to be transformed
|
||||
/// \param[inout] modified Indicator if the tree was changed
|
||||
/// \param[in,out] tree Pointer to the execution tree to be transformed
|
||||
/// \param[in,out] modified Indicator if the tree was changed
|
||||
Status Run(ExecutionTree *tree, bool *const modified) final;
|
||||
|
||||
/// \brief Derived classes may implement the PreRunOnNode function to implement any initial visit work on the way down
|
||||
|
|
|
@ -16,207 +16,130 @@
|
|||
|
||||
#include <vector>
|
||||
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/cache/cache_client.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/album_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
|
||||
#endif
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
|
||||
#endif
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/cache_node.h"
|
||||
#ifdef ENABLE_PYTHON
|
||||
#include "minddata/dataset/engine/datasetops/source/generator_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
|
||||
#endif
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// Constructor
|
||||
CacheTransformPass::CachePass::CachePass() : is_caching_(false), leaf_op_(nullptr) {}
|
||||
CacheTransformPass::CachePass::CachePass() : is_caching_(false), leaf_node_(nullptr), sampler_(nullptr) {}
|
||||
|
||||
// Identifies the subtree below this node as a cached descendant tree.
|
||||
Status CacheTransformPass::CachePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) {
|
||||
// Note that this function will only get called on non-leaf nodes.
|
||||
// For leaf nodes, the other Visit with NonMappableSourceNode or MappableSourceNode argument will be called instead.
|
||||
Status CacheTransformPass::CachePass::Visit(std::shared_ptr<DatasetNode> node, bool *const modified) {
|
||||
*modified = false;
|
||||
MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
|
||||
if (is_caching_) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Nested cache operations is not supported!");
|
||||
if (node->IsCached()) {
|
||||
MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
|
||||
is_caching_ = true;
|
||||
}
|
||||
is_caching_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache
|
||||
// Resets the tracking of the cache within the tree and assigns the nodes that will be involved in a cache
|
||||
// transformation
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) {
|
||||
Status CacheTransformPass::CachePass::VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) {
|
||||
*modified = false;
|
||||
is_caching_ = false; // We a no longer in a cache subtree. clear the flag.
|
||||
if (leaf_op_) {
|
||||
MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache.";
|
||||
// Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op,
|
||||
// using base class pointers.
|
||||
AddMappableCacheOperators(std::move(leaf_op_), node);
|
||||
} else {
|
||||
// If there was no leaf_op set, then this is a non-mappable scenario.
|
||||
|
||||
if (sampler_) {
|
||||
// Grab the sampler that was saved from the leaf and plug it into the cache op
|
||||
node->SetSampler(std::move(sampler_));
|
||||
MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf.";
|
||||
if (node->IsCached()) {
|
||||
is_caching_ = false; // We a no longer in a cache subtree. clear the flag.
|
||||
if (leaf_node_) {
|
||||
MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache.";
|
||||
// Assign the leaf node into the transform pass, using move to null our copy of it,
|
||||
// and also assign the cached node, using base class pointers.
|
||||
// In the cases where cache is directly injected after the leaf node, these two nodes might be the same.
|
||||
cache_pairs_.push_back(std::make_pair(std::move(leaf_node_), node));
|
||||
} else {
|
||||
// We're a cache op but no sampler was saved from leaf, so create a default sampler
|
||||
const int64_t num_samples = 0;
|
||||
const int64_t start_index = 0;
|
||||
sampler_ = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
|
||||
node->SetSampler(std::move(sampler_));
|
||||
MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op.";
|
||||
// If there was no leaf_node_ set, then this is a non-mappable scenario.
|
||||
// We only assign the cached node in this case.
|
||||
cached_nodes_.push_back(node);
|
||||
}
|
||||
|
||||
// Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache
|
||||
uint32_t cache_crc = DatasetOp::GenerateCRC(node);
|
||||
RETURN_IF_NOT_OK(node->CreateCache(cache_crc));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Common code for mappable leaf setup.
|
||||
Status CacheTransformPass::CachePass::MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) {
|
||||
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
|
||||
if (is_caching_ && leaf_op_) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
|
||||
"There is currently no support for multiple leaf nodes under cache.");
|
||||
}
|
||||
|
||||
// If we are a leaf in the caching path, then save this leaf.
|
||||
if (is_caching_) {
|
||||
MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected";
|
||||
leaf_op_ = std::move(leaf_op);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Common code for non mappable leaf setup.
|
||||
Status CacheTransformPass::CachePass::NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) {
|
||||
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
|
||||
if (is_caching_ && leaf_op_) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
|
||||
"There is currently no support for multiple leaf nodes under cache.");
|
||||
}
|
||||
|
||||
// Sampler for non mappable dataset only works if there is a downstream cache. Remove it from the leaf
|
||||
// as save it for use by cache op in ascendant tree.
|
||||
if (is_caching_) {
|
||||
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_));
|
||||
MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected";
|
||||
} else {
|
||||
// If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can
|
||||
// remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based)
|
||||
std::shared_ptr<SamplerRT> sampler_from_leaf;
|
||||
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *const modified) {
|
||||
if (is_caching_) {
|
||||
// If we are a TF Reader in a caching tree, then change our config so that it becomes a basic
|
||||
// TF reader that parses all files. Selection of data will come from the sampler on the cache instead.
|
||||
node->MakeSimpleProducer();
|
||||
Status CacheTransformPass::CachePass::Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) {
|
||||
if (node->IsCached()) {
|
||||
MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
|
||||
is_caching_ = true;
|
||||
}
|
||||
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ClueOp> node, bool *const modified) {
|
||||
// Cache might also be injected to the non-leaf node upper in the tree, so is_caching_ might also be set to true
|
||||
// by the other Visit() with DatasetNode argument
|
||||
if (is_caching_) {
|
||||
// If we are a ClueOp in a caching tree, then change our config so that it becomes a basic
|
||||
// ClueOp that parses all files. Selection of data will come from the sampler on the cache instead.
|
||||
node->MakeSimpleProducer();
|
||||
MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected";
|
||||
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
|
||||
if (leaf_node_) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
|
||||
"There is currently no support for multiple leaf nodes under cache.");
|
||||
}
|
||||
// Set up a sampler here to be used by cache if we are a non-mappable leaf in a caching tree.
|
||||
// Node that sampler for non mappable dataset only works if there is a downstream cache.
|
||||
RETURN_IF_NOT_OK(node->SetupSamplerForCache(&sampler_));
|
||||
// If we are a non-mappable source node in a caching tree, then change our config so that it becomes a basic
|
||||
// source node that parses all files. Selection of data will come from the sampler on the cache instead.
|
||||
RETURN_IF_NOT_OK(node->MakeSimpleProducer());
|
||||
}
|
||||
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CsvOp> node, bool *const modified) {
|
||||
if (is_caching_) {
|
||||
// If we are a CsvOp in a caching tree, then change our config so that it becomes a basic
|
||||
// CsvOp that parses all files. Selection of data will come from the sampler on the cache instead.
|
||||
node->MakeSimpleProducer();
|
||||
}
|
||||
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TextFileOp> node, bool *const modified) {
|
||||
if (is_caching_) {
|
||||
// If we are a TextFileOp in a caching tree, then change our config so that it becomes a basic
|
||||
// TextFileOp that parses all files. Selection of data will come from the sampler on the cache instead.
|
||||
node->MakeSimpleProducer();
|
||||
}
|
||||
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *const modified) {
|
||||
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
Status CacheTransformPass::CachePass::Visit(std::shared_ptr<RandomNode> node, bool *const modified) {
|
||||
if (node->IsCached()) {
|
||||
MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
|
||||
is_caching_ = true;
|
||||
}
|
||||
// Cache might also be injected to the non-leaf node upper in the tree, so is_caching_ might also be set to true
|
||||
// by the other Visit() with DatasetNode argument
|
||||
if (is_caching_) {
|
||||
MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected";
|
||||
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
|
||||
if (leaf_node_) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
|
||||
"There is currently no support for multiple leaf nodes under cache.");
|
||||
}
|
||||
// Set up a sampler here to be used by cache if we are a non-mappable leaf in a caching tree.
|
||||
// Node that sampler for non mappable dataset only works if there is a downstream cache.
|
||||
RETURN_IF_NOT_OK(node->SetupSamplerForCache(&sampler_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *const modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *const modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *const modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *const modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *const modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *const modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
Status CacheTransformPass::CachePass::Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified) {
|
||||
if (node->IsCached()) {
|
||||
MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
|
||||
is_caching_ = true;
|
||||
}
|
||||
// Cache might also be injected to the non-leaf node upper in the tree, so is_caching_ might also be set to true
|
||||
// by the other Visit() with DatasetNode argument
|
||||
if (is_caching_) {
|
||||
MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected";
|
||||
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
|
||||
if (leaf_node_) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
|
||||
"There is currently no support for multiple leaf nodes under cache.");
|
||||
}
|
||||
// If we are a leaf in the caching path, then save this leaf
|
||||
leaf_node_ = node;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *const modified) {
|
||||
if (is_caching_) {
|
||||
Status CacheTransformPass::CachePass::Visit(std::shared_ptr<MindDataNode> node, bool *const modified) {
|
||||
if (node->IsCached() || is_caching_) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
|
||||
"There is currently no support for MindRecordOp under cache.");
|
||||
}
|
||||
|
@ -226,102 +149,85 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MindRecordOp> no
|
|||
|
||||
#ifdef ENABLE_PYTHON
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) {
|
||||
if (is_caching_) {
|
||||
Status CacheTransformPass::CachePass::Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) {
|
||||
if (node->IsCached() || is_caching_) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
|
||||
"There is currently no support for GeneratorOp under cache.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *const modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *const modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
#endif
|
||||
|
||||
// Assigns the leaf and cache operators that are involved in a cache transformation
|
||||
void CacheTransformPass::CachePass::AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op,
|
||||
std::shared_ptr<CacheOp> cache_op) {
|
||||
cache_pairs_.push_back(std::make_pair(leaf_op, cache_op));
|
||||
}
|
||||
|
||||
// constructor
|
||||
CacheTransformPass::CacheTransformPass() {}
|
||||
|
||||
// Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations
|
||||
Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *const modified) {
|
||||
Status CacheTransformPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) {
|
||||
MS_LOG(INFO) << "Pre pass: Cache transform pass started.";
|
||||
// Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will
|
||||
// use to execute a transform.
|
||||
CachePass cache_pass = CachePass();
|
||||
RETURN_IF_NOT_OK(cache_pass.Run(tree, modified));
|
||||
RETURN_IF_NOT_OK(cache_pass.Run(root_ir, modified));
|
||||
|
||||
// Then, execute the transform for each pair
|
||||
// Execute the transform for non-mappable cache
|
||||
for (auto cached_node : cache_pass.cached_nodes()) {
|
||||
MS_LOG(DEBUG) << "Cache transform pass: Injecting a non-mappable cache node.";
|
||||
RETURN_IF_NOT_OK(InjectNonMappableCacheNode(cached_node, cache_pass.sampler()));
|
||||
}
|
||||
|
||||
// Execute the transform for mappable cache
|
||||
for (auto cache_pair : cache_pass.cache_pairs()) {
|
||||
MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform.";
|
||||
RETURN_IF_NOT_OK(
|
||||
ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client()));
|
||||
MS_LOG(DEBUG) << "Cache transform pass: Injecting a mappable cache node.";
|
||||
RETURN_IF_NOT_OK(InjectMappableCacheNode(cache_pair.first, cache_pair.second));
|
||||
}
|
||||
MS_LOG(INFO) << "Pre pass: Cache transform pass complete.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Helper function to execute the cache transformation.
|
||||
Status CacheTransformPass::ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op,
|
||||
std::shared_ptr<DatasetOp> cache_op,
|
||||
std::shared_ptr<CacheClient> cache_client) {
|
||||
// Get local pointers the child/parent of the cache op. It's possible that the parent is null if the cache was
|
||||
// the root node. It is also possible that cache_child == leaf_op
|
||||
std::shared_ptr<DatasetOp> cache_child = cache_op->child(0);
|
||||
DatasetOp *cache_parent = nullptr;
|
||||
cache_op->Parent(&cache_parent, 0); // fetch the cache op's parent
|
||||
// Helper function to execute mappable cache transformation.
|
||||
// Input:
|
||||
// Sampler
|
||||
// |
|
||||
// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache)
|
||||
//
|
||||
// Transformed:
|
||||
// Sampler --> CacheLookupNode ------------------------->
|
||||
// | |
|
||||
// | CacheMergeNode
|
||||
// | |
|
||||
// LeafNode --> OtherNodes --> CachedNode
|
||||
Status CacheTransformPass::InjectMappableCacheNode(std::shared_ptr<MappableSourceNode> leaf_node,
|
||||
std::shared_ptr<DatasetNode> cached_node) {
|
||||
// Create a cache merge node with defaults
|
||||
auto cache_merge_node = std::make_shared<CacheMergeNode>(nullptr, cached_node->GetDatasetCache());
|
||||
// Insert the cache merge node to become the cached_node's parent
|
||||
RETURN_IF_NOT_OK(cached_node->InsertAbove(cache_merge_node));
|
||||
|
||||
// Extract the sampler from the leaf. We will overwrite this sampler with the lookup op later.
|
||||
std::shared_ptr<SamplerRT> leaf_sampler = leaf_op->sampler();
|
||||
|
||||
// Construct the merge op with defaults
|
||||
std::shared_ptr<CacheMergeOp> merge_op;
|
||||
CacheMergeOp::Builder merge_builder;
|
||||
RETURN_IF_NOT_OK(merge_builder.SetClient(cache_client).Build(&merge_op));
|
||||
RETURN_IF_NOT_OK(tree->AssociateNode(merge_op));
|
||||
|
||||
// Construct the cache lookup op with defaults
|
||||
std::shared_ptr<CacheLookupOp> cache_lookup_op;
|
||||
CacheLookupOp::Builder lookup_builder;
|
||||
RETURN_IF_NOT_OK(lookup_builder.SetClient(cache_client).SetSampler(std::move(leaf_sampler)).Build(&cache_lookup_op));
|
||||
RETURN_IF_NOT_OK(tree->AssociateNode(cache_lookup_op));
|
||||
|
||||
// Overwrite the old sampler in this leaf op to become the lookup op
|
||||
leaf_op->SetSampler(cache_lookup_op);
|
||||
|
||||
// If the cache had a parent, then go into that parent to remove the cache from it's child list and then
|
||||
// replace it with the merge op.
|
||||
if (cache_parent != nullptr) {
|
||||
RETURN_IF_NOT_OK(cache_parent->RemoveChild(cache_op));
|
||||
RETURN_IF_NOT_OK(cache_parent->AddChild(merge_op));
|
||||
} else {
|
||||
// If we didn't have a parent, then the merge op is the root node
|
||||
RETURN_IF_NOT_OK(tree->AssignRoot(merge_op));
|
||||
}
|
||||
|
||||
// Set the cache op to no longer be a parent over it's child. This will fully disconnect the old cache op.
|
||||
// We maintain a local pointer to the old child though.
|
||||
RETURN_IF_NOT_OK(cache_op->RemoveChild(cache_child));
|
||||
|
||||
// Connect the merge op
|
||||
RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_lookup_op)));
|
||||
RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_child)));
|
||||
|
||||
// At this point, the cache op has already had it's children and parents taken away. Calling remove
|
||||
// on it at this point will not do any node hookups, and instead set internal fields to invalid.
|
||||
RETURN_IF_NOT_OK(cache_op->Remove());
|
||||
std::shared_ptr<SamplerObj> leaf_sampler = leaf_node->Sampler();
|
||||
// Create a cache lookup node with leaf_node's sampler
|
||||
auto cache_lookup_node = std::make_shared<CacheLookupNode>(nullptr, leaf_sampler, cached_node->GetDatasetCache());
|
||||
// Insert the cache lookup node as the first child of cache merge node
|
||||
RETURN_IF_NOT_OK(cache_merge_node->InsertChildAt(0, cache_lookup_node));
|
||||
// Overwrite the old sampler in this leaf node to become the cache lookup node
|
||||
leaf_node->SetSampler(std::static_pointer_cast<SamplerObj>(cache_lookup_node));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Helper function to execute non-mappable cache transformation.
|
||||
// Input:
|
||||
// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache)
|
||||
//
|
||||
// Transformed:
|
||||
// Sampler
|
||||
// |
|
||||
// LeafNode --> OtherNodes --> CachedNode --> CacheNode
|
||||
Status CacheTransformPass::InjectNonMappableCacheNode(std::shared_ptr<DatasetNode> cached_node,
|
||||
std::shared_ptr<SamplerObj> sampler) {
|
||||
// Create a cache node using the sampler we saved from the leaf
|
||||
auto cache_node = std::make_shared<CacheNode>(nullptr, sampler, cached_node->GetDatasetCache());
|
||||
// Insert the cache node to become the cached_node's parent
|
||||
RETURN_IF_NOT_OK(cached_node->InsertAbove(cache_node));
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
|
|
|
@ -20,6 +20,8 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -32,11 +34,11 @@ class CacheClient;
|
|||
/// \class CacheTransformPass cache_transform_pass.h
|
||||
/// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching
|
||||
/// operations
|
||||
class CacheTransformPass : public TreePass {
|
||||
class CacheTransformPass : public IRTreePass {
|
||||
/// \class CachePass
|
||||
/// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache
|
||||
/// transformation. It works in conjunction with the CacheTransformPass
|
||||
class CachePass : public NodePass {
|
||||
class CachePass : public IRNodePass {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
/// \param[in] transform_pass Raw pointer back to controlling tree pass
|
||||
|
@ -47,138 +49,72 @@ class CacheTransformPass : public TreePass {
|
|||
|
||||
/// \brief Identifies the subtree below this node as a cached descendant tree.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \param[in,out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override;
|
||||
Status Visit(std::shared_ptr<DatasetNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Resets the tracking of the cache within the tree and assigns the operators that
|
||||
/// will be involved in a cache transformation
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \param[in,out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override;
|
||||
Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \brief Perform non-mappable leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \param[in,out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<ClueOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CsvOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *const modified) override;
|
||||
Status Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) override;
|
||||
#endif
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \brief Perform non-mappable leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \param[in,out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *const modified) override;
|
||||
Status Visit(std::shared_ptr<RandomNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \brief Perform mappable leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \param[in,out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<MnistOp> node, bool *const modified) override;
|
||||
Status Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified) override;
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \param[in,out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<VOCOp> node, bool *const modified) override;
|
||||
Status Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) override;
|
||||
#endif
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CifarOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CocoOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *const modified) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \param[in,out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *const modified) override;
|
||||
Status Visit(std::shared_ptr<MindDataNode> node, bool *const modified) override;
|
||||
#endif
|
||||
|
||||
/// \brief Getter
|
||||
std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs() { return cache_pairs_; }
|
||||
std::vector<std::pair<std::shared_ptr<MappableSourceNode>, std::shared_ptr<DatasetNode>>> cache_pairs() {
|
||||
return cache_pairs_;
|
||||
}
|
||||
|
||||
/// \brief Getter
|
||||
std::vector<std::shared_ptr<DatasetNode>> cached_nodes() { return cached_nodes_; }
|
||||
|
||||
/// \brief Getter
|
||||
std::shared_ptr<SamplerObj> sampler() { return sampler_; }
|
||||
|
||||
private:
|
||||
/// \brief Common code for mappable leaf setup.
|
||||
/// \param[in] node The leaf node performing setup work.
|
||||
/// \return Status The status code returned
|
||||
Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);
|
||||
|
||||
/// \brief Common code for non-mappable leaf setup.
|
||||
/// \param[in] node The leaf node performing setup work.
|
||||
/// \return Status The status code returned
|
||||
Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);
|
||||
|
||||
/// \brief Assigns the leaf and cache operators that are involved in a cache transformation
|
||||
/// \param[in] leaf_op The leaf operator involved in the cache transform
|
||||
/// \param[in] cache_op The cache operator involved in the cache transform
|
||||
void AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, std::shared_ptr<CacheOp> cache_op);
|
||||
|
||||
bool is_caching_;
|
||||
std::shared_ptr<DatasetOp> leaf_op_;
|
||||
std::shared_ptr<SamplerRT> sampler_;
|
||||
// The two operators that work together to establish the cache transform
|
||||
std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs_;
|
||||
std::shared_ptr<MappableSourceNode> leaf_node_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
// The two nodes that work together to establish the cache transform
|
||||
std::vector<std::shared_ptr<DatasetNode>> cached_nodes_;
|
||||
std::vector<std::pair<std::shared_ptr<MappableSourceNode>, std::shared_ptr<DatasetNode>>> cache_pairs_;
|
||||
};
|
||||
|
||||
public:
|
||||
|
@ -189,32 +125,46 @@ class CacheTransformPass : public TreePass {
|
|||
~CacheTransformPass() = default;
|
||||
|
||||
/// \brief Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations
|
||||
/// \param[inout] tree The tree to operate on.
|
||||
/// \param[inout] Indicate of the tree was modified.
|
||||
/// \param[in,out] tree The tree to operate on.
|
||||
/// \param[in,out] Indicate of the tree was modified.
|
||||
/// \return Status The status code returned
|
||||
Status RunOnTree(ExecutionTree *tree, bool *const modified) override;
|
||||
Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) override;
|
||||
|
||||
private:
|
||||
/// \brief Helper function to execute the cache transformation.
|
||||
/// \brief Helper function to execute mappable cache transformation.
|
||||
///
|
||||
/// Input:
|
||||
/// Sampler
|
||||
/// |
|
||||
/// LeafOp --> OtherOps --> CacheOp
|
||||
/// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache)
|
||||
///
|
||||
/// Transformed:
|
||||
/// Sampler --> CacheLookupOp ---------------->
|
||||
/// | |
|
||||
/// | MergeOp
|
||||
/// | |
|
||||
/// LeafOp --> OtherOps -->
|
||||
/// Sampler --> CacheLookupNode ------------------------->
|
||||
/// | |
|
||||
/// | CacheMergeNode
|
||||
/// | |
|
||||
/// LeafNode --> OtherNodes --> CachedNode
|
||||
///
|
||||
/// \param[in] leaf_op The leaf node in the transform
|
||||
/// \param[in] cache_op The cache op in the transform (will get removed)
|
||||
/// \param[in] cache_client The cache client
|
||||
/// \param[in] leaf_node The leaf node in the transform
|
||||
/// \param[in] cached_node The node with cache attribute which is involved in the cache transform
|
||||
/// \return Status The status code returned
|
||||
Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op,
|
||||
std::shared_ptr<DatasetOp> cache_op, std::shared_ptr<CacheClient> cache_client);
|
||||
Status InjectMappableCacheNode(std::shared_ptr<MappableSourceNode> leaf_node,
|
||||
std::shared_ptr<DatasetNode> cached_node);
|
||||
|
||||
/// \brief Helper function to execute non-mappable cache transformation.
|
||||
///
|
||||
/// Input:
|
||||
/// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache)
|
||||
///
|
||||
/// Transformed:
|
||||
/// Sampler
|
||||
/// |
|
||||
/// LeafNode --> OtherNodes --> CachedNode --> CacheNode
|
||||
///
|
||||
/// \param[in] cached_node The node with cache attribute which is involved in the cache transform
|
||||
/// \param[in] sampler The sampler saved for non-mappable leaf nodes during the CachePass
|
||||
/// \return Status The status code returned
|
||||
Status InjectNonMappableCacheNode(std::shared_ptr<DatasetNode> cached_node, std::shared_ptr<SamplerObj> sampler);
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#ifdef ENABLE_PYTHON
|
||||
#include "minddata/dataset/engine/opt/post/generator_node_pass.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/cache_validation_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/deep_copy_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h"
|
||||
|
@ -53,6 +54,7 @@ Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) {
|
|||
actions.emplace_back(std::make_unique<NodeRemovalPass>());
|
||||
actions.emplace_back(std::make_unique<EpochCtrlPass>());
|
||||
if (usage_ == kDeGetter) actions.emplace_back(std::make_unique<GetterPass>());
|
||||
actions.emplace_back(std::make_unique<CacheTransformPass>());
|
||||
// Vector of flags for each action
|
||||
std::vector<bool> modified(actions.size(), false);
|
||||
// Apply pre-pass actions
|
||||
|
|
|
@ -35,7 +35,7 @@ namespace dataset {
|
|||
// Internal Sampler class forward declaration
|
||||
class SamplerRT;
|
||||
|
||||
class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
|
||||
class SamplerObj {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
SamplerObj();
|
||||
|
@ -122,7 +122,7 @@ std::shared_ptr<RandomSamplerObj> RandomSampler(bool replacement = false, int64_
|
|||
|
||||
/// Function to create a Sequential Sampler.
|
||||
/// \notes Samples the dataset elements sequentially, same as not having a sampler.
|
||||
/// \param[in] start_index - Index to start sampling at (dafault to start at first id).
|
||||
/// \param[in] start_index - Index to start sampling at (default to start at first id).
|
||||
/// \param[in] num_samples - The number of samples to draw (default to all elements).
|
||||
/// \return Shared pointer to the current Sampler.
|
||||
std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index = 0, int64_t num_samples = 0);
|
||||
|
|
|
@ -465,24 +465,21 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
|
|||
rc = ccbuilder.Build(&myClient);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// In a mappable dataset, it uses a complex interactions of cache lookup op and cache merge op.
|
||||
// Rather than manually build this, the way to do it is to choose the position of the cache in the tree by
|
||||
// adding a CacheOp. Then, the tree prepare code will drive a transform that will remove the CacheOp and
|
||||
// replace it with the required tree structures for cache lookup op and cache merge op.
|
||||
|
||||
std::shared_ptr<CacheOp> myCacheOp;
|
||||
rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp);
|
||||
std::shared_ptr<CacheLookupOp> myLookupOp;
|
||||
rc = CacheLookupOp::Builder().SetNumWorkers(4).SetClient(myClient).SetSampler(seq_sampler).Build(&myLookupOp);
|
||||
std::shared_ptr<CacheMergeOp> myMergeOp;
|
||||
rc = CacheMergeOp::Builder().SetNumWorkers(4).SetClient(myClient).Build(&myMergeOp);
|
||||
|
||||
std::shared_ptr<ImageFolderOp> so;
|
||||
ImageFolderOp::Builder builder;
|
||||
builder.SetSampler(std::move(seq_sampler))
|
||||
.SetOpConnectorSize(3)
|
||||
builder.SetOpConnectorSize(3)
|
||||
.SetNumWorkers(3)
|
||||
.SetRowsPerBuffer(2)
|
||||
.SetExtensions({".jpg", ".JPEG"})
|
||||
.SetRecursive(true)
|
||||
.SetImageFolderDir(datasets_root_path_ + "/testPK/data");
|
||||
rc = builder.Build(&so);
|
||||
so->SetSampler(myLookupOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// RepeatOp
|
||||
|
@ -495,7 +492,9 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
|
|||
rc = myTree->AssociateNode(so);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
rc = myTree->AssociateNode(myCacheOp);
|
||||
rc = myTree->AssociateNode(myLookupOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = myTree->AssociateNode(myMergeOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
rc = myTree->AssociateNode(myRepeatOp);
|
||||
|
@ -503,9 +502,11 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
|
|||
rc = myTree->AssignRoot(myRepeatOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
rc = myRepeatOp->AddChild(myCacheOp);
|
||||
rc = myRepeatOp->AddChild(myMergeOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = myCacheOp->AddChild(so);
|
||||
rc = myMergeOp->AddChild(myLookupOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = myMergeOp->AddChild(so);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
rc = myTree->Prepare(1);
|
||||
|
@ -532,119 +533,3 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
|
|||
rc = myClient->DestroyCache();
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
}
|
||||
|
||||
//// Simple test with a repeated cache op over random data producer.
|
||||
//// The difference in this one is that you do not add the sampler to the cache op directly.
|
||||
//// Instead, the sampler is added as part of the leaf op construction. Then, the prepare
|
||||
//// phase will pull this up from the leaf and into the cache.
|
||||
//// It removes the sampler from the leaf op, which doesn't make sense there anyway for
|
||||
//// the RandomDataOp which doesn't support sampling without a cache.
|
||||
////
|
||||
//// RepeatOp
|
||||
//// |
|
||||
//// CacheOp
|
||||
//// |
|
||||
//// RandomDataOp
|
||||
////
|
||||
TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) {
|
||||
// Clear the rc of the master thread if any
|
||||
(void)TaskManager::GetMasterThreadRc();
|
||||
Status rc;
|
||||
int32_t rank = 0; // not used
|
||||
MS_LOG(INFO) << "UT test TestCacheInheritSampler";
|
||||
|
||||
session_id_type env_session;
|
||||
rc = GetSessionFromEnv(&env_session);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
int64_t num_samples = 0;
|
||||
int64_t start_index = 0;
|
||||
auto seq_sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
|
||||
|
||||
// Start with an empty execution tree
|
||||
auto myTree = std::make_shared<ExecutionTree>();
|
||||
|
||||
// Create a schema using the C api's
|
||||
std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
|
||||
|
||||
// 2 columns. First column is an "image" 640,480,3
|
||||
TensorShape c1Shape({640, 480, 3});
|
||||
ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
|
||||
rank, // not used
|
||||
&c1Shape);
|
||||
|
||||
// Column 2 will just be a scalar label number
|
||||
TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
|
||||
ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
|
||||
|
||||
testSchema->AddColumn(c1);
|
||||
testSchema->AddColumn(c2);
|
||||
|
||||
// RandomDataOp
|
||||
std::shared_ptr<RandomDataOp> myRandomDataOp;
|
||||
rc = RandomDataOp::Builder()
|
||||
.SetRowsPerBuffer(2)
|
||||
.SetNumWorkers(4)
|
||||
.SetDataSchema(std::move(testSchema))
|
||||
.SetTotalRows(10)
|
||||
.SetSampler(std::move(seq_sampler))
|
||||
.Build(&myRandomDataOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = myTree->AssociateNode(myRandomDataOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// CacheOp
|
||||
CacheClient::Builder ccbuilder;
|
||||
// use arbitrary session of 1, size of 0, spilling// is true
|
||||
ccbuilder.SetSessionId(env_session).SetCacheMemSz(4).SetSpill(true);
|
||||
std::shared_ptr<CacheClient> myClient;
|
||||
rc = ccbuilder.Build(&myClient);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
std::shared_ptr<CacheOp> myCacheOp;
|
||||
rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = myTree->AssociateNode(myCacheOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// RepeatOp
|
||||
uint32_t numRepeats = 4;
|
||||
std::shared_ptr<RepeatOp> myRepeatOp;
|
||||
rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = myTree->AssociateNode(myRepeatOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// Assign tree relations and root
|
||||
rc = myRepeatOp->AddChild(myCacheOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = myCacheOp->AddChild(myRandomDataOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = myTree->AssignRoot(myRepeatOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
MS_LOG(INFO) << "Launching tree and begin iteration";
|
||||
rc = myTree->Prepare(1);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
std::cout << *myClient << std::endl;
|
||||
|
||||
rc = myTree->Launch();
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// Start the loop of reading tensors from our pipeline
|
||||
DatasetIterator dI(myTree);
|
||||
TensorRow tensorList;
|
||||
rc = dI.FetchNextTensorRow(&tensorList);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
int rowCount = 0;
|
||||
while (!tensorList.empty()) {
|
||||
// Don't display these rows, just count them
|
||||
MS_LOG(INFO) << "Row fetched #: " << rowCount;
|
||||
rc = dI.FetchNextTensorRow(&tensorList);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rowCount++;
|
||||
}
|
||||
ASSERT_EQ(rowCount, 40);
|
||||
rc = myClient->DestroyCache();
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
}
|
||||
|
|
|
@ -315,6 +315,9 @@ HandleRcExit $? 0 0
|
|||
PytestCmd "test_cache_nomap.py" "test_cache_nomap_long_file_list"
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
PytestCmd "test_cache_nomap.py" "test_cache_nomap_failure" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
for i in $(seq 1 3)
|
||||
do
|
||||
test_name="test_cache_nomap_multiple_cache${i}"
|
||||
|
|
|
@ -216,7 +216,7 @@ def test_cache_map_failure1():
|
|||
|
|
||||
Cache
|
||||
|
|
||||
ImageFolder
|
||||
Coco
|
||||
|
||||
"""
|
||||
logger.info("Test cache failure 1")
|
||||
|
@ -227,8 +227,9 @@ def test_cache_map_failure1():
|
|||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
|
||||
# This DATA_DIR has 6 images in it
|
||||
ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True,
|
||||
cache=some_cache)
|
||||
decode_op = c_vision.Decode()
|
||||
ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
|
||||
ds1 = ds1.repeat(4)
|
||||
|
@ -302,7 +303,7 @@ def test_cache_map_failure3():
|
|||
|
|
||||
Batch
|
||||
|
|
||||
ImageFolder
|
||||
Mnist
|
||||
"""
|
||||
logger.info("Test cache failure 3")
|
||||
if "SESSION_ID" in os.environ:
|
||||
|
@ -312,8 +313,7 @@ def test_cache_map_failure3():
|
|||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
|
||||
ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10)
|
||||
ds1 = ds1.batch(2)
|
||||
resize_op = c_vision.Resize((224, 224))
|
||||
ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
|
||||
|
@ -342,7 +342,7 @@ def test_cache_map_failure4():
|
|||
|
|
||||
Filter
|
||||
|
|
||||
ImageFolder
|
||||
CelebA
|
||||
|
||||
"""
|
||||
logger.info("Test cache failure 4")
|
||||
|
@ -353,8 +353,8 @@ def test_cache_map_failure4():
|
|||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
|
||||
# This dataset has 4 records
|
||||
ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True)
|
||||
ds1 = ds1.filter(predicate=lambda data: data < 11, input_columns=["label"])
|
||||
|
||||
decode_op = c_vision.Decode()
|
||||
|
@ -382,7 +382,7 @@ def test_cache_map_failure5():
|
|||
|
|
||||
Map(decode, randomCrop)
|
||||
|
|
||||
ImageFolder
|
||||
Manifest
|
||||
|
||||
"""
|
||||
logger.info("Test cache failure 5")
|
||||
|
@ -393,8 +393,8 @@ def test_cache_map_failure5():
|
|||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
|
||||
# This dataset has 4 records
|
||||
data = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True)
|
||||
random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
|
||||
decode_op = c_vision.Decode()
|
||||
|
||||
|
@ -505,7 +505,7 @@ def test_cache_map_failure8():
|
|||
|
|
||||
Repeat
|
||||
|
|
||||
ImageFolder
|
||||
Cifar10
|
||||
"""
|
||||
|
||||
logger.info("Test cache failure 8")
|
||||
|
@ -516,8 +516,7 @@ def test_cache_map_failure8():
|
|||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
|
||||
ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10)
|
||||
decode_op = c_vision.Decode()
|
||||
ds1 = ds1.repeat(4)
|
||||
ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
|
||||
|
@ -545,7 +544,7 @@ def test_cache_map_failure9():
|
|||
|
|
||||
Take
|
||||
|
|
||||
ImageFolder
|
||||
Cifar100
|
||||
|
||||
"""
|
||||
logger.info("Test cache failure 9")
|
||||
|
@ -556,8 +555,7 @@ def test_cache_map_failure9():
|
|||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
|
||||
ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10)
|
||||
ds1 = ds1.take(2)
|
||||
|
||||
decode_op = c_vision.Decode()
|
||||
|
@ -587,7 +585,7 @@ def test_cache_map_failure10():
|
|||
|
|
||||
Skip
|
||||
|
|
||||
ImageFolder
|
||||
VOC
|
||||
|
||||
"""
|
||||
logger.info("Test cache failure 10")
|
||||
|
@ -598,8 +596,8 @@ def test_cache_map_failure10():
|
|||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
|
||||
# This dataset has 9 records
|
||||
ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
ds1 = ds1.skip(1)
|
||||
|
||||
decode_op = c_vision.Decode()
|
||||
|
|
|
@ -1913,6 +1913,217 @@ def test_cache_nomap_long_file_list():
|
|||
logger.info("test_cache_nomap_long_file_list Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_nomap_failure1():
|
||||
"""
|
||||
Test nested cache (failure)
|
||||
|
||||
Repeat
|
||||
|
|
||||
Cache
|
||||
|
|
||||
Map(decode)
|
||||
|
|
||||
Cache
|
||||
|
|
||||
TFRecord
|
||||
|
||||
"""
|
||||
logger.info("Test cache nomap failure 1")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
|
||||
decode_op = c_vision.Decode()
|
||||
ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
|
||||
ds1 = ds1.repeat(4)
|
||||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
ds1.get_batch_size()
|
||||
assert "Nested cache operations" in str(e.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
assert "Nested cache operations" in str(e.value)
|
||||
|
||||
assert num_iter == 0
|
||||
logger.info('test_cache_nomap_failure1 Ended.\n')
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_nomap_failure2():
|
||||
"""
|
||||
Test zip under cache (failure)
|
||||
|
||||
repeat
|
||||
|
|
||||
Cache
|
||||
|
|
||||
Map(decode)
|
||||
|
|
||||
Zip
|
||||
| |
|
||||
Random Random
|
||||
|
||||
"""
|
||||
logger.info("Test cache nomap failure 2")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
schema = ds.Schema()
|
||||
schema.add_column('image', de_type=mstype.uint8,
|
||||
shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
|
||||
schema.add_column('label', de_type=mstype.uint8, shape=[1])
|
||||
|
||||
ds1 = ds.RandomDataset(schema=schema)
|
||||
ds2 = ds.RandomDataset(schema=schema)
|
||||
dsz = ds.zip((ds1, ds2))
|
||||
decode_op = c_vision.Decode()
|
||||
dsz = dsz.map(input_columns=["image"], operations=decode_op, cache=some_cache)
|
||||
dsz = dsz.repeat(4)
|
||||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in dsz.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert "ZipNode is not supported as a descendant operator under a cache" in str(e.value)
|
||||
|
||||
assert num_iter == 0
|
||||
logger.info('test_cache_nomap_failure2 Ended.\n')
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_nomap_failure3():
|
||||
"""
|
||||
Test batch under cache (failure)
|
||||
|
||||
repeat
|
||||
|
|
||||
Cache
|
||||
|
|
||||
Map(resize)
|
||||
|
|
||||
Batch
|
||||
|
|
||||
Clue
|
||||
"""
|
||||
logger.info("Test cache nomap failure 3")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train')
|
||||
ds1 = ds1.batch(2)
|
||||
resize_op = c_vision.Resize((224, 224))
|
||||
ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
|
||||
ds1 = ds1.repeat(4)
|
||||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert "BatchNode is not supported as a descendant operator under a cache" in str(e.value)
|
||||
|
||||
assert num_iter == 0
|
||||
logger.info('test_cache_nomap_failure3 Ended.\n')
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_nomap_failure4():
|
||||
"""
|
||||
Test filter under cache (failure)
|
||||
|
||||
repeat
|
||||
|
|
||||
Cache
|
||||
|
|
||||
Map(decode)
|
||||
|
|
||||
Filter
|
||||
|
|
||||
CSV
|
||||
|
||||
"""
|
||||
logger.info("Test cache nomap failure 4")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"],
|
||||
column_names=['col1', 'col2', 'col3', 'col4'])
|
||||
ds1 = ds1.filter(predicate=lambda data: data < 11, input_columns=["label"])
|
||||
|
||||
decode_op = c_vision.Decode()
|
||||
ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
|
||||
ds1 = ds1.repeat(4)
|
||||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert "FilterNode is not supported as a descendant operator under a cache" in str(e.value)
|
||||
|
||||
assert num_iter == 0
|
||||
logger.info('test_cache_nomap_failure4 Ended.\n')
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_nomap_failure5():
|
||||
"""
|
||||
Test Map with non-deterministic TensorOps under cache (failure)
|
||||
|
||||
repeat
|
||||
|
|
||||
Cache
|
||||
|
|
||||
Map(decode, randomCrop)
|
||||
|
|
||||
TextFile
|
||||
|
||||
"""
|
||||
logger.info("Test cache nomap failure 5")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
data = ds.TextFileDataset(TEXT_FILE_DATA_DIR)
|
||||
random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
|
||||
decode_op = c_vision.Decode()
|
||||
|
||||
data = data.map(input_columns=["image"], operations=decode_op)
|
||||
data = data.map(input_columns=["image"], operations=random_crop_op, cache=some_cache)
|
||||
data = data.repeat(4)
|
||||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert "MapNode with non-deterministic operations is not supported as a descendant of cache" in str(e.value)
|
||||
|
||||
assert num_iter == 0
|
||||
logger.info('test_cache_nomap_failure5 Ended.\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cache_nomap_basic1()
|
||||
test_cache_nomap_basic2()
|
||||
|
|
Loading…
Reference in New Issue