!6905 [MD] Enable cache for more leaf datasets

Merge pull request !6905 from lixiachen/CacheOp_dev
This commit is contained in:
mindspore-ci-bot 2020-09-30 16:49:23 +08:00 committed by Gitee
commit 6f77ec45f1
23 changed files with 1332 additions and 107 deletions

View File

@ -1075,7 +1075,7 @@ std::vector<std::shared_ptr<DatasetOp>> CLUEDataset::Build() {
std::shared_ptr<ClueOp> clue_op =
std::make_shared<ClueOp>(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map,
sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_);
sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr);
RETURN_EMPTY_IF_ERROR(clue_op->Init());
if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp
@ -1256,7 +1256,7 @@ std::vector<std::shared_ptr<DatasetOp>> CSVDataset::Build() {
std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>(
sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, rows_per_buffer_,
num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_);
num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr);
RETURN_EMPTY_IF_ERROR(csv_op->Init());
if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp
@ -1502,7 +1502,7 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() {
// Create and initalize TextFileOp
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files,
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(nullptr));
connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr);
RETURN_EMPTY_IF_ERROR(text_file_op->Init());
if (shuffle_ == ShuffleMode::kGlobal) {

View File

@ -1345,6 +1345,9 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset
std::string err_msg = "Error: No dataset files specified for manifest";
RETURN_STATUS_UNEXPECTED(err_msg);
}
int num_workers = 0;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<ManifestOp::Builder> builder = std::make_shared<ManifestOp::Builder>();
(void)builder->SetManifestFile(ToString(args["dataset_file"]));
@ -1354,7 +1357,8 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
num_workers = ToInt(value);
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
@ -1365,12 +1369,27 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset
(void)builder->SetDecode(ToBool(value));
} else if (key == "usage") {
(void)builder->SetUsage(ToString(value));
} else if (key == "cache") {
cache_client = value.cast<std::shared_ptr<CacheClient>>();
}
}
}
std::shared_ptr<ManifestOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*top = op;
std::shared_ptr<ManifestOp> manifest_op;
RETURN_IF_NOT_OK(builder->Build(&manifest_op));
RETURN_IF_NOT_OK(tree_->AssociateNode(manifest_op));
*top = manifest_op;
// Additionally, add a cache if required.
// Note that this cache op is only acting as a place holder for the caching position
// within the tree. Later, a pre-pass will execute a tree transform to set up the actual
// caching logic in the tree.
if (cache_client) {
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, manifest_op, &cache_op));
*top = cache_op;
*bottom = manifest_op;
}
return Status::OK();
}
@ -1380,6 +1399,8 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
CHECK_FAIL_RETURN_UNEXPECTED(!args["task"].is_none(), "Error: No task specified.");
CHECK_FAIL_RETURN_UNEXPECTED(!args["usage"].is_none(), "Error: No usage specified.");
int num_workers = 0;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<VOCOp::Builder> builder = std::make_shared<VOCOp::Builder>();
(void)builder->SetDir(ToString(args["dataset_dir"]));
(void)builder->SetTask(ToString(args["task"]));
@ -1389,7 +1410,8 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
num_workers = ToInt(value);
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
@ -1398,12 +1420,26 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
(void)builder->SetDecode(ToBool(value));
} else if (key == "class_indexing") {
(void)builder->SetClassIndex(ToStringMap(value));
} else if (key == "cache") {
cache_client = value.cast<std::shared_ptr<CacheClient>>();
}
}
}
std::shared_ptr<VOCOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*top = op;
std::shared_ptr<VOCOp> voc_op;
RETURN_IF_NOT_OK(builder->Build(&voc_op));
RETURN_IF_NOT_OK(tree_->AssociateNode(voc_op));
*top = voc_op;
// Additionally, add a cache if required.
// Note that this cache op is only acting as a place holder for the caching position
// within the tree. Later, a pre-pass will execute a tree transform to set up the actual
// caching logic in the tree.
if (cache_client) {
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, voc_op, &cache_op));
*top = cache_op;
*bottom = voc_op;
}
return Status::OK();
}
@ -1425,6 +1461,8 @@ Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp>
RETURN_STATUS_UNEXPECTED(err_msg);
}
int num_workers = 0;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<CocoOp::Builder> builder = std::make_shared<CocoOp::Builder>();
(void)builder->SetDir(ToString(args["dataset_dir"]));
(void)builder->SetFile(ToString(args["annotation_file"]));
@ -1434,19 +1472,35 @@ Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp>
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
num_workers = ToInt(value);
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "decode") {
(void)builder->SetDecode(ToBool(value));
} else if (key == "cache") {
cache_client = value.cast<std::shared_ptr<CacheClient>>();
}
}
}
std::shared_ptr<CocoOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*top = op;
std::shared_ptr<CocoOp> coco_op;
RETURN_IF_NOT_OK(builder->Build(&coco_op));
RETURN_IF_NOT_OK(tree_->AssociateNode(coco_op));
*top = coco_op;
// Additionally, add a cache if required.
// Note that this cache op is only acting as a place holder for the caching position
// within the tree. Later, a pre-pass will execute a tree transform to set up the actual
// caching logic in the tree.
if (cache_client) {
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, coco_op, &cache_op));
*top = cache_op;
*bottom = coco_op;
}
return Status::OK();
}
@ -1458,6 +1512,8 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO
RETURN_STATUS_UNEXPECTED(err_msg);
}
int num_workers = 0;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<CifarOp::Builder> builder = std::make_shared<CifarOp::Builder>();
(void)builder->SetCifarDir(ToString(args["dataset_dir"]));
@ -1467,22 +1523,38 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
num_workers = ToInt(value);
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "usage") {
(void)builder->SetUsage(ToString(value));
} else if (key == "cache") {
cache_client = value.cast<std::shared_ptr<CacheClient>>();
}
}
}
(void)builder->SetCifarType(true);
std::shared_ptr<CifarOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*top = op;
std::shared_ptr<CifarOp> cifar_op;
RETURN_IF_NOT_OK(builder->Build(&cifar_op));
RETURN_IF_NOT_OK(tree_->AssociateNode(cifar_op));
*top = cifar_op;
// Additionally, add a cache if required.
// Note that this cache op is only acting as a place holder for the caching position
// within the tree. Later, a pre-pass will execute a tree transform to set up the actual
// caching logic in the tree.
if (cache_client) {
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, cifar_op, &cache_op));
*top = cache_op;
*bottom = cifar_op;
}
return Status::OK();
}
@ -1494,6 +1566,8 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset
RETURN_STATUS_UNEXPECTED(err_msg);
}
int num_workers = 0;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<CifarOp::Builder> builder = std::make_shared<CifarOp::Builder>();
(void)builder->SetCifarDir(ToString(args["dataset_dir"]));
@ -1503,22 +1577,37 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
num_workers = ToInt(value);
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "usage") {
(void)builder->SetUsage(ToString(value));
} else if (key == "cache") {
cache_client = value.cast<std::shared_ptr<CacheClient>>();
}
}
}
(void)builder->SetCifarType(false);
std::shared_ptr<CifarOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*top = op;
std::shared_ptr<CifarOp> cifar_op;
RETURN_IF_NOT_OK(builder->Build(&cifar_op));
RETURN_IF_NOT_OK(tree_->AssociateNode(cifar_op));
*top = cifar_op;
// Additionally, add a cache if required.
// Note that this cache op is only acting as a place holder for the caching position
// within the tree. Later, a pre-pass will execute a tree transform to set up the actual
// caching logic in the tree.
if (cache_client) {
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, cifar_op, &cache_op));
*top = cache_op;
*bottom = cifar_op;
}
return Status::OK();
}
@ -1609,6 +1698,8 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp>
RETURN_STATUS_UNEXPECTED(err_msg);
}
int num_workers = 0;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<MnistOp::Builder> builder = std::make_shared<MnistOp::Builder>();
(void)builder->SetDir(ToString(args["dataset_dir"]));
@ -1618,19 +1709,35 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp>
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
num_workers = ToInt(value);
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "usage") {
(void)builder->SetUsage(ToString(value));
} else if (key == "cache") {
cache_client = value.cast<std::shared_ptr<CacheClient>>();
}
}
}
std::shared_ptr<MnistOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*top = op;
std::shared_ptr<MnistOp> mnist_op;
RETURN_IF_NOT_OK(builder->Build(&mnist_op));
RETURN_IF_NOT_OK(tree_->AssociateNode(mnist_op));
*top = mnist_op;
// Additionally, add a cache if required.
// Note that this cache op is only acting as a place holder for the caching position
// within the tree. Later, a pre-pass will execute a tree transform to set up the actual
// caching logic in the tree.
if (cache_client) {
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, mnist_op, &cache_op));
*top = cache_op;
*bottom = mnist_op;
}
return Status::OK();
}
@ -1642,6 +1749,8 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
}
int num_workers = 0;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<CelebAOp::Builder> builder = std::make_shared<CelebAOp::Builder>();
if (builder == nullptr) {
std::string err_msg = "Create celebaop builder failed";
@ -1653,7 +1762,8 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
num_workers = ToInt(value);
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
@ -1664,13 +1774,28 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
(void)builder->SetExtensions(ToStringSet(value));
} else if (key == "usage") {
(void)builder->SetUsage(ToString(value));
} else if (key == "cache") {
cache_client = value.cast<std::shared_ptr<CacheClient>>();
}
}
}
std::shared_ptr<CelebAOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*top = op;
std::shared_ptr<CelebAOp> celeba_op;
RETURN_IF_NOT_OK(builder->Build(&celeba_op));
RETURN_IF_NOT_OK(tree_->AssociateNode(celeba_op));
*top = celeba_op;
// Additionally, add a cache if required.
// Note that this cache op is only acting as a place holder for the caching position
// within the tree. Later, a pre-pass will execute a tree transform to set up the actual
// caching logic in the tree.
if (cache_client) {
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, celeba_op, &cache_op));
*top = cache_op;
*bottom = celeba_op;
}
return Status::OK();
}
@ -1678,6 +1803,9 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
std::shared_ptr<DatasetOp> *bottom) {
// Required arguments
std::vector<std::string> files_list;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
int num_workers = 0;
std::shared_ptr<TextFileOp::Builder> builder = std::make_shared<TextFileOp::Builder>();
if (!args["dataset_files"].is_none()) {
files_list = ToStringVector(args["dataset_files"]);
@ -1693,7 +1821,8 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
num_workers = ToInt(value);
(void)builder->SetNumWorkers(num_workers);
} else if (key == "shuffle_files") {
(void)builder->SetShuffleFiles(ToBool(value));
} else if (key == "shuffle_global") {
@ -1705,16 +1834,35 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
(void)builder->SetNumDevices(num_devices);
} else if (key == "shard_id") {
(void)builder->SetDeviceId(ToInt(value));
} else if (key == "cache") {
cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
}
}
}
// If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed
// because TextFileOp is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we save the sampler here in a leaf node that does not use sampling.
if (sampler) {
(void)builder->SetSampler(std::move(sampler));
} else if (cache_client) {
int64_t num_samples = 0;
int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
(void)builder->SetSampler(std::move(sampler));
}
std::shared_ptr<TextFileOp> txt_op;
RETURN_IF_NOT_OK(builder->Build(&txt_op));
RETURN_IF_NOT_OK(tree_->AssociateNode(txt_op));
*top = txt_op;
if (shuffle_required) {
if (!cache_client && shuffle_required) {
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t shuffle_size = 0;
int64_t num_rows = 0;
@ -1729,6 +1877,15 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
*bottom = txt_op;
}
// Add a cache op over this op if required and update the output subtree (top/bottom)
if (cache_client) {
// Note, it is not allowed to have both shuffle and cache
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, txt_op, &cache_op));
*top = cache_op;
*bottom = txt_op;
}
return Status::OK();
}
@ -1829,6 +1986,10 @@ Status DEPipeline::ParseBuildSentencePieceVocabOp(const py::dict &args, std::sha
Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom) {
std::vector<std::string> files_list;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
int num_workers = 0;
std::shared_ptr<ClueOp::Builder> builder = std::make_shared<ClueOp::Builder>();
if (!args["dataset_files"].is_none()) {
files_list = ToStringVector(args["dataset_files"]);
@ -1844,7 +2005,8 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp>
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
num_workers = ToInt(value);
(void)builder->SetNumWorkers(num_workers);
} else if (key == "shuffle_files") {
(void)builder->SetShuffleFiles(ToBool(value));
} else if (key == "shuffle_global") {
@ -1866,16 +2028,35 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp>
}
}
(void)builder->SetColsKeyMap(map_dict);
} else if (key == "cache") {
cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
}
}
}
// If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed
// because ClueOp is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we save the sampler here in a leaf node that does not use sampling.
if (sampler) {
(void)builder->SetSampler(std::move(sampler));
} else if (cache_client) {
int64_t num_samples = 0;
int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
(void)builder->SetSampler(std::move(sampler));
}
std::shared_ptr<ClueOp> clue_op;
RETURN_IF_NOT_OK(builder->Build(&clue_op));
RETURN_IF_NOT_OK(tree_->AssociateNode(clue_op));
*top = clue_op;
if (shuffle_required) {
if (!cache_client && shuffle_required) {
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t shuffle_size = 0;
int64_t num_rows = 0;
@ -1890,6 +2071,15 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp>
*bottom = clue_op;
}
// Add a cache op over this op if required and update the output subtree (top/bottom)
if (cache_client) {
// Note, it is not allowed to have both shuffle and cache
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, clue_op, &cache_op));
*top = cache_op;
*bottom = clue_op;
}
return Status::OK();
}
@ -1921,6 +2111,9 @@ Status DEPipeline::AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num
Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom) {
std::vector<std::string> files_list;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
int num_workers = 0;
std::shared_ptr<CsvOp::Builder> builder = std::make_shared<CsvOp::Builder>();
if (!args["dataset_files"].is_none()) {
files_list = ToStringVector(args["dataset_files"]);
@ -1938,7 +2131,8 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
num_workers = ToInt(value);
(void)builder->SetNumWorkers(num_workers);
} else if (key == "shuffle_files") {
(void)builder->SetShuffleFiles(ToBool(value));
} else if (key == "shuffle_global") {
@ -1971,16 +2165,35 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *
} else if (key == "column_names") {
col_names = ToStringVector(value);
(void)builder->SetColumName(col_names);
} else if (key == "cache") {
cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
}
}
}
// If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed
// because CsvOp is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we save the sampler here in a leaf node that does not use sampling.
if (sampler) {
(void)builder->SetSampler(std::move(sampler));
} else if (cache_client) {
int64_t num_samples = 0;
int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
(void)builder->SetSampler(std::move(sampler));
}
std::shared_ptr<CsvOp> csv_op;
RETURN_IF_NOT_OK(builder->Build(&csv_op));
RETURN_IF_NOT_OK(tree_->AssociateNode(csv_op));
*top = csv_op;
if (shuffle_required) {
if (!cache_client && shuffle_required) {
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t shuffle_size = 0;
int64_t num_rows = 0;
@ -1995,6 +2208,15 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *
*bottom = csv_op;
}
// Add a cache op over this op if required and update the output subtree (top/bottom)
if (cache_client) {
// Note, it is not allowed to have both shuffle and cache
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, csv_op, &cache_op));
*top = cache_op;
*bottom = csv_op;
}
return Status::OK();
}

