!6905 [MD] Enable cache for more leaf datasets
Merge pull request !6905 from lixiachen/CacheOp_dev
This commit is contained in:
commit
6f77ec45f1
|
@ -1075,7 +1075,7 @@ std::vector<std::shared_ptr<DatasetOp>> CLUEDataset::Build() {
|
|||
|
||||
std::shared_ptr<ClueOp> clue_op =
|
||||
std::make_shared<ClueOp>(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map,
|
||||
sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_);
|
||||
sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr);
|
||||
RETURN_EMPTY_IF_ERROR(clue_op->Init());
|
||||
if (shuffle_ == ShuffleMode::kGlobal) {
|
||||
// Inject ShuffleOp
|
||||
|
@ -1256,7 +1256,7 @@ std::vector<std::shared_ptr<DatasetOp>> CSVDataset::Build() {
|
|||
|
||||
std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>(
|
||||
sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, rows_per_buffer_,
|
||||
num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_);
|
||||
num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr);
|
||||
RETURN_EMPTY_IF_ERROR(csv_op->Init());
|
||||
if (shuffle_ == ShuffleMode::kGlobal) {
|
||||
// Inject ShuffleOp
|
||||
|
@ -1502,7 +1502,7 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() {
|
|||
// Create and initalize TextFileOp
|
||||
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
|
||||
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files,
|
||||
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(nullptr));
|
||||
connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr);
|
||||
RETURN_EMPTY_IF_ERROR(text_file_op->Init());
|
||||
|
||||
if (shuffle_ == ShuffleMode::kGlobal) {
|
||||
|
|
|
@ -1345,6 +1345,9 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset
|
|||
std::string err_msg = "Error: No dataset files specified for manifest";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
int num_workers = 0;
|
||||
std::shared_ptr<CacheClient> cache_client = nullptr;
|
||||
std::shared_ptr<ManifestOp::Builder> builder = std::make_shared<ManifestOp::Builder>();
|
||||
(void)builder->SetManifestFile(ToString(args["dataset_file"]));
|
||||
|
||||
|
@ -1354,7 +1357,8 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset
|
|||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
num_workers = ToInt(value);
|
||||
(void)builder->SetNumWorkers(num_workers);
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
|
||||
|
@ -1365,12 +1369,27 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset
|
|||
(void)builder->SetDecode(ToBool(value));
|
||||
} else if (key == "usage") {
|
||||
(void)builder->SetUsage(ToString(value));
|
||||
} else if (key == "cache") {
|
||||
cache_client = value.cast<std::shared_ptr<CacheClient>>();
|
||||
}
|
||||
}
|
||||
}
|
||||
std::shared_ptr<ManifestOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
*top = op;
|
||||
std::shared_ptr<ManifestOp> manifest_op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&manifest_op));
|
||||
RETURN_IF_NOT_OK(tree_->AssociateNode(manifest_op));
|
||||
*top = manifest_op;
|
||||
|
||||
// Additionally, add a cache if required.
|
||||
// Note that this cache op is only acting as a place holder for the caching position
|
||||
// within the tree. Later, a pre-pass will execute a tree transform to set up the actual
|
||||
// caching logic in the tree.
|
||||
if (cache_client) {
|
||||
std::shared_ptr<DatasetOp> cache_op = nullptr;
|
||||
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, manifest_op, &cache_op));
|
||||
*top = cache_op;
|
||||
*bottom = manifest_op;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1380,6 +1399,8 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
|
|||
CHECK_FAIL_RETURN_UNEXPECTED(!args["task"].is_none(), "Error: No task specified.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!args["usage"].is_none(), "Error: No usage specified.");
|
||||
|
||||
int num_workers = 0;
|
||||
std::shared_ptr<CacheClient> cache_client = nullptr;
|
||||
std::shared_ptr<VOCOp::Builder> builder = std::make_shared<VOCOp::Builder>();
|
||||
(void)builder->SetDir(ToString(args["dataset_dir"]));
|
||||
(void)builder->SetTask(ToString(args["task"]));
|
||||
|
@ -1389,7 +1410,8 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
|
|||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
num_workers = ToInt(value);
|
||||
(void)builder->SetNumWorkers(num_workers);
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
|
||||
|
@ -1398,12 +1420,26 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
|
|||
(void)builder->SetDecode(ToBool(value));
|
||||
} else if (key == "class_indexing") {
|
||||
(void)builder->SetClassIndex(ToStringMap(value));
|
||||
} else if (key == "cache") {
|
||||
cache_client = value.cast<std::shared_ptr<CacheClient>>();
|
||||
}
|
||||
}
|
||||
}
|
||||
std::shared_ptr<VOCOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
*top = op;
|
||||
std::shared_ptr<VOCOp> voc_op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&voc_op));
|
||||
RETURN_IF_NOT_OK(tree_->AssociateNode(voc_op));
|
||||
*top = voc_op;
|
||||
|
||||
// Additionally, add a cache if required.
|
||||
// Note that this cache op is only acting as a place holder for the caching position
|
||||
// within the tree. Later, a pre-pass will execute a tree transform to set up the actual
|
||||
// caching logic in the tree.
|
||||
if (cache_client) {
|
||||
std::shared_ptr<DatasetOp> cache_op = nullptr;
|
||||
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, voc_op, &cache_op));
|
||||
*top = cache_op;
|
||||
*bottom = voc_op;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -1425,6 +1461,8 @@ Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
|||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
int num_workers = 0;
|
||||
std::shared_ptr<CacheClient> cache_client = nullptr;
|
||||
std::shared_ptr<CocoOp::Builder> builder = std::make_shared<CocoOp::Builder>();
|
||||
(void)builder->SetDir(ToString(args["dataset_dir"]));
|
||||
(void)builder->SetFile(ToString(args["annotation_file"]));
|
||||
|
@ -1434,19 +1472,35 @@ Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
|||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
num_workers = ToInt(value);
|
||||
(void)builder->SetNumWorkers(num_workers);
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
|
||||
(void)builder->SetSampler(std::move(sampler));
|
||||
} else if (key == "decode") {
|
||||
(void)builder->SetDecode(ToBool(value));
|
||||
} else if (key == "cache") {
|
||||
cache_client = value.cast<std::shared_ptr<CacheClient>>();
|
||||
}
|
||||
}
|
||||
}
|
||||
std::shared_ptr<CocoOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
*top = op;
|
||||
std::shared_ptr<CocoOp> coco_op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&coco_op));
|
||||
RETURN_IF_NOT_OK(tree_->AssociateNode(coco_op));
|
||||
*top = coco_op;
|
||||
|
||||
// Additionally, add a cache if required.
|
||||
// Note that this cache op is only acting as a place holder for the caching position
|
||||
// within the tree. Later, a pre-pass will execute a tree transform to set up the actual
|
||||
// caching logic in the tree.
|
||||
if (cache_client) {
|
||||
std::shared_ptr<DatasetOp> cache_op = nullptr;
|
||||
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, coco_op, &cache_op));
|
||||
*top = cache_op;
|
||||
*bottom = coco_op;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1458,6 +1512,8 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO
|
|||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
int num_workers = 0;
|
||||
std::shared_ptr<CacheClient> cache_client = nullptr;
|
||||
std::shared_ptr<CifarOp::Builder> builder = std::make_shared<CifarOp::Builder>();
|
||||
(void)builder->SetCifarDir(ToString(args["dataset_dir"]));
|
||||
|
||||
|
@ -1467,22 +1523,38 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO
|
|||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
num_workers = ToInt(value);
|
||||
(void)builder->SetNumWorkers(num_workers);
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
|
||||
(void)builder->SetSampler(std::move(sampler));
|
||||
} else if (key == "usage") {
|
||||
(void)builder->SetUsage(ToString(value));
|
||||
} else if (key == "cache") {
|
||||
cache_client = value.cast<std::shared_ptr<CacheClient>>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(void)builder->SetCifarType(true);
|
||||
|
||||
std::shared_ptr<CifarOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
*top = op;
|
||||
std::shared_ptr<CifarOp> cifar_op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&cifar_op));
|
||||
RETURN_IF_NOT_OK(tree_->AssociateNode(cifar_op));
|
||||
*top = cifar_op;
|
||||
|
||||
// Additionally, add a cache if required.
|
||||
// Note that this cache op is only acting as a place holder for the caching position
|
||||
// within the tree. Later, a pre-pass will execute a tree transform to set up the actual
|
||||
// caching logic in the tree.
|
||||
if (cache_client) {
|
||||
std::shared_ptr<DatasetOp> cache_op = nullptr;
|
||||
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, cifar_op, &cache_op));
|
||||
*top = cache_op;
|
||||
*bottom = cifar_op;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1494,6 +1566,8 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset
|
|||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
int num_workers = 0;
|
||||
std::shared_ptr<CacheClient> cache_client = nullptr;
|
||||
std::shared_ptr<CifarOp::Builder> builder = std::make_shared<CifarOp::Builder>();
|
||||
(void)builder->SetCifarDir(ToString(args["dataset_dir"]));
|
||||
|
||||
|
@ -1503,22 +1577,37 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset
|
|||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
num_workers = ToInt(value);
|
||||
(void)builder->SetNumWorkers(num_workers);
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
|
||||
(void)builder->SetSampler(std::move(sampler));
|
||||
} else if (key == "usage") {
|
||||
(void)builder->SetUsage(ToString(value));
|
||||
} else if (key == "cache") {
|
||||
cache_client = value.cast<std::shared_ptr<CacheClient>>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(void)builder->SetCifarType(false);
|
||||
|
||||
std::shared_ptr<CifarOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
*top = op;
|
||||
std::shared_ptr<CifarOp> cifar_op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&cifar_op));
|
||||
RETURN_IF_NOT_OK(tree_->AssociateNode(cifar_op));
|
||||
*top = cifar_op;
|
||||
|
||||
// Additionally, add a cache if required.
|
||||
// Note that this cache op is only acting as a place holder for the caching position
|
||||
// within the tree. Later, a pre-pass will execute a tree transform to set up the actual
|
||||
// caching logic in the tree.
|
||||
if (cache_client) {
|
||||
std::shared_ptr<DatasetOp> cache_op = nullptr;
|
||||
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, cifar_op, &cache_op));
|
||||
*top = cache_op;
|
||||
*bottom = cifar_op;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1609,6 +1698,8 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
|||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
int num_workers = 0;
|
||||
std::shared_ptr<CacheClient> cache_client = nullptr;
|
||||
std::shared_ptr<MnistOp::Builder> builder = std::make_shared<MnistOp::Builder>();
|
||||
(void)builder->SetDir(ToString(args["dataset_dir"]));
|
||||
|
||||
|
@ -1618,19 +1709,35 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
|||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
num_workers = ToInt(value);
|
||||
(void)builder->SetNumWorkers(num_workers);
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
|
||||
(void)builder->SetSampler(std::move(sampler));
|
||||
} else if (key == "usage") {
|
||||
(void)builder->SetUsage(ToString(value));
|
||||
} else if (key == "cache") {
|
||||
cache_client = value.cast<std::shared_ptr<CacheClient>>();
|
||||
}
|
||||
}
|
||||
}
|
||||
std::shared_ptr<MnistOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
*top = op;
|
||||
std::shared_ptr<MnistOp> mnist_op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&mnist_op));
|
||||
RETURN_IF_NOT_OK(tree_->AssociateNode(mnist_op));
|
||||
*top = mnist_op;
|
||||
|
||||
// Additionally, add a cache if required.
|
||||
// Note that this cache op is only acting as a place holder for the caching position
|
||||
// within the tree. Later, a pre-pass will execute a tree transform to set up the actual
|
||||
// caching logic in the tree.
|
||||
if (cache_client) {
|
||||
std::shared_ptr<DatasetOp> cache_op = nullptr;
|
||||
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, mnist_op, &cache_op));
|
||||
*top = cache_op;
|
||||
*bottom = mnist_op;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1642,6 +1749,8 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
|
|||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
|
||||
int num_workers = 0;
|
||||
std::shared_ptr<CacheClient> cache_client = nullptr;
|
||||
std::shared_ptr<CelebAOp::Builder> builder = std::make_shared<CelebAOp::Builder>();
|
||||
if (builder == nullptr) {
|
||||
std::string err_msg = "Create celebaop builder failed";
|
||||
|
@ -1653,7 +1762,8 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
|
|||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
num_workers = ToInt(value);
|
||||
(void)builder->SetNumWorkers(num_workers);
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
|
||||
|
@ -1664,13 +1774,28 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
|
|||
(void)builder->SetExtensions(ToStringSet(value));
|
||||
} else if (key == "usage") {
|
||||
(void)builder->SetUsage(ToString(value));
|
||||
} else if (key == "cache") {
|
||||
cache_client = value.cast<std::shared_ptr<CacheClient>>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<CelebAOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
*top = op;
|
||||
std::shared_ptr<CelebAOp> celeba_op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&celeba_op));
|
||||
RETURN_IF_NOT_OK(tree_->AssociateNode(celeba_op));
|
||||
*top = celeba_op;
|
||||
|
||||
// Additionally, add a cache if required.
|
||||
// Note that this cache op is only acting as a place holder for the caching position
|
||||
// within the tree. Later, a pre-pass will execute a tree transform to set up the actual
|
||||
// caching logic in the tree.
|
||||
if (cache_client) {
|
||||
std::shared_ptr<DatasetOp> cache_op = nullptr;
|
||||
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, celeba_op, &cache_op));
|
||||
*top = cache_op;
|
||||
*bottom = celeba_op;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1678,6 +1803,9 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
|
|||
std::shared_ptr<DatasetOp> *bottom) {
|
||||
// Required arguments
|
||||
std::vector<std::string> files_list;
|
||||
std::shared_ptr<CacheClient> cache_client = nullptr;
|
||||
std::shared_ptr<Sampler> sampler = nullptr;
|
||||
int num_workers = 0;
|
||||
std::shared_ptr<TextFileOp::Builder> builder = std::make_shared<TextFileOp::Builder>();
|
||||
if (!args["dataset_files"].is_none()) {
|
||||
files_list = ToStringVector(args["dataset_files"]);
|
||||
|
@ -1693,7 +1821,8 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
|
|||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
num_workers = ToInt(value);
|
||||
(void)builder->SetNumWorkers(num_workers);
|
||||
} else if (key == "shuffle_files") {
|
||||
(void)builder->SetShuffleFiles(ToBool(value));
|
||||
} else if (key == "shuffle_global") {
|
||||
|
@ -1705,16 +1834,35 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
|
|||
(void)builder->SetNumDevices(num_devices);
|
||||
} else if (key == "shard_id") {
|
||||
(void)builder->SetDeviceId(ToInt(value));
|
||||
} else if (key == "cache") {
|
||||
cache_client = value.cast<std::shared_ptr<CacheClient>>();
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
sampler = create().cast<std::shared_ptr<Sampler>>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed
|
||||
// because TextFileOp is a non-mappable dataset that does not support sampling.
|
||||
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
// That is why we save the sampler here in a leaf node that does not use sampling.
|
||||
if (sampler) {
|
||||
(void)builder->SetSampler(std::move(sampler));
|
||||
} else if (cache_client) {
|
||||
int64_t num_samples = 0;
|
||||
int64_t start_index = 0;
|
||||
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
|
||||
(void)builder->SetSampler(std::move(sampler));
|
||||
}
|
||||
|
||||
std::shared_ptr<TextFileOp> txt_op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&txt_op));
|
||||
RETURN_IF_NOT_OK(tree_->AssociateNode(txt_op));
|
||||
*top = txt_op;
|
||||
|
||||
if (shuffle_required) {
|
||||
if (!cache_client && shuffle_required) {
|
||||
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||
int64_t shuffle_size = 0;
|
||||
int64_t num_rows = 0;
|
||||
|
@ -1729,6 +1877,15 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
|
|||
*bottom = txt_op;
|
||||
}
|
||||
|
||||
// Add a cache op over this op if required and update the output subtree (top/bottom)
|
||||
if (cache_client) {
|
||||
// Note, it is not allowed to have both shuffle and cache
|
||||
std::shared_ptr<DatasetOp> cache_op = nullptr;
|
||||
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, txt_op, &cache_op));
|
||||
*top = cache_op;
|
||||
*bottom = txt_op;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1829,6 +1986,10 @@ Status DEPipeline::ParseBuildSentencePieceVocabOp(const py::dict &args, std::sha
|
|||
Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
|
||||
std::shared_ptr<DatasetOp> *bottom) {
|
||||
std::vector<std::string> files_list;
|
||||
std::shared_ptr<CacheClient> cache_client = nullptr;
|
||||
std::shared_ptr<Sampler> sampler = nullptr;
|
||||
int num_workers = 0;
|
||||
|
||||
std::shared_ptr<ClueOp::Builder> builder = std::make_shared<ClueOp::Builder>();
|
||||
if (!args["dataset_files"].is_none()) {
|
||||
files_list = ToStringVector(args["dataset_files"]);
|
||||
|
@ -1844,7 +2005,8 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
|||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
num_workers = ToInt(value);
|
||||
(void)builder->SetNumWorkers(num_workers);
|
||||
} else if (key == "shuffle_files") {
|
||||
(void)builder->SetShuffleFiles(ToBool(value));
|
||||
} else if (key == "shuffle_global") {
|
||||
|
@ -1866,16 +2028,35 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
|||
}
|
||||
}
|
||||
(void)builder->SetColsKeyMap(map_dict);
|
||||
} else if (key == "cache") {
|
||||
cache_client = value.cast<std::shared_ptr<CacheClient>>();
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
sampler = create().cast<std::shared_ptr<Sampler>>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed
|
||||
// because ClueOp is a non-mappable dataset that does not support sampling.
|
||||
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
// That is why we save the sampler here in a leaf node that does not use sampling.
|
||||
if (sampler) {
|
||||
(void)builder->SetSampler(std::move(sampler));
|
||||
} else if (cache_client) {
|
||||
int64_t num_samples = 0;
|
||||
int64_t start_index = 0;
|
||||
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
|
||||
(void)builder->SetSampler(std::move(sampler));
|
||||
}
|
||||
|
||||
std::shared_ptr<ClueOp> clue_op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&clue_op));
|
||||
RETURN_IF_NOT_OK(tree_->AssociateNode(clue_op));
|
||||
*top = clue_op;
|
||||
|
||||
if (shuffle_required) {
|
||||
if (!cache_client && shuffle_required) {
|
||||
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||
int64_t shuffle_size = 0;
|
||||
int64_t num_rows = 0;
|
||||
|
@ -1890,6 +2071,15 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
|||
*bottom = clue_op;
|
||||
}
|
||||
|
||||
// Add a cache op over this op if required and update the output subtree (top/bottom)
|
||||
if (cache_client) {
|
||||
// Note, it is not allowed to have both shuffle and cache
|
||||
std::shared_ptr<DatasetOp> cache_op = nullptr;
|
||||
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, clue_op, &cache_op));
|
||||
*top = cache_op;
|
||||
*bottom = clue_op;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1921,6 +2111,9 @@ Status DEPipeline::AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num
|
|||
Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
|
||||
std::shared_ptr<DatasetOp> *bottom) {
|
||||
std::vector<std::string> files_list;
|
||||
std::shared_ptr<CacheClient> cache_client = nullptr;
|
||||
std::shared_ptr<Sampler> sampler = nullptr;
|
||||
int num_workers = 0;
|
||||
std::shared_ptr<CsvOp::Builder> builder = std::make_shared<CsvOp::Builder>();
|
||||
if (!args["dataset_files"].is_none()) {
|
||||
files_list = ToStringVector(args["dataset_files"]);
|
||||
|
@ -1938,7 +2131,8 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *
|
|||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
num_workers = ToInt(value);
|
||||
(void)builder->SetNumWorkers(num_workers);
|
||||
} else if (key == "shuffle_files") {
|
||||
(void)builder->SetShuffleFiles(ToBool(value));
|
||||
} else if (key == "shuffle_global") {
|
||||
|
@ -1971,16 +2165,35 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *
|
|||
} else if (key == "column_names") {
|
||||
col_names = ToStringVector(value);
|
||||
(void)builder->SetColumName(col_names);
|
||||
} else if (key == "cache") {
|
||||
cache_client = value.cast<std::shared_ptr<CacheClient>>();
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
sampler = create().cast<std::shared_ptr<Sampler>>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed
|
||||
// because CsvOp is a non-mappable dataset that does not support sampling.
|
||||
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
// That is why we save the sampler here in a leaf node that does not use sampling.
|
||||
if (sampler) {
|
||||
(void)builder->SetSampler(std::move(sampler));
|
||||
} else if (cache_client) {
|
||||
int64_t num_samples = 0;
|
||||
int64_t start_index = 0;
|
||||
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
|
||||
(void)builder->SetSampler(std::move(sampler));
|
||||
}
|
||||
|
||||
std::shared_ptr<CsvOp> csv_op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&csv_op));
|
||||
RETURN_IF_NOT_OK(tree_->AssociateNode(csv_op));
|
||||
*top = csv_op;
|
||||
|
||||
if (shuffle_required) {
|
||||
if (!cache_client && shuffle_required) {
|
||||
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||
int64_t shuffle_size = 0;
|
||||
int64_t num_rows = 0;
|
||||
|
@ -1995,6 +2208,15 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *
|
|||
*bottom = csv_op;
|
||||
}
|
||||
|
||||
// Add a cache op over this op if required and update the output subtree (top/bottom)
|
||||
if (cache_client) {
|
||||
// Note, it is not allowed to have both shuffle and cache
|
||||
std::shared_ptr<DatasetOp> cache_op = nullptr;
|
||||
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, csv_op, &cache_op));
|
||||
*top = cache_op;
|
||||
*bottom = csv_op;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -70,13 +70,12 @@ Status AlbumOp::Builder::SanityCheck() {
|
|||
AlbumOp::AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode,
|
||||
const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<Sampler> sampler)
|
||||
: ParallelOp(num_wkrs, queue_size),
|
||||
: ParallelOp(num_wkrs, queue_size, std::move(sampler)),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
folder_path_(file_dir),
|
||||
decode_(do_decode),
|
||||
extensions_(exts),
|
||||
data_schema_(std::move(data_schema)),
|
||||
sampler_(std::move(sampler)),
|
||||
row_cnt_(0),
|
||||
buf_cnt_(0),
|
||||
sampler_ind_(0),
|
||||
|
|
|
@ -284,7 +284,6 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
|
|||
std::set<std::string> extensions_; // extensions allowed
|
||||
std::unordered_map<std::string, int32_t> col_name_map_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
std::shared_ptr<Sampler> sampler_;
|
||||
int64_t row_cnt_;
|
||||
int64_t buf_cnt_;
|
||||
int64_t sampler_ind_;
|
||||
|
|
|
@ -25,13 +25,18 @@
|
|||
#include "minddata/dataset/util/task_manager.h"
|
||||
#include "minddata/dataset/engine/jagged_connector.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/io_block.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
ClueOp::Builder::Builder()
|
||||
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
|
||||
: builder_device_id_(0),
|
||||
builder_num_devices_(1),
|
||||
builder_num_samples_(0),
|
||||
builder_shuffle_files_(false),
|
||||
builder_sampler_(nullptr) {
|
||||
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
|
||||
builder_num_workers_ = config_manager->num_parallel_workers();
|
||||
builder_op_connector_size_ = config_manager->op_connector_size();
|
||||
|
@ -68,7 +73,7 @@ Status ClueOp::Builder::Build(std::shared_ptr<ClueOp> *op) {
|
|||
std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>(
|
||||
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map,
|
||||
builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_,
|
||||
builder_device_id_);
|
||||
builder_device_id_, std::move(builder_sampler_));
|
||||
RETURN_IF_NOT_OK(clue_op->Init());
|
||||
*op = std::move(clue_op);
|
||||
|
||||
|
@ -88,8 +93,8 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim
|
|||
|
||||
ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
|
||||
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
|
||||
bool shuffle_files, int32_t num_device, int32_t device_id)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr<Sampler> sampler)
|
||||
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
num_rows_per_shard_(0),
|
||||
all_num_rows_(0),
|
||||
|
@ -539,5 +544,21 @@ Status ClueOp::ComputeColMap() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing
|
||||
// a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so
|
||||
// that this clue op will produce the full set of data into the cache.
|
||||
void ClueOp::MakeSimpleProducer() {
|
||||
device_id_ = 0;
|
||||
num_devices_ = 1;
|
||||
shuffle_files_ = false;
|
||||
num_samples_ = 0;
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status ClueOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(shared_from_base<ClueOp>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <map>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
|
@ -122,6 +123,14 @@ class ClueOp : public ParallelOp {
|
|||
// @return - the a string vector
|
||||
std::vector<std::string> split(const std::string &s, char delim);
|
||||
|
||||
// Setter method
|
||||
// @param std::shared_ptr<Sampler> sampler
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
|
||||
builder_sampler_ = std::move(sampler);
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
int32_t builder_device_id_;
|
||||
int32_t builder_num_devices_;
|
||||
|
@ -133,12 +142,13 @@ class ClueOp : public ParallelOp {
|
|||
std::vector<std::string> builder_clue_files_list_;
|
||||
bool builder_shuffle_files_;
|
||||
std::map<std::string, std::string> builder_cols_to_keyword_;
|
||||
std::shared_ptr<Sampler> builder_sampler_;
|
||||
};
|
||||
|
||||
// Constructor of ClueOp
|
||||
ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
|
||||
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
|
||||
bool shuffle_files, int32_t num_devices, int32_t device_id);
|
||||
bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<Sampler> sampler);
|
||||
|
||||
// Default destructor
|
||||
~ClueOp() = default;
|
||||
|
@ -173,6 +183,17 @@ class ClueOp : public ParallelOp {
|
|||
// @return Vector of the input file names
|
||||
std::vector<std::string> FileNames() { return clue_files_list_; }
|
||||
|
||||
/// \Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing
|
||||
/// a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so
|
||||
/// that this clue op will produce the full set of data into the cache.
|
||||
void MakeSimpleProducer();
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
|
|
|
@ -124,7 +124,7 @@ Status CocoOp::Builder::SanityCheck() {
|
|||
CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path,
|
||||
int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
|
||||
: ParallelOp(num_workers, queue_size),
|
||||
: ParallelOp(num_workers, queue_size, std::move(sampler)),
|
||||
decode_(decode),
|
||||
row_cnt_(0),
|
||||
buf_cnt_(0),
|
||||
|
@ -132,7 +132,6 @@ CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path,
|
|||
image_folder_path_(image_folder_path),
|
||||
annotation_path_(annotation_path),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
sampler_(std::move(sampler)),
|
||||
data_schema_(std::move(data_schema)) {
|
||||
io_block_queues_.Init(num_workers_, queue_size);
|
||||
}
|
||||
|
|
|
@ -206,6 +206,10 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
|
|||
/// \return Status of the node visit
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
// Op name getter
|
||||
// @return Name of the current Op
|
||||
std::string Name() const override { return "CocoOp"; }
|
||||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
// @return Status - The error code return
|
||||
|
@ -324,7 +328,6 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
|
|||
std::string annotation_path_;
|
||||
TaskType task_type_;
|
||||
int32_t rows_per_buffer_;
|
||||
std::shared_ptr<Sampler> sampler_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
|
||||
WaitPost wp_;
|
||||
|
|
|
@ -22,12 +22,17 @@
|
|||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/engine/jagged_connector.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
CsvOp::Builder::Builder()
|
||||
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
|
||||
: builder_device_id_(0),
|
||||
builder_num_devices_(1),
|
||||
builder_num_samples_(0),
|
||||
builder_shuffle_files_(false),
|
||||
builder_sampler_(nullptr) {
|
||||
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
|
||||
builder_num_workers_ = config_manager->num_parallel_workers();
|
||||
builder_op_connector_size_ = config_manager->op_connector_size();
|
||||
|
@ -59,7 +64,8 @@ Status CsvOp::Builder::Build(std::shared_ptr<CsvOp> *op) {
|
|||
std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>(
|
||||
builder_csv_files_list_, builder_field_delim_, builder_column_default_list_, builder_column_name_list_,
|
||||
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_,
|
||||
builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_);
|
||||
builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_,
|
||||
std::move(builder_sampler_));
|
||||
RETURN_IF_NOT_OK(csv_op->Init());
|
||||
*op = std::move(csv_op);
|
||||
|
||||
|
@ -70,8 +76,8 @@ CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
|
|||
const std::vector<std::shared_ptr<BaseRecord>> &column_default,
|
||||
const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer,
|
||||
int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files,
|
||||
int32_t num_device, int32_t device_id)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
int32_t num_device, int32_t device_id, std::shared_ptr<Sampler> sampler)
|
||||
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
|
||||
csv_files_list_(std::move(csv_files_list)),
|
||||
field_delim_(field_delim),
|
||||
column_default_list_(column_default),
|
||||
|
@ -889,5 +895,21 @@ Status CsvOp::ComputeColMap() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing
|
||||
// a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so
|
||||
// that this csv op will produce the full set of data into the cache.
|
||||
void CsvOp::MakeSimpleProducer() {
|
||||
device_id_ = 0;
|
||||
num_devices_ = 1;
|
||||
shuffle_files_ = false;
|
||||
num_samples_ = 0;
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status CsvOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(shared_from_base<CsvOp>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -240,6 +240,14 @@ class CsvOp : public ParallelOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param std::shared_ptr<Sampler> sampler
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
|
||||
builder_sampler_ = std::move(sampler);
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
int32_t builder_device_id_;
|
||||
int32_t builder_num_devices_;
|
||||
|
@ -253,6 +261,7 @@ class CsvOp : public ParallelOp {
|
|||
char builder_field_delim_;
|
||||
std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_;
|
||||
std::vector<std::string> builder_column_name_list_;
|
||||
std::shared_ptr<Sampler> builder_sampler_;
|
||||
};
|
||||
|
||||
// Constructor of CsvOp
|
||||
|
@ -261,7 +270,8 @@ class CsvOp : public ParallelOp {
|
|||
CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
|
||||
const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name,
|
||||
int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
|
||||
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id);
|
||||
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id,
|
||||
std::shared_ptr<Sampler> sampler);
|
||||
|
||||
// Default destructor
|
||||
~CsvOp() = default;
|
||||
|
@ -297,6 +307,17 @@ class CsvOp : public ParallelOp {
|
|||
// @return Vector of the input file names
|
||||
std::vector<std::string> FileNames() { return csv_files_list_; }
|
||||
|
||||
/// \Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing
|
||||
/// a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so
|
||||
/// that this csv op will produce the full set of data into the cache.
|
||||
void MakeSimpleProducer();
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/io_block.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -499,5 +500,21 @@ Status TextFileOp::ComputeColMap() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing
|
||||
// a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so
|
||||
// that this text file op will produce the full set of data into the cache.
|
||||
void TextFileOp::MakeSimpleProducer() {
|
||||
device_id_ = 0;
|
||||
num_devices_ = 1;
|
||||
shuffle_files_ = false;
|
||||
total_rows_ = 0;
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status TextFileOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(shared_from_base<TextFileOp>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -188,6 +188,17 @@ class TextFileOp : public ParallelOp {
|
|||
// @return Vector of the input file names
|
||||
std::vector<std::string> FileNames() { return text_files_list_; }
|
||||
|
||||
/// \Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing
|
||||
/// a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so
|
||||
/// that this text file op will produce the full set of data into the cache.
|
||||
void MakeSimpleProducer();
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
|
|
|
@ -212,6 +212,7 @@ Status VOCOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, Ten
|
|||
folder_path_ + std::string(kAnnotationsFolder) + image_id + std::string(kAnnotationExtension);
|
||||
RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image));
|
||||
RETURN_IF_NOT_OK(ReadAnnotationToTensor(kAnnotationFile, &annotation));
|
||||
trow->setId(row_id);
|
||||
trow->push_back(std::move(image));
|
||||
trow->insert(trow->end(), annotation.begin(), annotation.end());
|
||||
}
|
||||
|
|
|
@ -45,6 +45,9 @@
|
|||
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
|
||||
#ifdef ENABLE_PYTHON
|
||||
|
@ -260,6 +263,21 @@ Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified)
|
|||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
|
|
|
@ -81,6 +81,12 @@ class CacheMergeOp;
|
|||
class CacheLookupOp;
|
||||
|
||||
class BuildSentencePieceVocabOp;
|
||||
|
||||
class ClueOp;
|
||||
|
||||
class CsvOp;
|
||||
|
||||
class TextFileOp;
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
|
@ -211,6 +217,12 @@ class NodePass : public Pass {
|
|||
|
||||
virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
|
||||
|
|
|
@ -36,6 +36,9 @@
|
|||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
|
@ -141,6 +144,36 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node
|
|||
}
|
||||
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) {
|
||||
if (is_caching_) {
|
||||
// If we are a ClueOp in a caching tree, then change our config so that it becomes a basic
|
||||
// ClueOp that parses all files. Selection of data will come from the sampler on the cache instead.
|
||||
node->MakeSimpleProducer();
|
||||
}
|
||||
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) {
|
||||
if (is_caching_) {
|
||||
// If we are a CsvOp in a caching tree, then change our config so that it becomes a basic
|
||||
// CsvOp that parses all files. Selection of data will come from the sampler on the cache instead.
|
||||
node->MakeSimpleProducer();
|
||||
}
|
||||
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) {
|
||||
if (is_caching_) {
|
||||
// If we are a TextFileOp in a caching tree, then change our config so that it becomes a basic
|
||||
// TextFileOp that parses all files. Selection of data will come from the sampler on the cache instead.
|
||||
node->MakeSimpleProducer();
|
||||
}
|
||||
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
#endif
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
|
@ -163,34 +196,22 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<AlbumOp> node, b
|
|||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) {
|
||||
if (is_caching_) {
|
||||
RETURN_STATUS_UNEXPECTED("There is currently no support for MnistOp under cache.");
|
||||
}
|
||||
return Status::OK();
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) {
|
||||
if (is_caching_) {
|
||||
RETURN_STATUS_UNEXPECTED("There is currently no support for CifarOp under cache.");
|
||||
}
|
||||
return Status::OK();
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) {
|
||||
if (is_caching_) {
|
||||
RETURN_STATUS_UNEXPECTED("There is currently no support for CocoOp under cache.");
|
||||
}
|
||||
return Status::OK();
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) {
|
||||
if (is_caching_) {
|
||||
RETURN_STATUS_UNEXPECTED("There is currently no support for CelebAOp under cache.");
|
||||
}
|
||||
return Status::OK();
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
|
@ -214,18 +235,12 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> nod
|
|||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
|
||||
if (is_caching_) {
|
||||
RETURN_STATUS_UNEXPECTED("There is currently no support for ManifestOp under cache.");
|
||||
}
|
||||
return Status::OK();
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
|
||||
if (is_caching_) {
|
||||
RETURN_STATUS_UNEXPECTED("There is currently no support for VOCOp under cache.");
|
||||
}
|
||||
return Status::OK();
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
@ -65,6 +65,24 @@ class CacheTransformPass : public TreePass {
|
|||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) override;
|
||||
#endif
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
|
|
|
@ -2969,6 +2969,8 @@ class MnistDataset(MappableDataset):
|
|||
into (default=None).
|
||||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||
argument can only be specified when num_shards is also specified.
|
||||
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
|
||||
The cache feature is under development and is not recommended.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sampler and shuffle are specified at the same time.
|
||||
|
@ -2988,7 +2990,7 @@ class MnistDataset(MappableDataset):
|
|||
|
||||
@check_mnist_cifar_dataset
|
||||
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None,
|
||||
shuffle=None, sampler=None, num_shards=None, shard_id=None):
|
||||
shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
|
||||
self.dataset_dir = dataset_dir
|
||||
|
@ -2998,6 +3000,7 @@ class MnistDataset(MappableDataset):
|
|||
self.shuffle_level = shuffle
|
||||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
self.cache = cache
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
|
@ -3008,6 +3011,7 @@ class MnistDataset(MappableDataset):
|
|||
args["sampler"] = self.sampler
|
||||
args["num_shards"] = self.num_shards
|
||||
args["shard_id"] = self.shard_id
|
||||
args["cache"] = self.cache.cache_client if self.cache is not None else None
|
||||
return args
|
||||
|
||||
def get_dataset_size(self):
|
||||
|
@ -3872,6 +3876,8 @@ class ManifestDataset(MappableDataset):
|
|||
into (default=None).
|
||||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||
argument can only be specified when num_shards is also specified.
|
||||
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
|
||||
The cache feature is under development and is not recommended.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sampler and shuffle are specified at the same time.
|
||||
|
@ -3897,7 +3903,8 @@ class ManifestDataset(MappableDataset):
|
|||
|
||||
@check_manifestdataset
|
||||
def __init__(self, dataset_file, usage="train", num_samples=None, num_parallel_workers=None,
|
||||
shuffle=None, sampler=None, class_indexing=None, decode=False, num_shards=None, shard_id=None):
|
||||
shuffle=None, sampler=None, class_indexing=None, decode=False, num_shards=None, shard_id=None,
|
||||
cache=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
|
||||
self.dataset_file = dataset_file
|
||||
|
@ -3913,6 +3920,7 @@ class ManifestDataset(MappableDataset):
|
|||
self.shuffle_level = shuffle
|
||||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
self.cache = cache
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
|
@ -3925,6 +3933,7 @@ class ManifestDataset(MappableDataset):
|
|||
args["decode"] = self.decode
|
||||
args["num_shards"] = self.num_shards
|
||||
args["shard_id"] = self.shard_id
|
||||
args["cache"] = self.cache.cache_client if self.cache is not None else None
|
||||
return args
|
||||
|
||||
def get_dataset_size(self):
|
||||
|
@ -4055,6 +4064,8 @@ class Cifar10Dataset(MappableDataset):
|
|||
into (default=None).
|
||||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||
argument can only be specified when num_shards is also specified.
|
||||
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
|
||||
The cache feature is under development and is not recommended.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sampler and shuffle are specified at the same time.
|
||||
|
@ -4082,7 +4093,7 @@ class Cifar10Dataset(MappableDataset):
|
|||
|
||||
@check_mnist_cifar_dataset
|
||||
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None,
|
||||
shuffle=None, sampler=None, num_shards=None, shard_id=None):
|
||||
shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
|
||||
self.dataset_dir = dataset_dir
|
||||
|
@ -4092,6 +4103,7 @@ class Cifar10Dataset(MappableDataset):
|
|||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
self.shuffle_level = shuffle
|
||||
self.cache = cache
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
|
@ -4102,6 +4114,7 @@ class Cifar10Dataset(MappableDataset):
|
|||
args["num_shards"] = self.num_shards
|
||||
args["shard_id"] = self.shard_id
|
||||
args["shuffle"] = self.shuffle_level
|
||||
args["cache"] = self.cache.cache_client if self.cache is not None else None
|
||||
return args
|
||||
|
||||
def get_dataset_size(self):
|
||||
|
@ -4202,6 +4215,8 @@ class Cifar100Dataset(MappableDataset):
|
|||
into (default=None).
|
||||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||
argument can only be specified when num_shards is also specified.
|
||||
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
|
||||
The cache feature is under development and is not recommended.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sampler and shuffle are specified at the same time.
|
||||
|
@ -4226,7 +4241,7 @@ class Cifar100Dataset(MappableDataset):
|
|||
|
||||
@check_mnist_cifar_dataset
|
||||
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None,
|
||||
shuffle=None, sampler=None, num_shards=None, shard_id=None):
|
||||
shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
|
||||
self.dataset_dir = dataset_dir
|
||||
|
@ -4236,6 +4251,7 @@ class Cifar100Dataset(MappableDataset):
|
|||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
self.shuffle_level = shuffle
|
||||
self.cache = cache
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
|
@ -4246,6 +4262,7 @@ class Cifar100Dataset(MappableDataset):
|
|||
args["num_shards"] = self.num_shards
|
||||
args["shard_id"] = self.shard_id
|
||||
args["shuffle"] = self.shuffle_level
|
||||
args["cache"] = self.cache.cache_client if self.cache is not None else None
|
||||
return args
|
||||
|
||||
def get_dataset_size(self):
|
||||
|
@ -4630,6 +4647,8 @@ class VOCDataset(MappableDataset):
|
|||
into (default=None).
|
||||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||
argument can only be specified when num_shards is also specified.
|
||||
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
|
||||
The cache feature is under development and is not recommended.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If xml of Annotations is an invalid format.
|
||||
|
@ -4667,7 +4686,8 @@ class VOCDataset(MappableDataset):
|
|||
|
||||
@check_vocdataset
|
||||
def __init__(self, dataset_dir, task="Segmentation", usage="train", class_indexing=None, num_samples=None,
|
||||
num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None):
|
||||
num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None,
|
||||
cache=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
self.dataset_dir = dataset_dir
|
||||
self.task = task
|
||||
|
@ -4679,6 +4699,7 @@ class VOCDataset(MappableDataset):
|
|||
self.shuffle_level = shuffle
|
||||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
self.cache = cache
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
|
@ -4692,6 +4713,7 @@ class VOCDataset(MappableDataset):
|
|||
args["shuffle"] = self.shuffle_level
|
||||
args["num_shards"] = self.num_shards
|
||||
args["shard_id"] = self.shard_id
|
||||
args["cache"] = self.cache.cache_client if self.cache is not None else None
|
||||
return args
|
||||
|
||||
def get_dataset_size(self):
|
||||
|
@ -4838,6 +4860,8 @@ class CocoDataset(MappableDataset):
|
|||
into (default=None).
|
||||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||
argument can only be specified when num_shards is also specified.
|
||||
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
|
||||
The cache feature is under development and is not recommended.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sampler and shuffle are specified at the same time.
|
||||
|
@ -4873,7 +4897,7 @@ class CocoDataset(MappableDataset):
|
|||
|
||||
@check_cocodataset
|
||||
def __init__(self, dataset_dir, annotation_file, task="Detection", num_samples=None, num_parallel_workers=None,
|
||||
shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None):
|
||||
shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
self.dataset_dir = dataset_dir
|
||||
self.annotation_file = annotation_file
|
||||
|
@ -4884,6 +4908,7 @@ class CocoDataset(MappableDataset):
|
|||
self.shuffle_level = shuffle
|
||||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
self.cache = cache
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
|
@ -4896,6 +4921,7 @@ class CocoDataset(MappableDataset):
|
|||
args["shuffle"] = self.shuffle_level
|
||||
args["num_shards"] = self.num_shards
|
||||
args["shard_id"] = self.shard_id
|
||||
args["cache"] = self.cache.cache_client if self.cache is not None else None
|
||||
return args
|
||||
|
||||
def get_dataset_size(self):
|
||||
|
@ -4993,6 +5019,8 @@ class CelebADataset(MappableDataset):
|
|||
into (default=None).
|
||||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||
argument can only be specified when num_shards is also specified.
|
||||
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
|
||||
The cache feature is under development and is not recommended.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -5003,7 +5031,7 @@ class CelebADataset(MappableDataset):
|
|||
|
||||
@check_celebadataset
|
||||
def __init__(self, dataset_dir, num_parallel_workers=None, shuffle=None, usage='all', sampler=None, decode=False,
|
||||
extensions=None, num_samples=None, num_shards=None, shard_id=None):
|
||||
extensions=None, num_samples=None, num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
self.dataset_dir = dataset_dir
|
||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
|
@ -5015,6 +5043,7 @@ class CelebADataset(MappableDataset):
|
|||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
self.shuffle_level = shuffle
|
||||
self.cache = cache
|
||||
|
||||
if usage != "all":
|
||||
dir = os.path.realpath(self.dataset_dir)
|
||||
|
@ -5033,6 +5062,7 @@ class CelebADataset(MappableDataset):
|
|||
args["usage"] = self.usage
|
||||
args["num_shards"] = self.num_shards
|
||||
args["shard_id"] = self.shard_id
|
||||
args["cache"] = self.cache.cache_client if self.cache is not None else None
|
||||
return args
|
||||
|
||||
def get_dataset_size(self):
|
||||
|
@ -5142,6 +5172,8 @@ class CLUEDataset(SourceDataset):
|
|||
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
|
||||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||
argument can only be specified when num_shards is also specified.
|
||||
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
|
||||
The cache feature is under development and is not recommended.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -5152,7 +5184,7 @@ class CLUEDataset(SourceDataset):
|
|||
|
||||
@check_cluedataset
|
||||
def __init__(self, dataset_files, task='AFQMC', usage='train', num_samples=None,
|
||||
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
|
||||
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
self.dataset_files = self._find_files(dataset_files)
|
||||
self.dataset_files.sort()
|
||||
|
@ -5293,6 +5325,15 @@ class CLUEDataset(SourceDataset):
|
|||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
|
||||
# The clue dataset does not directly support a sampler. It has provided sampling arguments
|
||||
# (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in
|
||||
# the pipeline contains a cache. If there is no cache above it, then this sampler is not used.
|
||||
sampler_shuffle = self.shuffle_files
|
||||
sampler = None
|
||||
self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id,
|
||||
non_mappable=True)
|
||||
self.cache = cache
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
args["dataset_files"] = self.dataset_files
|
||||
|
@ -5304,6 +5345,8 @@ class CLUEDataset(SourceDataset):
|
|||
args["num_shards"] = self.num_shards
|
||||
args["shard_id"] = self.shard_id
|
||||
args["cols_to_keyword"] = self.cols_to_keyword
|
||||
args["sampler"] = self.sampler
|
||||
args["cache"] = self.cache.cache_client if self.cache is not None else None
|
||||
return args
|
||||
|
||||
def get_dataset_size(self):
|
||||
|
@ -5359,6 +5402,9 @@ class CSVDataset(SourceDataset):
|
|||
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
|
||||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||
argument can only be specified when num_shards is also specified.
|
||||
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
|
||||
The cache feature is under development and is not recommended.
|
||||
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -5369,7 +5415,7 @@ class CSVDataset(SourceDataset):
|
|||
|
||||
@check_csvdataset
|
||||
def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None,
|
||||
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
|
||||
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
self.dataset_files = self._find_files(dataset_files)
|
||||
self.dataset_files.sort()
|
||||
|
@ -5394,6 +5440,15 @@ class CSVDataset(SourceDataset):
|
|||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
|
||||
self.cache = cache
|
||||
# The CSV dataset does not directly support a sampler. It has provided sampling arguments
|
||||
# (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in
|
||||
# the pipeline contains a cache. If there is no cache above it, then this sampler is not used.
|
||||
sampler_shuffle = self.shuffle_files
|
||||
sampler = None
|
||||
self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id,
|
||||
non_mappable=True)
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
args["dataset_files"] = self.dataset_files
|
||||
|
@ -5407,6 +5462,8 @@ class CSVDataset(SourceDataset):
|
|||
args["shuffle"] = self.shuffle_level
|
||||
args["num_shards"] = self.num_shards
|
||||
args["shard_id"] = self.shard_id
|
||||
args["sampler"] = self.sampler
|
||||
args["cache"] = self.cache.cache_client if self.cache is not None else None
|
||||
return args
|
||||
|
||||
def get_dataset_size(self):
|
||||
|
@ -5457,6 +5514,9 @@ class TextFileDataset(SourceDataset):
|
|||
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
|
||||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||
argument can only be specified when num_shards is also specified.
|
||||
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
|
||||
The cache feature is under development and is not recommended.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>>
|
||||
|
@ -5466,7 +5526,7 @@ class TextFileDataset(SourceDataset):
|
|||
|
||||
@check_textfiledataset
|
||||
def __init__(self, dataset_files, num_samples=None, num_parallel_workers=None,
|
||||
shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
|
||||
shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
self.dataset_files = self._find_files(dataset_files)
|
||||
self.dataset_files.sort()
|
||||
|
@ -5488,6 +5548,15 @@ class TextFileDataset(SourceDataset):
|
|||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
|
||||
self.cache = cache
|
||||
# The text file dataset does not directly support a sampler. It has provided sampling arguments
|
||||
# (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in
|
||||
# the pipeline contains a cache. If there is no cache above it, then this sampler is not used.
|
||||
sampler_shuffle = self.shuffle_files
|
||||
sampler = None
|
||||
self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id,
|
||||
non_mappable=True)
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
args["dataset_files"] = self.dataset_files
|
||||
|
@ -5498,6 +5567,8 @@ class TextFileDataset(SourceDataset):
|
|||
args["shuffle"] = self.shuffle_level
|
||||
args["num_shards"] = self.num_shards
|
||||
args["shard_id"] = self.shard_id
|
||||
args["sampler"] = self.sampler
|
||||
args["cache"] = self.cache.cache_client if self.cache is not None else None
|
||||
return args
|
||||
|
||||
def get_dataset_size(self):
|
||||
|
|
|
@ -83,6 +83,9 @@ def check_mnist_cifar_dataset(method):
|
|||
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
@ -110,6 +113,9 @@ def check_manifestdataset(method):
|
|||
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
@ -180,6 +186,9 @@ def check_vocdataset(method):
|
|||
validate_dataset_param_value(nreq_param_dict, param_dict, dict)
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
@ -216,6 +225,9 @@ def check_cocodataset(method):
|
|||
raise ValueError("CocoDataset doesn't support PKSampler")
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
@ -252,6 +264,9 @@ def check_celebadataset(method):
|
|||
if sampler is not None and isinstance(sampler, samplers.PKSampler):
|
||||
raise ValueError("CelebADataset does not support PKSampler.")
|
||||
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
@ -843,6 +858,9 @@ def check_cluedataset(method):
|
|||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
@ -886,6 +904,9 @@ def check_csvdataset(method):
|
|||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
@ -905,6 +926,9 @@ def check_textfiledataset(method):
|
|||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -103,6 +103,24 @@ HandleRcExit $? 0 0
|
|||
PytestCmd "test_cache_map.py" "test_cache_map_epoch_ctrl" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
PytestCmd "test_cache_map.py" "test_cache_map_coco" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
PytestCmd "test_cache_map.py" "test_cache_map_mnist" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
PytestCmd "test_cache_map.py" "test_cache_map_celeba" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
PytestCmd "test_cache_map.py" "test_cache_map_manifest" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
PytestCmd "test_cache_map.py" "test_cache_map_cifar" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
PytestCmd "test_cache_map.py" "test_cache_map_voc" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
# Run two parallel pipelines (sharing cache)
|
||||
for i in $(seq 1 2)
|
||||
do
|
||||
|
@ -282,6 +300,15 @@ HandleRcExit $? 0 0
|
|||
PytestCmd "test_cache_nomap.py" "test_cache_nomap_epoch_ctrl" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
PytestCmd "test_cache_nomap.py" "test_cache_nomap_clue" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
PytestCmd "test_cache_nomap.py" "test_cache_nomap_csv" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
PytestCmd "test_cache_nomap.py" "test_cache_nomap_textfile" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
for i in $(seq 1 3)
|
||||
do
|
||||
test_name="test_cache_nomap_multiple_cache${i}"
|
||||
|
|
|
@ -17,6 +17,7 @@ Testing cache operator with mappable datasets
|
|||
"""
|
||||
import os
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as c_vision
|
||||
from mindspore import log as logger
|
||||
|
@ -26,7 +27,13 @@ DATA_DIR = "../data/dataset/testImageNetData/train/"
|
|||
COCO_DATA_DIR = "../data/dataset/testCOCO/train/"
|
||||
COCO_ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json"
|
||||
NO_IMAGE_DIR = "../data/dataset/testRandomData/"
|
||||
|
||||
MNIST_DATA_DIR = "../data/dataset/testMnistData/"
|
||||
CELEBA_DATA_DIR = "../data/dataset/testCelebAData/"
|
||||
VOC_DATA_DIR = "../data/dataset/testVOC2012/"
|
||||
MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest"
|
||||
CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data/"
|
||||
CIFAR100_DATA_DIR = "../data/dataset/testCifar100Data/"
|
||||
MIND_RECORD_DATA_DIR = "../data/mindrecord/testTwoImageData/twobytes.mindrecord"
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
|
||||
|
@ -443,7 +450,7 @@ def test_cache_map_failure5():
|
|||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_map_failure6():
|
||||
"""
|
||||
Test no-cache-supporting leaf ops with Map under cache (failure)
|
||||
Test no-cache-supporting MindRecord leaf with Map under cache (failure)
|
||||
|
||||
repeat
|
||||
|
|
||||
|
@ -451,7 +458,7 @@ def test_cache_map_failure6():
|
|||
|
|
||||
Map(resize)
|
||||
|
|
||||
Coco
|
||||
MindRecord
|
||||
|
||||
"""
|
||||
logger.info("Test cache failure 6")
|
||||
|
@ -461,22 +468,66 @@ def test_cache_map_failure6():
|
|||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
data = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True)
|
||||
|
||||
columns_list = ["id", "file_name", "label_name", "img_data", "label_data"]
|
||||
num_readers = 1
|
||||
# The dataset has 5 records
|
||||
data = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list, num_readers)
|
||||
resize_op = c_vision.Resize((224, 224))
|
||||
|
||||
data = data.map(input_columns=["image"], operations=resize_op, cache=some_cache)
|
||||
data = data.map(input_columns=["img_data"], operations=resize_op, cache=some_cache)
|
||||
data = data.repeat(4)
|
||||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert "There is currently no support for CocoOp under cache" in str(e.value)
|
||||
assert "There is currently no support for MindRecordOp under cache" in str(e.value)
|
||||
|
||||
assert num_iter == 0
|
||||
logger.info('test_cache_failure6 Ended.\n')
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_map_failure7():
|
||||
"""
|
||||
Test no-cache-supporting Generator leaf with Map under cache (failure)
|
||||
|
||||
repeat
|
||||
|
|
||||
Cache
|
||||
|
|
||||
Map(lambda x: x)
|
||||
|
|
||||
Generator
|
||||
|
||||
"""
|
||||
def generator_1d():
|
||||
for i in range(64):
|
||||
yield (np.array(i),)
|
||||
|
||||
logger.info("Test cache failure 7")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
|
||||
data = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
data = data.map((lambda x: x), ["data"], cache=some_cache)
|
||||
data = data.repeat(4)
|
||||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert "There is currently no support for GeneratorOp under cache" in str(e.value)
|
||||
|
||||
assert num_iter == 0
|
||||
logger.info('test_cache_failure7 Ended.\n')
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_map_parameter_check():
|
||||
"""
|
||||
|
@ -1236,6 +1287,421 @@ def test_cache_map_epoch_ctrl3():
|
|||
logger.info("test_cache_map_epoch_ctrl3 Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_map_coco1():
|
||||
"""
|
||||
Test mappable coco leaf with cache op right over the leaf
|
||||
|
||||
cache
|
||||
|
|
||||
Coco
|
||||
"""
|
||||
|
||||
logger.info("Test cache map coco1")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
|
||||
# This dataset has 6 records
|
||||
ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True,
|
||||
cache=some_cache)
|
||||
|
||||
num_epoch = 4
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
assert sum([1 for _ in iter1]) == 6
|
||||
epoch_count += 1
|
||||
assert epoch_count == num_epoch
|
||||
|
||||
logger.info("test_cache_map_coco1 Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_map_coco2():
|
||||
"""
|
||||
Test mappable coco leaf with the cache op later in the tree above the map(resize)
|
||||
|
||||
cache
|
||||
|
|
||||
Map(resize)
|
||||
|
|
||||
Coco
|
||||
"""
|
||||
|
||||
logger.info("Test cache map coco2")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
|
||||
# This dataset has 6 records
|
||||
ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True)
|
||||
resize_op = c_vision.Resize((224, 224))
|
||||
ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
|
||||
|
||||
num_epoch = 4
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
assert sum([1 for _ in iter1]) == 6
|
||||
epoch_count += 1
|
||||
assert epoch_count == num_epoch
|
||||
|
||||
logger.info("test_cache_map_coco2 Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_map_mnist1():
|
||||
"""
|
||||
Test mappable mnist leaf with cache op right over the leaf
|
||||
|
||||
cache
|
||||
|
|
||||
Mnist
|
||||
"""
|
||||
|
||||
logger.info("Test cache map mnist1")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10, cache=some_cache)
|
||||
|
||||
num_epoch = 4
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
assert sum([1 for _ in iter1]) == 10
|
||||
epoch_count += 1
|
||||
assert epoch_count == num_epoch
|
||||
|
||||
logger.info("test_cache_map_mnist1 Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_map_mnist2():
|
||||
"""
|
||||
Test mappable mnist leaf with the cache op later in the tree above the map(resize)
|
||||
|
||||
cache
|
||||
|
|
||||
Map(resize)
|
||||
|
|
||||
Mnist
|
||||
"""
|
||||
|
||||
logger.info("Test cache map mnist2")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10)
|
||||
|
||||
resize_op = c_vision.Resize((224, 224))
|
||||
ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
|
||||
|
||||
num_epoch = 4
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
assert sum([1 for _ in iter1]) == 10
|
||||
epoch_count += 1
|
||||
assert epoch_count == num_epoch
|
||||
|
||||
logger.info("test_cache_map_mnist2 Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_map_celeba1():
|
||||
"""
|
||||
Test mappable celeba leaf with cache op right over the leaf
|
||||
|
||||
cache
|
||||
|
|
||||
CelebA
|
||||
"""
|
||||
|
||||
logger.info("Test cache map celeba1")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
|
||||
# This dataset has 4 records
|
||||
ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, cache=some_cache)
|
||||
|
||||
num_epoch = 4
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
assert sum([1 for _ in iter1]) == 4
|
||||
epoch_count += 1
|
||||
assert epoch_count == num_epoch
|
||||
|
||||
logger.info("test_cache_map_celeba1 Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_map_celeba2():
|
||||
"""
|
||||
Test mappable celeba leaf with the cache op later in the tree above the map(resize)
|
||||
|
||||
cache
|
||||
|
|
||||
Map(resize)
|
||||
|
|
||||
CelebA
|
||||
"""
|
||||
|
||||
logger.info("Test cache map celeba2")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
|
||||
# This dataset has 4 records
|
||||
ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True)
|
||||
resize_op = c_vision.Resize((224, 224))
|
||||
ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
|
||||
|
||||
num_epoch = 4
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
assert sum([1 for _ in iter1]) == 4
|
||||
epoch_count += 1
|
||||
assert epoch_count == num_epoch
|
||||
|
||||
logger.info("test_cache_map_celeba2 Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_map_manifest1():
|
||||
"""
|
||||
Test mappable manifest leaf with cache op right over the leaf
|
||||
|
||||
cache
|
||||
|
|
||||
Manifest
|
||||
"""
|
||||
|
||||
logger.info("Test cache map manifest1")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
|
||||
# This dataset has 4 records
|
||||
ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True, cache=some_cache)
|
||||
|
||||
num_epoch = 4
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
assert sum([1 for _ in iter1]) == 4
|
||||
epoch_count += 1
|
||||
assert epoch_count == num_epoch
|
||||
|
||||
logger.info("test_cache_map_manifest1 Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_map_manifest2():
|
||||
"""
|
||||
Test mappable manifest leaf with the cache op later in the tree above the map(resize)
|
||||
|
||||
cache
|
||||
|
|
||||
Map(resize)
|
||||
|
|
||||
Manifest
|
||||
"""
|
||||
|
||||
logger.info("Test cache map manifest2")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
|
||||
# This dataset has 4 records
|
||||
ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True)
|
||||
resize_op = c_vision.Resize((224, 224))
|
||||
ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
|
||||
|
||||
num_epoch = 4
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
assert sum([1 for _ in iter1]) == 4
|
||||
epoch_count += 1
|
||||
assert epoch_count == num_epoch
|
||||
|
||||
logger.info("test_cache_map_manifest2 Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_map_cifar1():
|
||||
"""
|
||||
Test mappable cifar10 leaf with cache op right over the leaf
|
||||
|
||||
cache
|
||||
|
|
||||
Cifar10
|
||||
"""
|
||||
|
||||
logger.info("Test cache map cifar1")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache)
|
||||
|
||||
num_epoch = 4
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
assert sum([1 for _ in iter1]) == 10
|
||||
epoch_count += 1
|
||||
assert epoch_count == num_epoch
|
||||
|
||||
logger.info("test_cache_map_cifar1 Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_map_cifar2():
|
||||
"""
|
||||
Test mappable cifar100 leaf with the cache op later in the tree above the map(resize)
|
||||
|
||||
cache
|
||||
|
|
||||
Map(resize)
|
||||
|
|
||||
Cifar100
|
||||
"""
|
||||
|
||||
logger.info("Test cache map cifar2")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
|
||||
ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10)
|
||||
resize_op = c_vision.Resize((224, 224))
|
||||
ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
|
||||
|
||||
num_epoch = 4
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
assert sum([1 for _ in iter1]) == 10
|
||||
epoch_count += 1
|
||||
assert epoch_count == num_epoch
|
||||
|
||||
logger.info("test_cache_map_cifar2 Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_map_voc1():
|
||||
"""
|
||||
Test mappable voc leaf with cache op right over the leaf
|
||||
|
||||
cache
|
||||
|
|
||||
VOC
|
||||
"""
|
||||
|
||||
logger.info("Test cache map voc1")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
|
||||
# This dataset has 9 records
|
||||
ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True, cache=some_cache)
|
||||
|
||||
num_epoch = 4
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
assert sum([1 for _ in iter1]) == 9
|
||||
epoch_count += 1
|
||||
assert epoch_count == num_epoch
|
||||
|
||||
logger.info("test_cache_map_voc1 Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_map_voc2():
|
||||
"""
|
||||
Test mappable voc leaf with the cache op later in the tree above the map(resize)
|
||||
|
||||
cache
|
||||
|
|
||||
Map(resize)
|
||||
|
|
||||
VOC
|
||||
"""
|
||||
|
||||
logger.info("Test cache map voc2")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
|
||||
# This dataset has 9 records
|
||||
ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
resize_op = c_vision.Resize((224, 224))
|
||||
ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
|
||||
|
||||
num_epoch = 4
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
assert sum([1 for _ in iter1]) == 9
|
||||
epoch_count += 1
|
||||
assert epoch_count == num_epoch
|
||||
|
||||
logger.info("test_cache_map_voc2 Ended.\n")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cache_map_basic1()
|
||||
test_cache_map_basic2()
|
||||
|
|
|
@ -20,22 +20,26 @@ import itertools
|
|||
import pytest
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.text as text
|
||||
import mindspore.dataset.vision.c_transforms as c_vision
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
||||
DATA_DIR2 = ["../data/dataset/testTextTFRecord/text.tfrecord"]
|
||||
TEXT_TF_DATA_DIR = ["../data/dataset/testTextTFRecord/text.tfrecord"]
|
||||
SCHEMA_DIR2 = "../data/dataset/testTextTFRecord/datasetSchema.json"
|
||||
|
||||
DATA_DIR3 = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
|
||||
"../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
|
||||
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
|
||||
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
|
||||
SCHEMA_DIR3 = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"
|
||||
TRAIN_DATA_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
|
||||
"../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
|
||||
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
|
||||
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
|
||||
TRAIN_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"
|
||||
|
||||
DATA_DIR4 = "../data/dataset/testImageNetData/train/"
|
||||
IMAGE_FOLDER_DATA_DIR = "../data/dataset/testImageNetData/train/"
|
||||
CLUE_DATA_DIR = '../data/dataset/testCLUE/afqmc/train.json'
|
||||
CSV_DATA_DIR = '../data/dataset/testCSV/1.csv'
|
||||
TEXT_FILE_DATA_DIR = "../data/dataset/testTextFileDataset/1.txt"
|
||||
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
|
@ -1310,7 +1314,7 @@ def test_cache_nomap_multiple_cache1():
|
|||
eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
|
||||
# This dataset has 12 records in it
|
||||
train_dataset = ds.TFRecordDataset(DATA_DIR3, SCHEMA_DIR3)
|
||||
train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR)
|
||||
decode_op = c_vision.Decode()
|
||||
train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache)
|
||||
|
||||
|
@ -1359,7 +1363,7 @@ def test_cache_nomap_multiple_cache2():
|
|||
image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
text_dataset = ds.TFRecordDataset(DATA_DIR2, SCHEMA_DIR2, cache=text_cache)
|
||||
text_dataset = ds.TFRecordDataset(TEXT_TF_DATA_DIR, SCHEMA_DIR2, cache=text_cache)
|
||||
|
||||
num_epoch = 5
|
||||
image_iter = image_dataset.create_dict_iterator(num_epochs=num_epoch)
|
||||
|
@ -1404,7 +1408,7 @@ def test_cache_nomap_multiple_cache3():
|
|||
tf_dataset = tf_dataset.map(input_columns=["image"], operations=decode_op, cache=tf_cache)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
image_dataset = ds.ImageFolderDataset(dataset_dir=DATA_DIR4)
|
||||
image_dataset = ds.ImageFolderDataset(dataset_dir=IMAGE_FOLDER_DATA_DIR)
|
||||
image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache)
|
||||
|
||||
num_epoch = 5
|
||||
|
@ -1443,7 +1447,7 @@ def test_cache_nomap_multiple_cache_train():
|
|||
train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
|
||||
# This dataset has 12 records in it
|
||||
train_dataset = ds.TFRecordDataset(DATA_DIR3, SCHEMA_DIR3)
|
||||
train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR)
|
||||
decode_op = c_vision.Decode()
|
||||
train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache)
|
||||
|
||||
|
@ -1497,6 +1501,239 @@ def test_cache_nomap_multiple_cache_eval():
|
|||
logger.info("test_cache_nomap_multiple_cache_eval Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_nomap_clue1():
|
||||
"""
|
||||
A clue dataset (a non mappable dataset) with a cache over it just after the leaf
|
||||
In this one, the clue dataset will be given sharding configuration, however since a cache is
|
||||
used, the tree prepare should undo the sharding configuration and instead, a distributed
|
||||
sampler will be chosen with the same shard config.
|
||||
|
||||
Cache
|
||||
|
|
||||
CLUE
|
||||
"""
|
||||
|
||||
logger.info("Test cache nomap clue 1")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
|
||||
# With only 3 records shard into 3, we expect only 1 record returned for this shard
|
||||
# However, the sharding will be done by the sampler, not by the clue leaf node
|
||||
# In this case, it is a row-based sharding, not the file-based sharding that would happen if
|
||||
# there was not any cache.
|
||||
ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_shards=3, shard_id=1, cache=some_cache)
|
||||
|
||||
num_epoch = 4
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
assert sum([1 for _ in iter1]) == 1
|
||||
epoch_count += 1
|
||||
assert epoch_count == num_epoch
|
||||
|
||||
logger.info("test_cache_nomap_clue1 Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_nomap_clue2():
|
||||
"""
|
||||
A clue dataset (a non mappable dataset) with a cache over it after map
|
||||
In this one, a num_samples argument is given
|
||||
|
||||
Cache
|
||||
|
|
||||
map(lambda x: x)
|
||||
|
|
||||
CLUE
|
||||
"""
|
||||
|
||||
logger.info("Test cache nomap clue 2")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
|
||||
ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_samples=2)
|
||||
ds1 = ds1.map((lambda x: x), ["label"], cache=some_cache)
|
||||
|
||||
num_epoch = 4
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
assert sum([1 for _ in iter1]) == 2
|
||||
epoch_count += 1
|
||||
assert epoch_count == num_epoch
|
||||
|
||||
logger.info("test_cache_nomap_clue2 Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_nomap_csv1():
|
||||
"""
|
||||
A csv dataset (a non mappable dataset) with a cache over it just after the leaf
|
||||
In this one, the csv dataset will be given sharding configuration, however since a cache is
|
||||
used, the tree prepare should undo the sharding configuration and instead, a distributed
|
||||
sampler will be chosen with the same shard config.
|
||||
|
||||
Cache
|
||||
|
|
||||
CSV
|
||||
"""
|
||||
|
||||
logger.info("Test cache nomap csv 1")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
|
||||
# With only 3 records shard into 3, we expect only 1 record returned for this shard
|
||||
# However, the sharding will be done by the sampler, not by the clue leaf node
|
||||
# In this case, it is a row-based sharding, not the file-based sharding that would happen if
|
||||
# there was not any cache.
|
||||
ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"],
|
||||
column_names=['col1', 'col2', 'col3', 'col4'], num_shards=3, shard_id=1, cache=some_cache)
|
||||
|
||||
num_epoch = 4
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
assert sum([1 for _ in iter1]) == 1
|
||||
epoch_count += 1
|
||||
assert epoch_count == num_epoch
|
||||
|
||||
logger.info("test_cache_nomap_csv1 Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_nomap_csv2():
|
||||
"""
|
||||
A csv dataset (a non mappable dataset) with a cache over it after map
|
||||
In this one, a num_samples argument is given
|
||||
|
||||
Cache
|
||||
|
|
||||
map(lambda x: x)
|
||||
|
|
||||
CSV
|
||||
"""
|
||||
|
||||
logger.info("Test cache nomap csv 2")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
|
||||
ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"],
|
||||
column_names=['col1', 'col2', 'col3', 'col4'], num_samples=2)
|
||||
ds1 = ds1.map((lambda x: x), ["col1"], cache=some_cache)
|
||||
|
||||
num_epoch = 4
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
assert sum([1 for _ in iter1]) == 2
|
||||
epoch_count += 1
|
||||
assert epoch_count == num_epoch
|
||||
|
||||
logger.info("test_cache_nomap_csv2 Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_nomap_textfile1():
|
||||
"""
|
||||
A text file dataset (a non mappable dataset) with a cache over it just after the leaf
|
||||
In this one, the text file dataset will be given sharding configuration, however since a cache is
|
||||
used, the tree prepare should undo the sharding configuration and instead, a distributed
|
||||
sampler will be chosen with the same shard config.
|
||||
|
||||
Cache
|
||||
|
|
||||
TextFile
|
||||
"""
|
||||
|
||||
logger.info("Test cache nomap textfile 1")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
|
||||
# With only 3 records shard into 3, we expect only 1 record returned for this shard
|
||||
# However, the sharding will be done by the sampler, not by the clue leaf node
|
||||
# In this case, it is a row-based sharding, not the file-based sharding that would happen if
|
||||
# there was not any cache.
|
||||
ds1 = ds.CSVDataset(TEXT_FILE_DATA_DIR, num_shards=3, shard_id=1, cache=some_cache)
|
||||
|
||||
num_epoch = 4
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
assert sum([1 for _ in iter1]) == 1
|
||||
epoch_count += 1
|
||||
assert epoch_count == num_epoch
|
||||
|
||||
logger.info("test_cache_nomap_textfile1 Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_nomap_textfile2():
|
||||
"""
|
||||
A text file dataset (a non mappable dataset) with a cache over it after map
|
||||
In this one, a num_samples argument is given
|
||||
|
||||
Cache
|
||||
|
|
||||
Map(tokenizer)
|
||||
|
|
||||
TextFile
|
||||
"""
|
||||
def my_tokenizer(line):
|
||||
words = line.split()
|
||||
if not words:
|
||||
return [""]
|
||||
return words
|
||||
|
||||
logger.info("Test cache nomap textfile 2")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
|
||||
ds1 = ds.TextFileDataset(TEXT_FILE_DATA_DIR, num_samples=2)
|
||||
tokenizer = text.PythonTokenizer(my_tokenizer)
|
||||
ds1 = ds1.map(operations=tokenizer, cache=some_cache)
|
||||
|
||||
num_epoch = 4
|
||||
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
|
||||
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
assert sum([1 for _ in iter1]) == 2
|
||||
epoch_count += 1
|
||||
assert epoch_count == num_epoch
|
||||
|
||||
logger.info("test_cache_nomap_textfile2 Ended.\n")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cache_nomap_basic1()
|
||||
test_cache_nomap_basic2()
|
||||
|
|
|
@ -40,8 +40,9 @@ def test_textline_dataset_all_file():
|
|||
assert count == 5
|
||||
|
||||
|
||||
def test_textline_dataset_num_samples_zero():
|
||||
data = ds.TextFileDataset(DATA_FILE, num_samples=0)
|
||||
def test_textline_dataset_num_samples_none():
|
||||
# Do not provide a num_samples argument, so it would be None by default
|
||||
data = ds.TextFileDataset(DATA_FILE)
|
||||
count = 0
|
||||
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
logger.info("{}".format(i["text"]))
|
||||
|
@ -208,7 +209,7 @@ def test_textline_dataset_exceptions():
|
|||
if __name__ == "__main__":
|
||||
test_textline_dataset_one_file()
|
||||
test_textline_dataset_all_file()
|
||||
test_textline_dataset_num_samples_zero()
|
||||
test_textline_dataset_num_samples_none()
|
||||
test_textline_dataset_shuffle_false4()
|
||||
test_textline_dataset_shuffle_false1()
|
||||
test_textline_dataset_shuffle_files4()
|
||||
|
|
Loading…
Reference in New Issue