diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index ee27a7f2f14..c70b83c3f2f 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -1075,7 +1075,7 @@ std::vector> CLUEDataset::Build() { std::shared_ptr clue_op = std::make_shared(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, - sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_); + sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr); RETURN_EMPTY_IF_ERROR(clue_op->Init()); if (shuffle_ == ShuffleMode::kGlobal) { // Inject ShuffleOp @@ -1256,7 +1256,7 @@ std::vector> CSVDataset::Build() { std::shared_ptr csv_op = std::make_shared( sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, rows_per_buffer_, - num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_); + num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr); RETURN_EMPTY_IF_ERROR(csv_op->Init()); if (shuffle_ == ShuffleMode::kGlobal) { // Inject ShuffleOp @@ -1502,7 +1502,7 @@ std::vector> TextFileDataset::Build() { // Create and initalize TextFileOp std::shared_ptr text_file_op = std::make_shared( num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files, - connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(nullptr)); + connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr); RETURN_EMPTY_IF_ERROR(text_file_op->Init()); if (shuffle_ == ShuffleMode::kGlobal) { diff --git a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc index 5ae21482637..68cba69db3b 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc @@ -1345,6 +1345,9 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr cache_client = nullptr; std::shared_ptr builder = std::make_shared(); (void)builder->SetManifestFile(ToString(args["dataset_file"])); @@ -1354,7 +1357,8 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptrSetNumWorkers(ToInt(value)); + num_workers = ToInt(value); + (void)builder->SetNumWorkers(num_workers); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); std::shared_ptr sampler = create().cast>(); @@ -1365,12 +1369,27 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptrSetDecode(ToBool(value)); } else if (key == "usage") { (void)builder->SetUsage(ToString(value)); + } else if (key == "cache") { + cache_client = value.cast>(); } } } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; + std::shared_ptr manifest_op; + RETURN_IF_NOT_OK(builder->Build(&manifest_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(manifest_op)); + *top = manifest_op; + + // Additionally, add a cache if required. + // Note that this cache op is only acting as a place holder for the caching position + // within the tree. Later, a pre-pass will execute a tree transform to set up the actual + // caching logic in the tree. + if (cache_client) { + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, manifest_op, &cache_op)); + *top = cache_op; + *bottom = manifest_op; + } + return Status::OK(); } @@ -1380,6 +1399,8 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr * CHECK_FAIL_RETURN_UNEXPECTED(!args["task"].is_none(), "Error: No task specified."); CHECK_FAIL_RETURN_UNEXPECTED(!args["usage"].is_none(), "Error: No usage specified."); + int num_workers = 0; + std::shared_ptr cache_client = nullptr; std::shared_ptr builder = std::make_shared(); (void)builder->SetDir(ToString(args["dataset_dir"])); (void)builder->SetTask(ToString(args["task"])); @@ -1389,7 +1410,8 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr * py::handle value = arg.second; if (!value.is_none()) { if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); + num_workers = ToInt(value); + (void)builder->SetNumWorkers(num_workers); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); std::shared_ptr sampler = create().cast>(); @@ -1398,12 +1420,26 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr * (void)builder->SetDecode(ToBool(value)); } else if (key == "class_indexing") { (void)builder->SetClassIndex(ToStringMap(value)); + } else if (key == "cache") { + cache_client = value.cast>(); } } } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; + std::shared_ptr voc_op; + RETURN_IF_NOT_OK(builder->Build(&voc_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(voc_op)); + *top = voc_op; + + // Additionally, add a cache if required. + // Note that this cache op is only acting as a place holder for the caching position + // within the tree. Later, a pre-pass will execute a tree transform to set up the actual + // caching logic in the tree. + if (cache_client) { + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, voc_op, &cache_op)); + *top = cache_op; + *bottom = voc_op; + } return Status::OK(); } @@ -1425,6 +1461,8 @@ Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr RETURN_STATUS_UNEXPECTED(err_msg); } + int num_workers = 0; + std::shared_ptr cache_client = nullptr; std::shared_ptr builder = std::make_shared(); (void)builder->SetDir(ToString(args["dataset_dir"])); (void)builder->SetFile(ToString(args["annotation_file"])); @@ -1434,19 +1472,35 @@ Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr py::handle value = arg.second; if (!value.is_none()) { if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); + num_workers = ToInt(value); + (void)builder->SetNumWorkers(num_workers); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); std::shared_ptr sampler = create().cast>(); (void)builder->SetSampler(std::move(sampler)); } else if (key == "decode") { (void)builder->SetDecode(ToBool(value)); + } else if (key == "cache") { + cache_client = value.cast>(); } } } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; + std::shared_ptr coco_op; + RETURN_IF_NOT_OK(builder->Build(&coco_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(coco_op)); + *top = coco_op; + + // Additionally, add a cache if required. + // Note that this cache op is only acting as a place holder for the caching position + // within the tree. Later, a pre-pass will execute a tree transform to set up the actual + // caching logic in the tree. + if (cache_client) { + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, coco_op, &cache_op)); + *top = cache_op; + *bottom = coco_op; + } + return Status::OK(); } @@ -1458,6 +1512,8 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr cache_client = nullptr; std::shared_ptr builder = std::make_shared(); (void)builder->SetCifarDir(ToString(args["dataset_dir"])); @@ -1467,22 +1523,38 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptrSetNumWorkers(ToInt(value)); + num_workers = ToInt(value); + (void)builder->SetNumWorkers(num_workers); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); std::shared_ptr sampler = create().cast>(); (void)builder->SetSampler(std::move(sampler)); } else if (key == "usage") { (void)builder->SetUsage(ToString(value)); + } else if (key == "cache") { + cache_client = value.cast>(); } } } (void)builder->SetCifarType(true); - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; + std::shared_ptr cifar_op; + RETURN_IF_NOT_OK(builder->Build(&cifar_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(cifar_op)); + *top = cifar_op; + + // Additionally, add a cache if required. + // Note that this cache op is only acting as a place holder for the caching position + // within the tree. Later, a pre-pass will execute a tree transform to set up the actual + // caching logic in the tree. + if (cache_client) { + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, cifar_op, &cache_op)); + *top = cache_op; + *bottom = cifar_op; + } + return Status::OK(); } @@ -1494,6 +1566,8 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr cache_client = nullptr; std::shared_ptr builder = std::make_shared(); (void)builder->SetCifarDir(ToString(args["dataset_dir"])); @@ -1503,22 +1577,37 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptrSetNumWorkers(ToInt(value)); + num_workers = ToInt(value); + (void)builder->SetNumWorkers(num_workers); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); std::shared_ptr sampler = create().cast>(); (void)builder->SetSampler(std::move(sampler)); } else if (key == "usage") { (void)builder->SetUsage(ToString(value)); + } else if (key == "cache") { + cache_client = value.cast>(); } } } (void)builder->SetCifarType(false); - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; + std::shared_ptr cifar_op; + RETURN_IF_NOT_OK(builder->Build(&cifar_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(cifar_op)); + *top = cifar_op; + + // Additionally, add a cache if required. + // Note that this cache op is only acting as a place holder for the caching position + // within the tree. Later, a pre-pass will execute a tree transform to set up the actual + // caching logic in the tree. + if (cache_client) { + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, cifar_op, &cache_op)); + *top = cache_op; + *bottom = cifar_op; + } return Status::OK(); } @@ -1609,6 +1698,8 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr RETURN_STATUS_UNEXPECTED(err_msg); } + int num_workers = 0; + std::shared_ptr cache_client = nullptr; std::shared_ptr builder = std::make_shared(); (void)builder->SetDir(ToString(args["dataset_dir"])); @@ -1618,19 +1709,35 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr py::handle value = arg.second; if (!value.is_none()) { if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); + num_workers = ToInt(value); + (void)builder->SetNumWorkers(num_workers); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); std::shared_ptr sampler = create().cast>(); (void)builder->SetSampler(std::move(sampler)); } else if (key == "usage") { (void)builder->SetUsage(ToString(value)); + } else if (key == "cache") { + cache_client = value.cast>(); } } } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; + std::shared_ptr mnist_op; + RETURN_IF_NOT_OK(builder->Build(&mnist_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(mnist_op)); + *top = mnist_op; + + // Additionally, add a cache if required. + // Note that this cache op is only acting as a place holder for the caching position + // within the tree. Later, a pre-pass will execute a tree transform to set up the actual + // caching logic in the tree. + if (cache_client) { + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, mnist_op, &cache_op)); + *top = cache_op; + *bottom = mnist_op; + } + return Status::OK(); } @@ -1642,6 +1749,8 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr cache_client = nullptr; std::shared_ptr builder = std::make_shared(); if (builder == nullptr) { std::string err_msg = "Create celebaop builder failed"; @@ -1653,7 +1762,8 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptrSetNumWorkers(ToInt(value)); + num_workers = ToInt(value); + (void)builder->SetNumWorkers(num_workers); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); std::shared_ptr sampler = create().cast>(); @@ -1664,13 +1774,28 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptrSetExtensions(ToStringSet(value)); } else if (key == "usage") { (void)builder->SetUsage(ToString(value)); + } else if (key == "cache") { + cache_client = value.cast>(); } } } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; + std::shared_ptr celeba_op; + RETURN_IF_NOT_OK(builder->Build(&celeba_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(celeba_op)); + *top = celeba_op; + + // Additionally, add a cache if required. + // Note that this cache op is only acting as a place holder for the caching position + // within the tree. Later, a pre-pass will execute a tree transform to set up the actual + // caching logic in the tree. + if (cache_client) { + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, celeba_op, &cache_op)); + *top = cache_op; + *bottom = celeba_op; + } + return Status::OK(); } @@ -1678,6 +1803,9 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr *bottom) { // Required arguments std::vector files_list; + std::shared_ptr cache_client = nullptr; + std::shared_ptr sampler = nullptr; + int num_workers = 0; std::shared_ptr builder = std::make_shared(); if (!args["dataset_files"].is_none()) { files_list = ToStringVector(args["dataset_files"]); @@ -1693,7 +1821,8 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptrSetNumWorkers(ToInt(value)); + num_workers = ToInt(value); + (void)builder->SetNumWorkers(num_workers); } else if (key == "shuffle_files") { (void)builder->SetShuffleFiles(ToBool(value)); } else if (key == "shuffle_global") { @@ -1705,16 +1834,35 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptrSetNumDevices(num_devices); } else if (key == "shard_id") { (void)builder->SetDeviceId(ToInt(value)); + } else if (key == "cache") { + cache_client = value.cast>(); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + sampler = create().cast>(); } } } + // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed + // because TextFileOp is a non-mappable dataset that does not support sampling. + // However, if a cache operator is injected at some other place higher in the tree, that cache can + // inherit this sampler from the leaf, providing sampling support from the caching layer. + // That is why we save the sampler here in a leaf node that does not use sampling. + if (sampler) { + (void)builder->SetSampler(std::move(sampler)); + } else if (cache_client) { + int64_t num_samples = 0; + int64_t start_index = 0; + sampler = std::make_shared(num_samples, start_index); + (void)builder->SetSampler(std::move(sampler)); + } + std::shared_ptr txt_op; RETURN_IF_NOT_OK(builder->Build(&txt_op)); RETURN_IF_NOT_OK(tree_->AssociateNode(txt_op)); *top = txt_op; - if (shuffle_required) { + if (!cache_client && shuffle_required) { std::shared_ptr shuffle_op = nullptr; int64_t shuffle_size = 0; int64_t num_rows = 0; @@ -1729,6 +1877,15 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, txt_op, &cache_op)); + *top = cache_op; + *bottom = txt_op; + } + return Status::OK(); } @@ -1829,6 +1986,10 @@ Status DEPipeline::ParseBuildSentencePieceVocabOp(const py::dict &args, std::sha Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom) { std::vector files_list; + std::shared_ptr cache_client = nullptr; + std::shared_ptr sampler = nullptr; + int num_workers = 0; + std::shared_ptr builder = std::make_shared(); if (!args["dataset_files"].is_none()) { files_list = ToStringVector(args["dataset_files"]); @@ -1844,7 +2005,8 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr py::handle value = arg.second; if (!value.is_none()) { if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); + num_workers = ToInt(value); + (void)builder->SetNumWorkers(num_workers); } else if (key == "shuffle_files") { (void)builder->SetShuffleFiles(ToBool(value)); } else if (key == "shuffle_global") { @@ -1866,16 +2028,35 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr } } (void)builder->SetColsKeyMap(map_dict); + } else if (key == "cache") { + cache_client = value.cast>(); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + sampler = create().cast>(); } } } + // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed + // because ClueOp is a non-mappable dataset that does not support sampling. + // However, if a cache operator is injected at some other place higher in the tree, that cache can + // inherit this sampler from the leaf, providing sampling support from the caching layer. + // That is why we save the sampler here in a leaf node that does not use sampling. + if (sampler) { + (void)builder->SetSampler(std::move(sampler)); + } else if (cache_client) { + int64_t num_samples = 0; + int64_t start_index = 0; + sampler = std::make_shared(num_samples, start_index); + (void)builder->SetSampler(std::move(sampler)); + } + std::shared_ptr clue_op; RETURN_IF_NOT_OK(builder->Build(&clue_op)); RETURN_IF_NOT_OK(tree_->AssociateNode(clue_op)); *top = clue_op; - if (shuffle_required) { + if (!cache_client && shuffle_required) { std::shared_ptr shuffle_op = nullptr; int64_t shuffle_size = 0; int64_t num_rows = 0; @@ -1890,6 +2071,15 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr *bottom = clue_op; } + // Add a cache op over this op if required and update the output subtree (top/bottom) + if (cache_client) { + // Note, it is not allowed to have both shuffle and cache + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, clue_op, &cache_op)); + *top = cache_op; + *bottom = clue_op; + } + return Status::OK(); } @@ -1921,6 +2111,9 @@ Status DEPipeline::AddCacheOp(std::shared_ptr cache_client, int num Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom) { std::vector files_list; + std::shared_ptr cache_client = nullptr; + std::shared_ptr sampler = nullptr; + int num_workers = 0; std::shared_ptr builder = std::make_shared(); if (!args["dataset_files"].is_none()) { files_list = ToStringVector(args["dataset_files"]); @@ -1938,7 +2131,8 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr * py::handle value = arg.second; if (!value.is_none()) { if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); + num_workers = ToInt(value); + (void)builder->SetNumWorkers(num_workers); } else if (key == "shuffle_files") { (void)builder->SetShuffleFiles(ToBool(value)); } else if (key == "shuffle_global") { @@ -1971,16 +2165,35 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr * } else if (key == "column_names") { col_names = ToStringVector(value); (void)builder->SetColumName(col_names); + } else if (key == "cache") { + cache_client = value.cast>(); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + sampler = create().cast>(); } } } + // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed + // because CsvOp is a non-mappable dataset that does not support sampling. + // However, if a cache operator is injected at some other place higher in the tree, that cache can + // inherit this sampler from the leaf, providing sampling support from the caching layer. + // That is why we save the sampler here in a leaf node that does not use sampling. + if (sampler) { + (void)builder->SetSampler(std::move(sampler)); + } else if (cache_client) { + int64_t num_samples = 0; + int64_t start_index = 0; + sampler = std::make_shared(num_samples, start_index); + (void)builder->SetSampler(std::move(sampler)); + } + std::shared_ptr csv_op; RETURN_IF_NOT_OK(builder->Build(&csv_op)); RETURN_IF_NOT_OK(tree_->AssociateNode(csv_op)); *top = csv_op; - if (shuffle_required) { + if (!cache_client && shuffle_required) { std::shared_ptr shuffle_op = nullptr; int64_t shuffle_size = 0; int64_t num_rows = 0; @@ -1995,6 +2208,15 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr * *bottom = csv_op; } + // Add a cache op over this op if required and update the output subtree (top/bottom) + if (cache_client) { + // Note, it is not allowed to have both shuffle and cache + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, csv_op, &cache_op)); + *top = cache_op; + *bottom = csv_op; + } + return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.cc index e673d20b21f..407b841557b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.cc @@ -70,13 +70,12 @@ Status AlbumOp::Builder::SanityCheck() { AlbumOp::AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode, const std::set &exts, std::unique_ptr data_schema, std::shared_ptr sampler) - : ParallelOp(num_wkrs, queue_size), + : ParallelOp(num_wkrs, queue_size, std::move(sampler)), rows_per_buffer_(rows_per_buffer), folder_path_(file_dir), decode_(do_decode), extensions_(exts), data_schema_(std::move(data_schema)), - sampler_(std::move(sampler)), row_cnt_(0), buf_cnt_(0), sampler_ind_(0), diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.h index 3ef4e7bf894..88af87062bf 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.h @@ -284,7 +284,6 @@ class AlbumOp : public ParallelOp, public RandomAccessOp { std::set extensions_; // extensions allowed std::unordered_map col_name_map_; std::unique_ptr data_schema_; - std::shared_ptr sampler_; int64_t row_cnt_; int64_t buf_cnt_; int64_t sampler_ind_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc index 8a327cbdcd6..bd110096764 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc @@ -25,13 +25,18 @@ #include "minddata/dataset/util/task_manager.h" #include "minddata/dataset/engine/jagged_connector.h" #include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/engine/datasetops/source/io_block.h" #include "minddata/dataset/util/random.h" namespace mindspore { namespace dataset { ClueOp::Builder::Builder() - : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { + : builder_device_id_(0), + builder_num_devices_(1), + builder_num_samples_(0), + builder_shuffle_files_(false), + builder_sampler_(nullptr) { std::shared_ptr config_manager = GlobalContext::config_manager(); builder_num_workers_ = config_manager->num_parallel_workers(); builder_op_connector_size_ = config_manager->op_connector_size(); @@ -68,7 +73,7 @@ Status ClueOp::Builder::Build(std::shared_ptr *op) { std::shared_ptr clue_op = std::make_shared( builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map, builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, - builder_device_id_); + builder_device_id_, std::move(builder_sampler_)); RETURN_IF_NOT_OK(clue_op->Init()); *op = std::move(clue_op); @@ -88,8 +93,8 @@ std::vector ClueOp::Builder::split(const std::string &s, char delim ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, ColKeyMap cols_to_keyword, std::vector clue_files_list, int32_t op_connector_size, - bool shuffle_files, int32_t num_device, int32_t device_id) - : ParallelOp(num_workers, op_connector_size), + bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr sampler) + : ParallelOp(num_workers, op_connector_size, std::move(sampler)), rows_per_buffer_(rows_per_buffer), num_rows_per_shard_(0), all_num_rows_(0), @@ -539,5 +544,21 @@ Status ClueOp::ComputeColMap() { } return Status::OK(); } + +// Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing +// a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so +// that this clue op will produce the full set of data into the cache. +void ClueOp::MakeSimpleProducer() { + device_id_ = 0; + num_devices_ = 1; + shuffle_files_ = false; + num_samples_ = 0; +} + +// Visitor accept method for NodePass +Status ClueOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h index eea6b5aa7ec..719bb045376 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -122,6 +123,14 @@ class ClueOp : public ParallelOp { // @return - the a string vector std::vector split(const std::string &s, char delim); + // Setter method + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + private: int32_t builder_device_id_; int32_t builder_num_devices_; @@ -133,12 +142,13 @@ class ClueOp : public ParallelOp { std::vector builder_clue_files_list_; bool builder_shuffle_files_; std::map builder_cols_to_keyword_; + std::shared_ptr builder_sampler_; }; // Constructor of ClueOp ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, ColKeyMap cols_to_keyword, std::vector clue_files_list, int32_t op_connector_size, - bool shuffle_files, int32_t num_devices, int32_t device_id); + bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr sampler); // Default destructor ~ClueOp() = default; @@ -173,6 +183,17 @@ class ClueOp : public ParallelOp { // @return Vector of the input file names std::vector FileNames() { return clue_files_list_; } + /// \Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing + /// a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so + /// that this clue op will produce the full set of data into the cache. + void MakeSimpleProducer(); + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + private: // The entry point for when workers are launched. // @param worker_id - the id of the worker that is executing this function. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc index f7819bd1646..16061eb79e2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc @@ -124,7 +124,7 @@ Status CocoOp::Builder::SanityCheck() { CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, std::unique_ptr data_schema, std::shared_ptr sampler) - : ParallelOp(num_workers, queue_size), + : ParallelOp(num_workers, queue_size, std::move(sampler)), decode_(decode), row_cnt_(0), buf_cnt_(0), @@ -132,7 +132,6 @@ CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, image_folder_path_(image_folder_path), annotation_path_(annotation_path), rows_per_buffer_(rows_per_buffer), - sampler_(std::move(sampler)), data_schema_(std::move(data_schema)) { io_block_queues_.Init(num_workers_, queue_size); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h index ff18c6b8c2c..7f5202429d0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h @@ -206,6 +206,10 @@ class CocoOp : public ParallelOp, public RandomAccessOp { /// \return Status of the node visit Status Accept(NodePass *p, bool *modified) override; + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "CocoOp"; } + private: // Initialize Sampler, calls sampler->Init() within // @return Status - The error code return @@ -324,7 +328,6 @@ class CocoOp : public ParallelOp, public RandomAccessOp { std::string annotation_path_; TaskType task_type_; int32_t rows_per_buffer_; - std::shared_ptr sampler_; std::unique_ptr data_schema_; WaitPost wp_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc index 65a7847c796..a6ccecc3d93 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc @@ -22,12 +22,17 @@ #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/engine/jagged_connector.h" #include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/util/random.h" namespace mindspore { namespace dataset { CsvOp::Builder::Builder() - : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { + : builder_device_id_(0), + builder_num_devices_(1), + builder_num_samples_(0), + builder_shuffle_files_(false), + builder_sampler_(nullptr) { std::shared_ptr config_manager = GlobalContext::config_manager(); builder_num_workers_ = config_manager->num_parallel_workers(); builder_op_connector_size_ = config_manager->op_connector_size(); @@ -59,7 +64,8 @@ Status CsvOp::Builder::Build(std::shared_ptr *op) { std::shared_ptr csv_op = std::make_shared( builder_csv_files_list_, builder_field_delim_, builder_column_default_list_, builder_column_name_list_, builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, - builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_); + builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_, + std::move(builder_sampler_)); RETURN_IF_NOT_OK(csv_op->Init()); *op = std::move(csv_op); @@ -70,8 +76,8 @@ CsvOp::CsvOp(const std::vector &csv_files_list, char field_delim, const std::vector> &column_default, const std::vector &column_name, int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, - int32_t num_device, int32_t device_id) - : ParallelOp(num_workers, op_connector_size), + int32_t num_device, int32_t device_id, std::shared_ptr sampler) + : ParallelOp(num_workers, op_connector_size, std::move(sampler)), csv_files_list_(std::move(csv_files_list)), field_delim_(field_delim), column_default_list_(column_default), @@ -889,5 +895,21 @@ Status CsvOp::ComputeColMap() { } return Status::OK(); } + +// Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing +// a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so +// that this csv op will produce the full set of data into the cache. +void CsvOp::MakeSimpleProducer() { + device_id_ = 0; + num_devices_ = 1; + shuffle_files_ = false; + num_samples_ = 0; +} + +// Visitor accept method for NodePass +Status CsvOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h index 1417d9c8b09..9092ede8995 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h @@ -240,6 +240,14 @@ class CsvOp : public ParallelOp { return *this; } + // Setter method + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + private: int32_t builder_device_id_; int32_t builder_num_devices_; @@ -253,6 +261,7 @@ class CsvOp : public ParallelOp { char builder_field_delim_; std::vector> builder_column_default_list_; std::vector builder_column_name_list_; + std::shared_ptr builder_sampler_; }; // Constructor of CsvOp @@ -261,7 +270,8 @@ class CsvOp : public ParallelOp { CsvOp(const std::vector &csv_files_list, char field_delim, const std::vector> &column_default, const std::vector &column_name, int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, - int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id); + int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id, + std::shared_ptr sampler); // Default destructor ~CsvOp() = default; @@ -297,6 +307,17 @@ class CsvOp : public ParallelOp { // @return Vector of the input file names std::vector FileNames() { return csv_files_list_; } + /// \Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing + /// a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so + /// that this csv op will produce the full set of data into the cache. + void MakeSimpleProducer(); + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + private: // The entry point for when workers are launched. // @param worker_id - the id of the worker that is executing this function. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc index 8b72a971508..a35103203db 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc @@ -29,6 +29,7 @@ #include "minddata/dataset/util/random.h" #include "minddata/dataset/engine/datasetops/source/io_block.h" #include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" namespace mindspore { namespace dataset { @@ -499,5 +500,21 @@ Status TextFileOp::ComputeColMap() { } return Status::OK(); } + +// Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing +// a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so +// that this text file op will produce the full set of data into the cache. +void TextFileOp::MakeSimpleProducer() { + device_id_ = 0; + num_devices_ = 1; + shuffle_files_ = false; + total_rows_ = 0; +} + +// Visitor accept method for NodePass +Status TextFileOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h index 9dfb4ac2ae6..30a87dffdb9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h @@ -188,6 +188,17 @@ class TextFileOp : public ParallelOp { // @return Vector of the input file names std::vector FileNames() { return text_files_list_; } + /// \Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing + /// a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so + /// that this text file op will produce the full set of data into the cache. + void MakeSimpleProducer(); + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + private: // The entry point for when workers are launched. // @param worker_id - the id of the worker that is executing this function. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc index b92e73620e6..711b282f39f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc @@ -212,6 +212,7 @@ Status VOCOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, Ten folder_path_ + std::string(kAnnotationsFolder) + image_id + std::string(kAnnotationExtension); RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image)); RETURN_IF_NOT_OK(ReadAnnotationToTensor(kAnnotationFile, &annotation)); + trow->setId(row_id); trow->push_back(std::move(image)); trow->insert(trow->end(), annotation.begin(), annotation.end()); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc index 1ee5d2b68c4..a4bcc7cbec0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc @@ -45,6 +45,9 @@ #include "minddata/dataset/engine/datasetops/source/random_data_op.h" #ifndef ENABLE_ANDROID #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" +#include "minddata/dataset/engine/datasetops/source/clue_op.h" +#include "minddata/dataset/engine/datasetops/source/csv_op.h" +#include "minddata/dataset/engine/datasetops/source/text_file_op.h" #endif #include "minddata/dataset/engine/datasetops/source/voc_op.h" #ifdef ENABLE_PYTHON @@ -260,6 +263,21 @@ Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) return RunOnNode(std::static_pointer_cast(node), modified); } +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { // Fallback to base class visitor by default return PreRunOnNode(std::static_pointer_cast(node), modified); diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h index 33433f6e490..efb9b6eb049 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h @@ -81,6 +81,12 @@ class CacheMergeOp; class CacheLookupOp; class BuildSentencePieceVocabOp; + +class ClueOp; + +class CsvOp; + +class TextFileOp; #endif #ifdef ENABLE_PYTHON @@ -211,6 +217,12 @@ class NodePass : public Pass { virtual Status RunOnNode(std::shared_ptr node, bool *modified); + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc index 953aeab5376..a2728e4a8c3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc @@ -36,6 +36,9 @@ #ifndef ENABLE_ANDROID #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" +#include "minddata/dataset/engine/datasetops/source/clue_op.h" +#include "minddata/dataset/engine/datasetops/source/csv_op.h" +#include "minddata/dataset/engine/datasetops/source/text_file_op.h" #endif #ifdef ENABLE_PYTHON @@ -141,6 +144,36 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node } return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); } + +// Perform leaf node cache transform identification +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + if (is_caching_) { + // If we are a ClueOp in a caching tree, then change our config so that it becomes a basic + // ClueOp that parses all files. Selection of data will come from the sampler on the cache instead. + node->MakeSimpleProducer(); + } + return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache transform identification +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + if (is_caching_) { + // If we are a CsvOp in a caching tree, then change our config so that it becomes a basic + // CsvOp that parses all files. Selection of data will come from the sampler on the cache instead. + node->MakeSimpleProducer(); + } + return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache transform identification +Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + if (is_caching_) { + // If we are a TextFileOp in a caching tree, then change our config so that it becomes a basic + // TextFileOp that parses all files. Selection of data will come from the sampler on the cache instead. + node->MakeSimpleProducer(); + } + return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); +} #endif // Perform leaf node cache transform identification @@ -163,34 +196,22 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, b // Perform leaf node cache transform identification Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - if (is_caching_) { - RETURN_STATUS_UNEXPECTED("There is currently no support for MnistOp under cache."); - } - return Status::OK(); + return MappableCacheLeafSetup(std::static_pointer_cast(node)); } // Perform leaf node cache transform identification Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - if (is_caching_) { - RETURN_STATUS_UNEXPECTED("There is currently no support for CifarOp under cache."); - } - return Status::OK(); + return MappableCacheLeafSetup(std::static_pointer_cast(node)); } // Perform leaf node cache transform identification Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - if (is_caching_) { - RETURN_STATUS_UNEXPECTED("There is currently no support for CocoOp under cache."); - } - return Status::OK(); + return MappableCacheLeafSetup(std::static_pointer_cast(node)); } // Perform leaf node cache transform identification Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - if (is_caching_) { - RETURN_STATUS_UNEXPECTED("There is currently no support for CelebAOp under cache."); - } - return Status::OK(); + return MappableCacheLeafSetup(std::static_pointer_cast(node)); } #ifndef ENABLE_ANDROID @@ -214,18 +235,12 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr nod // Perform leaf node cache transform identification Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - if (is_caching_) { - RETURN_STATUS_UNEXPECTED("There is currently no support for ManifestOp under cache."); - } - return Status::OK(); + return MappableCacheLeafSetup(std::static_pointer_cast(node)); } // Perform leaf node cache transform identification Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - if (is_caching_) { - RETURN_STATUS_UNEXPECTED("There is currently no support for VOCOp under cache."); - } - return Status::OK(); + return MappableCacheLeafSetup(std::static_pointer_cast(node)); } #endif diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h index 89525d07e89..346f4dd62d3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h @@ -65,6 +65,24 @@ class CacheTransformPass : public TreePass { /// \param[inout] modified Indicator if the node was changed at all /// \return Status The error code return Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; #endif /// \brief Perform leaf node cache tranform identifications diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 51aa0200064..013ced32bd5 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2969,6 +2969,8 @@ class MnistDataset(MappableDataset): into (default=None). shard_id (int, optional): The shard ID within num_shards (default=None). This argument can only be specified when num_shards is also specified. + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). + The cache feature is under development and is not recommended. Raises: RuntimeError: If sampler and shuffle are specified at the same time. @@ -2988,7 +2990,7 @@ class MnistDataset(MappableDataset): @check_mnist_cifar_dataset def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, - shuffle=None, sampler=None, num_shards=None, shard_id=None): + shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None): super().__init__(num_parallel_workers) self.dataset_dir = dataset_dir @@ -2998,6 +3000,7 @@ class MnistDataset(MappableDataset): self.shuffle_level = shuffle self.num_shards = num_shards self.shard_id = shard_id + self.cache = cache def get_args(self): args = super().get_args() @@ -3008,6 +3011,7 @@ class MnistDataset(MappableDataset): args["sampler"] = self.sampler args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id + args["cache"] = self.cache.cache_client if self.cache is not None else None return args def get_dataset_size(self): @@ -3872,6 +3876,8 @@ class ManifestDataset(MappableDataset): into (default=None). shard_id (int, optional): The shard ID within num_shards (default=None). This argument can only be specified when num_shards is also specified. + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). + The cache feature is under development and is not recommended. Raises: RuntimeError: If sampler and shuffle are specified at the same time. @@ -3897,7 +3903,8 @@ class ManifestDataset(MappableDataset): @check_manifestdataset def __init__(self, dataset_file, usage="train", num_samples=None, num_parallel_workers=None, - shuffle=None, sampler=None, class_indexing=None, decode=False, num_shards=None, shard_id=None): + shuffle=None, sampler=None, class_indexing=None, decode=False, num_shards=None, shard_id=None, + cache=None): super().__init__(num_parallel_workers) self.dataset_file = dataset_file @@ -3913,6 +3920,7 @@ class ManifestDataset(MappableDataset): self.shuffle_level = shuffle self.num_shards = num_shards self.shard_id = shard_id + self.cache = cache def get_args(self): args = super().get_args() @@ -3925,6 +3933,7 @@ class ManifestDataset(MappableDataset): args["decode"] = self.decode args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id + args["cache"] = self.cache.cache_client if self.cache is not None else None return args def get_dataset_size(self): @@ -4055,6 +4064,8 @@ class Cifar10Dataset(MappableDataset): into (default=None). shard_id (int, optional): The shard ID within num_shards (default=None). This argument can only be specified when num_shards is also specified. + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). + The cache feature is under development and is not recommended. Raises: RuntimeError: If sampler and shuffle are specified at the same time. @@ -4082,7 +4093,7 @@ class Cifar10Dataset(MappableDataset): @check_mnist_cifar_dataset def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, - shuffle=None, sampler=None, num_shards=None, shard_id=None): + shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None): super().__init__(num_parallel_workers) self.dataset_dir = dataset_dir @@ -4092,6 +4103,7 @@ class Cifar10Dataset(MappableDataset): self.num_shards = num_shards self.shard_id = shard_id self.shuffle_level = shuffle + self.cache = cache def get_args(self): args = super().get_args() @@ -4102,6 +4114,7 @@ class Cifar10Dataset(MappableDataset): args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id args["shuffle"] = self.shuffle_level + args["cache"] = self.cache.cache_client if self.cache is not None else None return args def get_dataset_size(self): @@ -4202,6 +4215,8 @@ class Cifar100Dataset(MappableDataset): into (default=None). shard_id (int, optional): The shard ID within num_shards (default=None). This argument can only be specified when num_shards is also specified. + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). + The cache feature is under development and is not recommended. Raises: RuntimeError: If sampler and shuffle are specified at the same time. @@ -4226,7 +4241,7 @@ class Cifar100Dataset(MappableDataset): @check_mnist_cifar_dataset def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, - shuffle=None, sampler=None, num_shards=None, shard_id=None): + shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None): super().__init__(num_parallel_workers) self.dataset_dir = dataset_dir @@ -4236,6 +4251,7 @@ class Cifar100Dataset(MappableDataset): self.num_shards = num_shards self.shard_id = shard_id self.shuffle_level = shuffle + self.cache = cache def get_args(self): args = super().get_args() @@ -4246,6 +4262,7 @@ class Cifar100Dataset(MappableDataset): args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id args["shuffle"] = self.shuffle_level + args["cache"] = self.cache.cache_client if self.cache is not None else None return args def get_dataset_size(self): @@ -4630,6 +4647,8 @@ class VOCDataset(MappableDataset): into (default=None). shard_id (int, optional): The shard ID within num_shards (default=None). This argument can only be specified when num_shards is also specified. + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). + The cache feature is under development and is not recommended. Raises: RuntimeError: If xml of Annotations is an invalid format. @@ -4667,7 +4686,8 @@ class VOCDataset(MappableDataset): @check_vocdataset def __init__(self, dataset_dir, task="Segmentation", usage="train", class_indexing=None, num_samples=None, - num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None): + num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, + cache=None): super().__init__(num_parallel_workers) self.dataset_dir = dataset_dir self.task = task @@ -4679,6 +4699,7 @@ class VOCDataset(MappableDataset): self.shuffle_level = shuffle self.num_shards = num_shards self.shard_id = shard_id + self.cache = cache def get_args(self): args = super().get_args() @@ -4692,6 +4713,7 @@ class VOCDataset(MappableDataset): args["shuffle"] = self.shuffle_level args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id + args["cache"] = self.cache.cache_client if self.cache is not None else None return args def get_dataset_size(self): @@ -4838,6 +4860,8 @@ class CocoDataset(MappableDataset): into (default=None). shard_id (int, optional): The shard ID within num_shards (default=None). This argument can only be specified when num_shards is also specified. + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). + The cache feature is under development and is not recommended. Raises: RuntimeError: If sampler and shuffle are specified at the same time. @@ -4873,7 +4897,7 @@ class CocoDataset(MappableDataset): @check_cocodataset def __init__(self, dataset_dir, annotation_file, task="Detection", num_samples=None, num_parallel_workers=None, - shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None): + shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None): super().__init__(num_parallel_workers) self.dataset_dir = dataset_dir self.annotation_file = annotation_file @@ -4884,6 +4908,7 @@ class CocoDataset(MappableDataset): self.shuffle_level = shuffle self.num_shards = num_shards self.shard_id = shard_id + self.cache = cache def get_args(self): args = super().get_args() @@ -4896,6 +4921,7 @@ class CocoDataset(MappableDataset): args["shuffle"] = self.shuffle_level args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id + args["cache"] = self.cache.cache_client if self.cache is not None else None return args def get_dataset_size(self): @@ -4993,6 +5019,8 @@ class CelebADataset(MappableDataset): into (default=None). shard_id (int, optional): The shard ID within num_shards (default=None). This argument can only be specified when num_shards is also specified. + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). + The cache feature is under development and is not recommended. Examples: >>> import mindspore.dataset as ds @@ -5003,7 +5031,7 @@ class CelebADataset(MappableDataset): @check_celebadataset def __init__(self, dataset_dir, num_parallel_workers=None, shuffle=None, usage='all', sampler=None, decode=False, - extensions=None, num_samples=None, num_shards=None, shard_id=None): + extensions=None, num_samples=None, num_shards=None, shard_id=None, cache=None): super().__init__(num_parallel_workers) self.dataset_dir = dataset_dir self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) @@ -5015,6 +5043,7 @@ class CelebADataset(MappableDataset): self.num_shards = num_shards self.shard_id = shard_id self.shuffle_level = shuffle + self.cache = cache if usage != "all": dir = os.path.realpath(self.dataset_dir) @@ -5033,6 +5062,7 @@ class CelebADataset(MappableDataset): args["usage"] = self.usage args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id + args["cache"] = self.cache.cache_client if self.cache is not None else None return args def get_dataset_size(self): @@ -5142,6 +5172,8 @@ class CLUEDataset(SourceDataset): num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). shard_id (int, optional): The shard ID within num_shards (default=None). This argument can only be specified when num_shards is also specified. + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). + The cache feature is under development and is not recommended. Examples: >>> import mindspore.dataset as ds @@ -5152,7 +5184,7 @@ class CLUEDataset(SourceDataset): @check_cluedataset def __init__(self, dataset_files, task='AFQMC', usage='train', num_samples=None, - num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): + num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None): super().__init__(num_parallel_workers) self.dataset_files = self._find_files(dataset_files) self.dataset_files.sort() @@ -5293,6 +5325,15 @@ class CLUEDataset(SourceDataset): self.num_shards = num_shards self.shard_id = shard_id + # The clue dataset does not directly support a sampler. It has provided sampling arguments + # (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in + # the pipeline contains a cache. If there is no cache above it, then this sampler is not used. + sampler_shuffle = self.shuffle_files + sampler = None + self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id, + non_mappable=True) + self.cache = cache + def get_args(self): args = super().get_args() args["dataset_files"] = self.dataset_files @@ -5304,6 +5345,8 @@ class CLUEDataset(SourceDataset): args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id args["cols_to_keyword"] = self.cols_to_keyword + args["sampler"] = self.sampler + args["cache"] = self.cache.cache_client if self.cache is not None else None return args def get_dataset_size(self): @@ -5359,6 +5402,9 @@ class CSVDataset(SourceDataset): num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). shard_id (int, optional): The shard ID within num_shards (default=None). This argument can only be specified when num_shards is also specified. + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). + The cache feature is under development and is not recommended. + Examples: >>> import mindspore.dataset as ds @@ -5369,7 +5415,7 @@ class CSVDataset(SourceDataset): @check_csvdataset def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None, - num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): + num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None): super().__init__(num_parallel_workers) self.dataset_files = self._find_files(dataset_files) self.dataset_files.sort() @@ -5394,6 +5440,15 @@ class CSVDataset(SourceDataset): self.num_shards = num_shards self.shard_id = shard_id + self.cache = cache + # The CSV dataset does not directly support a sampler. It has provided sampling arguments + # (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in + # the pipeline contains a cache. If there is no cache above it, then this sampler is not used. + sampler_shuffle = self.shuffle_files + sampler = None + self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id, + non_mappable=True) + def get_args(self): args = super().get_args() args["dataset_files"] = self.dataset_files @@ -5407,6 +5462,8 @@ class CSVDataset(SourceDataset): args["shuffle"] = self.shuffle_level args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id + args["sampler"] = self.sampler + args["cache"] = self.cache.cache_client if self.cache is not None else None return args def get_dataset_size(self): @@ -5457,6 +5514,9 @@ class TextFileDataset(SourceDataset): num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). shard_id (int, optional): The shard ID within num_shards (default=None). This argument can only be specified when num_shards is also specified. + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). + The cache feature is under development and is not recommended. + Examples: >>> import mindspore.dataset as ds >>> @@ -5466,7 +5526,7 @@ class TextFileDataset(SourceDataset): @check_textfiledataset def __init__(self, dataset_files, num_samples=None, num_parallel_workers=None, - shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): + shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None): super().__init__(num_parallel_workers) self.dataset_files = self._find_files(dataset_files) self.dataset_files.sort() @@ -5488,6 +5548,15 @@ class TextFileDataset(SourceDataset): self.num_shards = num_shards self.shard_id = shard_id + self.cache = cache + # The text file dataset does not directly support a sampler. It has provided sampling arguments + # (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in + # the pipeline contains a cache. If there is no cache above it, then this sampler is not used. + sampler_shuffle = self.shuffle_files + sampler = None + self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id, + non_mappable=True) + def get_args(self): args = super().get_args() args["dataset_files"] = self.dataset_files @@ -5498,6 +5567,8 @@ class TextFileDataset(SourceDataset): args["shuffle"] = self.shuffle_level args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id + args["sampler"] = self.sampler + args["cache"] = self.cache.cache_client if self.cache is not None else None return args def get_dataset_size(self): diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index b311a93f5ec..b026967ec1b 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -83,6 +83,9 @@ def check_mnist_cifar_dataset(method): check_sampler_shuffle_shard_options(param_dict) + cache = param_dict.get('cache') + check_cache_option(cache) + return method(self, *args, **kwargs) return new_method @@ -110,6 +113,9 @@ def check_manifestdataset(method): check_sampler_shuffle_shard_options(param_dict) + cache = param_dict.get('cache') + check_cache_option(cache) + return method(self, *args, **kwargs) return new_method @@ -180,6 +186,9 @@ def check_vocdataset(method): validate_dataset_param_value(nreq_param_dict, param_dict, dict) check_sampler_shuffle_shard_options(param_dict) + cache = param_dict.get('cache') + check_cache_option(cache) + return method(self, *args, **kwargs) return new_method @@ -216,6 +225,9 @@ def check_cocodataset(method): raise ValueError("CocoDataset doesn't support PKSampler") check_sampler_shuffle_shard_options(param_dict) + cache = param_dict.get('cache') + check_cache_option(cache) + return method(self, *args, **kwargs) return new_method @@ -252,6 +264,9 @@ def check_celebadataset(method): if sampler is not None and isinstance(sampler, samplers.PKSampler): raise ValueError("CelebADataset does not support PKSampler.") + cache = param_dict.get('cache') + check_cache_option(cache) + return method(self, *args, **kwargs) return new_method @@ -843,6 +858,9 @@ def check_cluedataset(method): validate_dataset_param_value(nreq_param_int, param_dict, int) check_sampler_shuffle_shard_options(param_dict) + cache = param_dict.get('cache') + check_cache_option(cache) + return method(self, *args, **kwargs) return new_method @@ -886,6 +904,9 @@ def check_csvdataset(method): validate_dataset_param_value(nreq_param_int, param_dict, int) check_sampler_shuffle_shard_options(param_dict) + cache = param_dict.get('cache') + check_cache_option(cache) + return method(self, *args, **kwargs) return new_method @@ -905,6 +926,9 @@ def check_textfiledataset(method): validate_dataset_param_value(nreq_param_int, param_dict, int) check_sampler_shuffle_shard_options(param_dict) + cache = param_dict.get('cache') + check_cache_option(cache) + return method(self, *args, **kwargs) return new_method diff --git a/tests/ut/python/cachetests/cachetest_py.sh b/tests/ut/python/cachetests/cachetest_py.sh index 1094e104866..a65ff8855ce 100755 --- a/tests/ut/python/cachetests/cachetest_py.sh +++ b/tests/ut/python/cachetests/cachetest_py.sh @@ -103,6 +103,24 @@ HandleRcExit $? 0 0 PytestCmd "test_cache_map.py" "test_cache_map_epoch_ctrl" 1 HandleRcExit $? 0 0 +PytestCmd "test_cache_map.py" "test_cache_map_coco" 1 +HandleRcExit $? 0 0 + +PytestCmd "test_cache_map.py" "test_cache_map_mnist" 1 +HandleRcExit $? 0 0 + +PytestCmd "test_cache_map.py" "test_cache_map_celeba" 1 +HandleRcExit $? 0 0 + +PytestCmd "test_cache_map.py" "test_cache_map_manifest" 1 +HandleRcExit $? 0 0 + +PytestCmd "test_cache_map.py" "test_cache_map_cifar" 1 +HandleRcExit $? 0 0 + +PytestCmd "test_cache_map.py" "test_cache_map_voc" 1 +HandleRcExit $? 0 0 + # Run two parallel pipelines (sharing cache) for i in $(seq 1 2) do @@ -282,6 +300,15 @@ HandleRcExit $? 0 0 PytestCmd "test_cache_nomap.py" "test_cache_nomap_epoch_ctrl" 1 HandleRcExit $? 0 0 +PytestCmd "test_cache_nomap.py" "test_cache_nomap_clue" 1 +HandleRcExit $? 0 0 + +PytestCmd "test_cache_nomap.py" "test_cache_nomap_csv" 1 +HandleRcExit $? 0 0 + +PytestCmd "test_cache_nomap.py" "test_cache_nomap_textfile" 1 +HandleRcExit $? 0 0 + for i in $(seq 1 3) do test_name="test_cache_nomap_multiple_cache${i}" diff --git a/tests/ut/python/dataset/test_cache_map.py b/tests/ut/python/dataset/test_cache_map.py index e703b7d3099..452fb1271f7 100644 --- a/tests/ut/python/dataset/test_cache_map.py +++ b/tests/ut/python/dataset/test_cache_map.py @@ -17,6 +17,7 @@ Testing cache operator with mappable datasets """ import os import pytest +import numpy as np import mindspore.dataset as ds import mindspore.dataset.vision.c_transforms as c_vision from mindspore import log as logger @@ -26,7 +27,13 @@ DATA_DIR = "../data/dataset/testImageNetData/train/" COCO_DATA_DIR = "../data/dataset/testCOCO/train/" COCO_ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json" NO_IMAGE_DIR = "../data/dataset/testRandomData/" - +MNIST_DATA_DIR = "../data/dataset/testMnistData/" +CELEBA_DATA_DIR = "../data/dataset/testCelebAData/" +VOC_DATA_DIR = "../data/dataset/testVOC2012/" +MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest" +CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data/" +CIFAR100_DATA_DIR = "../data/dataset/testCifar100Data/" +MIND_RECORD_DATA_DIR = "../data/mindrecord/testTwoImageData/twobytes.mindrecord" GENERATE_GOLDEN = False @@ -443,7 +450,7 @@ def test_cache_map_failure5(): @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_map_failure6(): """ - Test no-cache-supporting leaf ops with Map under cache (failure) + Test no-cache-supporting MindRecord leaf with Map under cache (failure) repeat | @@ -451,7 +458,7 @@ def test_cache_map_failure6(): | Map(resize) | - Coco + MindRecord """ logger.info("Test cache failure 6") @@ -461,22 +468,66 @@ def test_cache_map_failure6(): raise RuntimeError("Testcase requires SESSION_ID environment variable") some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) - data = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True) + + columns_list = ["id", "file_name", "label_name", "img_data", "label_data"] + num_readers = 1 + # The dataset has 5 records + data = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list, num_readers) resize_op = c_vision.Resize((224, 224)) - data = data.map(input_columns=["image"], operations=resize_op, cache=some_cache) + data = data.map(input_columns=["img_data"], operations=resize_op, cache=some_cache) data = data.repeat(4) with pytest.raises(RuntimeError) as e: num_iter = 0 for _ in data.create_dict_iterator(): num_iter += 1 - assert "There is currently no support for CocoOp under cache" in str(e.value) + assert "There is currently no support for MindRecordOp under cache" in str(e.value) assert num_iter == 0 logger.info('test_cache_failure6 Ended.\n') +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_failure7(): + """ + Test no-cache-supporting Generator leaf with Map under cache (failure) + + repeat + | + Cache + | + Map(lambda x: x) + | + Generator + + """ + def generator_1d(): + for i in range(64): + yield (np.array(i),) + + logger.info("Test cache failure 7") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + data = ds.GeneratorDataset(generator_1d, ["data"]) + data = data.map((lambda x: x), ["data"], cache=some_cache) + data = data.repeat(4) + + with pytest.raises(RuntimeError) as e: + num_iter = 0 + for _ in data.create_dict_iterator(): + num_iter += 1 + assert "There is currently no support for GeneratorOp under cache" in str(e.value) + + assert num_iter == 0 + logger.info('test_cache_failure7 Ended.\n') + + @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_map_parameter_check(): """ @@ -1236,6 +1287,421 @@ def test_cache_map_epoch_ctrl3(): logger.info("test_cache_map_epoch_ctrl3 Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_coco1(): + """ + Test mappable coco leaf with cache op right over the leaf + + cache + | + Coco + """ + + logger.info("Test cache map coco1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 6 records + ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True, + cache=some_cache) + + num_epoch = 4 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in iter1]) == 6 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_map_coco1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_coco2(): + """ + Test mappable coco leaf with the cache op later in the tree above the map(resize) + + cache + | + Map(resize) + | + Coco + """ + + logger.info("Test cache map coco2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 6 records + ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True) + resize_op = c_vision.Resize((224, 224)) + ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) + + num_epoch = 4 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in iter1]) == 6 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_map_coco2 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_mnist1(): + """ + Test mappable mnist leaf with cache op right over the leaf + + cache + | + Mnist + """ + + logger.info("Test cache map mnist1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10, cache=some_cache) + + num_epoch = 4 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in iter1]) == 10 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_map_mnist1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_mnist2(): + """ + Test mappable mnist leaf with the cache op later in the tree above the map(resize) + + cache + | + Map(resize) + | + Mnist + """ + + logger.info("Test cache map mnist2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10) + + resize_op = c_vision.Resize((224, 224)) + ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) + + num_epoch = 4 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in iter1]) == 10 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_map_mnist2 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_celeba1(): + """ + Test mappable celeba leaf with cache op right over the leaf + + cache + | + CelebA + """ + + logger.info("Test cache map celeba1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 4 records + ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, cache=some_cache) + + num_epoch = 4 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in iter1]) == 4 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_map_celeba1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_celeba2(): + """ + Test mappable celeba leaf with the cache op later in the tree above the map(resize) + + cache + | + Map(resize) + | + CelebA + """ + + logger.info("Test cache map celeba2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 4 records + ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True) + resize_op = c_vision.Resize((224, 224)) + ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) + + num_epoch = 4 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in iter1]) == 4 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_map_celeba2 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_manifest1(): + """ + Test mappable manifest leaf with cache op right over the leaf + + cache + | + Manifest + """ + + logger.info("Test cache map manifest1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 4 records + ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True, cache=some_cache) + + num_epoch = 4 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in iter1]) == 4 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_map_manifest1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_manifest2(): + """ + Test mappable manifest leaf with the cache op later in the tree above the map(resize) + + cache + | + Map(resize) + | + Manifest + """ + + logger.info("Test cache map manifest2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 4 records + ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True) + resize_op = c_vision.Resize((224, 224)) + ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) + + num_epoch = 4 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in iter1]) == 4 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_map_manifest2 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_cifar1(): + """ + Test mappable cifar10 leaf with cache op right over the leaf + + cache + | + Cifar10 + """ + + logger.info("Test cache map cifar1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache) + + num_epoch = 4 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in iter1]) == 10 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_map_cifar1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_cifar2(): + """ + Test mappable cifar100 leaf with the cache op later in the tree above the map(resize) + + cache + | + Map(resize) + | + Cifar100 + """ + + logger.info("Test cache map cifar2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10) + resize_op = c_vision.Resize((224, 224)) + ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) + + num_epoch = 4 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in iter1]) == 10 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_map_cifar2 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_voc1(): + """ + Test mappable voc leaf with cache op right over the leaf + + cache + | + VOC + """ + + logger.info("Test cache map voc1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 9 records + ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True, cache=some_cache) + + num_epoch = 4 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in iter1]) == 9 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_map_voc1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_voc2(): + """ + Test mappable voc leaf with the cache op later in the tree above the map(resize) + + cache + | + Map(resize) + | + VOC + """ + + logger.info("Test cache map voc2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 9 records + ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) + resize_op = c_vision.Resize((224, 224)) + ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) + + num_epoch = 4 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in iter1]) == 9 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_map_voc2 Ended.\n") + + if __name__ == '__main__': test_cache_map_basic1() test_cache_map_basic2() diff --git a/tests/ut/python/dataset/test_cache_nomap.py b/tests/ut/python/dataset/test_cache_nomap.py index 52b3424d651..bb1580d2da9 100644 --- a/tests/ut/python/dataset/test_cache_nomap.py +++ b/tests/ut/python/dataset/test_cache_nomap.py @@ -20,22 +20,26 @@ import itertools import pytest import mindspore.common.dtype as mstype import mindspore.dataset as ds +import mindspore.dataset.text as text import mindspore.dataset.vision.c_transforms as c_vision from mindspore import log as logger DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" -DATA_DIR2 = ["../data/dataset/testTextTFRecord/text.tfrecord"] +TEXT_TF_DATA_DIR = ["../data/dataset/testTextTFRecord/text.tfrecord"] SCHEMA_DIR2 = "../data/dataset/testTextTFRecord/datasetSchema.json" -DATA_DIR3 = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", - "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data", - "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data", - "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"] -SCHEMA_DIR3 = "../data/dataset/test_tf_file_3_images2/datasetSchema.json" +TRAIN_DATA_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", + "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data", + "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data", + "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"] +TRAIN_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images2/datasetSchema.json" -DATA_DIR4 = "../data/dataset/testImageNetData/train/" +IMAGE_FOLDER_DATA_DIR = "../data/dataset/testImageNetData/train/" +CLUE_DATA_DIR = '../data/dataset/testCLUE/afqmc/train.json' +CSV_DATA_DIR = '../data/dataset/testCSV/1.csv' +TEXT_FILE_DATA_DIR = "../data/dataset/testTextFileDataset/1.txt" GENERATE_GOLDEN = False @@ -1310,7 +1314,7 @@ def test_cache_nomap_multiple_cache1(): eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) # This dataset has 12 records in it - train_dataset = ds.TFRecordDataset(DATA_DIR3, SCHEMA_DIR3) + train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR) decode_op = c_vision.Decode() train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache) @@ -1359,7 +1363,7 @@ def test_cache_nomap_multiple_cache2(): image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache) # This dataset has 3 records in it only - text_dataset = ds.TFRecordDataset(DATA_DIR2, SCHEMA_DIR2, cache=text_cache) + text_dataset = ds.TFRecordDataset(TEXT_TF_DATA_DIR, SCHEMA_DIR2, cache=text_cache) num_epoch = 5 image_iter = image_dataset.create_dict_iterator(num_epochs=num_epoch) @@ -1404,7 +1408,7 @@ def test_cache_nomap_multiple_cache3(): tf_dataset = tf_dataset.map(input_columns=["image"], operations=decode_op, cache=tf_cache) # This DATA_DIR only has 2 images in it - image_dataset = ds.ImageFolderDataset(dataset_dir=DATA_DIR4) + image_dataset = ds.ImageFolderDataset(dataset_dir=IMAGE_FOLDER_DATA_DIR) image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache) num_epoch = 5 @@ -1443,7 +1447,7 @@ def test_cache_nomap_multiple_cache_train(): train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) # This dataset has 12 records in it - train_dataset = ds.TFRecordDataset(DATA_DIR3, SCHEMA_DIR3) + train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR) decode_op = c_vision.Decode() train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache) @@ -1497,6 +1501,239 @@ def test_cache_nomap_multiple_cache_eval(): logger.info("test_cache_nomap_multiple_cache_eval Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_clue1(): + """ + A clue dataset (a non mappable dataset) with a cache over it just after the leaf + In this one, the clue dataset will be given sharding configuration, however since a cache is + used, the tree prepare should undo the sharding configuration and instead, a distributed + sampler will be chosen with the same shard config. + + Cache + | + CLUE + """ + + logger.info("Test cache nomap clue 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # With only 3 records shard into 3, we expect only 1 record returned for this shard + # However, the sharding will be done by the sampler, not by the clue leaf node + # In this case, it is a row-based sharding, not the file-based sharding that would happen if + # there was not any cache. + ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_shards=3, shard_id=1, cache=some_cache) + + num_epoch = 4 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in iter1]) == 1 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_nomap_clue1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_clue2(): + """ + A clue dataset (a non mappable dataset) with a cache over it after map + In this one, a num_samples argument is given + + Cache + | + map(lambda x: x) + | + CLUE + """ + + logger.info("Test cache nomap clue 2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_samples=2) + ds1 = ds1.map((lambda x: x), ["label"], cache=some_cache) + + num_epoch = 4 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in iter1]) == 2 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_nomap_clue2 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_csv1(): + """ + A csv dataset (a non mappable dataset) with a cache over it just after the leaf + In this one, the csv dataset will be given sharding configuration, however since a cache is + used, the tree prepare should undo the sharding configuration and instead, a distributed + sampler will be chosen with the same shard config. + + Cache + | + CSV + """ + + logger.info("Test cache nomap csv 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # With only 3 records shard into 3, we expect only 1 record returned for this shard + # However, the sharding will be done by the sampler, not by the clue leaf node + # In this case, it is a row-based sharding, not the file-based sharding that would happen if + # there was not any cache. + ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"], + column_names=['col1', 'col2', 'col3', 'col4'], num_shards=3, shard_id=1, cache=some_cache) + + num_epoch = 4 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in iter1]) == 1 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_nomap_csv1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_csv2(): + """ + A csv dataset (a non mappable dataset) with a cache over it after map + In this one, a num_samples argument is given + + Cache + | + map(lambda x: x) + | + CSV + """ + + logger.info("Test cache nomap csv 2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"], + column_names=['col1', 'col2', 'col3', 'col4'], num_samples=2) + ds1 = ds1.map((lambda x: x), ["col1"], cache=some_cache) + + num_epoch = 4 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in iter1]) == 2 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_nomap_csv2 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_textfile1(): + """ + A text file dataset (a non mappable dataset) with a cache over it just after the leaf + In this one, the text file dataset will be given sharding configuration, however since a cache is + used, the tree prepare should undo the sharding configuration and instead, a distributed + sampler will be chosen with the same shard config. + + Cache + | + TextFile + """ + + logger.info("Test cache nomap textfile 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # With only 3 records shard into 3, we expect only 1 record returned for this shard + # However, the sharding will be done by the sampler, not by the clue leaf node + # In this case, it is a row-based sharding, not the file-based sharding that would happen if + # there was not any cache. + ds1 = ds.CSVDataset(TEXT_FILE_DATA_DIR, num_shards=3, shard_id=1, cache=some_cache) + + num_epoch = 4 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in iter1]) == 1 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_nomap_textfile1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_textfile2(): + """ + A text file dataset (a non mappable dataset) with a cache over it after map + In this one, a num_samples argument is given + + Cache + | + Map(tokenizer) + | + TextFile + """ + def my_tokenizer(line): + words = line.split() + if not words: + return [""] + return words + + logger.info("Test cache nomap textfile 2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + ds1 = ds.TextFileDataset(TEXT_FILE_DATA_DIR, num_samples=2) + tokenizer = text.PythonTokenizer(my_tokenizer) + ds1 = ds1.map(operations=tokenizer, cache=some_cache) + + num_epoch = 4 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in iter1]) == 2 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_nomap_textfile2 Ended.\n") + + if __name__ == '__main__': test_cache_nomap_basic1() test_cache_nomap_basic2() diff --git a/tests/ut/python/dataset/test_datasets_textfileop.py b/tests/ut/python/dataset/test_datasets_textfileop.py index 74de6363db4..3c2c4cd7ea8 100644 --- a/tests/ut/python/dataset/test_datasets_textfileop.py +++ b/tests/ut/python/dataset/test_datasets_textfileop.py @@ -40,8 +40,9 @@ def test_textline_dataset_all_file(): assert count == 5 -def test_textline_dataset_num_samples_zero(): - data = ds.TextFileDataset(DATA_FILE, num_samples=0) +def test_textline_dataset_num_samples_none(): + # Do not provide a num_samples argument, so it would be None by default + data = ds.TextFileDataset(DATA_FILE) count = 0 for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): logger.info("{}".format(i["text"])) @@ -208,7 +209,7 @@ def test_textline_dataset_exceptions(): if __name__ == "__main__": test_textline_dataset_one_file() test_textline_dataset_all_file() - test_textline_dataset_num_samples_zero() + test_textline_dataset_num_samples_none() test_textline_dataset_shuffle_false4() test_textline_dataset_shuffle_false1() test_textline_dataset_shuffle_files4()