View File

@ -70,13 +70,12 @@ Status AlbumOp::Builder::SanityCheck() {
AlbumOp::AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode,
const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler)
: ParallelOp(num_wkrs, queue_size),
: ParallelOp(num_wkrs, queue_size, std::move(sampler)),
rows_per_buffer_(rows_per_buffer),
folder_path_(file_dir),
decode_(do_decode),
extensions_(exts),
data_schema_(std::move(data_schema)),
sampler_(std::move(sampler)),
row_cnt_(0),
buf_cnt_(0),
sampler_ind_(0),

View File

@ -284,7 +284,6 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
std::set<std::string> extensions_; // extensions allowed
std::unordered_map<std::string, int32_t> col_name_map_;
std::unique_ptr<DataSchema> data_schema_;
std::shared_ptr<Sampler> sampler_;
int64_t row_cnt_;
int64_t buf_cnt_;
int64_t sampler_ind_;

View File

@ -25,13 +25,18 @@
#include "minddata/dataset/util/task_manager.h"
#include "minddata/dataset/engine/jagged_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/util/random.h"
namespace mindspore {
namespace dataset {
ClueOp::Builder::Builder()
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
: builder_device_id_(0),
builder_num_devices_(1),
builder_num_samples_(0),
builder_shuffle_files_(false),
builder_sampler_(nullptr) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers();
builder_op_connector_size_ = config_manager->op_connector_size();
@ -68,7 +73,7 @@ Status ClueOp::Builder::Build(std::shared_ptr<ClueOp> *op) {
std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>(
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map,
builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_,
builder_device_id_);
builder_device_id_, std::move(builder_sampler_));
RETURN_IF_NOT_OK(clue_op->Init());
*op = std::move(clue_op);
@ -88,8 +93,8 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim
ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_device, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
rows_per_buffer_(rows_per_buffer),
num_rows_per_shard_(0),
all_num_rows_(0),
@ -539,5 +544,21 @@ Status ClueOp::ComputeColMap() {
}
return Status::OK();
}
// Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing
// a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so
// that this clue op will produce the full set of data into the cache.
void ClueOp::MakeSimpleProducer() {
device_id_ = 0;
num_devices_ = 1;
shuffle_files_ = false;
num_samples_ = 0;
}
// Visitor accept method for NodePass
Status ClueOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<ClueOp>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -20,6 +20,7 @@
#include <map>
#include <mutex>
#include <string>
#include <utility>
#include <vector>
#include <nlohmann/json.hpp>
@ -122,6 +123,14 @@ class ClueOp : public ParallelOp {
// @return - the a string vector
std::vector<std::string> split(const std::string &s, char delim);
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
private:
int32_t builder_device_id_;
int32_t builder_num_devices_;
@ -133,12 +142,13 @@ class ClueOp : public ParallelOp {
std::vector<std::string> builder_clue_files_list_;
bool builder_shuffle_files_;
std::map<std::string, std::string> builder_cols_to_keyword_;
std::shared_ptr<Sampler> builder_sampler_;
};
// Constructor of ClueOp
ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_devices, int32_t device_id);
bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<Sampler> sampler);
// Default destructor
~ClueOp() = default;
@ -173,6 +183,17 @@ class ClueOp : public ParallelOp {
// @return Vector of the input file names
std::vector<std::string> FileNames() { return clue_files_list_; }
/// \Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so
/// that this clue op will produce the full set of data into the cache.
void MakeSimpleProducer();
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.

View File

@ -124,7 +124,7 @@ Status CocoOp::Builder::SanityCheck() {
CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path,
int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_workers, queue_size),
: ParallelOp(num_workers, queue_size, std::move(sampler)),
decode_(decode),
row_cnt_(0),
buf_cnt_(0),
@ -132,7 +132,6 @@ CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path,
image_folder_path_(image_folder_path),
annotation_path_(annotation_path),
rows_per_buffer_(rows_per_buffer),
sampler_(std::move(sampler)),
data_schema_(std::move(data_schema)) {
io_block_queues_.Init(num_workers_, queue_size);
}

View File

@ -206,6 +206,10 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "CocoOp"; }
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return
@ -324,7 +328,6 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
std::string annotation_path_;
TaskType task_type_;
int32_t rows_per_buffer_;
std::shared_ptr<Sampler> sampler_;
std::unique_ptr<DataSchema> data_schema_;
WaitPost wp_;

