Implement save/load of autotune dataset pipeline configuration #27943

This commit is contained in:
harshvardhangupta 2021-12-13 14:46:39 -05:00
parent 4aa82cc21e
commit bd61adbb72
71 changed files with 588 additions and 74 deletions

View File

@ -49,10 +49,8 @@ class ParallelOp : public DatasetOp {
epoch_sync_flag_(false), epoch_sync_flag_(false),
next_worker_id_(0) { next_worker_id_(0) {
// reduce excessive memory usage with high parallelism // reduce excessive memory usage with high parallelism
// when num_workers > 4, reduce op_connector_size to have similar total size if there were only 4 workers
constexpr int32_t worker_limit = 4; constexpr int32_t worker_limit = 4;
if (num_workers_ > worker_limit) { if (num_workers_ > worker_limit) {
oc_queue_size_ = std::max(1, op_connector_size * worker_limit / num_workers_);
worker_connector_size_ = std::max(1, op_connector_size * worker_limit / num_workers_); worker_connector_size_ = std::max(1, op_connector_size * worker_limit / num_workers_);
} }
} }

View File

@ -67,7 +67,7 @@ class ClueOp : public NonMappableLeafOp {
// Op name getter // Op name getter
// @return Name of the current Op // @return Name of the current Op
std::string Name() const override { return "ClueOp"; } std::string Name() const override { return "CLUEOp"; }
private: private:
// Reads a clue file and loads the data into multiple TensorRows. // Reads a clue file and loads the data into multiple TensorRows.

View File

@ -181,7 +181,7 @@ class CsvOp : public NonMappableLeafOp {
/// Op name getter /// Op name getter
/// @return Name of the current Op /// @return Name of the current Op
std::string Name() const override { return "CsvOp"; } std::string Name() const override { return "CSVOp"; }
// DatasetName name getter // DatasetName name getter
// \return DatasetName of the current Op // \return DatasetName of the current Op

View File

@ -61,6 +61,8 @@ std::shared_ptr<DatasetNode> BatchNode::Copy() {
#else #else
auto node = std::make_shared<BatchNode>(nullptr, batch_size_, drop_remainder_); auto node = std::make_shared<BatchNode>(nullptr, batch_size_, drop_remainder_);
#endif #endif
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -155,6 +157,7 @@ Status BatchNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
Status BatchNode::to_json(nlohmann::json *out_json) { Status BatchNode::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["batch_size"] = batch_size_; args["batch_size"] = batch_size_;
args["drop_remainder"] = drop_remainder_; args["drop_remainder"] = drop_remainder_;
#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
@ -170,12 +173,14 @@ Status BatchNode::to_json(nlohmann::json *out_json) {
Status BatchNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds, Status BatchNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result) { std::shared_ptr<DatasetNode> *result) {
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kBatchNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kBatchNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kBatchNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "batch_size", kBatchNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "batch_size", kBatchNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "drop_remainder", kBatchNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "drop_remainder", kBatchNode));
int32_t batch_size = json_obj["batch_size"]; int32_t batch_size = json_obj["batch_size"];
bool drop_remainder = json_obj["drop_remainder"]; bool drop_remainder = json_obj["drop_remainder"];
*result = std::make_shared<BatchNode>(ds, batch_size, drop_remainder); *result = std::make_shared<BatchNode>(ds, batch_size, drop_remainder);
(*result)->SetNumWorkers(json_obj["num_parallel_workers"]); (*result)->SetNumWorkers(json_obj["num_parallel_workers"]);
(*result)->SetConnectorQueueSize(json_obj["connector_queue_size"]);
return Status::OK(); return Status::OK();
} }

View File

@ -228,6 +228,11 @@ std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) {
return shared_from_this(); return shared_from_this();
} }
std::shared_ptr<DatasetNode> DatasetNode::SetConnectorQueueSize(int32_t connector_queue_size) {
connector_que_size_ = connector_queue_size;
return shared_from_this();
}
std::shared_ptr<DatasetNode> DatasetNode::SetDatasetCache(const std::shared_ptr<DatasetCache> &cache) { std::shared_ptr<DatasetNode> DatasetNode::SetDatasetCache(const std::shared_ptr<DatasetCache> &cache) {
cache_ = cache; cache_ = cache;
return shared_from_this(); return shared_from_this();
@ -657,6 +662,7 @@ Status DatasetNode::ValidateParams() {
Status DatasetNode::to_json(nlohmann::json *out_json) { Status DatasetNode::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
*out_json = args; *out_json = args;
return Status::OK(); return Status::OK();
} }

View File

@ -303,6 +303,9 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \brief Getter of the number of workers /// \brief Getter of the number of workers
int32_t NumWorkers() const { return num_workers_; } int32_t NumWorkers() const { return num_workers_; }
/// \brief Getter of the connector queue size
int32_t ConnectorQueueSize() { return connector_que_size_; }
/// \brief Getter of dataset cache /// \brief Getter of dataset cache
std::shared_ptr<DatasetCache> GetDatasetCache() { return cache_; } std::shared_ptr<DatasetCache> GetDatasetCache() { return cache_; }
@ -311,6 +314,8 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \return Shared pointer to the original object /// \return Shared pointer to the original object
std::shared_ptr<DatasetNode> SetNumWorkers(int32_t num_workers); std::shared_ptr<DatasetNode> SetNumWorkers(int32_t num_workers);
std::shared_ptr<DatasetNode> SetConnectorQueueSize(int32_t connector_queue_size);
/// \brief Setter function for DatasetCache /// \brief Setter function for DatasetCache
/// \param[in] cache Shared pointer to DatasetCache /// \param[in] cache Shared pointer to DatasetCache
/// \return Shared pointer to the original object /// \return Shared pointer to the original object

View File

@ -50,6 +50,8 @@ std::shared_ptr<DatasetNode> MapNode::Copy() {
std::vector<std::shared_ptr<TensorOperation>> operations = operations_; std::vector<std::shared_ptr<TensorOperation>> operations = operations_;
auto node = std::make_shared<MapNode>(nullptr, operations, input_columns_, output_columns_, project_columns_, cache_, auto node = std::make_shared<MapNode>(nullptr, operations, input_columns_, output_columns_, project_columns_, cache_,
callbacks_, offload_); callbacks_, offload_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -80,8 +82,6 @@ Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
RETURN_STATUS_UNEXPECTED("MapNode containing random operation is not supported as a descendant of cache."); RETURN_STATUS_UNEXPECTED("MapNode containing random operation is not supported as a descendant of cache.");
} }
} }
// This parameter will be removed with next rebase
std::vector<std::string> col_orders;
auto map_op = std::make_shared<MapOp>(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_); auto map_op = std::make_shared<MapOp>(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_);
if (!callbacks_.empty()) { if (!callbacks_.empty()) {
@ -156,6 +156,7 @@ Status MapNode::to_json(nlohmann::json *out_json) {
RETURN_UNEXPECTED_IF_NULL(out_json); RETURN_UNEXPECTED_IF_NULL(out_json);
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["input_columns"] = input_columns_; args["input_columns"] = input_columns_;
args["output_columns"] = output_columns_; args["output_columns"] = output_columns_;
args["project_columns"] = project_columns_; args["project_columns"] = project_columns_;
@ -192,6 +193,7 @@ Status MapNode::to_json(nlohmann::json *out_json) {
Status MapNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds, Status MapNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result) { std::shared_ptr<DatasetNode> *result) {
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kMapNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kMapNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kMapNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "input_columns", kMapNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "input_columns", kMapNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "output_columns", kMapNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "output_columns", kMapNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "project_columns", kMapNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "project_columns", kMapNode));
@ -203,6 +205,7 @@ Status MapNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode>
RETURN_IF_NOT_OK(Serdes::ConstructTensorOps(json_obj["operations"], &operations)); RETURN_IF_NOT_OK(Serdes::ConstructTensorOps(json_obj["operations"], &operations));
*result = std::make_shared<MapNode>(ds, operations, input_columns, output_columns, project_columns); *result = std::make_shared<MapNode>(ds, operations, input_columns, output_columns, project_columns);
(*result)->SetNumWorkers(json_obj["num_parallel_workers"]); (*result)->SetNumWorkers(json_obj["num_parallel_workers"]);
(*result)->SetConnectorQueueSize(json_obj["connector_queue_size"]);
return Status::OK(); return Status::OK();
} }
#endif #endif

View File

@ -45,6 +45,8 @@ AGNewsNode::AGNewsNode(const std::string &dataset_dir, int64_t num_samples, Shuf
std::shared_ptr<DatasetNode> AGNewsNode::Copy() { std::shared_ptr<DatasetNode> AGNewsNode::Copy() {
auto node = auto node =
std::make_shared<AGNewsNode>(dataset_dir_, num_samples_, shuffle_, usage_, num_shards_, shard_id_, cache_); std::make_shared<AGNewsNode>(dataset_dir_, num_samples_, shuffle_, usage_, num_shards_, shard_id_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -131,6 +133,7 @@ Status AGNewsNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
Status AGNewsNode::to_json(nlohmann::json *out_json) { Status AGNewsNode::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
args["num_samples"] = num_samples_; args["num_samples"] = num_samples_;

View File

@ -45,6 +45,8 @@ AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_sch
std::shared_ptr<DatasetNode> AlbumNode::Copy() { std::shared_ptr<DatasetNode> AlbumNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<AlbumNode>(dataset_dir_, schema_path_, column_names_, decode_, sampler, cache_); auto node = std::make_shared<AlbumNode>(dataset_dir_, schema_path_, column_names_, decode_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -132,6 +134,7 @@ Status AlbumNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["decode"] = decode_; args["decode"] = decode_;
args["data_schema"] = schema_path_; args["data_schema"] = schema_path_;
@ -148,6 +151,7 @@ Status AlbumNode::to_json(nlohmann::json *out_json) {
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
Status AlbumNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { Status AlbumNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kAlbumNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kAlbumNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kAlbumNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kAlbumNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kAlbumNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "data_schema", kAlbumNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "data_schema", kAlbumNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "column_names", kAlbumNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "column_names", kAlbumNode));
@ -163,6 +167,7 @@ Status AlbumNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<AlbumNode>(dataset_dir, data_schema, column_names, decode, sampler, cache); *ds = std::make_shared<AlbumNode>(dataset_dir, data_schema, column_names, decode, sampler, cache);
(void)((*ds)->SetNumWorkers(json_obj["num_parallel_workers"])); (void)((*ds)->SetNumWorkers(json_obj["num_parallel_workers"]));
(void)((*ds)->SetConnectorQueueSize(json_obj["connector_queue_size"]));
return Status::OK(); return Status::OK();
} }
#endif #endif

View File

@ -44,6 +44,8 @@ AmazonReviewNode::AmazonReviewNode(const std::string &dataset_dir, const std::st
std::shared_ptr<DatasetNode> AmazonReviewNode::Copy() { std::shared_ptr<DatasetNode> AmazonReviewNode::Copy() {
auto node = auto node =
std::make_shared<AmazonReviewNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); std::make_shared<AmazonReviewNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -133,6 +135,7 @@ Status AmazonReviewNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter>
Status AmazonReviewNode::to_json(nlohmann::json *out_json) { Status AmazonReviewNode::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
args["num_samples"] = num_samples_; args["num_samples"] = num_samples_;

View File

@ -103,6 +103,7 @@ Status Caltech256Node::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["decode"] = decode_; args["decode"] = decode_;
if (cache_ != nullptr) { if (cache_ != nullptr) {
@ -117,6 +118,7 @@ Status Caltech256Node::to_json(nlohmann::json *out_json) {
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
Status Caltech256Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { Status Caltech256Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCaltech256Node)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCaltech256Node));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kCaltech256Node));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kCaltech256Node)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kCaltech256Node));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "decode", kCaltech256Node)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "decode", kCaltech256Node));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kCaltech256Node)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kCaltech256Node));
@ -127,7 +129,8 @@ Status Caltech256Node::from_json(nlohmann::json json_obj, std::shared_ptr<Datase
std::shared_ptr<DatasetCache> cache = nullptr; std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<Caltech256Node>(dataset_dir, decode, sampler, cache); *ds = std::make_shared<Caltech256Node>(dataset_dir, decode, sampler, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); (void)(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
(void)(*ds)->SetConnectorQueueSize(json_obj["connector_queue_size"]);
return Status::OK(); return Status::OK();
} }
#endif #endif

View File