View File

@ -22,12 +22,17 @@
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/jagged_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/random.h"
namespace mindspore {
namespace dataset {
CsvOp::Builder::Builder()
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
: builder_device_id_(0),
builder_num_devices_(1),
builder_num_samples_(0),
builder_shuffle_files_(false),
builder_sampler_(nullptr) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers();
builder_op_connector_size_ = config_manager->op_connector_size();
@ -59,7 +64,8 @@ Status CsvOp::Builder::Build(std::shared_ptr<CsvOp> *op) {
std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>(
builder_csv_files_list_, builder_field_delim_, builder_column_default_list_, builder_column_name_list_,
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_,
builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_);
builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_,
std::move(builder_sampler_));
RETURN_IF_NOT_OK(csv_op->Init());
*op = std::move(csv_op);
@ -70,8 +76,8 @@ CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default,
const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer,
int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files,
int32_t num_device, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
int32_t num_device, int32_t device_id, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
csv_files_list_(std::move(csv_files_list)),
field_delim_(field_delim),
column_default_list_(column_default),
@ -889,5 +895,21 @@ Status CsvOp::ComputeColMap() {
}
return Status::OK();
}
// Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing
// a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so
// that this csv op will produce the full set of data into the cache.
void CsvOp::MakeSimpleProducer() {
device_id_ = 0;
num_devices_ = 1;
shuffle_files_ = false;
num_samples_ = 0;
}
// Visitor accept method for NodePass
Status CsvOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<CsvOp>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -240,6 +240,14 @@ class CsvOp : public ParallelOp {
return *this;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
private:
int32_t builder_device_id_;
int32_t builder_num_devices_;
@ -253,6 +261,7 @@ class CsvOp : public ParallelOp {
char builder_field_delim_;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_;
std::vector<std::string> builder_column_name_list_;
std::shared_ptr<Sampler> builder_sampler_;
};
// Constructor of CsvOp
@ -261,7 +270,8 @@ class CsvOp : public ParallelOp {
CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name,
int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id);
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id,
std::shared_ptr<Sampler> sampler);
// Default destructor
~CsvOp() = default;
@ -297,6 +307,17 @@ class CsvOp : public ParallelOp {
// @return Vector of the input file names
std::vector<std::string> FileNames() { return csv_files_list_; }
/// \Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so
/// that this csv op will produce the full set of data into the cache.
void MakeSimpleProducer();
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.

View File

@ -29,6 +29,7 @@
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
@ -499,5 +500,21 @@ Status TextFileOp::ComputeColMap() {
}
return Status::OK();
}
// Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing
// a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so
// that this text file op will produce the full set of data into the cache.
void TextFileOp::MakeSimpleProducer() {
device_id_ = 0;
num_devices_ = 1;
shuffle_files_ = false;
total_rows_ = 0;
}
// Visitor accept method for NodePass
Status TextFileOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<TextFileOp>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -188,6 +188,17 @@ class TextFileOp : public ParallelOp {
// @return Vector of the input file names
std::vector<std::string> FileNames() { return text_files_list_; }
/// \Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so
/// that this text file op will produce the full set of data into the cache.
void MakeSimpleProducer();
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.

View File

@ -212,6 +212,7 @@ Status VOCOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, Ten
folder_path_ + std::string(kAnnotationsFolder) + image_id + std::string(kAnnotationExtension);
RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image));
RETURN_IF_NOT_OK(ReadAnnotationToTensor(kAnnotationFile, &annotation));
trow->setId(row_id);
trow->push_back(std::move(image));
trow->insert(trow->end(), annotation.begin(), annotation.end());
}

View File

@ -45,6 +45,9 @@
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#endif
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
#ifdef ENABLE_PYTHON
@ -260,6 +263,21 @@ Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified)
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);

View File

@ -81,6 +81,12 @@ class CacheMergeOp;
class CacheLookupOp;
class BuildSentencePieceVocabOp;
class ClueOp;
class CsvOp;
class TextFileOp;
#endif
#ifdef ENABLE_PYTHON
@ -211,6 +217,12 @@ class NodePass : public Pass {
virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);

View File

@ -36,6 +36,9 @@
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#endif
#ifdef ENABLE_PYTHON
@ -141,6 +144,36 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node
}
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) {
if (is_caching_) {
// If we are a ClueOp in a caching tree, then change our config so that it becomes a basic
// ClueOp that parses all files. Selection of data will come from the sampler on the cache instead.
node->MakeSimpleProducer();
}
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) {
if (is_caching_) {
// If we are a CsvOp in a caching tree, then change our config so that it becomes a basic
// CsvOp that parses all files. Selection of data will come from the sampler on the cache instead.
node->MakeSimpleProducer();
}
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) {
if (is_caching_) {
// If we are a TextFileOp in a caching tree, then change our config so that it becomes a basic
// TextFileOp that parses all files. Selection of data will come from the sampler on the cache instead.
node->MakeSimpleProducer();
}
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
#endif
// Perform leaf node cache transform identification
@ -163,34 +196,22 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<AlbumOp> node, b
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) {
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for MnistOp under cache.");
}
return Status::OK();
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) {
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for CifarOp under cache.");
}
return Status::OK();
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) {
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for CocoOp under cache.");
}
return Status::OK();
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) {
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for CelebAOp under cache.");
}
return Status::OK();
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
#ifndef ENABLE_ANDROID
@ -214,18 +235,12 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> nod
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for ManifestOp under cache.");
}
return Status::OK();
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for VOCOp under cache.");
}
return Status::OK();
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
#endif

View File

@ -65,6 +65,24 @@ class CacheTransformPass : public TreePass {
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) override;
#endif
/// \brief Perform leaf node cache tranform identifications

View File