@ -46,6 +46,8 @@ CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage,
std::shared_ptr<DatasetNode> CelebANode::Copy() { std::shared_ptr<DatasetNode> CelebANode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<CelebANode>(dataset_dir_, usage_, sampler, decode_, extensions_, cache_); auto node = std::make_shared<CelebANode>(dataset_dir_, usage_, sampler, decode_, extensions_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -176,6 +178,7 @@ Status CelebANode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["decode"] = decode_; args["decode"] = decode_;
args["extensions"] = extensions_; args["extensions"] = extensions_;
@ -192,6 +195,7 @@ Status CelebANode::to_json(nlohmann::json *out_json) {
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
Status CelebANode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { Status CelebANode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCelebANode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCelebANode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kCelebANode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kCelebANode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kCelebANode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kCelebANode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kCelebANode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kCelebANode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kCelebANode));
@ -206,7 +210,8 @@ Status CelebANode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNod
std::shared_ptr<DatasetCache> cache = nullptr; std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<CelebANode>(dataset_dir, usage, sampler, decode, extension, cache); *ds = std::make_shared<CelebANode>(dataset_dir, usage, sampler, decode, extension, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); (void)(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
(void)(*ds)->SetConnectorQueueSize(json_obj["connector_queue_size"]);
return Status::OK(); return Status::OK();
} }
#endif #endif

View File

@ -38,6 +38,8 @@ Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &us
std::shared_ptr<DatasetNode> Cifar100Node::Copy() { std::shared_ptr<DatasetNode> Cifar100Node::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<Cifar100Node>(dataset_dir_, usage_, sampler, cache_); auto node = std::make_shared<Cifar100Node>(dataset_dir_, usage_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -110,6 +112,7 @@ Status Cifar100Node::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
if (cache_ != nullptr) { if (cache_ != nullptr) {
@ -124,6 +127,7 @@ Status Cifar100Node::to_json(nlohmann::json *out_json) {
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
Status Cifar100Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { Status Cifar100Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCifar100Node)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCifar100Node));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kCifar100Node));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kCifar100Node)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kCifar100Node));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kCifar100Node)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kCifar100Node));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kCifar100Node)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kCifar100Node));
@ -134,7 +138,8 @@ Status Cifar100Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetN
std::shared_ptr<DatasetCache> cache = nullptr; std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<Cifar100Node>(dataset_dir, usage, sampler, cache); *ds = std::make_shared<Cifar100Node>(dataset_dir, usage, sampler, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); (void)(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
(void)(*ds)->SetConnectorQueueSize(json_obj["connector_queue_size"]);
return Status::OK(); return Status::OK();
} }
#endif #endif

View File

@ -38,6 +38,8 @@ Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usag
std::shared_ptr<DatasetNode> Cifar10Node::Copy() { std::shared_ptr<DatasetNode> Cifar10Node::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<Cifar10Node>(dataset_dir_, usage_, sampler, cache_); auto node = std::make_shared<Cifar10Node>(dataset_dir_, usage_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -111,6 +113,7 @@ Status Cifar10Node::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
if (cache_ != nullptr) { if (cache_ != nullptr) {
@ -125,6 +128,7 @@ Status Cifar10Node::to_json(nlohmann::json *out_json) {
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
Status Cifar10Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { Status Cifar10Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCifar10Node)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCifar10Node));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kCifar10Node));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kCifar10Node)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kCifar10Node));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kCifar10Node)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kCifar10Node));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kCifar10Node)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kCifar10Node));
@ -135,7 +139,8 @@ Status Cifar10Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNo
std::shared_ptr<DatasetCache> cache = nullptr; std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<Cifar10Node>(dataset_dir, usage, sampler, cache); *ds = std::make_shared<Cifar10Node>(dataset_dir, usage, sampler, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); (void)(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
(void)(*ds)->SetConnectorQueueSize(json_obj["connector_queue_size"]);
return Status::OK(); return Status::OK();
} }
#endif #endif

View File