@ -2969,6 +2969,8 @@ class MnistDataset(MappableDataset):
into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
The cache feature is under development and is not recommended.
Raises:
RuntimeError: If sampler and shuffle are specified at the same time.
@ -2988,7 +2990,7 @@ class MnistDataset(MappableDataset):
@check_mnist_cifar_dataset
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None,
shuffle=None, sampler=None, num_shards=None, shard_id=None):
shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers)
self.dataset_dir = dataset_dir
@ -2998,6 +3000,7 @@ class MnistDataset(MappableDataset):
self.shuffle_level = shuffle
self.num_shards = num_shards
self.shard_id = shard_id
self.cache = cache
def get_args(self):
args = super().get_args()
@ -3008,6 +3011,7 @@ class MnistDataset(MappableDataset):
args["sampler"] = self.sampler
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
args["cache"] = self.cache.cache_client if self.cache is not None else None
return args
def get_dataset_size(self):
@ -3872,6 +3876,8 @@ class ManifestDataset(MappableDataset):
into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
The cache feature is under development and is not recommended.
Raises:
RuntimeError: If sampler and shuffle are specified at the same time.
@ -3897,7 +3903,8 @@ class ManifestDataset(MappableDataset):
@check_manifestdataset
def __init__(self, dataset_file, usage="train", num_samples=None, num_parallel_workers=None,
shuffle=None, sampler=None, class_indexing=None, decode=False, num_shards=None, shard_id=None):
shuffle=None, sampler=None, class_indexing=None, decode=False, num_shards=None, shard_id=None,
cache=None):
super().__init__(num_parallel_workers)
self.dataset_file = dataset_file
@ -3913,6 +3920,7 @@ class ManifestDataset(MappableDataset):
self.shuffle_level = shuffle
self.num_shards = num_shards
self.shard_id = shard_id
self.cache = cache
def get_args(self):
args = super().get_args()
@ -3925,6 +3933,7 @@ class ManifestDataset(MappableDataset):
args["decode"] = self.decode
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
args["cache"] = self.cache.cache_client if self.cache is not None else None
return args
def get_dataset_size(self):
@ -4055,6 +4064,8 @@ class Cifar10Dataset(MappableDataset):
into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
The cache feature is under development and is not recommended.
Raises:
RuntimeError: If sampler and shuffle are specified at the same time.
@ -4082,7 +4093,7 @@ class Cifar10Dataset(MappableDataset):
@check_mnist_cifar_dataset
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None,
shuffle=None, sampler=None, num_shards=None, shard_id=None):
shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers)
self.dataset_dir = dataset_dir
@ -4092,6 +4103,7 @@ class Cifar10Dataset(MappableDataset):
self.num_shards = num_shards
self.shard_id = shard_id
self.shuffle_level = shuffle
self.cache = cache
def get_args(self):
args = super().get_args()
@ -4102,6 +4114,7 @@ class Cifar10Dataset(MappableDataset):
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
args["shuffle"] = self.shuffle_level
args["cache"] = self.cache.cache_client if self.cache is not None else None
return args
def get_dataset_size(self):
@ -4202,6 +4215,8 @@ class Cifar100Dataset(MappableDataset):
into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
The cache feature is under development and is not recommended.
Raises:
RuntimeError: If sampler and shuffle are specified at the same time.
@ -4226,7 +4241,7 @@ class Cifar100Dataset(MappableDataset):
@check_mnist_cifar_dataset
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None,
shuffle=None, sampler=None, num_shards=None, shard_id=None):
shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers)
self.dataset_dir = dataset_dir
@ -4236,6 +4251,7 @@ class Cifar100Dataset(MappableDataset):
self.num_shards = num_shards
self.shard_id = shard_id
self.shuffle_level = shuffle
self.cache = cache
def get_args(self):
args = super().get_args()
@ -4246,6 +4262,7 @@ class Cifar100Dataset(MappableDataset):
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
args["shuffle"] = self.shuffle_level
args["cache"] = self.cache.cache_client if self.cache is not None else None
return args
def get_dataset_size(self):
@ -4630,6 +4647,8 @@ class VOCDataset(MappableDataset):
into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
The cache feature is under development and is not recommended.
Raises:
RuntimeError: If xml of Annotations is an invalid format.
@ -4667,7 +4686,8 @@ class VOCDataset(MappableDataset):
@check_vocdataset
def __init__(self, dataset_dir, task="Segmentation", usage="train", class_indexing=None, num_samples=None,
num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None):
num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None,
cache=None):
super().__init__(num_parallel_workers)
self.dataset_dir = dataset_dir
self.task = task
@ -4679,6 +4699,7 @@ class VOCDataset(MappableDataset):
self.shuffle_level = shuffle
self.num_shards = num_shards
self.shard_id = shard_id
self.cache = cache
def get_args(self):
args = super().get_args()
@ -4692,6 +4713,7 @@ class VOCDataset(MappableDataset):
args["shuffle"] = self.shuffle_level
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
args["cache"] = self.cache.cache_client if self.cache is not None else None
return args
def get_dataset_size(self):
@ -4838,6 +4860,8 @@ class CocoDataset(MappableDataset):
into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
The cache feature is under development and is not recommended.
Raises:
RuntimeError: If sampler and shuffle are specified at the same time.
@ -4873,7 +4897,7 @@ class CocoDataset(MappableDataset):
@check_cocodataset
def __init__(self, dataset_dir, annotation_file, task="Detection", num_samples=None, num_parallel_workers=None,
shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None):
shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers)
self.dataset_dir = dataset_dir
self.annotation_file = annotation_file
@ -4884,6 +4908,7 @@ class CocoDataset(MappableDataset):
self.shuffle_level = shuffle
self.num_shards = num_shards
self.shard_id = shard_id
self.cache = cache
def get_args(self):
args = super().get_args()
@ -4896,6 +4921,7 @@ class CocoDataset(MappableDataset):
args["shuffle"] = self.shuffle_level
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
args["cache"] = self.cache.cache_client if self.cache is not None else None
return args
def get_dataset_size(self):
@ -4993,6 +5019,8 @@ class CelebADataset(MappableDataset):
into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
The cache feature is under development and is not recommended.
Examples:
>>> import mindspore.dataset as ds
@ -5003,7 +5031,7 @@ class CelebADataset(MappableDataset):
@check_celebadataset
def __init__(self, dataset_dir, num_parallel_workers=None, shuffle=None, usage='all', sampler=None, decode=False,
extensions=None, num_samples=None, num_shards=None, shard_id=None):
extensions=None, num_samples=None, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers)
self.dataset_dir = dataset_dir
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
@ -5015,6 +5043,7 @@ class CelebADataset(MappableDataset):
self.num_shards = num_shards
self.shard_id = shard_id
self.shuffle_level = shuffle
self.cache = cache
if usage != "all":
dir = os.path.realpath(self.dataset_dir)
@ -5033,6 +5062,7 @@ class CelebADataset(MappableDataset):
args["usage"] = self.usage
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
args["cache"] = self.cache.cache_client if self.cache is not None else None
return args
def get_dataset_size(self):
@ -5142,6 +5172,8 @@ class CLUEDataset(SourceDataset):
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
The cache feature is under development and is not recommended.
Examples:
>>> import mindspore.dataset as ds
@ -5152,7 +5184,7 @@ class CLUEDataset(SourceDataset):
@check_cluedataset
def __init__(self, dataset_files, task='AFQMC', usage='train', num_samples=None,
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers)
self.dataset_files = self._find_files(dataset_files)
self.dataset_files.sort()
@ -5293,6 +5325,15 @@ class CLUEDataset(SourceDataset):
self.num_shards = num_shards
self.shard_id = shard_id
# The clue dataset does not directly support a sampler. It has provided sampling arguments
# (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in
# the pipeline contains a cache. If there is no cache above it, then this sampler is not used.
sampler_shuffle = self.shuffle_files
sampler = None
self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id,
non_mappable=True)
self.cache = cache
def get_args(self):
args = super().get_args()
args["dataset_files"] = self.dataset_files
@ -5304,6 +5345,8 @@ class CLUEDataset(SourceDataset):
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
args["cols_to_keyword"] = self.cols_to_keyword
args["sampler"] = self.sampler
args["cache"] = self.cache.cache_client if self.cache is not None else None
return args
def get_dataset_size(self):
@ -5359,6 +5402,9 @@ class CSVDataset(SourceDataset):
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
The cache feature is under development and is not recommended.
Examples:
>>> import mindspore.dataset as ds
@ -5369,7 +5415,7 @@ class CSVDataset(SourceDataset):
@check_csvdataset
def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None,
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers)
self.dataset_files = self._find_files(dataset_files)
self.dataset_files.sort()
@ -5394,6 +5440,15 @@ class CSVDataset(SourceDataset):
self.num_shards = num_shards
self.shard_id = shard_id
self.cache = cache
# The CSV dataset does not directly support a sampler. It has provided sampling arguments
# (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in
# the pipeline contains a cache. If there is no cache above it, then this sampler is not used.
sampler_shuffle = self.shuffle_files
sampler = None
self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id,
non_mappable=True)
def get_args(self):
args = super().get_args()
args["dataset_files"] = self.dataset_files
@ -5407,6 +5462,8 @@ class CSVDataset(SourceDataset):
args["shuffle"] = self.shuffle_level
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
args["sampler"] = self.sampler
args["cache"] = self.cache.cache_client if self.cache is not None else None
return args
def get_dataset_size(self):
@ -5457,6 +5514,9 @@ class TextFileDataset(SourceDataset):
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
The cache feature is under development and is not recommended.
Examples:
>>> import mindspore.dataset as ds
>>>
@ -5466,7 +5526,7 @@ class TextFileDataset(SourceDataset):
@check_textfiledataset
def __init__(self, dataset_files, num_samples=None, num_parallel_workers=None,
shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers)
self.dataset_files = self._find_files(dataset_files)
self.dataset_files.sort()
@ -5488,6 +5548,15 @@ class TextFileDataset(SourceDataset):
self.num_shards = num_shards
self.shard_id = shard_id
self.cache = cache
# The text file dataset does not directly support a sampler. It has provided sampling arguments
# (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in
# the pipeline contains a cache. If there is no cache above it, then this sampler is not used.
sampler_shuffle = self.shuffle_files
sampler = None
self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id,
non_mappable=True)
def get_args(self):
args = super().get_args()
args["dataset_files"] = self.dataset_files
@ -5498,6 +5567,8 @@ class TextFileDataset(SourceDataset):
args["shuffle"] = self.shuffle_level
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
args["sampler"] = self.sampler
args["cache"] = self.cache.cache_client if self.cache is not None else None
return args
def get_dataset_size(self):

View File

@ -83,6 +83,9 @@ def check_mnist_cifar_dataset(method):
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
@ -110,6 +113,9 @@ def check_manifestdataset(method):
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
@ -180,6 +186,9 @@ def check_vocdataset(method):
validate_dataset_param_value(nreq_param_dict, param_dict, dict)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
@ -216,6 +225,9 @@ def check_cocodataset(method):
raise ValueError("CocoDataset doesn't support PKSampler")
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
@ -252,6 +264,9 @@ def check_celebadataset(method):
if sampler is not None and isinstance(sampler, samplers.PKSampler):
raise ValueError("CelebADataset does not support PKSampler.")
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
@ -843,6 +858,9 @@ def check_cluedataset(method):
validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
@ -886,6 +904,9 @@ def check_csvdataset(method):
validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
@ -905,6 +926,9 @@ def check_textfiledataset(method):
validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method

View File

@ -103,6 +103,24 @@ HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_epoch_ctrl" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_coco" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_mnist" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_celeba" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_manifest" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_cifar" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_voc" 1
HandleRcExit $? 0 0
# Run two parallel pipelines (sharing cache)
for i in $(seq 1 2)
do
@ -282,6 +300,15 @@ HandleRcExit $? 0 0
PytestCmd "test_cache_nomap.py" "test_cache_nomap_epoch_ctrl" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_nomap.py" "test_cache_nomap_clue" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_nomap.py" "test_cache_nomap_csv" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_nomap.py" "test_cache_nomap_textfile" 1
HandleRcExit $? 0 0
for i in $(seq 1 3)
do
test_name="test_cache_nomap_multiple_cache${i}"

View File