@ -40,6 +40,8 @@ CityscapesNode::CityscapesNode(const std::string &dataset_dir, const std::string
std::shared_ptr<DatasetNode> CityscapesNode::Copy() { std::shared_ptr<DatasetNode> CityscapesNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<CityscapesNode>(dataset_dir_, usage_, quality_mode_, task_, decode_, sampler, cache_); auto node = std::make_shared<CityscapesNode>(dataset_dir_, usage_, quality_mode_, task_, decode_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -127,6 +129,7 @@ Status CityscapesNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
args["quality_mode"] = quality_mode_; args["quality_mode"] = quality_mode_;

View File

@ -40,6 +40,8 @@ CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task,
std::shared_ptr<DatasetNode> CLUENode::Copy() { std::shared_ptr<DatasetNode> CLUENode::Copy() {
auto node = auto node =
std::make_shared<CLUENode>(dataset_files_, task_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); std::make_shared<CLUENode>(dataset_files_, task_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -223,6 +225,7 @@ Status CLUENode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g
Status CLUENode::to_json(nlohmann::json *out_json) { Status CLUENode::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_files_; args["dataset_dir"] = dataset_files_;
args["task"] = task_; args["task"] = task_;
args["usage"] = usage_; args["usage"] = usage_;
@ -241,6 +244,7 @@ Status CLUENode::to_json(nlohmann::json *out_json) {
Status CLUENode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { Status CLUENode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCLUENode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCLUENode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kCLUENode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kCLUENode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kCLUENode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "task", kCLUENode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "task", kCLUENode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kCLUENode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kCLUENode));
@ -258,7 +262,8 @@ Status CLUENode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode>
std::shared_ptr<DatasetCache> cache = nullptr; std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<CLUENode>(dataset_files, task, usage, num_samples, shuffle, num_shards, shard_id, cache); *ds = std::make_shared<CLUENode>(dataset_files, task, usage, num_samples, shuffle, num_shards, shard_id, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); (void)(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
(void)(*ds)->SetConnectorQueueSize(json_obj["connector_queue_size"]);
return Status::OK(); return Status::OK();
} }
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent // Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent

View File

@ -34,6 +34,8 @@ void CMUArcticNode::Print(std::ostream &out) const { out << Name(); }
std::shared_ptr<DatasetNode> CMUArcticNode::Copy() { std::shared_ptr<DatasetNode> CMUArcticNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<CMUArcticNode>(dataset_dir_, name_, sampler, cache_); auto node = std::make_shared<CMUArcticNode>(dataset_dir_, name_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -102,6 +104,7 @@ Status CMUArcticNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["name"] = name_; args["name"] = name_;
if (cache_ != nullptr) { if (cache_ != nullptr) {

View File

@ -46,6 +46,8 @@ std::shared_ptr<DatasetNode> CocoNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = auto node =
std::make_shared<CocoNode>(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_, extra_metadata_); std::make_shared<CocoNode>(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_, extra_metadata_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -179,6 +181,7 @@ Status CocoNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["annotation_file"] = annotation_file_; args["annotation_file"] = annotation_file_;
args["task"] = task_; args["task"] = task_;
@ -196,6 +199,7 @@ Status CocoNode::to_json(nlohmann::json *out_json) {
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
Status CocoNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { Status CocoNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCocoNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCocoNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kCocoNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kCocoNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kCocoNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "annotation_file", kCocoNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "annotation_file", kCocoNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "task", kCocoNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "task", kCocoNode));
@ -212,7 +216,8 @@ Status CocoNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode>
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
bool extra_metadata = json_obj["extra_metadata"]; bool extra_metadata = json_obj["extra_metadata"];
*ds = std::make_shared<CocoNode>(dataset_dir, annotation_file, task, decode, sampler, cache, extra_metadata); *ds = std::make_shared<CocoNode>(dataset_dir, annotation_file, task, decode, sampler, cache, extra_metadata);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); (void)(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
(void)(*ds)->SetConnectorQueueSize(json_obj["connector_queue_size"]);
return Status::OK(); return Status::OK();
} }
#endif #endif

View File

@ -46,6 +46,8 @@ CoNLL2000Node::CoNLL2000Node(const std::string &dataset_dir, const std::string &
std::shared_ptr<DatasetNode> CoNLL2000Node::Copy() { std::shared_ptr<DatasetNode> CoNLL2000Node::Copy() {
auto node = auto node =
std::make_shared<CoNLL2000Node>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); std::make_shared<CoNLL2000Node>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -142,6 +144,7 @@ Status CoNLL2000Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &s
Status CoNLL2000Node::to_json(nlohmann::json *out_json) { Status CoNLL2000Node::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
args["num_samples"] = num_samples_; args["num_samples"] = num_samples_;

View File

@ -49,6 +49,8 @@ CSVNode::CSVNode(const std::vector<std::string> &csv_files, char field_delim,
std::shared_ptr<DatasetNode> CSVNode::Copy() { std::shared_ptr<DatasetNode> CSVNode::Copy() {
auto node = std::make_shared<CSVNode>(dataset_files_, field_delim_, column_defaults_, column_names_, num_samples_, auto node = std::make_shared<CSVNode>(dataset_files_, field_delim_, column_defaults_, column_names_, num_samples_,
shuffle_, num_shards_, shard_id_, cache_); shuffle_, num_shards_, shard_id_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -163,6 +165,7 @@ Status CSVNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ge
Status CSVNode::to_json(nlohmann::json *out_json) { Status CSVNode::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_files"] = dataset_files_; args["dataset_files"] = dataset_files_;
args["field_delim"] = std::string(1, field_delim_); args["field_delim"] = std::string(1, field_delim_);
args["column_names"] = column_names_; args["column_names"] = column_names_;
@ -181,6 +184,7 @@ Status CSVNode::to_json(nlohmann::json *out_json) {
Status CSVNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { Status CSVNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCSVNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCSVNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kCSVNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_files", kCSVNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_files", kCSVNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "field_delim", kCSVNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "field_delim", kCSVNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "column_names", kCSVNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "column_names", kCSVNode));
@ -200,7 +204,8 @@ Status CSVNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode>
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<CSVNode>(dataset_files, field_delim.c_str()[0], column_defaults, column_names, num_samples, *ds = std::make_shared<CSVNode>(dataset_files, field_delim.c_str()[0], column_defaults, column_names, num_samples,
shuffle, num_shards, shard_id, cache); shuffle, num_shards, shard_id, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); (void)(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
(void)(*ds)->SetConnectorQueueSize(json_obj["connector_queue_size"]);
return Status::OK(); return Status::OK();
} }

View File

@ -46,6 +46,8 @@ DBpediaNode::DBpediaNode(const std::string &dataset_dir, const std::string &usag
std::shared_ptr<DatasetNode> DBpediaNode::Copy() { std::shared_ptr<DatasetNode> DBpediaNode::Copy() {
auto node = auto node =
std::make_shared<DBpediaNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); std::make_shared<DBpediaNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -163,6 +165,7 @@ Status DBpediaNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz
Status DBpediaNode::to_json(nlohmann::json *out_json) { Status DBpediaNode::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
args["num_samples"] = num_samples_; args["num_samples"] = num_samples_;

View File

@ -41,6 +41,8 @@ DIV2KNode::DIV2KNode(const std::string &dataset_dir, const std::string &usage, c
std::shared_ptr<DatasetNode> DIV2KNode::Copy() { std::shared_ptr<DatasetNode> DIV2KNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<DIV2KNode>(dataset_dir_, usage_, downgrade_, scale_, decode_, sampler, cache_); auto node = std::make_shared<DIV2KNode>(dataset_dir_, usage_, downgrade_, scale_, decode_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -136,6 +138,7 @@ Status DIV2KNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
args["downgrade"] = downgrade_; args["downgrade"] = downgrade_;

View File

@ -33,6 +33,8 @@ EMnistNode::EMnistNode(const std::string &dataset_dir, const std::string &name,
std::shared_ptr<DatasetNode> EMnistNode::Copy() { std::shared_ptr<DatasetNode> EMnistNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<EMnistNode>(dataset_dir_, name_, usage_, sampler, cache_); auto node = std::make_shared<EMnistNode>(dataset_dir_, name_, usage_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -105,6 +107,7 @@ Status EMnistNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["name"] = name_; args["name"] = name_;
args["usage"] = usage_; args["usage"] = usage_;

View File

@ -46,6 +46,8 @@ EnWik9Node::EnWik9Node(const std::string &dataset_dir, int32_t num_samples, Shuf
std::shared_ptr<DatasetNode> EnWik9Node::Copy() { std::shared_ptr<DatasetNode> EnWik9Node::Copy() {
auto node = std::make_shared<EnWik9Node>(dataset_dir_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); auto node = std::make_shared<EnWik9Node>(dataset_dir_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -135,6 +137,7 @@ Status EnWik9Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
Status EnWik9Node::to_json(nlohmann::json *out_json) { Status EnWik9Node::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["num_samples"] = num_samples_; args["num_samples"] = num_samples_;
args["shuffle"] = shuffle_; args["shuffle"] = shuffle_;

View File

@ -39,6 +39,8 @@ FakeImageNode::FakeImageNode(int32_t num_images, const std::vector<int32_t> &ima
std::shared_ptr<DatasetNode> FakeImageNode::Copy() { std::shared_ptr<DatasetNode> FakeImageNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<FakeImageNode>(num_images_, image_size_, num_classes_, base_seed_, sampler, cache_); auto node = std::make_shared<FakeImageNode>(num_images_, image_size_, num_classes_, base_seed_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -119,7 +121,7 @@ Status FakeImageNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["num_images"] = num_images_; args["num_images"] = num_images_;
args["image_size"] = image_size_; args["image_size"] = image_size_;
args["num_classes"] = num_classes_; args["num_classes"] = num_classes_;

View File

@ -33,6 +33,8 @@ FashionMnistNode::FashionMnistNode(const std::string &dataset_dir, const std::st
std::shared_ptr<DatasetNode> FashionMnistNode::Copy() { std::shared_ptr<DatasetNode> FashionMnistNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<FashionMnistNode>(dataset_dir_, usage_, sampler, cache_); auto node = std::make_shared<FashionMnistNode>(dataset_dir_, usage_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -100,6 +102,7 @@ Status FashionMnistNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
if (cache_ != nullptr) { if (cache_ != nullptr) {

View File

@ -42,6 +42,8 @@ FlickrNode::FlickrNode(const std::string &dataset_dir, const std::string &annota
std::shared_ptr<DatasetNode> FlickrNode::Copy() { std::shared_ptr<DatasetNode> FlickrNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<FlickrNode>(dataset_dir_, annotation_file_, decode_, sampler, cache_); auto node = std::make_shared<FlickrNode>(dataset_dir_, annotation_file_, decode_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -133,6 +135,7 @@ Status FlickrNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["annotation_file"] = annotation_file_; args["annotation_file"] = annotation_file_;
args["decode"] = decode_; args["decode"] = decode_;
@ -147,6 +150,7 @@ Status FlickrNode::to_json(nlohmann::json *out_json) {
Status FlickrNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { Status FlickrNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kFlickrNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kFlickrNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kFlickrNode));
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir"); CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("annotation_file") != json_obj.end(), "Failed to find annotation_file"); CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("annotation_file") != json_obj.end(), "Failed to find annotation_file");
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode"); CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
@ -158,7 +162,8 @@ Status FlickrNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNod
std::shared_ptr<DatasetCache> cache = nullptr; std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<FlickrNode>(dataset_dir, annotation_file, decode, sampler, cache); *ds = std::make_shared<FlickrNode>(dataset_dir, annotation_file, decode, sampler, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); (void)(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
(void)(*ds)->SetConnectorQueueSize(json_obj["connector_queue_size"]);
return Status::OK(); return Status::OK();
} }
} // namespace dataset } // namespace dataset

View File

@ -55,6 +55,8 @@ std::shared_ptr<DatasetNode> GeneratorNode::Copy() {
} else { } else {
node = std::make_shared<GeneratorNode>(generator_function_, schema_, source_len_, sampler_, num_parallel_workers_); node = std::make_shared<GeneratorNode>(generator_function_, schema_, source_len_, sampler_, num_parallel_workers_);
} }
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }

View File

@ -34,6 +34,8 @@ void GTZANNode::Print(std::ostream &out) const { out << Name(); }
std::shared_ptr<DatasetNode> GTZANNode::Copy() { std::shared_ptr<DatasetNode> GTZANNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<GTZANNode>(dataset_dir_, usage_, sampler, cache_); auto node = std::make_shared<GTZANNode>(dataset_dir_, usage_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -96,6 +98,7 @@ Status GTZANNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
if (cache_ != nullptr) { if (cache_ != nullptr) {

View File

@ -48,6 +48,8 @@ std::shared_ptr<DatasetNode> ImageFolderNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = auto node =
std::make_shared<ImageFolderNode>(dataset_dir_, decode_, sampler, recursive_, exts_, class_indexing_, cache_); std::make_shared<ImageFolderNode>(dataset_dir_, decode_, sampler, recursive_, exts_, class_indexing_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -115,6 +117,7 @@ Status ImageFolderNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["recursive"] = recursive_; args["recursive"] = recursive_;
args["decode"] = decode_; args["decode"] = decode_;
@ -132,6 +135,7 @@ Status ImageFolderNode::to_json(nlohmann::json *out_json) {
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
Status ImageFolderNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { Status ImageFolderNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kImageFolderNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kImageFolderNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kImageFolderNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kImageFolderNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kImageFolderNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "decode", kImageFolderNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "decode", kImageFolderNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kImageFolderNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kImageFolderNode));
@ -154,7 +158,8 @@ Status ImageFolderNode::from_json(nlohmann::json json_obj, std::shared_ptr<Datas
std::shared_ptr<DatasetCache> cache = nullptr; std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<ImageFolderNode>(dataset_dir, decode, sampler, recursive, extension, class_indexing, cache); *ds = std::make_shared<ImageFolderNode>(dataset_dir, decode, sampler, recursive, extension, class_indexing, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); (void)(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
(void)(*ds)->SetConnectorQueueSize(json_obj["connector_queue_size"]);
return Status::OK(); return Status::OK();
} }
#endif #endif

View File

@ -38,6 +38,8 @@ IMDBNode::IMDBNode(const std::string &dataset_dir, const std::string &usage, std
std::shared_ptr<DatasetNode> IMDBNode::Copy() { std::shared_ptr<DatasetNode> IMDBNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<IMDBNode>(dataset_dir_, usage_, sampler, cache_); auto node = std::make_shared<IMDBNode>(dataset_dir_, usage_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -106,6 +108,7 @@ Status IMDBNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
if (cache_ != nullptr) { if (cache_ != nullptr) {
@ -121,6 +124,7 @@ Status IMDBNode::to_json(nlohmann::json *out_json) {
Status IMDBNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { Status IMDBNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
RETURN_UNEXPECTED_IF_NULL(ds); RETURN_UNEXPECTED_IF_NULL(ds);
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kIMDBNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kIMDBNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kIMDBNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kIMDBNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kIMDBNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kIMDBNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kIMDBNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kIMDBNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kIMDBNode));

View File

@ -59,6 +59,8 @@ IWSLT2016Node::IWSLT2016Node(const std::string &dataset_dir, const std::string &
std::shared_ptr<DatasetNode> IWSLT2016Node::Copy() { std::shared_ptr<DatasetNode> IWSLT2016Node::Copy() {
auto node = std::make_shared<IWSLT2016Node>(dataset_dir_, usage_, language_pair_, valid_set_, test_set_, num_samples_, auto node = std::make_shared<IWSLT2016Node>(dataset_dir_, usage_, language_pair_, valid_set_, test_set_, num_samples_,
shuffle_, num_shards_, shard_id_, cache_); shuffle_, num_shards_, shard_id_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -157,6 +159,7 @@ Status IWSLT2016Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &s
Status IWSLT2016Node::to_json(nlohmann::json *out_json) { Status IWSLT2016Node::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
args["language_pair"] = language_pair_; args["language_pair"] = language_pair_;

View File

@ -57,6 +57,8 @@ IWSLT2017Node::IWSLT2017Node(const std::string &dataset_dir, const std::string &
std::shared_ptr<DatasetNode> IWSLT2017Node::Copy() { std::shared_ptr<DatasetNode> IWSLT2017Node::Copy() {
auto node = std::make_shared<IWSLT2017Node>(dataset_dir_, usage_, language_pair_, num_samples_, shuffle_, num_shards_, auto node = std::make_shared<IWSLT2017Node>(dataset_dir_, usage_, language_pair_, num_samples_, shuffle_, num_shards_,
shard_id_, cache_); shard_id_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -150,6 +152,7 @@ Status IWSLT2017Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &s
Status IWSLT2017Node::to_json(nlohmann::json *out_json) { Status IWSLT2017Node::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
args["language_pair"] = language_pair_; args["language_pair"] = language_pair_;

View File

@ -33,6 +33,8 @@ KMnistNode::KMnistNode(const std::string &dataset_dir, const std::string &usage,
std::shared_ptr<DatasetNode> KMnistNode::Copy() { std::shared_ptr<DatasetNode> KMnistNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<KMnistNode>(dataset_dir_, usage_, sampler, cache_); auto node = std::make_shared<KMnistNode>(dataset_dir_, usage_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -100,6 +102,7 @@ Status KMnistNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
if (cache_ != nullptr) { if (cache_ != nullptr) {

View File

@ -34,6 +34,8 @@ void LibriTTSNode::Print(std::ostream &out) const { out << Name(); }
std::shared_ptr<DatasetNode> LibriTTSNode::Copy() { std::shared_ptr<DatasetNode> LibriTTSNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<LibriTTSNode>(dataset_dir_, usage_, sampler, cache_); auto node = std::make_shared<LibriTTSNode>(dataset_dir_, usage_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -107,6 +109,7 @@ Status LibriTTSNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
if (cache_ != nullptr) { if (cache_ != nullptr) {

View File

@ -30,6 +30,8 @@ LJSpeechNode::LJSpeechNode(const std::string &dataset_dir, std::shared_ptr<Sampl
std::shared_ptr<DatasetNode> LJSpeechNode::Copy() { std::shared_ptr<DatasetNode> LJSpeechNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<LJSpeechNode>(dataset_dir_, sampler, cache_); auto node = std::make_shared<LJSpeechNode>(dataset_dir_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -104,6 +106,7 @@ Status LJSpeechNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
if (cache_ != nullptr) { if (cache_ != nullptr) {
nlohmann::json cache_args; nlohmann::json cache_args;

View File

@ -45,6 +45,8 @@ ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &u
std::shared_ptr<DatasetNode> ManifestNode::Copy() { std::shared_ptr<DatasetNode> ManifestNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<ManifestNode>(dataset_file_, usage_, sampler, class_index_, decode_, cache_); auto node = std::make_shared<ManifestNode>(dataset_file_, usage_, sampler, class_index_, decode_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -138,6 +140,7 @@ Status ManifestNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_file"] = dataset_file_; args["dataset_file"] = dataset_file_;
args["usage"] = usage_; args["usage"] = usage_;
args["class_indexing"] = class_index_; args["class_indexing"] = class_index_;
@ -155,6 +158,7 @@ Status ManifestNode::to_json(nlohmann::json *out_json) {
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
Status ManifestNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { Status ManifestNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kManifestNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kManifestNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kManifestNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_file", kManifestNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_file", kManifestNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kManifestNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kManifestNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kManifestNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kManifestNode));
@ -175,7 +179,8 @@ Status ManifestNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetN
std::shared_ptr<DatasetCache> cache = nullptr; std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<ManifestNode>(dataset_file, usage, sampler, class_indexing, decode, cache); *ds = std::make_shared<ManifestNode>(dataset_file, usage, sampler, class_indexing, decode, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); (void)(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
(void)(*ds)->SetConnectorQueueSize(json_obj["connector_queue_size"]);
return Status::OK(); return Status::OK();
} }
#endif #endif

View File

@ -37,6 +37,8 @@ MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr
std::shared_ptr<DatasetNode> MnistNode::Copy() { std::shared_ptr<DatasetNode> MnistNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<MnistNode>(dataset_dir_, usage_, sampler, cache_); auto node = std::make_shared<MnistNode>(dataset_dir_, usage_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -62,7 +64,6 @@ Status MnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
std::shared_ptr<SamplerRT> sampler_rt = nullptr; std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
auto op = std::make_shared<MnistOp>(usage_, num_workers_, dataset_dir_, connector_que_size_, std::move(schema), auto op = std::make_shared<MnistOp>(usage_, num_workers_, dataset_dir_, connector_que_size_, std::move(schema),
std::move(sampler_rt)); std::move(sampler_rt));
op->SetTotalRepeats(GetTotalRepeats()); op->SetTotalRepeats(GetTotalRepeats());
@ -104,6 +105,7 @@ Status MnistNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
if (cache_ != nullptr) { if (cache_ != nullptr) {
@ -118,6 +120,7 @@ Status MnistNode::to_json(nlohmann::json *out_json) {
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
Status MnistNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { Status MnistNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kMnistNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kMnistNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kMnistNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kMnistNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kMnistNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kMnistNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kMnistNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kMnistNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kMnistNode));
@ -129,6 +132,7 @@ Status MnistNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<MnistNode>(dataset_dir, usage, sampler, cache); *ds = std::make_shared<MnistNode>(dataset_dir, usage, sampler, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); (*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
(*ds)->SetConnectorQueueSize(json_obj["connector_queue_size"]);
return Status::OK(); return Status::OK();
} }
#endif #endif

View File

@ -50,6 +50,8 @@ void Multi30kNode::Print(std::ostream &out) const {
std::shared_ptr<DatasetNode> Multi30kNode::Copy() { std::shared_ptr<DatasetNode> Multi30kNode::Copy() {
auto node = std::make_shared<Multi30kNode>(dataset_dir_, usage_, language_pair_, num_samples_, shuffle_, num_shards_, auto node = std::make_shared<Multi30kNode>(dataset_dir_, usage_, language_pair_, num_samples_, shuffle_, num_shards_,
shard_id_, cache_); shard_id_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -139,6 +141,7 @@ Status Multi30kNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
Status Multi30kNode::to_json(nlohmann::json *out_json) { Status Multi30kNode::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["num_samples"] = num_samples_; args["num_samples"] = num_samples_;
args["shuffle"] = shuffle_; args["shuffle"] = shuffle_;

View File

@ -49,6 +49,8 @@ PennTreebankNode::PennTreebankNode(const std::string &dataset_dir, const std::st
std::shared_ptr<DatasetNode> PennTreebankNode::Copy() { std::shared_ptr<DatasetNode> PennTreebankNode::Copy() {
auto node = auto node =
std::make_shared<PennTreebankNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); std::make_shared<PennTreebankNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -131,6 +133,7 @@ Status PennTreebankNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter>
Status PennTreebankNode::to_json(nlohmann::json *out_json) { Status PennTreebankNode::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
args["num_samples"] = num_samples_; args["num_samples"] = num_samples_;

View File

@ -33,6 +33,8 @@ PhotoTourNode::PhotoTourNode(const std::string &dataset_dir, const std::string &
std::shared_ptr<DatasetNode> PhotoTourNode::Copy() { std::shared_ptr<DatasetNode> PhotoTourNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<PhotoTourNode>(dataset_dir_, name_, usage_, sampler, cache_); auto node = std::make_shared<PhotoTourNode>(dataset_dir_, name_, usage_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -116,6 +118,7 @@ Status PhotoTourNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["name"] = name_; args["name"] = name_;
args["usage"] = usage_; args["usage"] = usage_;

View File

@ -38,6 +38,8 @@ Places365Node::Places365Node(const std::string &dataset_dir, const std::string &
std::shared_ptr<DatasetNode> Places365Node::Copy() { std::shared_ptr<DatasetNode> Places365Node::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<Places365Node>(dataset_dir_, usage_, small_, decode_, sampler, cache_); auto node = std::make_shared<Places365Node>(dataset_dir_, usage_, small_, decode_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -103,6 +105,7 @@ Status Places365Node::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
args["small"] = small_; args["small"] = small_;

View File

@ -40,6 +40,8 @@ QMnistNode::QMnistNode(const std::string &dataset_dir, const std::string &usage,
std::shared_ptr<DatasetNode> QMnistNode::Copy() { std::shared_ptr<DatasetNode> QMnistNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<QMnistNode>(dataset_dir_, usage_, compat_, sampler, cache_); auto node = std::make_shared<QMnistNode>(dataset_dir_, usage_, compat_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -114,6 +116,7 @@ Status QMnistNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
args["compat"] = compat_; args["compat"] = compat_;
@ -129,6 +132,7 @@ Status QMnistNode::to_json(nlohmann::json *out_json) {
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
Status QMnistNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { Status QMnistNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kQMnistNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kQMnistNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kQMnistNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kQMnistNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kQMnistNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kQMnistNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kQMnistNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "compat", kQMnistNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "compat", kQMnistNode));
@ -141,7 +145,8 @@ Status QMnistNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNod
std::shared_ptr<DatasetCache> cache = nullptr; std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache)); RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<QMnistNode>(dataset_dir, usage, compat, sampler, cache); *ds = std::make_shared<QMnistNode>(dataset_dir, usage, compat, sampler, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]); (void)(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
(void)(*ds)->SetConnectorQueueSize(json_obj["connector_queue_size"]);
return Status::OK(); return Status::OK();
} }
#endif #endif

View File

@ -33,6 +33,8 @@ SBUNode::SBUNode(const std::string &dataset_dir, bool decode, const std::shared_
std::shared_ptr<DatasetNode> SBUNode::Copy() { std::shared_ptr<DatasetNode> SBUNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<SBUNode>(dataset_dir_, decode_, sampler, cache_); auto node = std::make_shared<SBUNode>(dataset_dir_, decode_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -108,6 +110,7 @@ Status SBUNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["decode"] = decode_; args["decode"] = decode_;
if (cache_ != nullptr) { if (cache_ != nullptr) {

View File

@ -31,6 +31,8 @@ SemeionNode::SemeionNode(const std::string &dataset_dir, const std::shared_ptr<S
std::shared_ptr<DatasetNode> SemeionNode::Copy() { std::shared_ptr<DatasetNode> SemeionNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<SemeionNode>(dataset_dir_, sampler, cache_); auto node = std::make_shared<SemeionNode>(dataset_dir_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -98,6 +100,7 @@ Status SemeionNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
if (cache_ != nullptr) { if (cache_ != nullptr) {
nlohmann::json cache_args; nlohmann::json cache_args;

View File

@ -46,6 +46,8 @@ SogouNewsNode::SogouNewsNode(const std::string &dataset_dir, const std::string &
std::shared_ptr<DatasetNode> SogouNewsNode::Copy() { std::shared_ptr<DatasetNode> SogouNewsNode::Copy() {
auto node = auto node =
std::make_shared<SogouNewsNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); std::make_shared<SogouNewsNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -136,6 +138,7 @@ Status SogouNewsNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &s
Status SogouNewsNode::to_json(nlohmann::json *out_json) { Status SogouNewsNode::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
args["num_samples"] = num_samples_; args["num_samples"] = num_samples_;

View File

@ -28,6 +28,8 @@ SpeechCommandsNode::SpeechCommandsNode(const std::string &dataset_dir, const std
std::shared_ptr<DatasetNode> SpeechCommandsNode::Copy() { std::shared_ptr<DatasetNode> SpeechCommandsNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<SpeechCommandsNode>(dataset_dir_, usage_, sampler, cache_); auto node = std::make_shared<SpeechCommandsNode>(dataset_dir_, usage_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -104,6 +106,7 @@ Status SpeechCommandsNode::to_json(nlohmann::json *out_json) {
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["usage"] = usage_; args["usage"] = usage_;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
if (cache_ != nullptr) { if (cache_ != nullptr) {
nlohmann::json cache_args; nlohmann::json cache_args;

View File

@ -46,6 +46,8 @@ void SQuADNode::Print(std::ostream &out) const {
std::shared_ptr<DatasetNode> SQuADNode::Copy() { std::shared_ptr<DatasetNode> SQuADNode::Copy() {
auto node = std::make_shared<SQuADNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); auto node = std::make_shared<SQuADNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -137,6 +139,7 @@ Status SQuADNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_
Status SQuADNode::to_json(nlohmann::json *out_json) { Status SQuADNode::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
args["num_samples"] = num_samples_; args["num_samples"] = num_samples_;

View File

@ -33,6 +33,8 @@ STL10Node::STL10Node(const std::string &dataset_dir, const std::string &usage, s
std::shared_ptr<DatasetNode> STL10Node::Copy() { std::shared_ptr<DatasetNode> STL10Node::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<STL10Node>(dataset_dir_, usage_, sampler, cache_); auto node = std::make_shared<STL10Node>(dataset_dir_, usage_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -107,6 +109,7 @@ Status STL10Node::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
if (cache_ != nullptr) { if (cache_ != nullptr) {

View File

@ -37,6 +37,8 @@ TedliumNode::TedliumNode(const std::string &dataset_dir, const std::string &rele
std::shared_ptr<DatasetNode> TedliumNode::Copy() { std::shared_ptr<DatasetNode> TedliumNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<TedliumNode>(dataset_dir_, release_, usage_, extensions_, sampler, cache_); auto node = std::make_shared<TedliumNode>(dataset_dir_, release_, usage_, extensions_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -136,6 +138,7 @@ Status TedliumNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["release"] = release_; args["release"] = release_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;

View File

@ -43,6 +43,8 @@ TextFileNode::TextFileNode(std::vector<std::string> dataset_files, int32_t num_s
std::shared_ptr<DatasetNode> TextFileNode::Copy() { std::shared_ptr<DatasetNode> TextFileNode::Copy() {
auto node = std::make_shared<TextFileNode>(dataset_files_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); auto node = std::make_shared<TextFileNode>(dataset_files_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -132,6 +134,7 @@ Status TextFileNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
Status TextFileNode::to_json(nlohmann::json *out_json) { Status TextFileNode::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_files"] = dataset_files_; args["dataset_files"] = dataset_files_;
args["num_samples"] = num_samples_; args["num_samples"] = num_samples_;
args["shuffle"] = shuffle_; args["shuffle"] = shuffle_;
@ -148,6 +151,7 @@ Status TextFileNode::to_json(nlohmann::json *out_json) {
Status TextFileNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { Status TextFileNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kTextFileNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kTextFileNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kTextFileNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_files", kTextFileNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_files", kTextFileNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_samples", kTextFileNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_samples", kTextFileNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "shuffle", kTextFileNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "shuffle", kTextFileNode));

View File

@ -41,6 +41,8 @@ std::shared_ptr<DatasetNode> TFRecordNode::Copy() {
node = std::make_shared<TFRecordNode>(dataset_files_, schema_path_, columns_list_, num_samples_, shuffle_, node = std::make_shared<TFRecordNode>(dataset_files_, schema_path_, columns_list_, num_samples_, shuffle_,
num_shards_, shard_id_, shard_equal_rows_, cache_); num_shards_, shard_id_, shard_equal_rows_, cache_);
} }
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -177,6 +179,7 @@ Status TFRecordNode::to_json(nlohmann::json *out_json) {
RETURN_UNEXPECTED_IF_NULL(out_json); RETURN_UNEXPECTED_IF_NULL(out_json);
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_files"] = dataset_files_; args["dataset_files"] = dataset_files_;
args["columns_list"] = columns_list_; args["columns_list"] = columns_list_;
args["num_samples"] = num_samples_; args["num_samples"] = num_samples_;
@ -206,6 +209,7 @@ Status TFRecordNode::to_json(nlohmann::json *out_json) {
Status TFRecordNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { Status TFRecordNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kTFRecordNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kTFRecordNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kTFRecordNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_files", kTFRecordNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_files", kTFRecordNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "columns_list", kTFRecordNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "columns_list", kTFRecordNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_samples", kTFRecordNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_samples", kTFRecordNode));

View File

@ -44,6 +44,8 @@ UDPOSNode::UDPOSNode(const std::string &dataset_dir, const std::string &usage, i
std::shared_ptr<DatasetNode> UDPOSNode::Copy() { std::shared_ptr<DatasetNode> UDPOSNode::Copy() {
auto node = std::make_shared<UDPOSNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); auto node = std::make_shared<UDPOSNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -136,6 +138,7 @@ Status UDPOSNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_
Status UDPOSNode::to_json(nlohmann::json *out_json) { Status UDPOSNode::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
args["num_samples"] = num_samples_; args["num_samples"] = num_samples_;

View File

@ -45,6 +45,8 @@ USPSNode::USPSNode(const std::string &dataset_dir, const std::string &usage, int
std::shared_ptr<DatasetNode> USPSNode::Copy() { std::shared_ptr<DatasetNode> USPSNode::Copy() {
auto node = std::make_shared<USPSNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); auto node = std::make_shared<USPSNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -125,6 +127,7 @@ Status USPSNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g
Status USPSNode::to_json(nlohmann::json *out_json) { Status USPSNode::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
args["num_samples"] = num_samples_; args["num_samples"] = num_samples_;

View File

@ -48,6 +48,8 @@ std::shared_ptr<DatasetNode> VOCNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = auto node =
std::make_shared<VOCNode>(dataset_dir_, task_, usage_, class_index_, decode_, sampler, cache_, extra_metadata_); std::make_shared<VOCNode>(dataset_dir_, task_, usage_, class_index_, decode_, sampler, cache_, extra_metadata_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -167,6 +169,7 @@ Status VOCNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["task"] = task_; args["task"] = task_;
args["usage"] = usage_; args["usage"] = usage_;
@ -185,6 +188,7 @@ Status VOCNode::to_json(nlohmann::json *out_json) {
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
Status VOCNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) { Status VOCNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kTFRecordNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kTFRecordNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "connector_queue_size", kTFRecordNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kTFRecordNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kTFRecordNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "task", kTFRecordNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "task", kTFRecordNode));
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kTFRecordNode)); RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kTFRecordNode));

View File

@ -36,6 +36,8 @@ WIDERFaceNode::WIDERFaceNode(const std::string &dataset_dir, const std::string &
std::shared_ptr<DatasetNode> WIDERFaceNode::Copy() { std::shared_ptr<DatasetNode> WIDERFaceNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<WIDERFaceNode>(dataset_dir_, usage_, decode_, sampler, cache_); auto node = std::make_shared<WIDERFaceNode>(dataset_dir_, usage_, decode_, sampler, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -116,6 +118,7 @@ Status WIDERFaceNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["decode"] = decode_; args["decode"] = decode_;
args["usage"] = usage_; args["usage"] = usage_;

View File

@ -49,6 +49,8 @@ WikiTextNode::WikiTextNode(const std::string &dataset_dir, const std::string &us
std::shared_ptr<DatasetNode> WikiTextNode::Copy() { std::shared_ptr<DatasetNode> WikiTextNode::Copy() {
auto node = auto node =
std::make_shared<WikiTextNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); std::make_shared<WikiTextNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -131,6 +133,7 @@ Status WikiTextNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
Status WikiTextNode::to_json(nlohmann::json *out_json) { Status WikiTextNode::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
args["num_samples"] = num_samples_; args["num_samples"] = num_samples_;

View File

@ -47,6 +47,8 @@ YahooAnswersNode::YahooAnswersNode(const std::string &dataset_dir, const std::st
std::shared_ptr<DatasetNode> YahooAnswersNode::Copy() { std::shared_ptr<DatasetNode> YahooAnswersNode::Copy() {
auto node = auto node =
std::make_shared<YahooAnswersNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); std::make_shared<YahooAnswersNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -170,6 +172,7 @@ Status YahooAnswersNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter>
Status YahooAnswersNode::to_json(nlohmann::json *out_json) { Status YahooAnswersNode::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
args["num_samples"] = num_samples_; args["num_samples"] = num_samples_;

View File

@ -47,6 +47,8 @@ YelpReviewNode::YelpReviewNode(const std::string &dataset_dir, const std::string
std::shared_ptr<DatasetNode> YelpReviewNode::Copy() { std::shared_ptr<DatasetNode> YelpReviewNode::Copy() {
auto node = auto node =
std::make_shared<YelpReviewNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); std::make_shared<YelpReviewNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
node->SetNumWorkers(num_workers_);
node->SetConnectorQueueSize(connector_que_size_);
return node; return node;
} }
@ -134,6 +136,7 @@ Status YelpReviewNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &
Status YelpReviewNode::to_json(nlohmann::json *out_json) { Status YelpReviewNode::to_json(nlohmann::json *out_json) {
nlohmann::json args; nlohmann::json args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_; args["usage"] = usage_;
args["num_samples"] = num_samples_; args["num_samples"] = num_samples_;

View File

@ -102,6 +102,7 @@ Status YesNoNode::to_json(nlohmann::json *out_json) {
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args; args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_; args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_; args["dataset_dir"] = dataset_dir_;
if (cache_ != nullptr) { if (cache_ != nullptr) {
nlohmann::json cache_args; nlohmann::json cache_args;

View File

@ -48,10 +48,11 @@ Status DeepCopyPass::Visit(std::shared_ptr<DatasetNode> node, bool *const modifi
// Clone a new copy of this node // Clone a new copy of this node
std::shared_ptr<DatasetNode> new_node = node->Copy(); std::shared_ptr<DatasetNode> new_node = node->Copy();
// Temporary fix to set the num_workers to each cloned node. // Temporary fix to set the num_workers and connector_queue_size to each cloned node.
// This can be improved by adding a new method in the base class DatasetNode to transfer the properties to // This can be improved by adding a new method in the base class DatasetNode to transfer the properties to
// the cloned node. Each derived class's Copy() will need to include this method. // the cloned node. Each derived class's Copy() will need to include this method.
new_node->SetNumWorkers(node->NumWorkers()); new_node->SetNumWorkers(node->NumWorkers());
new_node->SetConnectorQueueSize(node->ConnectorQueueSize());
// This method below assumes a DFS walk and from the first child to the last child. // This method below assumes a DFS walk and from the first child to the last child.
// Future: A more robust implementation that does not depend on the above assumption. // Future: A more robust implementation that does not depend on the above assumption.
RETURN_IF_NOT_OK(parent_->AppendChild(new_node)); RETURN_IF_NOT_OK(parent_->AppendChild(new_node));

View File

@ -21,10 +21,11 @@
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <string>
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h" #include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
#include "minddata/dataset/engine/serdes.h"
#endif #endif
#include "minddata/dataset/util/task_manager.h" #include "minddata/dataset/util/task_manager.h"
namespace mindspore { namespace mindspore {
@ -53,8 +54,22 @@ Status AutoTune::Main() {
} else if (step_gap_ == 0) { } else if (step_gap_ == 0) {
mode_ = AutoTuneMode::kAutoTuneModeEpoch; mode_ = AutoTuneMode::kAutoTuneModeEpoch;
} }
const bool nodes_offloaded = !tree_adapter_->GetOffloadJson().empty();
if (nodes_offloaded) {
// When nodes are offloaded they are removed from the optimized IR tree.
// Serializing the optimized IR Tree and then deserializing will not work.
MS_LOG(WARNING) << "Some nodes have been offloaded. AutoTune is unable to write the autotune configuration to "
"disk. Disable offload to prevent this from happening.";
}
bool output_final_config = save_autoconfig_ && !nodes_offloaded;
bool output_intermediate_config = save_intermediate_autoconfig_ && output_final_config;
Status rc; Status rc;
int loop_cnt = 0;
while (!this_thread::is_interrupted() && !(tree_adapter_->tree_->isFinished())) { while (!this_thread::is_interrupted() && !(tree_adapter_->tree_->isFinished())) {
#ifndef ENABLE_ANDROID
auto last_epoch = cur_epoch_;
auto last_step = cur_step_;
#endif
if (mode_ == AutoTuneMode::kAutoTuneModeEpoch) { if (mode_ == AutoTuneMode::kAutoTuneModeEpoch) {
rc = RunIterationEpoch(); rc = RunIterationEpoch();
} else if (mode_ == AutoTuneMode::kAutoTuneModeStep) { } else if (mode_ == AutoTuneMode::kAutoTuneModeStep) {
@ -65,6 +80,16 @@ Status AutoTune::Main() {
RETURN_IF_NOT_OK(profiling_manager_->Stop()); RETURN_IF_NOT_OK(profiling_manager_->Stop());
break; break;
} }
#ifndef ENABLE_ANDROID
if (last_epoch != cur_epoch_ || last_step != cur_step_) {
if (output_intermediate_config &&
(SaveAutotuneConfig(tree_adapter_->tree_->GetUniqueId() + "_autotune_" + std::to_string(loop_cnt) + ".json")
.IsError())) {
MS_LOG(WARNING) << "Failed to write current iteration autotune configuration to disk";
}
++loop_cnt;
}
#endif
rc = cv_.WaitFor(&_lock, GlobalContext::config_manager()->monitor_sampling_interval()); rc = cv_.WaitFor(&_lock, GlobalContext::config_manager()->monitor_sampling_interval());
// the thread may be interrupted for tree termination when waiting (we should not report error in this case) // the thread may be interrupted for tree termination when waiting (we should not report error in this case)
if (rc.IsError() && rc != StatusCode::kMDInterrupted) { if (rc.IsError() && rc != StatusCode::kMDInterrupted) {
@ -79,11 +104,43 @@ Status AutoTune::Main() {
<< "mindspore.dataset.config.set_num_parallel_workers"; << "mindspore.dataset.config.set_num_parallel_workers";
MS_LOG(INFO) << "Suggest to choose maximum prefetch_size from tuned result and set by global setting API: " MS_LOG(INFO) << "Suggest to choose maximum prefetch_size from tuned result and set by global setting API: "
<< "mindspore.dataset.config.set_prefetch_size"; << "mindspore.dataset.config.set_prefetch_size";
#ifndef ENABLE_ANDROID
if (output_final_config && (SaveAutotuneConfig(autotune_json_filepath_).IsError())) {
MS_LOG(WARNING) << "Failed to write final autotune configuration to disk";
}
#endif
return Status::OK(); return Status::OK();
} }
void AutoTune::PrintTreeConfiguration() { #ifndef ENABLE_ANDROID
ExecutionTree *tree = tree_adapter_->tree_.get(); Status AutoTune::SaveAutotuneConfig(const std::string &file_name) {
RETURN_IF_NOT_OK(SetAutotuneConfigJson());
// The Execution Tree is built by visiting the optimized IR Tree in DFS order.
// So we visit the optimized IR tree in DFS order and try to match each IR node with its corresponding dataset op.
RETURN_IF_NOT_OK(Serdes::UpdateOptimizedIRTreeJSON(&autotune_config_json_, ops_));
RETURN_IF_NOT_OK(Serdes::SaveJSONToFile(autotune_config_json_, file_name));
return Status::OK();
}
Status AutoTune::SetAutotuneConfigJson() {
if (autotune_config_json_.empty()) {
nlohmann::json out_json;
RETURN_IF_NOT_OK(Serdes::SaveToJSON(tree_adapter_->RootIRNode(), "", &out_json));
// We do not want to serialize TransferNode/DeviceQueueOp
if (out_json["op_type"] == kTransferNode) {
CHECK_FAIL_RETURN_UNEXPECTED(
out_json["children"].size() == 1,
"Expected Transfer node to have exactly 1 child but it has " + std::to_string(out_json["children"].size()));
out_json = out_json["children"][0];
}
autotune_config_json_ = std::move(out_json);
}
return Status::OK();
}
#endif
void AutoTune::PrintTreeConfiguration() const {
ExecutionTree const *tree = tree_adapter_->tree_.get();
for (auto itr = tree->begin(); itr != tree->end(); itr++) { for (auto itr = tree->begin(); itr != tree->end(); itr++) {
if (!itr->inlined() && itr->Name() != "DeviceQueueOp") { if (!itr->inlined() && itr->Name() != "DeviceQueueOp") {
MS_LOG(INFO) << itr->NameWithID() << " num_parallel_workers: " << itr->NumWorkers() MS_LOG(INFO) << itr->NameWithID() << " num_parallel_workers: " << itr->NumWorkers()
@ -106,7 +163,7 @@ Status AutoTune::LaunchThread() {
} }
Status AutoTune::CollectOpsInfo() { Status AutoTune::CollectOpsInfo() {
ExecutionTree *tree = tree_adapter_->tree_.get(); ExecutionTree const *tree = tree_adapter_->tree_.get();
RETURN_UNEXPECTED_IF_NULL(tree); RETURN_UNEXPECTED_IF_NULL(tree);
for (auto itr = tree->begin(); itr != tree->end(); ++itr) { for (auto itr = tree->begin(); itr != tree->end(); ++itr) {
ops_[itr->id()] = itr.get(); ops_[itr->id()] = itr.get();

View File

@ -51,7 +51,18 @@ class AutoTune {
Status Main(); Status Main();
/// \brief Helper to print the tree configuration /// \brief Helper to print the tree configuration
void PrintTreeConfiguration(); void PrintTreeConfiguration() const;
#ifndef ENABLE_ANDROID
/// \brief Serialize the dataset and save the AT config (workers and queue size) to a json file
/// \param file_name Name of the file
/// \return Status object
Status SaveAutotuneConfig(const std::string &file_name);
/// Setter for autotune_config_json_
/// \return Status code
Status SetAutotuneConfigJson();
#endif
/// Function to collect info from the tree /// Function to collect info from the tree
/// \return Status code /// \return Status code
@ -195,8 +206,14 @@ class AutoTune {
/// True if should save AutoTune configuration /// True if should save AutoTune configuration
bool save_autoconfig_; bool save_autoconfig_;
/// Flag to enable saving of intermediate autotune config to disk
bool save_intermediate_autoconfig_{false};
/// Filepath name of the final AutoTune Configuration JSON file /// Filepath name of the final AutoTune Configuration JSON file
std::string autotune_json_filepath_; std::string autotune_json_filepath_;
/// Serialized json of the optimized ir tree that holds the updated configuration (workers and queue size)
nlohmann::json autotune_config_json_;
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -30,6 +30,13 @@ std::map<std::string, Status (*)(nlohmann::json json_obj, std::shared_ptr<Tensor
Status Serdes::SaveToJSON(std::shared_ptr<DatasetNode> node, const std::string &filename, nlohmann::json *out_json) { Status Serdes::SaveToJSON(std::shared_ptr<DatasetNode> node, const std::string &filename, nlohmann::json *out_json) {
RETURN_UNEXPECTED_IF_NULL(node); RETURN_UNEXPECTED_IF_NULL(node);
RETURN_UNEXPECTED_IF_NULL(out_json); RETURN_UNEXPECTED_IF_NULL(out_json);
// If an optimized IR Tree is sent (use-case for MD AutoTune), ignore Top and EpochCtrl nodes
if (node->Name() == "Top" || node->Name() == "EpochCtrl") {
CHECK_FAIL_RETURN_UNEXPECTED(
node->Children().size() == 1,
"Expected " + node->Name() + " to have exactly 1 child but it has " + std::to_string(node->Children().size()));
return SaveToJSON(node->Children()[0], filename, out_json);
}
// Dump attributes of current node to json string // Dump attributes of current node to json string
nlohmann::json args; nlohmann::json args;
RETURN_IF_NOT_OK(node->to_json(&args)); RETURN_IF_NOT_OK(node->to_json(&args));
@ -55,7 +62,7 @@ Status Serdes::SaveToJSON(std::shared_ptr<DatasetNode> node, const std::string &
return Status::OK(); return Status::OK();
} }
Status Serdes::SaveJSONToFile(nlohmann::json json_string, const std::string &file_name) { Status Serdes::SaveJSONToFile(const nlohmann::json &json_string, const std::string &file_name) {
try { try {
std::optional<std::string> dir = ""; std::optional<std::string> dir = "";
std::optional<std::string> local_file_name = ""; std::optional<std::string> local_file_name = "";
@ -116,7 +123,7 @@ Status Serdes::ConstructPipeline(nlohmann::json json_obj, std::shared_ptr<Datase
RETURN_IF_NOT_OK(CreateNode(child_ds, json_obj, ds)); RETURN_IF_NOT_OK(CreateNode(child_ds, json_obj, ds));
} else { } else {
std::vector<std::shared_ptr<DatasetNode>> datasets; std::vector<std::shared_ptr<DatasetNode>> datasets;
for (auto child_json_obj : json_obj["children"]) { for (const auto &child_json_obj : json_obj["children"]) {
RETURN_IF_NOT_OK(ConstructPipeline(child_json_obj, &child_ds)); RETURN_IF_NOT_OK(ConstructPipeline(child_json_obj, &child_ds));
datasets.push_back(child_ds); datasets.push_back(child_ds);
} }
@ -380,6 +387,64 @@ Status Serdes::ParseMindIRPreprocess(const std::vector<std::string> &map_json_st
return Status::OK(); return Status::OK();
} }
Status Serdes::UpdateOptimizedIRTreeJSON(nlohmann::json *serialized_json,
const std::map<int32_t, std::shared_ptr<DatasetOp>> &op_map) {
RETURN_UNEXPECTED_IF_NULL(serialized_json);
int32_t op_id = 0;
return RecurseUpdateOptimizedIRTreeJSON(serialized_json, &op_id, op_map);
}
bool IsDatasetOpMatchIRNode(std::string_view ir_node_name, std::string_view dataset_op_name) {
// Helper function to match IR Node name to its dataset op name
if (ir_node_name == kSyncWaitNode) {
return dataset_op_name == kBarrierOp;
} else if (ir_node_name == kCifar10Node || ir_node_name == kCifar100Node) {
return dataset_op_name == "CifarOp";
} else if (ir_node_name == kMindDataNode) {
return dataset_op_name == "MindRecordOp";
} else if (ir_node_name == kRandomNode) {
return dataset_op_name == "RandomDataOp";
} else if (ir_node_name == kTFRecordNode) {
return dataset_op_name == "TFReaderOp";
} else if (ir_node_name == kIWSLT2016Node || ir_node_name == kIWSLT2017Node) {
return dataset_op_name == "IWSLTOp";
} else {
// Generic way of matching, special cases handled above. Special cases will evolve over time.
return ir_node_name.substr(0, ir_node_name.find("Dataset")) ==
dataset_op_name.substr(0, dataset_op_name.find("Op"));
}
}
Status Serdes::RecurseUpdateOptimizedIRTreeJSON(nlohmann::json *serialized_json, int32_t *op_id,
const std::map<int32_t, std::shared_ptr<DatasetOp>> &op_map) {
RETURN_UNEXPECTED_IF_NULL(serialized_json);
RETURN_UNEXPECTED_IF_NULL(op_id);
std::string ir_node_name = (*serialized_json)["op_type"];
MS_LOG(INFO) << "Visiting IR Node: " << ir_node_name;
// Each IR Node should have a corresponding dataset node in the execution tree but the reverse is not necessarily true
while (!IsDatasetOpMatchIRNode(ir_node_name, op_map.find(*op_id)->second->Name())) {
// During the construction of execution tree, extra dataset nodes may have been inserted
// Skip dataset ops unless we get to the expected node
MS_LOG(INFO) << "\tSkipping dataset op: " << op_map.find(*op_id)->second->NameWithID();
++(*op_id);
CHECK_FAIL_RETURN_UNEXPECTED(*op_id < op_map.size(), "op_id is out of bounds");
}
MS_LOG(INFO) << "\tMatch found for IR Node: " << ir_node_name
<< " with dataset op: " << op_map.find(*op_id)->second->NameWithID();
if (!op_map.find(*op_id)->second->inlined() && serialized_json->contains("num_parallel_workers") &&
serialized_json->contains("connector_queue_size")) {
(*serialized_json)["num_parallel_workers"] = op_map.find(*op_id)->second->NumWorkers();
(*serialized_json)["connector_queue_size"] = op_map.find(*op_id)->second->ConnectorCapacity();
}
++(*op_id);
auto num_children = (*serialized_json)["children"].size();
for (int i = 0; i < num_children; ++i) {
RETURN_IF_NOT_OK(RecurseUpdateOptimizedIRTreeJSON(&(*serialized_json)["children"][i], op_id, op_map));
}
return Status::OK();
}
// In the current stage, there is a cyclic dependency between libmindspore.so and c_dataengine.so, // In the current stage, there is a cyclic dependency between libmindspore.so and c_dataengine.so,
// we make a C function here and dlopen by libminspore.so to avoid linking explicitly, // we make a C function here and dlopen by libminspore.so to avoid linking explicitly,
// will be fix after decouling libminspore.so into multi submodules // will be fix after decouling libminspore.so into multi submodules

View File

@ -154,6 +154,13 @@ class Serdes {
/// \return Status The status code returned /// \return Status The status code returned
static Status SaveToJSON(std::shared_ptr<DatasetNode> node, const std::string &filename, nlohmann::json *out_json); static Status SaveToJSON(std::shared_ptr<DatasetNode> node, const std::string &filename, nlohmann::json *out_json);
/// \brief Function to update the parameters [num_parallel_workers, connector_queue_size] in the serialized JSON
/// object of the optimized IR tree
/// \param[in, out] serialized_json The optimized ir tree json node
/// \param[in] op_map An ID to DatasetOp mapping
static Status UpdateOptimizedIRTreeJSON(nlohmann::json *serialized_json,
const std::map<int32_t, std::shared_ptr<DatasetOp>> &op_map);
/// \brief function to de-serialize JSON file to IR tree /// \brief function to de-serialize JSON file to IR tree
/// \param[in] json_filepath input path of json file /// \param[in] json_filepath input path of json file
/// \param[out] ds The deserialized dataset /// \param[out] ds The deserialized dataset
@ -185,13 +192,13 @@ class Serdes {
static Status ParseMindIRPreprocess(const std::vector<std::string> &map_json_string, static Status ParseMindIRPreprocess(const std::vector<std::string> &map_json_string,
std::vector<std::shared_ptr<mindspore::dataset::Execute>> *data_graph); std::vector<std::shared_ptr<mindspore::dataset::Execute>> *data_graph);
protected:
/// \brief Helper function to save JSON to a file /// \brief Helper function to save JSON to a file
/// \param[in] json_string The JSON string to be saved to the file /// \param[in] json_string The JSON string to be saved to the file
/// \param[in] file_name The file name /// \param[in] file_name The file name
/// \return Status The status code returned /// \return Status The status code returned
static Status SaveJSONToFile(nlohmann::json json_string, const std::string &file_name); static Status SaveJSONToFile(const nlohmann::json &json_string, const std::string &file_name);
protected:
/// \brief Function to determine type of the node - dataset node if no dataset exists or operation node /// \brief Function to determine type of the node - dataset node if no dataset exists or operation node
/// \param[in] child_ds children datasets that is already created /// \param[in] child_ds children datasets that is already created
/// \param[in] json_obj json object to read out type of the node /// \param[in] json_obj json object to read out type of the node
@ -221,6 +228,14 @@ class Serdes {
static std::map<std::string, Status (*)(nlohmann::json json_obj, std::shared_ptr<TensorOperation> *operation)> static std::map<std::string, Status (*)(nlohmann::json json_obj, std::shared_ptr<TensorOperation> *operation)>
InitializeFuncPtr(); InitializeFuncPtr();
/// \brief Helper function to perform recursive DFS on the optimized IR tree and to match each IR node with its
/// corresponding dataset op
/// \param [in, out] serialized_json The optimized ir tree json node
/// \param [in, out] op_id The id in execution tree from where to continue the IR Node - DatasetOp matching search
/// \param [in] op_map An ID to DatasetOp mapping
static Status RecurseUpdateOptimizedIRTreeJSON(nlohmann::json *serialized_json, int32_t *op_id,
const std::map<int32_t, std::shared_ptr<DatasetOp>> &op_map);
private: private:
static std::map<std::string, Status (*)(nlohmann::json json_obj, std::shared_ptr<TensorOperation> *operation)> static std::map<std::string, Status (*)(nlohmann::json json_obj, std::shared_ptr<TensorOperation> *operation)>
func_ptr_; func_ptr_;

View File

@ -168,14 +168,11 @@ Status TreeAdapter::Build(std::shared_ptr<DatasetNode> root_ir) {
RETURN_IF_NOT_OK(BuildExecutionTreeRecur(root_ir->Children()[0], &root_op)); RETURN_IF_NOT_OK(BuildExecutionTreeRecur(root_ir->Children()[0], &root_op));
RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); RETURN_IF_NOT_OK(tree_->AssignRoot(root_op));
// Note: We will gradually move the pre pass, optimizer pass, and post pass
// on ExecutionTree to perform on IR tree.
// Prepare the tree // Prepare the tree
RETURN_IF_NOT_OK(tree_->Prepare()); RETURN_IF_NOT_OK(tree_->Prepare());
// After the tree is prepared, the col_name_id_map can safely be obtained // After the tree is prepared, the col_name_id_map can safely be obtained
column_name_map_ = tree_->root()->column_name_id_map(); column_name_map_ = tree_->root()->column_name_id_map();
return Status::OK(); return Status::OK();
} }
@ -219,7 +216,6 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_e
RETURN_IF_NOT_OK(Build(root_ir_)); RETURN_IF_NOT_OK(Build(root_ir_));
tree_state_ = kCompileStateReady; tree_state_ = kCompileStateReady;
return Status::OK(); return Status::OK();
} }

View File

@ -62,6 +62,8 @@ class TreeAdapter {
// Return the root node of the IR after cloned from the parsed IR tree // Return the root node of the IR after cloned from the parsed IR tree
std::shared_ptr<DatasetNode> RootIRNode() const { return root_ir_; } std::shared_ptr<DatasetNode> RootIRNode() const { return root_ir_; }
const ExecutionTree *GetExecutionTree() const { return tree_.get(); }
// This is the main method TreeConsumer uses to interact with TreeAdapter // This is the main method TreeConsumer uses to interact with TreeAdapter
// 1. GetNext will Launch() the ExeTree on its first call by iterator (tree is already prepared) // 1. GetNext will Launch() the ExeTree on its first call by iterator (tree is already prepared)
// 2. GetNext will return empty row when eoe/eof is obtained // 2. GetNext will return empty row when eoe/eof is obtained
@ -87,7 +89,6 @@ class TreeAdapter {
// Return Offload Json // Return Offload Json
nlohmann::json GetOffloadJson(); nlohmann::json GetOffloadJson();
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
/// \brief Setter for Profiling Manager /// \brief Setter for Profiling Manager
Status SetProfilingManagerPtr(std::shared_ptr<ProfilingManager> profiling_manager, Status SetProfilingManagerPtr(std::shared_ptr<ProfilingManager> profiling_manager,
@ -119,8 +120,8 @@ class TreeAdapter {
std::unordered_map<std::string, int32_t> column_name_map_; std::unordered_map<std::string, int32_t> column_name_map_;
std::shared_ptr<DatasetNode> root_ir_; std::shared_ptr<DatasetNode> root_ir_;
std::unique_ptr<ExecutionTree> tree_; // current connector capacity of root op, used for profiling std::unique_ptr<ExecutionTree> tree_;
bool optimize_; // Flag to enable optional optimization pass bool optimize_; // Flag to enable optional optimization pass
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
std::shared_ptr<ProfilingManager> profiling_manager_; // Profiling manager std::shared_ptr<ProfilingManager> profiling_manager_; // Profiling manager
std::shared_ptr<DatasetIteratorTracing> tracing_; // trace profiling data std::shared_ptr<DatasetIteratorTracing> tracing_; // trace profiling data

View File

@ -45,7 +45,7 @@ def create_model():
return model_simple return model_simple
def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1): def create_dataset(data_path, batch_size=32, num_parallel_workers=1):
""" """
Create dataset for train or test Create dataset for train or test
""" """
@ -71,12 +71,7 @@ def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
# Apply DatasetOps
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds return mnist_ds
@ -85,29 +80,44 @@ def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.forked @pytest.mark.forked
def test_autotune_train_simple_model(): def test_autotune_train_simple_model(tmp_path):
""" """
Feature: Dataset AutoTune Feature: Dataset AutoTune
Description: Test Dataset AutoTune for Training of a Simple Model Description: Test Dataset AutoTune for training of a simple model and deserialize the written at config file
Expectation: Training completes successfully Expectation: Training and data deserialization completes successfully
""" """
original_seed = ds.config.get_seed() original_seed = ds.config.get_seed()
set_seed(1) set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
context.set_context(enable_graph_kernel=True) context.set_context(enable_graph_kernel=True)
at_config_filename = "test_autotune_train_simple_model_at_config.json"
# Enable Dataset AutoTune # Enable Dataset AutoTune
original_autotune = ds.config.get_enable_autotune() original_autotune = ds.config.get_enable_autotune()
ds.config.set_enable_autotune(True) ds.config.set_enable_autotune(True, str(tmp_path) + at_config_filename)
ds_train = create_dataset(os.path.join("/home/workspace/mindspore_dataset/mnist", "train"), 32, 1) ds_train = create_dataset(os.path.join("/home/workspace/mindspore_dataset/mnist", "train"), 32)
model = create_model() model = create_model()
print("Start Training.") print("Start training.")
epoch_size = 10 epoch_size = 10
start_time = time.time()
model.train(epoch_size, ds_train) model.train(epoch_size, ds_train)
print("Training is finished.") print("Training finished. Took {}s".format(time.time() - start_time))
ds.config.set_enable_autotune(False)
ds_train_deserialized = ds.deserialize(json_filepath=str(tmp_path) + at_config_filename)
num = 0
for data1, data2 in zip(ds_train.create_dict_iterator(num_epochs=1, output_numpy=True),
ds_train_deserialized.create_dict_iterator(num_epochs=1, output_numpy=True)):
np.testing.assert_array_equal(data1['image'], data2['image'])
np.testing.assert_array_equal(data1['label'], data2['label'])
num += 1
assert num == 1875
# Restore settings # Restore settings
ds.config.set_enable_autotune(original_autotune) ds.config.set_enable_autotune(original_autotune)
@ -188,5 +198,5 @@ def test_autotune_pymultiproc_train_simple_model():
if __name__ == "__main__": if __name__ == "__main__":
test_autotune_train_simple_model() test_autotune_train_simple_model("")
test_autotune_pymultiproc_train_simple_model() test_autotune_pymultiproc_train_simple_model()

View File

@ -21,24 +21,15 @@
#include "minddata/dataset/include/dataset/transforms.h" #include "minddata/dataset/include/dataset/transforms.h"
// IR non-leaf nodes // IR non-leaf nodes
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
#include "minddata/dataset/engine/tree_modifier.h" #include "minddata/dataset/engine/tree_modifier.h"
#include "minddata/dataset/engine/serdes.h"
using namespace mindspore::dataset; using namespace mindspore::dataset;
using mindspore::dataset::Tensor; using mindspore::dataset::Tensor;
class MindDataTestTreeAdapter : public UT::DatasetOpTesting { class MindDataTestTreeAdapter : public UT::DatasetOpTesting {};
protected:
};
TEST_F(MindDataTestTreeAdapter, TestSimpleTreeAdapter) { TEST_F(MindDataTestTreeAdapter, TestSimpleTreeAdapter) {
MS_LOG(INFO) << "Doing MindDataTestTreeAdapter-TestSimpleTreeAdapter."; MS_LOG(INFO) << "Doing MindDataTestTreeAdapter-TestSimpleTreeAdapter.";
@ -148,6 +139,121 @@ TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) {
EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos); EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos);
} }
// Feature: Test for Serializing and Deserializing an optimized IR Tree after the tree has been modified with
// TreeModifier or in other words through Autotune indirectly.
// Description: Create a simple tree, modify the workers and queue size, serialize the optimized IR Tree, obtain a new
// tree with deserialize and then compare the output of serializing the new optimized IR tree with the first tree.
// Expectation: No failures.
TEST_F(MindDataTestTreeAdapter, TestOptimizedTreeSerializeDeserializeForAutoTune) {
MS_LOG(INFO) << "Doing MindDataTestTreeAdapter-TestOptimizedTreeSerializeDeserializeForAutoTune.";
// Create a CSVDataset, with single CSV file
std::string train_file = datasets_root_path_ + "/testCSV/1.csv";
std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"};
std::shared_ptr<Dataset> ds = CSV({train_file}, ',', {}, column_names, 0, ShuffleMode::kFalse);
ASSERT_NE(ds, nullptr);
ds = ds->Project({"col1"});
ASSERT_NE(ds, nullptr);
ds = ds->Repeat(2);
ASSERT_NE(ds, nullptr);
auto to_number = std::make_shared<text::ToNumber>(mindspore::DataType::kNumberTypeInt32);
ASSERT_NE(to_number, nullptr);
ds = ds->Map({to_number}, {"col1"}, {"col1"});
ds->SetNumWorkers(1);
ds = ds->Batch(1);
ds->SetNumWorkers(1);
// Create a tree adapter and compile the IR Tree
auto tree_adapter1 = std::make_shared<TreeAdapter>();
ASSERT_OK(tree_adapter1->Compile(ds->IRNode(), 1));
// Change num_parallel_workers and connector_queue_size for some ops
auto tree_modifier = std::make_unique<TreeModifier>(tree_adapter1.get());
tree_modifier->AddChangeRequest(1, std::make_shared<ChangeNumWorkersRequest>(10));
tree_modifier->AddChangeRequest(1, std::make_shared<ResizeConnectorRequest>(20));
tree_modifier->AddChangeRequest(0, std::make_shared<ResizeConnectorRequest>(100));
tree_modifier->AddChangeRequest(0, std::make_shared<ChangeNumWorkersRequest>(10));
std::vector<int32_t> expected_result = {1, 5, 9, 1, 5, 9};
TensorRow row;
uint64_t i = 0;
ASSERT_OK(tree_adapter1->GetNext(&row));
while (!row.empty()) {
auto tensor = row[0];
int32_t num;
ASSERT_OK(tensor->GetItemAt(&num, {0}));
EXPECT_EQ(num, expected_result[i]);
ASSERT_OK(tree_adapter1->GetNext(&row));
i++;
}
// Expect 6 samples
EXPECT_EQ(i, 6);
// Serialize the optimized IR Tree
nlohmann::json out_json;
ASSERT_OK(Serdes::SaveToJSON(tree_adapter1->RootIRNode(), "", &out_json));
// Check that updated values of num_parallel_workers and connector_queue_size are not reflected in the json
EXPECT_EQ(out_json["op_type"], "Batch");
EXPECT_NE(out_json["num_parallel_workers"], 10);
EXPECT_NE(out_json["connector_queue_size"], 100);
EXPECT_EQ(out_json["children"][0]["op_type"], "Map");
EXPECT_NE(out_json["children"][0]["num_parallel_workers"], 10);
EXPECT_NE(out_json["children"][0]["connector_queue_size"], 20);
// Create an op_id to dataset op mapping
std::map<int32_t, std::shared_ptr<DatasetOp>> op_mapping;
auto tree = tree_adapter1->GetExecutionTree();
ASSERT_NE(tree, nullptr);
for (auto itr = tree->begin(); itr != tree->end(); ++itr) {
op_mapping[itr->id()] = itr.get();
}
// Update the serialized JSON object of the optimized IR tree
ASSERT_OK(Serdes::UpdateOptimizedIRTreeJSON(&out_json, op_mapping));
// Check that updated values of num_parallel_workers and connector_queue_size are reflected in the json now
EXPECT_EQ(out_json["op_type"], "Batch");
EXPECT_EQ(out_json["num_parallel_workers"], 10);
EXPECT_EQ(out_json["connector_queue_size"], 100);
EXPECT_EQ(out_json["children"][0]["op_type"], "Map");
EXPECT_EQ(out_json["children"][0]["num_parallel_workers"], 10);
EXPECT_EQ(out_json["children"][0]["connector_queue_size"], 20);
// Deserialize the above updated serialized optimized IR Tree
std::shared_ptr<DatasetNode> deserialized_node;
ASSERT_OK(Serdes::ConstructPipeline(out_json, &deserialized_node));
// Create a new tree adapter and compile the IR Tree obtained from deserialization above
auto tree_adapter2 = std::make_shared<TreeAdapter>();
ASSERT_OK(tree_adapter2->Compile(deserialized_node, 1));
// Serialize the new optimized IR Tree
nlohmann::json out_json1;
ASSERT_OK(Serdes::SaveToJSON(tree_adapter2->RootIRNode(), "", &out_json1));
// Ensure that both the serialized outputs are equal
EXPECT_TRUE(out_json == out_json1);
i = 0;
ASSERT_OK(tree_adapter2->GetNext(&row));
while (!row.empty()) {
auto tensor = row[0];
int32_t num;
ASSERT_OK(tensor->GetItemAt(&num, {0}));
EXPECT_EQ(num, expected_result[i]);
ASSERT_OK(tree_adapter2->GetNext(&row));
i++;
}
// Expect 6 samples
EXPECT_EQ(i, 6);
}
// Feature: Basic test for TreeModifier // Feature: Basic test for TreeModifier
// Description: Create simple tree and modify the tree by adding workers, change queue size and then removing workers // Description: Create simple tree and modify the tree by adding workers, change queue size and then removing workers
// Expectation: No failures. // Expectation: No failures.
@ -193,7 +299,7 @@ TEST_F(MindDataTestTreeAdapter, TestSimpleTreeModifier) {
uint64_t i = 0; uint64_t i = 0;
ASSERT_OK(tree_adapter->GetNext(&row)); ASSERT_OK(tree_adapter->GetNext(&row));
while (row.size() != 0) { while (!row.empty()) {
auto tensor = row[0]; auto tensor = row[0];
int32_t num; int32_t num;
ASSERT_OK(tensor->GetItemAt(&num, {0})); ASSERT_OK(tensor->GetItemAt(&num, {0}));
@ -232,7 +338,7 @@ TEST_F(MindDataTestTreeAdapter, TestTreeModifierMindRecord) {
// Iterate the dataset and collect the file_names in the dataset // Iterate the dataset and collect the file_names in the dataset
ASSERT_OK(tree_adapter->GetNext(&row)); ASSERT_OK(tree_adapter->GetNext(&row));
uint64_t i = 0; uint64_t i = 0;
while (row.size() != 0) { while (!row.empty()) {
auto tensor = row[0]; auto tensor = row[0];
std::string_view sv; std::string_view sv;
ASSERT_OK(tensor->GetItemAt(&sv, {})); ASSERT_OK(tensor->GetItemAt(&sv, {}));
@ -255,7 +361,7 @@ TEST_F(MindDataTestTreeAdapter, TestTreeModifierMindRecord) {
i = 0; i = 0;
ASSERT_OK(tree_adapter2->GetNext(&row)); ASSERT_OK(tree_adapter2->GetNext(&row));
while (row.size() != 0) { while (!row.empty()) {
auto tensor = row[0]; auto tensor = row[0];
std::string_view sv; std::string_view sv;
ASSERT_OK(tensor->GetItemAt(&sv, {})); ASSERT_OK(tensor->GetItemAt(&sv, {}));
@ -278,7 +384,7 @@ TEST_F(MindDataTestTreeAdapter, TestTreeModifierMindRecord) {
i = 0; i = 0;
ASSERT_OK(tree_adapter3->GetNext(&row)); ASSERT_OK(tree_adapter3->GetNext(&row));
while (row.size() != 0) { while (!row.empty()) {
auto tensor = row[0]; auto tensor = row[0];
std::string_view sv; std::string_view sv;
ASSERT_OK(tensor->GetItemAt(&sv, {})); ASSERT_OK(tensor->GetItemAt(&sv, {}));
@ -290,4 +396,4 @@ TEST_F(MindDataTestTreeAdapter, TestTreeModifierMindRecord) {
} }
// Expect 20 samples // Expect 20 samples
EXPECT_EQ(i, 20); EXPECT_EQ(i, 20);
} }

View File

@ -1 +1 @@
{"callback":[],"children":[{"callback":[],"children":[{"children":[],"columns_list":["image","label"],"dataset_files":["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"],"num_parallel_workers":8,"num_samples":0,"num_shards":1,"op_type":"TFRecordDataset","schema_file_path":"../data/dataset/test_tf_file_3_images/datasetSchema.json","shard_equal_rows":false,"shard_id":0,"shuffle":0,"shuffle_files":false,"shuffle_global":false}],"input_columns":["image"],"num_parallel_workers":8,"op_type":"Map","operations":[{"python_module":"mindspore.dataset.transforms.py_transforms","tensor_op_name":"Compose","tensor_op_params":{"random":false,"transforms":[{"python_module":"mindspore.dataset.vision.py_transforms","tensor_op_name":"Decode","tensor_op_params":{"random":false}},{"python_module":"mindspore.dataset.vision.py_transforms","tensor_op_name":"CenterCrop","tensor_op_params":{"random":false,"size":[32,32]}},{"python_module":"mindspore.dataset.vision.py_transforms","tensor_op_name":"ToTensor","tensor_op_params":{"output_type":"float32","random":false}}]}}],"output_columns":["image"],"project_columns":[]}],"input_columns":["image"],"num_parallel_workers":8,"op_type":"Map","operations":[{"python_module":"mindspore.dataset.transforms.py_transforms","tensor_op_name":"RandomApply","tensor_op_params":{"prob":0.5,"transforms":[{"python_module":"mindspore.dataset.vision.py_transforms","tensor_op_name":"RandomColorAdjust","tensor_op_params":{"brightness":[1,1],"contrast":[1,1],"hue":[0,0],"saturation":[1,1]}},{"python_module":"mindspore.dataset.vision.py_transforms","tensor_op_name":"FiveCrop","tensor_op_params":{"random":false,"size":1}},{"python_module":"mindspore.dataset.vision.py_transforms","tensor_op_name":"Grayscale","tensor_op_params":{"num_output_channels":1,"random":false}},{"python_module":"mindspore.dataset.transforms.py_transforms","tensor_op_name":"OneHotOp","tensor_op_params":{"num_classes":1,"random":false,"smoothing_rate":0.0}}]}}],"output_columns":["image"],"project_columns":[]} {"callback":[],"children":[{"callback":[],"children":[{"children":[],"columns_list":["image","label"],"dataset_files":["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"],"num_parallel_workers":8,"connector_queue_size":8,"num_samples":0,"num_shards":1,"op_type":"TFRecordDataset","schema_file_path":"../data/dataset/test_tf_file_3_images/datasetSchema.json","shard_equal_rows":false,"shard_id":0,"shuffle":0,"shuffle_files":false,"shuffle_global":false}],"input_columns":["image"],"num_parallel_workers":8,"connector_queue_size":8,"op_type":"Map","operations":[{"python_module":"mindspore.dataset.transforms.py_transforms","tensor_op_name":"Compose","tensor_op_params":{"random":false,"transforms":[{"python_module":"mindspore.dataset.vision.py_transforms","tensor_op_name":"Decode","tensor_op_params":{"random":false}},{"python_module":"mindspore.dataset.vision.py_transforms","tensor_op_name":"CenterCrop","tensor_op_params":{"random":false,"size":[32,32]}},{"python_module":"mindspore.dataset.vision.py_transforms","tensor_op_name":"ToTensor","tensor_op_params":{"output_type":"float32","random":false}}]}}],"output_columns":["image"],"project_columns":[]}],"input_columns":["image"],"num_parallel_workers":8,"connector_queue_size":8,"op_type":"Map","operations":[{"python_module":"mindspore.dataset.transforms.py_transforms","tensor_op_name":"RandomApply","tensor_op_params":{"prob":0.5,"transforms":[{"python_module":"mindspore.dataset.vision.py_transforms","tensor_op_name":"RandomColorAdjust","tensor_op_params":{"brightness":[1,1],"contrast":[1,1],"hue":[0,0],"saturation":[1,1]}},{"python_module":"mindspore.dataset.vision.py_transforms","tensor_op_name":"FiveCrop","tensor_op_params":{"random":false,"size":1}},{"python_module":"mindspore.dataset.vision.py_transforms","tensor_op_name":"Grayscale","tensor_op_params":{"num_output_channels":1,"random":false}},{"python_module":"mindspore.dataset.transforms.py_transforms","tensor_op_name":"OneHotOp","tensor_op_params":{"num_classes":1,"random":false,"smoothing_rate":0.0}}]}}],"output_columns":["image"],"project_columns":[]}

View File

@ -16,12 +16,15 @@
Test Dataset AutoTune's Save and Load Configuration support Test Dataset AutoTune's Save and Load Configuration support
""" """
import filecmp import filecmp
import numpy as np import numpy as np
import pytest import pytest
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as c_transforms import mindspore.dataset.transforms.c_transforms as c_transforms
import mindspore.dataset.vision.c_transforms as c_vision
MNIST_DATA_DIR = "../data/dataset/testMnistData" MNIST_DATA_DIR = "../data/dataset/testMnistData"
DATA_DIR = "../data/dataset/testPK/data"
@pytest.mark.forked @pytest.mark.forked
@ -36,6 +39,7 @@ class TestAutotuneSaveLoad:
Description: Test save final config with GeneratorDataset pipeline: Generator -> Shuffle -> Batch Description: Test save final config with GeneratorDataset pipeline: Generator -> Shuffle -> Batch
Expectation: pipeline runs successfully Expectation: pipeline runs successfully
""" """
original_autotune = ds.config.get_enable_autotune()
ds.config.set_enable_autotune(True, str(tmp_path) + "test_autotune_generator_atfinal.json") ds.config.set_enable_autotune(True, str(tmp_path) + "test_autotune_generator_atfinal.json")
source = [(np.array([x]),) for x in range(1024)] source = [(np.array([x]),) for x in range(1024)]
@ -50,17 +54,18 @@ class TestAutotuneSaveLoad:
for _ in itr: for _ in itr:
pass pass
ds.config.set_enable_autotune(False) ds.config.set_enable_autotune(original_autotune)
@staticmethod @staticmethod
def skip_test_autotune_mnist_pipeline(tmp_path): def test_autotune_mnist_pipeline(tmp_path):
""" """
Feature: Autotuning Feature: Autotuning
Description: Test save final config with Mnist pipeline: Mnist -> Batch -> Map Description: Test save final config with Mnist pipeline: Mnist -> Batch -> Map
Expectation: pipeline runs successfully Expectation: pipeline runs successfully
""" """
original_autotune = ds.config.get_enable_autotune()
ds.config.set_enable_autotune(True, str(tmp_path) + "test_autotune_mnist_pipeline_atfinal.json") ds.config.set_enable_autotune(True, str(tmp_path) + "test_autotune_mnist_pipeline_atfinal.json")
original_seed = ds.config.get_seed()
ds.config.set_seed(1) ds.config.set_seed(1)
data1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=100) data1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=100)
@ -74,7 +79,7 @@ class TestAutotuneSaveLoad:
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True): for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
pass pass
ds.config.set_enable_autotune(False) ds.config.set_enable_autotune(original_autotune)
# Confirm final AutoTune config file is identical to the serialized file. # Confirm final AutoTune config file is identical to the serialized file.
assert filecmp.cmp(str(tmp_path) + "test_autotune_mnist_pipeline_atfinal.json", assert filecmp.cmp(str(tmp_path) + "test_autotune_mnist_pipeline_atfinal.json",
@ -91,6 +96,8 @@ class TestAutotuneSaveLoad:
num += 1 num += 1
assert num == 10 assert num == 10
ds.config.set_seed(original_seed)
@staticmethod @staticmethod
def test_autotune_save_overwrite_generator(tmp_path): def test_autotune_save_overwrite_generator(tmp_path):
""" """
@ -102,7 +109,7 @@ class TestAutotuneSaveLoad:
source = [(np.array([x]),) for x in range(1024)] source = [(np.array([x]),) for x in range(1024)]
at_final_json_filename = "test_autotune_save_overwrite_generator_atfinal.json" at_final_json_filename = "test_autotune_save_overwrite_generator_atfinal.json"
original_autotune = ds.config.get_enable_autotune()
ds.config.set_enable_autotune(True, str(tmp_path) + at_final_json_filename) ds.config.set_enable_autotune(True, str(tmp_path) + at_final_json_filename)
data1 = ds.GeneratorDataset(source, ["data"]) data1 = ds.GeneratorDataset(source, ["data"])
@ -120,20 +127,22 @@ class TestAutotuneSaveLoad:
for _ in data2.create_dict_iterator(num_epochs=1, output_numpy=True): for _ in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
pass pass
ds.config.set_enable_autotune(False) ds.config.set_enable_autotune(original_autotune)
@staticmethod @staticmethod
def skip_test_autotune_save_overwrite_mnist(tmp_path): def test_autotune_save_overwrite_mnist(tmp_path):
""" """
Feature: Autotuning Feature: Autotuning
Description: Test set_enable_autotune and existing json_filepath is overwritten Description: Test set_enable_autotune and existing json_filepath is overwritten
Expectation: set_enable_autotune() executes successfully with file-exist warning produced. Expectation: set_enable_autotune() executes successfully with file-exist warning produced.
Execution of 2nd pipeline overwrites AutoTune configuration file of 1st pipeline. Execution of 2nd pipeline overwrites AutoTune configuration file of 1st pipeline.
""" """
original_seed = ds.config.get_seed()
ds.config.set_seed(1) ds.config.set_seed(1)
at_final_json_filename = "test_autotune_save_overwrite_mnist_atfinal.json" at_final_json_filename = "test_autotune_save_overwrite_mnist_atfinal.json"
# Pipeline#1 # Pipeline#1
original_autotune = ds.config.get_enable_autotune()
ds.config.set_enable_autotune(True, str(tmp_path) + at_final_json_filename) ds.config.set_enable_autotune(True, str(tmp_path) + at_final_json_filename)
data1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=100) data1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=100)
@ -170,3 +179,41 @@ class TestAutotuneSaveLoad:
# Confirm the serialized files for the 2 different pipelines are different # Confirm the serialized files for the 2 different pipelines are different
assert not filecmp.cmp(str(tmp_path) + "test_autotune_save_overwrite_mnist_serialized1.json", assert not filecmp.cmp(str(tmp_path) + "test_autotune_save_overwrite_mnist_serialized1.json",
str(tmp_path) + "test_autotune_save_overwrite_mnist_serialized2.json") str(tmp_path) + "test_autotune_save_overwrite_mnist_serialized2.json")
ds.config.set_seed(original_seed)
ds.config.set_enable_autotune(original_autotune)
@staticmethod
def test_autotune_warning_with_offload(tmp_path, capfd):
"""
Feature: Autotuning
Description: Test autotune config saving with offload=True
Expectation: Autotune should not write the config file and print a log message
"""
original_seed = ds.config.get_seed()
ds.config.set_seed(1)
at_final_json_filename = "test_autotune_warning_with_offload_config.json"
config_path = tmp_path / at_final_json_filename
original_autotune = ds.config.get_enable_autotune()
ds.config.set_enable_autotune(True, str(config_path))
# Dataset with offload activated.
dataset = ds.ImageFolderDataset(DATA_DIR)
dataset = dataset.map(operations=[c_vision.Decode()], input_columns="image")
dataset = dataset.map(operations=[c_vision.HWC2CHW()], input_columns="image", offload=True)
dataset = dataset.batch(8, drop_remainder=True)
for _ in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True):
pass
_, err = capfd.readouterr()
assert "Some nodes have been offloaded. AutoTune is unable to write the autotune configuration to disk. " \
"Disable offload to prevent this from happening." in err
with pytest.raises(FileNotFoundError):
with open(config_path) as _:
pass
ds.config.set_enable_autotune(original_autotune)
ds.config.set_seed(original_seed)