@ -17,6 +17,7 @@ Testing cache operator with mappable datasets
"""
import os
import pytest
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as c_vision
from mindspore import log as logger
@ -26,7 +27,13 @@ DATA_DIR = "../data/dataset/testImageNetData/train/"
COCO_DATA_DIR = "../data/dataset/testCOCO/train/"
COCO_ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json"
NO_IMAGE_DIR = "../data/dataset/testRandomData/"
MNIST_DATA_DIR = "../data/dataset/testMnistData/"
CELEBA_DATA_DIR = "../data/dataset/testCelebAData/"
VOC_DATA_DIR = "../data/dataset/testVOC2012/"
MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest"
CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data/"
CIFAR100_DATA_DIR = "../data/dataset/testCifar100Data/"
MIND_RECORD_DATA_DIR = "../data/mindrecord/testTwoImageData/twobytes.mindrecord"
GENERATE_GOLDEN = False
@ -443,7 +450,7 @@ def test_cache_map_failure5():
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_failure6():
"""
Test no-cache-supporting leaf ops with Map under cache (failure)
Test no-cache-supporting MindRecord leaf with Map under cache (failure)
repeat
|
@ -451,7 +458,7 @@ def test_cache_map_failure6():
|
Map(resize)
|
Coco
MindRecord
"""
logger.info("Test cache failure 6")
@ -461,22 +468,66 @@ def test_cache_map_failure6():
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
data = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True)
columns_list = ["id", "file_name", "label_name", "img_data", "label_data"]
num_readers = 1
# The dataset has 5 records
data = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list, num_readers)
resize_op = c_vision.Resize((224, 224))
data = data.map(input_columns=["image"], operations=resize_op, cache=some_cache)
data = data.map(input_columns=["img_data"], operations=resize_op, cache=some_cache)
data = data.repeat(4)
with pytest.raises(RuntimeError) as e:
num_iter = 0
for _ in data.create_dict_iterator():
num_iter += 1
assert "There is currently no support for CocoOp under cache" in str(e.value)
assert "There is currently no support for MindRecordOp under cache" in str(e.value)
assert num_iter == 0
logger.info('test_cache_failure6 Ended.\n')
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_failure7():
"""
Test no-cache-supporting Generator leaf with Map under cache (failure)
repeat
|
Cache
|
Map(lambda x: x)
|
Generator
"""
def generator_1d():
for i in range(64):
yield (np.array(i),)
logger.info("Test cache failure 7")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
data = ds.GeneratorDataset(generator_1d, ["data"])
data = data.map((lambda x: x), ["data"], cache=some_cache)
data = data.repeat(4)
with pytest.raises(RuntimeError) as e:
num_iter = 0
for _ in data.create_dict_iterator():
num_iter += 1
assert "There is currently no support for GeneratorOp under cache" in str(e.value)
assert num_iter == 0
logger.info('test_cache_failure7 Ended.\n')
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_parameter_check():
"""
@ -1236,6 +1287,421 @@ def test_cache_map_epoch_ctrl3():
logger.info("test_cache_map_epoch_ctrl3 Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_coco1():
"""
Test mappable coco leaf with cache op right over the leaf
cache
|
Coco
"""
logger.info("Test cache map coco1")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
# This dataset has 6 records
ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True,
cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
epoch_count = 0
for _ in range(num_epoch):
assert sum([1 for _ in iter1]) == 6
epoch_count += 1
assert epoch_count == num_epoch
logger.info("test_cache_map_coco1 Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_coco2():
"""
Test mappable coco leaf with the cache op later in the tree above the map(resize)
cache
|
Map(resize)
|
Coco
"""
logger.info("Test cache map coco2")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
# This dataset has 6 records
ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True)
resize_op = c_vision.Resize((224, 224))
ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
epoch_count = 0
for _ in range(num_epoch):
assert sum([1 for _ in iter1]) == 6
epoch_count += 1
assert epoch_count == num_epoch
logger.info("test_cache_map_coco2 Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_mnist1():
"""
Test mappable mnist leaf with cache op right over the leaf
cache
|
Mnist
"""
logger.info("Test cache map mnist1")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10, cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
epoch_count = 0
for _ in range(num_epoch):
assert sum([1 for _ in iter1]) == 10
epoch_count += 1
assert epoch_count == num_epoch
logger.info("test_cache_map_mnist1 Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_mnist2():
"""
Test mappable mnist leaf with the cache op later in the tree above the map(resize)
cache
|
Map(resize)
|
Mnist
"""
logger.info("Test cache map mnist2")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10)
resize_op = c_vision.Resize((224, 224))
ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
epoch_count = 0
for _ in range(num_epoch):
assert sum([1 for _ in iter1]) == 10
epoch_count += 1
assert epoch_count == num_epoch
logger.info("test_cache_map_mnist2 Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_celeba1():
"""
Test mappable celeba leaf with cache op right over the leaf
cache
|
CelebA
"""
logger.info("Test cache map celeba1")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
# This dataset has 4 records
ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
epoch_count = 0
for _ in range(num_epoch):
assert sum([1 for _ in iter1]) == 4
epoch_count += 1
assert epoch_count == num_epoch
logger.info("test_cache_map_celeba1 Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_celeba2():
"""
Test mappable celeba leaf with the cache op later in the tree above the map(resize)
cache
|
Map(resize)
|
CelebA
"""
logger.info("Test cache map celeba2")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
# This dataset has 4 records
ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True)
resize_op = c_vision.Resize((224, 224))
ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
epoch_count = 0
for _ in range(num_epoch):
assert sum([1 for _ in iter1]) == 4
epoch_count += 1
assert epoch_count == num_epoch
logger.info("test_cache_map_celeba2 Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_manifest1():
"""
Test mappable manifest leaf with cache op right over the leaf
cache
|
Manifest
"""
logger.info("Test cache map manifest1")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
# This dataset has 4 records
ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True, cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
epoch_count = 0
for _ in range(num_epoch):
assert sum([1 for _ in iter1]) == 4
epoch_count += 1
assert epoch_count == num_epoch
logger.info("test_cache_map_manifest1 Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_manifest2():
"""
Test mappable manifest leaf with the cache op later in the tree above the map(resize)
cache
|
Map(resize)
|
Manifest
"""
logger.info("Test cache map manifest2")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
# This dataset has 4 records
ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True)
resize_op = c_vision.Resize((224, 224))
ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
epoch_count = 0
for _ in range(num_epoch):
assert sum([1 for _ in iter1]) == 4
epoch_count += 1
assert epoch_count == num_epoch
logger.info("test_cache_map_manifest2 Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_cifar1():
"""
Test mappable cifar10 leaf with cache op right over the leaf
cache
|
Cifar10
"""
logger.info("Test cache map cifar1")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
epoch_count = 0
for _ in range(num_epoch):
assert sum([1 for _ in iter1]) == 10
epoch_count += 1
assert epoch_count == num_epoch
logger.info("test_cache_map_cifar1 Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_cifar2():
"""
Test mappable cifar100 leaf with the cache op later in the tree above the map(resize)
cache
|
Map(resize)
|
Cifar100
"""
logger.info("Test cache map cifar2")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10)
resize_op = c_vision.Resize((224, 224))
ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
epoch_count = 0
for _ in range(num_epoch):
assert sum([1 for _ in iter1]) == 10
epoch_count += 1
assert epoch_count == num_epoch
logger.info("test_cache_map_cifar2 Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_voc1():
"""
Test mappable voc leaf with cache op right over the leaf
cache
|
VOC
"""
logger.info("Test cache map voc1")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
# This dataset has 9 records
ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True, cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
epoch_count = 0
for _ in range(num_epoch):
assert sum([1 for _ in iter1]) == 9
epoch_count += 1
assert epoch_count == num_epoch
logger.info("test_cache_map_voc1 Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_voc2():
"""
Test mappable voc leaf with the cache op later in the tree above the map(resize)
cache
|
Map(resize)
|
VOC
"""
logger.info("Test cache map voc2")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
# This dataset has 9 records
ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
resize_op = c_vision.Resize((224, 224))
ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
epoch_count = 0
for _ in range(num_epoch):
assert sum([1 for _ in iter1]) == 9
epoch_count += 1
assert epoch_count == num_epoch
logger.info("test_cache_map_voc2 Ended.\n")
if __name__ == '__main__':
test_cache_map_basic1()
test_cache_map_basic2()

View File

@ -20,22 +20,26 @@ import itertools
import pytest
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.text as text
import mindspore.dataset.vision.c_transforms as c_vision
from mindspore import log as logger
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
DATA_DIR2 = ["../data/dataset/testTextTFRecord/text.tfrecord"]
TEXT_TF_DATA_DIR = ["../data/dataset/testTextTFRecord/text.tfrecord"]
SCHEMA_DIR2 = "../data/dataset/testTextTFRecord/datasetSchema.json"
DATA_DIR3 = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
SCHEMA_DIR3 = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"
TRAIN_DATA_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
TRAIN_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"
DATA_DIR4 = "../data/dataset/testImageNetData/train/"
IMAGE_FOLDER_DATA_DIR = "../data/dataset/testImageNetData/train/"
CLUE_DATA_DIR = '../data/dataset/testCLUE/afqmc/train.json'
CSV_DATA_DIR = '../data/dataset/testCSV/1.csv'
TEXT_FILE_DATA_DIR = "../data/dataset/testTextFileDataset/1.txt"
GENERATE_GOLDEN = False
@ -1310,7 +1314,7 @@ def test_cache_nomap_multiple_cache1():
eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
# This dataset has 12 records in it
train_dataset = ds.TFRecordDataset(DATA_DIR3, SCHEMA_DIR3)
train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR)
decode_op = c_vision.Decode()
train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache)
@ -1359,7 +1363,7 @@ def test_cache_nomap_multiple_cache2():
image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache)
# This dataset has 3 records in it only
text_dataset = ds.TFRecordDataset(DATA_DIR2, SCHEMA_DIR2, cache=text_cache)
text_dataset = ds.TFRecordDataset(TEXT_TF_DATA_DIR, SCHEMA_DIR2, cache=text_cache)
num_epoch = 5
image_iter = image_dataset.create_dict_iterator(num_epochs=num_epoch)
@ -1404,7 +1408,7 @@ def test_cache_nomap_multiple_cache3():
tf_dataset = tf_dataset.map(input_columns=["image"], operations=decode_op, cache=tf_cache)
# This DATA_DIR only has 2 images in it
image_dataset = ds.ImageFolderDataset(dataset_dir=DATA_DIR4)
image_dataset = ds.ImageFolderDataset(dataset_dir=IMAGE_FOLDER_DATA_DIR)
image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache)
num_epoch = 5
@ -1443,7 +1447,7 @@ def test_cache_nomap_multiple_cache_train():
train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
# This dataset has 12 records in it
train_dataset = ds.TFRecordDataset(DATA_DIR3, SCHEMA_DIR3)
train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR)
decode_op = c_vision.Decode()
train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache)
@ -1497,6 +1501,239 @@ def test_cache_nomap_multiple_cache_eval():
logger.info("test_cache_nomap_multiple_cache_eval Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_clue1():
"""
A clue dataset (a non mappable dataset) with a cache over it just after the leaf
In this one, the clue dataset will be given sharding configuration, however since a cache is
used, the tree prepare should undo the sharding configuration and instead, a distributed
sampler will be chosen with the same shard config.
Cache
|
CLUE
"""
logger.info("Test cache nomap clue 1")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
# With only 3 records shard into 3, we expect only 1 record returned for this shard
# However, the sharding will be done by the sampler, not by the clue leaf node
# In this case, it is a row-based sharding, not the file-based sharding that would happen if
# there was not any cache.
ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_shards=3, shard_id=1, cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
epoch_count = 0
for _ in range(num_epoch):
assert sum([1 for _ in iter1]) == 1
epoch_count += 1
assert epoch_count == num_epoch
logger.info("test_cache_nomap_clue1 Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_clue2():
"""
A clue dataset (a non mappable dataset) with a cache over it after map
In this one, a num_samples argument is given
Cache
|
map(lambda x: x)
|
CLUE
"""
logger.info("Test cache nomap clue 2")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_samples=2)
ds1 = ds1.map((lambda x: x), ["label"], cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
epoch_count = 0
for _ in range(num_epoch):
assert sum([1 for _ in iter1]) == 2
epoch_count += 1
assert epoch_count == num_epoch
logger.info("test_cache_nomap_clue2 Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_csv1():
"""
A csv dataset (a non mappable dataset) with a cache over it just after the leaf
In this one, the csv dataset will be given sharding configuration, however since a cache is
used, the tree prepare should undo the sharding configuration and instead, a distributed
sampler will be chosen with the same shard config.
Cache
|
CSV
"""
logger.info("Test cache nomap csv 1")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
# With only 3 records shard into 3, we expect only 1 record returned for this shard
# However, the sharding will be done by the sampler, not by the clue leaf node
# In this case, it is a row-based sharding, not the file-based sharding that would happen if
# there was not any cache.
ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"],
column_names=['col1', 'col2', 'col3', 'col4'], num_shards=3, shard_id=1, cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
epoch_count = 0
for _ in range(num_epoch):
assert sum([1 for _ in iter1]) == 1
epoch_count += 1
assert epoch_count == num_epoch
logger.info("test_cache_nomap_csv1 Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_csv2():
"""
A csv dataset (a non mappable dataset) with a cache over it after map
In this one, a num_samples argument is given
Cache
|
map(lambda x: x)
|
CSV
"""
logger.info("Test cache nomap csv 2")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"],
column_names=['col1', 'col2', 'col3', 'col4'], num_samples=2)
ds1 = ds1.map((lambda x: x), ["col1"], cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
epoch_count = 0
for _ in range(num_epoch):
assert sum([1 for _ in iter1]) == 2
epoch_count += 1
assert epoch_count == num_epoch
logger.info("test_cache_nomap_csv2 Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_textfile1():
"""
A text file dataset (a non mappable dataset) with a cache over it just after the leaf
In this one, the text file dataset will be given sharding configuration, however since a cache is
used, the tree prepare should undo the sharding configuration and instead, a distributed
sampler will be chosen with the same shard config.
Cache
|
TextFile
"""
logger.info("Test cache nomap textfile 1")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
# With only 3 records shard into 3, we expect only 1 record returned for this shard
# However, the sharding will be done by the sampler, not by the clue leaf node
# In this case, it is a row-based sharding, not the file-based sharding that would happen if
# there was not any cache.
ds1 = ds.CSVDataset(TEXT_FILE_DATA_DIR, num_shards=3, shard_id=1, cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
epoch_count = 0
for _ in range(num_epoch):
assert sum([1 for _ in iter1]) == 1
epoch_count += 1
assert epoch_count == num_epoch
logger.info("test_cache_nomap_textfile1 Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_textfile2():
"""
A text file dataset (a non mappable dataset) with a cache over it after map
In this one, a num_samples argument is given
Cache
|
Map(tokenizer)
|
TextFile
"""
def my_tokenizer(line):
words = line.split()
if not words:
return [""]
return words
logger.info("Test cache nomap textfile 2")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
ds1 = ds.TextFileDataset(TEXT_FILE_DATA_DIR, num_samples=2)
tokenizer = text.PythonTokenizer(my_tokenizer)
ds1 = ds1.map(operations=tokenizer, cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
epoch_count = 0
for _ in range(num_epoch):
assert sum([1 for _ in iter1]) == 2
epoch_count += 1
assert epoch_count == num_epoch
logger.info("test_cache_nomap_textfile2 Ended.\n")
if __name__ == '__main__':
test_cache_nomap_basic1()
test_cache_nomap_basic2()

View File

@ -40,8 +40,9 @@ def test_textline_dataset_all_file():
assert count == 5
def test_textline_dataset_num_samples_zero():
data = ds.TextFileDataset(DATA_FILE, num_samples=0)
def test_textline_dataset_num_samples_none():
# Do not provide a num_samples argument, so it would be None by default
data = ds.TextFileDataset(DATA_FILE)
count = 0
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("{}".format(i["text"]))
@ -208,7 +209,7 @@ def test_textline_dataset_exceptions():
if __name__ == "__main__":
test_textline_dataset_one_file()
test_textline_dataset_all_file()
test_textline_dataset_num_samples_zero()
test_textline_dataset_num_samples_none()
test_textline_dataset_shuffle_false4()
test_textline_dataset_shuffle_false1()
test_textline_dataset_shuffle_files4()