forked from mindspore-Ecosystem/mindspore
commit
eadcb341e1
|
@ -47,6 +47,8 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/dataset/include)
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib")
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default")
|
||||||
|
|
||||||
|
ms_build_flatbuffers("engine/cache/de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${CMAKE_BINARY_DIR})
|
||||||
|
|
||||||
################## Include sub-modules ###############################
|
################## Include sub-modules ###############################
|
||||||
add_subdirectory(util)
|
add_subdirectory(util)
|
||||||
add_subdirectory(core)
|
add_subdirectory(core)
|
||||||
|
@ -55,7 +57,7 @@ add_subdirectory(engine)
|
||||||
add_subdirectory(api)
|
add_subdirectory(api)
|
||||||
add_subdirectory(text)
|
add_subdirectory(text)
|
||||||
######################################################################
|
######################################################################
|
||||||
add_dependencies(core utils)
|
add_dependencies(utils core)
|
||||||
add_dependencies(kernels-image core)
|
add_dependencies(kernels-image core)
|
||||||
add_dependencies(kernels-data core)
|
add_dependencies(kernels-data core)
|
||||||
add_dependencies(kernels core)
|
add_dependencies(kernels core)
|
||||||
|
@ -89,6 +91,8 @@ set(submodules
|
||||||
$<TARGET_OBJECTS:engine-perf>
|
$<TARGET_OBJECTS:engine-perf>
|
||||||
$<TARGET_OBJECTS:engine-datasetops>
|
$<TARGET_OBJECTS:engine-datasetops>
|
||||||
$<TARGET_OBJECTS:engine-opt>
|
$<TARGET_OBJECTS:engine-opt>
|
||||||
|
$<TARGET_OBJECTS:engine-cache-client>
|
||||||
|
$<TARGET_OBJECTS:engine-cache-server>
|
||||||
$<TARGET_OBJECTS:engine>
|
$<TARGET_OBJECTS:engine>
|
||||||
$<TARGET_OBJECTS:text>
|
$<TARGET_OBJECTS:text>
|
||||||
$<TARGET_OBJECTS:text-kernels>
|
$<TARGET_OBJECTS:text-kernels>
|
||||||
|
@ -106,6 +110,8 @@ else ()
|
||||||
add_library(_c_dataengine SHARED ${submodules})
|
add_library(_c_dataengine SHARED ${submodules})
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
|
add_dependencies(_c_dataengine generated_engine_files)
|
||||||
|
|
||||||
set_target_properties(_c_dataengine PROPERTIES
|
set_target_properties(_c_dataengine PROPERTIES
|
||||||
PREFIX "${PYTHON_MODULE_PREFIX}"
|
PREFIX "${PYTHON_MODULE_PREFIX}"
|
||||||
SUFFIX "${PYTHON_MODULE_EXTENSION}"
|
SUFFIX "${PYTHON_MODULE_EXTENSION}"
|
||||||
|
|
|
@ -21,8 +21,10 @@
|
||||||
|
|
||||||
#include "common/utils.h"
|
#include "common/utils.h"
|
||||||
#include "dataset/core/tensor.h"
|
#include "dataset/core/tensor.h"
|
||||||
|
#include "dataset/engine/cache/cache_client.h"
|
||||||
#include "dataset/engine/dataset_iterator.h"
|
#include "dataset/engine/dataset_iterator.h"
|
||||||
#include "dataset/engine/datasetops/bucket_batch_by_length_op.h"
|
#include "dataset/engine/datasetops/bucket_batch_by_length_op.h"
|
||||||
|
#include "dataset/engine/datasetops/cache_op.h"
|
||||||
#include "dataset/engine/datasetops/filter_op.h"
|
#include "dataset/engine/datasetops/filter_op.h"
|
||||||
#include "dataset/engine/datasetops/source/celeba_op.h"
|
#include "dataset/engine/datasetops/source/celeba_op.h"
|
||||||
#include "dataset/engine/datasetops/source/cifar_op.h"
|
#include "dataset/engine/datasetops/source/cifar_op.h"
|
||||||
|
@ -34,6 +36,7 @@
|
||||||
#include "dataset/engine/datasetops/source/random_data_op.h"
|
#include "dataset/engine/datasetops/source/random_data_op.h"
|
||||||
#include "dataset/engine/datasetops/source/text_file_op.h"
|
#include "dataset/engine/datasetops/source/text_file_op.h"
|
||||||
#include "dataset/engine/datasetops/source/voc_op.h"
|
#include "dataset/engine/datasetops/source/voc_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||||
#include "dataset/kernels/py_func_op.h"
|
#include "dataset/kernels/py_func_op.h"
|
||||||
#include "dataset/util/random.h"
|
#include "dataset/util/random.h"
|
||||||
#include "dataset/util/status.h"
|
#include "dataset/util/status.h"
|
||||||
|
@ -441,6 +444,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
|
||||||
MapOp::Builder map_builder;
|
MapOp::Builder map_builder;
|
||||||
std::vector<std::shared_ptr<TensorOp>> tensor_op_list;
|
std::vector<std::shared_ptr<TensorOp>> tensor_op_list;
|
||||||
std::vector<std::string> project_columns;
|
std::vector<std::string> project_columns;
|
||||||
|
std::shared_ptr<CacheClient> cache_client = nullptr;
|
||||||
|
int num_workers = 0;
|
||||||
|
|
||||||
if (args["operations"].is_none()) RETURN_STATUS_UNEXPECTED("Error: 'operations' is not set. \n");
|
if (args["operations"].is_none()) RETURN_STATUS_UNEXPECTED("Error: 'operations' is not set. \n");
|
||||||
|
|
||||||
|
@ -456,7 +461,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
|
||||||
} else if (key == "columns_order") {
|
} else if (key == "columns_order") {
|
||||||
project_columns = ToStringVector(value);
|
project_columns = ToStringVector(value);
|
||||||
} else if (key == "num_parallel_workers") {
|
} else if (key == "num_parallel_workers") {
|
||||||
(void)map_builder.SetNumWorkers(ToInt(value));
|
num_workers = ToInt(value);
|
||||||
|
(void)map_builder.SetNumWorkers(num_workers);
|
||||||
} else if (key == "prefetch_size") {
|
} else if (key == "prefetch_size") {
|
||||||
(void)map_builder.SetOpConnectorSize(ToInt(value));
|
(void)map_builder.SetOpConnectorSize(ToInt(value));
|
||||||
} else if (key == "operations") {
|
} else if (key == "operations") {
|
||||||
|
@ -477,6 +483,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
|
||||||
}
|
}
|
||||||
if (tensor_op_list.empty()) RETURN_STATUS_UNEXPECTED("Error: tensor_op is invalid or not set.");
|
if (tensor_op_list.empty()) RETURN_STATUS_UNEXPECTED("Error: tensor_op is invalid or not set.");
|
||||||
(void)map_builder.SetTensorFuncs(std::move(tensor_op_list));
|
(void)map_builder.SetTensorFuncs(std::move(tensor_op_list));
|
||||||
|
} else if (key == "cache") {
|
||||||
|
cache_client = value.cast<std::shared_ptr<CacheClient>>();
|
||||||
} else {
|
} else {
|
||||||
RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key);
|
RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key);
|
||||||
}
|
}
|
||||||
|
@ -499,6 +507,15 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
|
||||||
*bottom = map_op;
|
*bottom = map_op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Additionally, add a cache if required. This will go over top of the project op if one
|
||||||
|
// was created, otherwise it goes over top of the map op
|
||||||
|
if (cache_client) {
|
||||||
|
std::shared_ptr<DatasetOp> cache_op = nullptr;
|
||||||
|
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, *top, &cache_op));
|
||||||
|
*top = cache_op;
|
||||||
|
*bottom = map_op;
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -809,6 +826,9 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
|
||||||
std::shared_ptr<DatasetOp> *bottom) {
|
std::shared_ptr<DatasetOp> *bottom) {
|
||||||
// Required arguments
|
// Required arguments
|
||||||
std::vector<std::string> files_list;
|
std::vector<std::string> files_list;
|
||||||
|
std::shared_ptr<CacheClient> cache_client = nullptr;
|
||||||
|
std::shared_ptr<Sampler> sampler = nullptr;
|
||||||
|
int num_workers = 0;
|
||||||
std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>();
|
std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>();
|
||||||
if (!args["dataset_files"].is_none()) {
|
if (!args["dataset_files"].is_none()) {
|
||||||
files_list = ToStringVector(args["dataset_files"]);
|
files_list = ToStringVector(args["dataset_files"]);
|
||||||
|
@ -828,7 +848,8 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
|
||||||
py::handle value = arg.second;
|
py::handle value = arg.second;
|
||||||
if (!value.is_none()) {
|
if (!value.is_none()) {
|
||||||
if (key == "num_parallel_workers") {
|
if (key == "num_parallel_workers") {
|
||||||
(void)builder->SetNumWorkers(ToInt(value));
|
num_workers = ToInt(value);
|
||||||
|
(void)builder->SetNumWorkers(num_workers);
|
||||||
} else if (key == "columns_list") {
|
} else if (key == "columns_list") {
|
||||||
columns_to_load = ToStringVector(value);
|
columns_to_load = ToStringVector(value);
|
||||||
(void)builder->SetColumnsToLoad(columns_to_load);
|
(void)builder->SetColumnsToLoad(columns_to_load);
|
||||||
|
@ -848,6 +869,11 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
|
||||||
(void)builder->SetDeviceId(ToInt(value));
|
(void)builder->SetDeviceId(ToInt(value));
|
||||||
} else if (key == "shard_equal_rows") {
|
} else if (key == "shard_equal_rows") {
|
||||||
(void)builder->SetShardEqualRows(ToBool(value));
|
(void)builder->SetShardEqualRows(ToBool(value));
|
||||||
|
} else if (key == "cache") {
|
||||||
|
cache_client = value.cast<std::shared_ptr<CacheClient>>();
|
||||||
|
} else if (key == "sampler") {
|
||||||
|
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||||
|
sampler = create().cast<std::shared_ptr<Sampler>>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -860,12 +886,27 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
|
||||||
}
|
}
|
||||||
(void)builder->SetDataSchema(std::move(schema));
|
(void)builder->SetDataSchema(std::move(schema));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed
|
||||||
|
// because TFReaderOp is a non-mappable dataset that does not support sampling.
|
||||||
|
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||||
|
// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||||
|
// That is why we save the sampler here in a leaf node that does not use sampling.
|
||||||
|
if (sampler) {
|
||||||
|
(void)builder->SetSampler(std::move(sampler));
|
||||||
|
} else if (cache_client) {
|
||||||
|
int64_t num_samples = 0;
|
||||||
|
int64_t start_index = 0;
|
||||||
|
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
|
||||||
|
(void)builder->SetSampler(std::move(sampler));
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<TFReaderOp> tf_op;
|
std::shared_ptr<TFReaderOp> tf_op;
|
||||||
RETURN_IF_NOT_OK(builder->Build(&tf_op));
|
RETURN_IF_NOT_OK(builder->Build(&tf_op));
|
||||||
RETURN_IF_NOT_OK(tree_->AssociateNode(tf_op));
|
RETURN_IF_NOT_OK(tree_->AssociateNode(tf_op));
|
||||||
*top = tf_op;
|
*top = tf_op;
|
||||||
|
|
||||||
if (shuffle_required) {
|
if (!cache_client && shuffle_required) {
|
||||||
const boolean estimate = true;
|
const boolean estimate = true;
|
||||||
const int64_t workers = 8;
|
const int64_t workers = 8;
|
||||||
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||||
|
@ -882,6 +923,15 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
|
||||||
*bottom = tf_op;
|
*bottom = tf_op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add a cache op over this op if required and update the output subtree (top/bottom)
|
||||||
|
if (cache_client) {
|
||||||
|
// Note, it is not allowed to have both shuffle and cache
|
||||||
|
std::shared_ptr<DatasetOp> cache_op = nullptr;
|
||||||
|
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, tf_op, &cache_op));
|
||||||
|
*top = cache_op;
|
||||||
|
*bottom = tf_op;
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -906,6 +956,8 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data
|
||||||
std::string err_msg = "Error: No dataset path specified";
|
std::string err_msg = "Error: No dataset path specified";
|
||||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||||
}
|
}
|
||||||
|
int num_workers = 0;
|
||||||
|
std::shared_ptr<CacheClient> cache_client = nullptr;
|
||||||
std::shared_ptr<ImageFolderOp::Builder> builder = std::make_shared<ImageFolderOp::Builder>();
|
std::shared_ptr<ImageFolderOp::Builder> builder = std::make_shared<ImageFolderOp::Builder>();
|
||||||
(void)builder->SetImageFolderDir(ToString(args["dataset_dir"]));
|
(void)builder->SetImageFolderDir(ToString(args["dataset_dir"]));
|
||||||
|
|
||||||
|
@ -915,7 +967,8 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data
|
||||||
py::handle value = arg.second;
|
py::handle value = arg.second;
|
||||||
if (!value.is_none()) {
|
if (!value.is_none()) {
|
||||||
if (key == "num_parallel_workers") {
|
if (key == "num_parallel_workers") {
|
||||||
(void)builder->SetNumWorkers(ToInt(value));
|
num_workers = ToInt(value);
|
||||||
|
(void)builder->SetNumWorkers(num_workers);
|
||||||
} else if (key == "sampler") {
|
} else if (key == "sampler") {
|
||||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||||
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
|
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
|
||||||
|
@ -926,12 +979,27 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data
|
||||||
(void)builder->SetClassIndex(ToStringMap(value));
|
(void)builder->SetClassIndex(ToStringMap(value));
|
||||||
} else if (key == "decode") {
|
} else if (key == "decode") {
|
||||||
(void)builder->SetDecode(ToBool(value));
|
(void)builder->SetDecode(ToBool(value));
|
||||||
|
} else if (key == "cache") {
|
||||||
|
cache_client = value.cast<std::shared_ptr<CacheClient>>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::shared_ptr<ImageFolderOp> op;
|
std::shared_ptr<ImageFolderOp> if_op;
|
||||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
RETURN_IF_NOT_OK(builder->Build(&if_op));
|
||||||
*top = op;
|
RETURN_IF_NOT_OK(tree_->AssociateNode(if_op));
|
||||||
|
*top = if_op;
|
||||||
|
|
||||||
|
// Additionally, add a cache if required.
|
||||||
|
// Note that this cache op is only acting as a place holder for the caching position
|
||||||
|
// within the tree. Later, a pre-pass will execute a tree transform to set up the actual
|
||||||
|
// caching logic in the tree.
|
||||||
|
if (cache_client) {
|
||||||
|
std::shared_ptr<DatasetOp> cache_op = nullptr;
|
||||||
|
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, if_op, &cache_op));
|
||||||
|
*top = cache_op;
|
||||||
|
*bottom = if_op;
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1130,9 +1198,12 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
|
||||||
std::shared_ptr<DatasetOp> *bottom) {
|
std::shared_ptr<DatasetOp> *bottom) {
|
||||||
// Required arguments
|
// Required arguments
|
||||||
RandomDataOp::Builder builder;
|
RandomDataOp::Builder builder;
|
||||||
|
std::shared_ptr<CacheClient> cache_client = nullptr;
|
||||||
|
std::shared_ptr<Sampler> sampler = nullptr;
|
||||||
|
int num_workers = 0;
|
||||||
|
|
||||||
if (args["num_samples"].is_none()) {
|
if (args["total_rows"].is_none()) {
|
||||||
std::string err_msg = "Error: num_samples is a required argument";
|
std::string err_msg = "Error: total_rows is a required argument";
|
||||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||||
}
|
}
|
||||||
std::vector<std::string> columns_to_load;
|
std::vector<std::string> columns_to_load;
|
||||||
|
@ -1141,16 +1212,23 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
|
||||||
for (auto arg : args) {
|
for (auto arg : args) {
|
||||||
std::string key = py::str(arg.first);
|
std::string key = py::str(arg.first);
|
||||||
py::handle value = arg.second;
|
py::handle value = arg.second;
|
||||||
if (key == "num_parallel_workers") {
|
if (!value.is_none()) {
|
||||||
(void)builder.SetNumWorkers(ToInt(value));
|
if (key == "num_parallel_workers") {
|
||||||
} else if (key == "schema_file_path" || key == "schema_json_string") {
|
num_workers = ToInt(value);
|
||||||
schema_exists = true;
|
(void)builder.SetNumWorkers(num_workers);
|
||||||
} else if (key == "columns_list") {
|
} else if (key == "schema_file_path" || key == "schema_json_string") {
|
||||||
columns_to_load = ToStringVector(value);
|
schema_exists = true;
|
||||||
} else if (key == "num_samples") {
|
} else if (key == "columns_list") {
|
||||||
// This is not sampling here. The random data op needs to know how much data to
|
columns_to_load = ToStringVector(value);
|
||||||
// generate. It does not currently support sampling.
|
} else if (key == "total_rows") {
|
||||||
(void)builder.SetTotalRows(ToInt(value));
|
// This is not sampling here. The random data op needs to know how much data to generate.
|
||||||
|
(void)builder.SetTotalRows(ToInt(value));
|
||||||
|
} else if (key == "cache") {
|
||||||
|
cache_client = value.cast<std::shared_ptr<CacheClient>>();
|
||||||
|
} else if (key == "sampler") {
|
||||||
|
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||||
|
sampler = create().cast<std::shared_ptr<Sampler>>();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (schema_exists) {
|
if (schema_exists) {
|
||||||
|
@ -1162,9 +1240,34 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
|
||||||
}
|
}
|
||||||
(void)builder.SetDataSchema(std::move(schema));
|
(void)builder.SetDataSchema(std::move(schema));
|
||||||
}
|
}
|
||||||
std::shared_ptr<RandomDataOp> op;
|
|
||||||
RETURN_IF_NOT_OK(builder.Build(&op));
|
// If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed
|
||||||
*top = op;
|
// because RandomDataOp is a non-mappable dataset that does not support sampling.
|
||||||
|
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||||
|
// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||||
|
// That is why we save the sampler here in a leaf node that does not use sampling.
|
||||||
|
if (sampler) {
|
||||||
|
(void)builder.SetSampler(std::move(sampler));
|
||||||
|
} else if (cache_client) {
|
||||||
|
int64_t num_samples = 0;
|
||||||
|
int64_t start_index = 0;
|
||||||
|
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
|
||||||
|
(void)builder.SetSampler(std::move(sampler));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<RandomDataOp> random_op = nullptr;
|
||||||
|
RETURN_IF_NOT_OK(builder.Build(&random_op));
|
||||||
|
RETURN_IF_NOT_OK(tree_->AssociateNode(random_op));
|
||||||
|
*top = random_op;
|
||||||
|
|
||||||
|
// Add a cache op over this op if required and update the output subtree (top/bottom)
|
||||||
|
if (cache_client) {
|
||||||
|
std::shared_ptr<DatasetOp> cache_op = nullptr;
|
||||||
|
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, random_op, &cache_op));
|
||||||
|
*top = cache_op;
|
||||||
|
*bottom = random_op;
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1425,6 +1528,31 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper function to inject the cache operator over top of the current operation being built.
|
||||||
|
Status DEPipeline::AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num_workers,
|
||||||
|
std::shared_ptr<DatasetOp> input_op, std::shared_ptr<DatasetOp> *cache_op) {
|
||||||
|
std::shared_ptr<CacheOp> new_cache_op = nullptr;
|
||||||
|
CacheOp::Builder cache_builder;
|
||||||
|
// use the same number of workers as the leaf. We need some optimization here, the user does not
|
||||||
|
// give the cache op number of workers directly.
|
||||||
|
if (num_workers != 0) {
|
||||||
|
(void)cache_builder.SetNumWorkers(num_workers);
|
||||||
|
}
|
||||||
|
(void)cache_builder.SetClient(cache_client);
|
||||||
|
RETURN_IF_NOT_OK(cache_builder.Build(&new_cache_op));
|
||||||
|
RETURN_IF_NOT_OK(tree_->AssociateNode(new_cache_op));
|
||||||
|
RETURN_IF_NOT_OK(new_cache_op->AddChild(input_op));
|
||||||
|
// We have now created:
|
||||||
|
//
|
||||||
|
// CacheOp
|
||||||
|
// |
|
||||||
|
// input_op
|
||||||
|
//
|
||||||
|
*cache_op = new_cache_op;
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// Helper function to inject a shuffle operator over top of the current operation being built.
|
// Helper function to inject a shuffle operator over top of the current operation being built.
|
||||||
Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr<DatasetOp> input_op,
|
Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr<DatasetOp> input_op,
|
||||||
std::shared_ptr<DatasetOp> *shuffle_op) {
|
std::shared_ptr<DatasetOp> *shuffle_op) {
|
||||||
|
|
|
@ -35,6 +35,8 @@ namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
using DsOpPtr = std::shared_ptr<DatasetOp>;
|
using DsOpPtr = std::shared_ptr<DatasetOp>;
|
||||||
|
|
||||||
|
class CacheClient;
|
||||||
|
|
||||||
// enum for the dataset operator names
|
// enum for the dataset operator names
|
||||||
enum OpName {
|
enum OpName {
|
||||||
kShuffle,
|
kShuffle,
|
||||||
|
@ -181,6 +183,16 @@ class DEPipeline {
|
||||||
|
|
||||||
static Status ParsePadInfo(py::handle value, PadInfo *pad_info);
|
static Status ParsePadInfo(py::handle value, PadInfo *pad_info);
|
||||||
|
|
||||||
|
/// \brief Helper function to inject a cache operator over top of the current operation being built.
|
||||||
|
/// \param[in] cache_client The client to use for caching
|
||||||
|
/// \param[in] num_workers The number of workers to use in the cache op
|
||||||
|
/// \param[in] input_op The operator to build the cache on top of
|
||||||
|
/// \param[out] cache_op The top node of the created subtree (subtree contains two nodes). In this case it will be
|
||||||
|
/// the cache operator
|
||||||
|
/// \return Status return code
|
||||||
|
Status AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num_workers, std::shared_ptr<DatasetOp> input_op,
|
||||||
|
std::shared_ptr<DatasetOp> *cache_op);
|
||||||
|
|
||||||
/// \brief Helper function to inject a shuffle operator over top of the current operation being built.
|
/// \brief Helper function to inject a shuffle operator over top of the current operation being built.
|
||||||
/// \param[in] shuffle_size The size to use in the shuffle buffer
|
/// \param[in] shuffle_size The size to use in the shuffle buffer
|
||||||
/// \param[in] input_op The operator to build shuffle on top of
|
/// \param[in] input_op The operator to build shuffle on top of
|
||||||
|
|
|
@ -35,6 +35,7 @@
|
||||||
#include "dataset/engine/datasetops/source/text_file_op.h"
|
#include "dataset/engine/datasetops/source/text_file_op.h"
|
||||||
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
||||||
#include "dataset/engine/datasetops/source/voc_op.h"
|
#include "dataset/engine/datasetops/source/voc_op.h"
|
||||||
|
#include "dataset/engine/cache/cache_client.h"
|
||||||
#include "dataset/engine/gnn/graph.h"
|
#include "dataset/engine/gnn/graph.h"
|
||||||
#include "dataset/engine/jagged_connector.h"
|
#include "dataset/engine/jagged_connector.h"
|
||||||
#include "dataset/kernels/data/concatenate_op.h"
|
#include "dataset/kernels/data/concatenate_op.h"
|
||||||
|
@ -768,6 +769,11 @@ void bindInfoObjects(py::module *m) {
|
||||||
.def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num);
|
.def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void bindCacheClient(py::module *m) {
|
||||||
|
(void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient")
|
||||||
|
.def(py::init<uint32_t, uint64_t, bool>());
|
||||||
|
}
|
||||||
|
|
||||||
void bindVocabObjects(py::module *m) {
|
void bindVocabObjects(py::module *m) {
|
||||||
(void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab")
|
(void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab")
|
||||||
.def(py::init<>())
|
.def(py::init<>())
|
||||||
|
@ -939,6 +945,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
|
||||||
bindSamplerOps(&m);
|
bindSamplerOps(&m);
|
||||||
bindDatasetOps(&m);
|
bindDatasetOps(&m);
|
||||||
bindInfoObjects(&m);
|
bindInfoObjects(&m);
|
||||||
|
bindCacheClient(&m);
|
||||||
bindVocabObjects(&m);
|
bindVocabObjects(&m);
|
||||||
bindGraphData(&m);
|
bindGraphData(&m);
|
||||||
bindDependIcuTokenizerOps(&m);
|
bindDependIcuTokenizerOps(&m);
|
||||||
|
|
|
@ -2,6 +2,7 @@ add_subdirectory(datasetops)
|
||||||
add_subdirectory(opt)
|
add_subdirectory(opt)
|
||||||
add_subdirectory(gnn)
|
add_subdirectory(gnn)
|
||||||
add_subdirectory(perf)
|
add_subdirectory(perf)
|
||||||
|
add_subdirectory(cache)
|
||||||
if (ENABLE_TDTQUE)
|
if (ENABLE_TDTQUE)
|
||||||
add_subdirectory(tdt)
|
add_subdirectory(tdt)
|
||||||
endif ()
|
endif ()
|
||||||
|
@ -17,7 +18,9 @@ add_library(engine OBJECT
|
||||||
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})
|
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})
|
||||||
|
|
||||||
if (ENABLE_TDTQUE)
|
if (ENABLE_TDTQUE)
|
||||||
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf)
|
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf
|
||||||
else()
|
engine-cache-client engine-cache-server)
|
||||||
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf)
|
else ()
|
||||||
|
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf
|
||||||
|
engine-cache-client engine-cache-server)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||||
|
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||||
|
add_library(engine-cache-client OBJECT
|
||||||
|
cache_client.cc
|
||||||
|
cache_request.cc)
|
||||||
|
add_library(engine-cache-server OBJECT
|
||||||
|
cache_service.cc
|
||||||
|
cache_server.cc)
|
|
@ -0,0 +1,208 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <iomanip>
|
||||||
|
#include "dataset/engine/cache/cache_client.h"
|
||||||
|
#include "dataset/engine/cache/cache_request.h"
|
||||||
|
#include "dataset/util/bit.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
// Constructor
|
||||||
|
CacheClient::CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill)
|
||||||
|
: server_connection_id_(0), session_id_(session_id), cache_crc_(0), cache_mem_sz_(cache_mem_sz), spill_(spill) {}
|
||||||
|
|
||||||
|
// print method for display cache details
|
||||||
|
void CacheClient::Print(std::ostream &out) const {
|
||||||
|
out << " Session id: " << session_id_ << "\n Cache crc: " << cache_crc_
|
||||||
|
<< "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << cache_mem_sz_
|
||||||
|
<< "\n Spilling: " << std::boolalpha << spill_;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const {
|
||||||
|
CacheRowRequest rq(server_connection_id_, cookie());
|
||||||
|
RETURN_IF_NOT_OK(rq.SerializeCacheRowRequest(row));
|
||||||
|
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
|
||||||
|
RETURN_IF_NOT_OK(rq.Wait());
|
||||||
|
if (row_id_from_server != nullptr) {
|
||||||
|
*row_id_from_server = rq.GetRowIdAfterCache();
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const {
|
||||||
|
std::unique_ptr<DataBuffer> db_ptr = std::move(in);
|
||||||
|
auto num_rows = db_ptr->NumRows();
|
||||||
|
std::vector<TensorRow> all_rows;
|
||||||
|
if (num_rows > 0) {
|
||||||
|
all_rows.reserve(num_rows);
|
||||||
|
// Break down the DataBuffer into TensorRow. We will send the requests async
|
||||||
|
// and then do a final wait.
|
||||||
|
MemGuard<CacheRowRequest> rq_arr;
|
||||||
|
RETURN_IF_NOT_OK(rq_arr.allocate(num_rows, server_connection_id_, cookie()));
|
||||||
|
CacheServer &cs = CacheServer::GetInstance();
|
||||||
|
for (auto i = 0; i < num_rows; ++i) {
|
||||||
|
TensorRow row;
|
||||||
|
auto rq = rq_arr[i];
|
||||||
|
RETURN_IF_NOT_OK(db_ptr->PopRow(&row));
|
||||||
|
RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(row));
|
||||||
|
RETURN_IF_NOT_OK(cs.PushRequest(rq));
|
||||||
|
// We can't let row go out of scope. Otherwise it will free all the tensor memory.
|
||||||
|
// So park it in the vector. When this function go out of scope, its memory
|
||||||
|
// will be freed.
|
||||||
|
all_rows.push_back(std::move(row));
|
||||||
|
}
|
||||||
|
// Now we wait for the requests to be done.
|
||||||
|
for (auto i = 0; i < num_rows; ++i) {
|
||||||
|
auto rq = rq_arr[i];
|
||||||
|
RETURN_IF_NOT_OK(rq->Wait());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(out);
|
||||||
|
BatchFetchRequest rq(server_connection_id_, row_id);
|
||||||
|
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
|
||||||
|
RETURN_IF_NOT_OK(rq.Wait());
|
||||||
|
RETURN_IF_NOT_OK(rq.RestoreRows(out));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
|
||||||
|
UniqueLock lck(&mux_);
|
||||||
|
// To create a cache, we identify ourself at the client by:
|
||||||
|
// - the shared session id
|
||||||
|
// - a crc for the tree nodes from the cache downward
|
||||||
|
// Pack these 2 into a single 64 bit request id
|
||||||
|
//
|
||||||
|
// Consider this example:
|
||||||
|
// tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch
|
||||||
|
// tree2: cifar10 --> map(rotate) --> cache (session id = 1, crc = 456) --> batch
|
||||||
|
// These are different trees in a single session, but the user wants to share the cache.
|
||||||
|
// This is not allowed because the data of these caches are different.
|
||||||
|
//
|
||||||
|
// Consider this example:
|
||||||
|
// tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch
|
||||||
|
// tree2: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> map(rotate) --> batch
|
||||||
|
// These are different trees in the same session, but the cached data is the same, so it is okay
|
||||||
|
// to allow the sharing of this cache between these pipelines.
|
||||||
|
|
||||||
|
// The CRC is computed by the tree prepare phase and passed to this function when creating the cache.
|
||||||
|
// If we already have a server_connection_id_, then it means this same cache client has already been used
|
||||||
|
// to create a cache and some other tree is trying to use the same cache.
|
||||||
|
// That is allowed, however the crc better match!
|
||||||
|
if (server_connection_id_) {
|
||||||
|
if (cache_crc_ != tree_crc) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Attempt to re-use a cache for a different tree!");
|
||||||
|
}
|
||||||
|
// Check the state of the server. For non-mappable case where there is a build phase and a fetch phase, we should
|
||||||
|
// skip the build phase.
|
||||||
|
lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock.
|
||||||
|
CacheClient::ServiceStat stat{};
|
||||||
|
RETURN_IF_NOT_OK(GetStat(&stat));
|
||||||
|
if (stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase)) {
|
||||||
|
return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cache_crc_ = tree_crc; // It's really a new cache we're creating so save our crc in the client
|
||||||
|
// Combine the session and crc. This will form our client cache identifier.
|
||||||
|
connection_id_type connection_identification = (static_cast<uint64_t>(session_id_) << 32) | cache_crc_;
|
||||||
|
// Now execute the cache create request using this identifier and other configs
|
||||||
|
BaseRequest::CreateCacheFlag createFlag = BaseRequest::CreateCacheFlag::kNone;
|
||||||
|
if (spill_) {
|
||||||
|
createFlag |= BaseRequest::CreateCacheFlag::kSpillToDisk;
|
||||||
|
}
|
||||||
|
if (generate_id) {
|
||||||
|
createFlag |= BaseRequest::CreateCacheFlag::kGenerateRowId;
|
||||||
|
}
|
||||||
|
CreationCacheRequest rq(connection_identification, cache_mem_sz_, createFlag);
|
||||||
|
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
|
||||||
|
Status rc = rq.Wait();
|
||||||
|
if (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey) {
|
||||||
|
server_connection_id_ = rq.GetServerConnectionId();
|
||||||
|
if (rc.IsOk()) {
|
||||||
|
// The 1st guy creating the cache will get a cookie back.
|
||||||
|
// But this object may be shared among pipelines and we don't want
|
||||||
|
// overwrite it.
|
||||||
|
cookie_ = rq.cookie();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the
|
||||||
|
// CacheOp to bypass the build phase.
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CacheClient::PurgeCache() {
|
||||||
|
UniqueLock lck(&mux_);
|
||||||
|
PurgeCacheRequest rq(server_connection_id_);
|
||||||
|
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
|
||||||
|
return rq.Wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CacheClient::DestroyCache() {
|
||||||
|
UniqueLock lck(&mux_);
|
||||||
|
DestroyCacheRequest rq(server_connection_id_);
|
||||||
|
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
|
||||||
|
return rq.Wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CacheClient::GetStat(ServiceStat *stat) {
|
||||||
|
SharedLock lck(&mux_);
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(stat);
|
||||||
|
GetStatRequest rq(server_connection_id_);
|
||||||
|
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
|
||||||
|
RETURN_IF_NOT_OK(rq.Wait());
|
||||||
|
stat->num_disk_cached = rq.GetNumDiskCached();
|
||||||
|
stat->num_mem_cached = rq.GetNumMemCached();
|
||||||
|
stat->min_row_id = rq.GetMinRowId();
|
||||||
|
stat->max_row_id = rq.GetMaxRowId();
|
||||||
|
stat->cache_service_state = rq.GetState();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CacheClient::CacheSchema(const std::unordered_map<std::string, int32_t> &map) {
|
||||||
|
SharedLock lck(&mux_);
|
||||||
|
CacheSchemaRequest rq(server_connection_id_);
|
||||||
|
RETURN_IF_NOT_OK(rq.SerializeCacheSchemaRequest(map));
|
||||||
|
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
|
||||||
|
RETURN_IF_NOT_OK(rq.Wait());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CacheClient::FetchSchema(std::unordered_map<std::string, int32_t> *map) {
|
||||||
|
SharedLock lck(&mux_);
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(map);
|
||||||
|
FetchSchemaRequest rq(server_connection_id_);
|
||||||
|
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
|
||||||
|
RETURN_IF_NOT_OK(rq.Wait());
|
||||||
|
*map = rq.GetColumnMap();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CacheClient::BuildPhaseDone() const {
|
||||||
|
SharedLock lck(&mux_);
|
||||||
|
BuildPhaseDoneRequest rq(server_connection_id_, cookie());
|
||||||
|
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
|
||||||
|
RETURN_IF_NOT_OK(rq.Wait());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,141 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef DATASET_ENGINE_CACHE_CLIENT_H_
|
||||||
|
#define DATASET_ENGINE_CACHE_CLIENT_H_
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "./de_tensor_generated.h"
|
||||||
|
#include "dataset/engine/data_buffer.h"
|
||||||
|
#include "dataset/engine/cache/cache_server.h"
|
||||||
|
#include "dataset/util/lock.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
/// \brief A CacheClient is a bridge between a DatasetOp and a CacheServer. All communications are through
|
||||||
|
/// a CacheClient. Typical tasks including like creating a cache service, cache a data buffer, restore a previously
|
||||||
|
/// rows, etc.
|
||||||
|
class CacheClient {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor
|
||||||
|
/// \param session_id A user assigned session id for the current pipeline
|
||||||
|
/// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited
|
||||||
|
/// \param spill Spill to disk if out of memory
|
||||||
|
CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill);
|
||||||
|
|
||||||
|
/// \brief Destructor
|
||||||
|
~CacheClient() = default;
|
||||||
|
|
||||||
|
/// \brief Getter function for returning the current session id
|
||||||
|
/// \return session id
|
||||||
|
uint64_t session_id() const { return session_id_; }
|
||||||
|
|
||||||
|
/// \brief Send a TensorRow to the cache server
|
||||||
|
/// \param[in] row
|
||||||
|
/// \param[out] row_id_from_server Optional. The row id assigned by the server for non-mappable dataset
|
||||||
|
/// \return return code
|
||||||
|
Status WriteRow(const TensorRow &row, row_id_type *row_id_from_server = nullptr) const;
|
||||||
|
|
||||||
|
/// \brief Send a DataBuffer to the cache server
|
||||||
|
/// \param in Unique pointer of the DataBuffer to be cached
|
||||||
|
/// \return return code
|
||||||
|
Status WriteBuffer(std::unique_ptr<DataBuffer> &&in) const;
|
||||||
|
|
||||||
|
/// \brief Fetch a list of rows from the cache server. An empty TensorRow will be returned if there is
|
||||||
|
/// any cache miss
|
||||||
|
/// \param row_id A vector of row id's
|
||||||
|
/// \param out A TensorTable of TensorRows.
|
||||||
|
/// \return return code
|
||||||
|
Status GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const;
|
||||||
|
|
||||||
|
/// \brief Create a cache.
|
||||||
|
/// \param tree_crc A crc that was generated during tree prepare phase
|
||||||
|
/// \param generate_id Let the cache service generate row id
|
||||||
|
/// \return Status object
|
||||||
|
Status CreateCache(uint32_t tree_crc, bool generate_id);
|
||||||
|
|
||||||
|
/// \brief Purge a cache. Cache can be reused after reset.
|
||||||
|
/// \return Status object
|
||||||
|
Status PurgeCache();
|
||||||
|
|
||||||
|
/// \brief Destroy a cache. Like Purge but the cache is deleted and can't be reused.
|
||||||
|
/// \return Status object
|
||||||
|
Status DestroyCache();
|
||||||
|
|
||||||
|
/// \brief Get the statistics from a cache.
|
||||||
|
/// \param[in/out] Pointer to a pre-allocated ServiceStat object
|
||||||
|
/// \return Status object
|
||||||
|
struct ServiceStat {
|
||||||
|
int64_t num_mem_cached;
|
||||||
|
int64_t num_disk_cached;
|
||||||
|
row_id_type min_row_id;
|
||||||
|
row_id_type max_row_id;
|
||||||
|
int8_t cache_service_state;
|
||||||
|
};
|
||||||
|
Status GetStat(ServiceStat *);
|
||||||
|
|
||||||
|
/// \brief Cache the schema at the cache server
|
||||||
|
/// \param map The unordered map of the schema
|
||||||
|
/// \return Status object
|
||||||
|
Status CacheSchema(const std::unordered_map<std::string, int32_t> &map);
|
||||||
|
|
||||||
|
/// \brief Fetch the schema from the cache server
|
||||||
|
/// \param map Pointer to pre-allocated map object
|
||||||
|
/// \return Status object.
|
||||||
|
Status FetchSchema(std::unordered_map<std::string, int32_t> *map);
|
||||||
|
|
||||||
|
/// \brief Change the state from build phase to read phase. Applicable to non-mappable dataset only. Only the cache
|
||||||
|
/// client that holds cookie can be allowed to make this request
|
||||||
|
/// \return Status object
|
||||||
|
Status BuildPhaseDone() const;
|
||||||
|
|
||||||
|
/// \brief A print method typically used for debugging
|
||||||
|
/// \param out The output stream to write output to
|
||||||
|
void Print(std::ostream &out) const;
|
||||||
|
|
||||||
|
/// \brief Stream output operator overload
|
||||||
|
/// \return the output stream must be returned
|
||||||
|
friend std::ostream &operator<<(std::ostream &out, const CacheClient &cc) {
|
||||||
|
cc.Print(out);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Every cache server has a cookie which uniquely identifies the CacheClient that creates it.
|
||||||
|
/// \return Cookie
|
||||||
|
std::string cookie() const { return cookie_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
mutable RWLock mux_;
|
||||||
|
uint64_t cache_mem_sz_;
|
||||||
|
bool spill_;
|
||||||
|
// The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow
|
||||||
|
// sharing of the cache.
|
||||||
|
uint32_t session_id_;
|
||||||
|
uint32_t cache_crc_;
|
||||||
|
// The server_connection_id_ is the actual id we use for operations after the cache is built
|
||||||
|
connection_id_type server_connection_id_;
|
||||||
|
// Some magic cookie returned from the cache server.
|
||||||
|
std::string cookie_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // DATASET_ENGINE_CACHE_CLIENT_H_
|
|
@ -0,0 +1,223 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "dataset/engine/cache/cache_request.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
Status CacheRowRequest::SerializeCacheRowRequest(const TensorRow &row) {
|
||||||
|
buffers_.reserve(row.size() + 1);
|
||||||
|
RETURN_IF_NOT_OK(SerializeTensorRowHeader(row));
|
||||||
|
buffers_.push_back(fbb_->GetBufferPointer());
|
||||||
|
for (const auto &ts : row) {
|
||||||
|
buffers_.push_back(ts->GetBuffer());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CacheRowRequest::SerializeTensorRowHeader(const TensorRow &row) {
|
||||||
|
try {
|
||||||
|
fbb_ = std::make_shared<flatbuffers::FlatBufferBuilder>();
|
||||||
|
std::vector<flatbuffers::Offset<TensorMetaMsg>> v;
|
||||||
|
std::vector<int64_t> tensor_sz;
|
||||||
|
v.reserve(row.size());
|
||||||
|
tensor_sz.reserve(row.size());
|
||||||
|
// We will go through each column in the row.
|
||||||
|
for (const std::shared_ptr<Tensor> &ts_ptr : row) {
|
||||||
|
flatbuffers::Offset<TensorMetaMsg> ts_off;
|
||||||
|
RETURN_IF_NOT_OK(SerializeOneTensorMeta(ts_ptr, &ts_off));
|
||||||
|
v.push_back(ts_off);
|
||||||
|
tensor_sz.push_back(ts_ptr->SizeInBytes());
|
||||||
|
}
|
||||||
|
auto column_off = fbb_->CreateVector(v);
|
||||||
|
auto data_sz_off = fbb_->CreateVector(tensor_sz);
|
||||||
|
TensorRowHeaderMsgBuilder row_builder(*fbb_);
|
||||||
|
row_builder.add_column(column_off);
|
||||||
|
row_builder.add_data_sz(data_sz_off);
|
||||||
|
// Pass the row_id even if it may not be known.
|
||||||
|
row_builder.add_row_id(row.getId());
|
||||||
|
row_builder.add_size_of_this(-1); // fill in later after we call Finish.
|
||||||
|
auto out = row_builder.Finish();
|
||||||
|
fbb_->Finish(out);
|
||||||
|
// Now go back to fill in size_of_this in the flat buffer.
|
||||||
|
auto msg = GetMutableTensorRowHeaderMsg(fbb_->GetBufferPointer());
|
||||||
|
auto success = msg->mutate_size_of_this(fbb_->GetSize());
|
||||||
|
if (!success) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Unable to set size_of_this");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
} catch (const std::bad_alloc &e) {
|
||||||
|
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CacheRowRequest::SerializeOneTensorMeta(const std::shared_ptr<Tensor> &ts_ptr,
|
||||||
|
flatbuffers::Offset<TensorMetaMsg> *out_off) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(out_off);
|
||||||
|
const Tensor *ts = ts_ptr.get();
|
||||||
|
auto shape_off = fbb_->CreateVector(ts->shape().AsVector());
|
||||||
|
const auto ptr = ts->GetBuffer();
|
||||||
|
if (ptr == nullptr) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Tensor buffer is null");
|
||||||
|
}
|
||||||
|
auto src = ts->type().value();
|
||||||
|
TensorType dest;
|
||||||
|
#define CASE(t) \
|
||||||
|
case DataType::t: \
|
||||||
|
dest = TensorType::TensorType_##t; \
|
||||||
|
break
|
||||||
|
// Map the type to fill in the flat buffer.
|
||||||
|
switch (src) {
|
||||||
|
CASE(DE_BOOL);
|
||||||
|
CASE(DE_INT8);
|
||||||
|
CASE(DE_UINT8);
|
||||||
|
CASE(DE_INT16);
|
||||||
|
CASE(DE_UINT16);
|
||||||
|
CASE(DE_INT32);
|
||||||
|
CASE(DE_UINT32);
|
||||||
|
CASE(DE_INT64);
|
||||||
|
CASE(DE_UINT64);
|
||||||
|
CASE(DE_FLOAT16);
|
||||||
|
CASE(DE_FLOAT32);
|
||||||
|
CASE(DE_FLOAT64);
|
||||||
|
CASE(DE_STRING);
|
||||||
|
default:
|
||||||
|
MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts;
|
||||||
|
RETURN_STATUS_UNEXPECTED("Unknown type");
|
||||||
|
}
|
||||||
|
#undef CASE
|
||||||
|
|
||||||
|
TensorMetaMsgBuilder ts_builder(*fbb_);
|
||||||
|
ts_builder.add_dims(shape_off);
|
||||||
|
ts_builder.add_type(dest);
|
||||||
|
auto ts_off = ts_builder.Finish();
|
||||||
|
*out_off = ts_off;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status BatchFetchRequest::RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data,
|
||||||
|
std::shared_ptr<Tensor> *out) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(col_ts);
|
||||||
|
auto shape_in = col_ts->dims();
|
||||||
|
auto type_in = col_ts->type();
|
||||||
|
std::vector<dsize_t> v;
|
||||||
|
v.reserve(shape_in->size());
|
||||||
|
v.assign(shape_in->begin(), shape_in->end());
|
||||||
|
TensorShape shape(v);
|
||||||
|
DataType::Type dest = DataType::DE_UNKNOWN;
|
||||||
|
#define CASE(t) \
|
||||||
|
case TensorType_##t: \
|
||||||
|
dest = DataType::Type::t; \
|
||||||
|
break
|
||||||
|
|
||||||
|
switch (type_in) {
|
||||||
|
CASE(DE_BOOL);
|
||||||
|
CASE(DE_INT8);
|
||||||
|
CASE(DE_UINT8);
|
||||||
|
CASE(DE_INT16);
|
||||||
|
CASE(DE_UINT16);
|
||||||
|
CASE(DE_INT32);
|
||||||
|
CASE(DE_UINT32);
|
||||||
|
CASE(DE_INT64);
|
||||||
|
CASE(DE_UINT64);
|
||||||
|
CASE(DE_FLOAT16);
|
||||||
|
CASE(DE_FLOAT32);
|
||||||
|
CASE(DE_FLOAT64);
|
||||||
|
CASE(DE_STRING);
|
||||||
|
}
|
||||||
|
#undef CASE
|
||||||
|
|
||||||
|
DataType type(dest);
|
||||||
|
std::shared_ptr<Tensor> ts =
|
||||||
|
std::make_shared<Tensor>(shape, type, static_cast<const unsigned char *>(data.GetPointer()), data.GetSize());
|
||||||
|
// Next we restore the real data which can be embedded or stored separately.
|
||||||
|
if (ts->SizeInBytes() != data.GetSize()) {
|
||||||
|
MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n"
|
||||||
|
<< "Dumping tensor\n"
|
||||||
|
<< *ts << "\n";
|
||||||
|
RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
|
||||||
|
}
|
||||||
|
*out = std::move(ts);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status BatchFetchRequest::RestoreRows(TensorTable *out) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(out);
|
||||||
|
auto num_elements = row_id_.size();
|
||||||
|
auto *offset_array = reinterpret_cast<const int64_t *>(mem_.GetPointer());
|
||||||
|
TensorTable tbl;
|
||||||
|
tbl.reserve(num_elements);
|
||||||
|
ReadableSlice all(mem_.GetPointer(), mem_.GetSizeInBytes());
|
||||||
|
for (auto i = 0; i < num_elements; ++i) {
|
||||||
|
auto len = offset_array[i + 1] - offset_array[i];
|
||||||
|
TensorRow row;
|
||||||
|
row.setId(row_id_.at(i));
|
||||||
|
if (len > 0) {
|
||||||
|
ReadableSlice row_data(all, offset_array[i], len);
|
||||||
|
// Next we de-serialize flat buffer to get back each column
|
||||||
|
auto msg = GetTensorRowHeaderMsg(row_data.GetPointer());
|
||||||
|
auto msg_sz = msg->size_of_this();
|
||||||
|
// Start of the tensor data
|
||||||
|
auto ts_offset = msg_sz;
|
||||||
|
row.reserve(msg->column()->size());
|
||||||
|
for (auto k = 0; k < msg->column()->size(); ++k) {
|
||||||
|
auto col_ts = msg->column()->Get(k);
|
||||||
|
std::shared_ptr<Tensor> ts;
|
||||||
|
ReadableSlice data(row_data, ts_offset, msg->data_sz()->Get(k));
|
||||||
|
RETURN_IF_NOT_OK(RestoreOneTensor(col_ts, data, &ts));
|
||||||
|
row.push_back(ts);
|
||||||
|
ts_offset += data.GetSize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tbl.push_back(std::move(row));
|
||||||
|
}
|
||||||
|
*out = std::move(tbl);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map) {
|
||||||
|
try {
|
||||||
|
fbb_ = std::make_shared<flatbuffers::FlatBufferBuilder>();
|
||||||
|
std::vector<flatbuffers::Offset<ColumnNameMsg>> v;
|
||||||
|
v.reserve(map.size());
|
||||||
|
for (auto &column : map) {
|
||||||
|
auto c = CreateColumnNameMsg(*fbb_, fbb_->CreateString(column.first), column.second);
|
||||||
|
v.push_back(c);
|
||||||
|
}
|
||||||
|
auto v_off = fbb_->CreateVector(v);
|
||||||
|
auto final_off = CreateSchemaMsg(*fbb_, v_off);
|
||||||
|
fbb_->Finish(final_off);
|
||||||
|
buf_ = fbb_->GetBufferPointer();
|
||||||
|
len_of_buf_ = fbb_->GetSize();
|
||||||
|
return Status::OK();
|
||||||
|
} catch (const std::bad_alloc &e) {
|
||||||
|
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_map<std::string, int32_t> FetchSchemaRequest::GetColumnMap() {
|
||||||
|
if (column_name_id_map_.empty()) {
|
||||||
|
auto *map_msg = flatbuffers::GetRoot<SchemaMsg>(mem_.GetPointer());
|
||||||
|
auto v = map_msg->column();
|
||||||
|
for (auto i = 0; i < v->size(); ++i) {
|
||||||
|
auto col = map_msg->column()->Get(i);
|
||||||
|
column_name_id_map_.emplace(col->name()->str(), col->id());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return column_name_id_map_;
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,225 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef DATASET_ENGINE_CACHE_REQ_H_
|
||||||
|
#define DATASET_ENGINE_CACHE_REQ_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "./de_tensor_generated.h"
|
||||||
|
#include "dataset/core/tensor_row.h"
|
||||||
|
#include "dataset/util/slice.h"
|
||||||
|
#include "dataset/util/wait_post.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
/// \brief CacheClient communicates with CacheServer using Requests.
|
||||||
|
class BaseRequest {
|
||||||
|
public:
|
||||||
|
// Request types
|
||||||
|
enum class RequestType : int16_t {
|
||||||
|
kCacheRow = 0,
|
||||||
|
kBatchFetchRows = 1,
|
||||||
|
kCreateCache = 2,
|
||||||
|
kPurgeCache = 3,
|
||||||
|
kDestroyCache = 4,
|
||||||
|
kGetStat = 5,
|
||||||
|
kCacheSchema = 6,
|
||||||
|
kFetchSchema = 7,
|
||||||
|
kBuildPhaseDone = 8,
|
||||||
|
// Add new request before it.
|
||||||
|
kRequestUnknown = 32767
|
||||||
|
};
|
||||||
|
// For kCreateCache
|
||||||
|
enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L };
|
||||||
|
friend class CacheServer;
|
||||||
|
/// \brief Base class of a cache server request
|
||||||
|
/// \param connection_id A combination of session id and crc that uniquely identifies a connection.
|
||||||
|
/// \param type Type of the request
|
||||||
|
explicit BaseRequest(connection_id_type connection_id, RequestType type)
|
||||||
|
: type_(type), connection_id_(connection_id) {}
|
||||||
|
virtual ~BaseRequest() = default;
|
||||||
|
/// \brief Wait for the completion of a request
|
||||||
|
/// \return Status returned from the cache server
|
||||||
|
Status Wait() {
|
||||||
|
RETURN_IF_NOT_OK(wp_.Wait());
|
||||||
|
return rc_;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Getter function of the current connection id
|
||||||
|
/// \return Connection id
|
||||||
|
connection_id_type GetServerConnectionId() const { return connection_id_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
RequestType type_;
|
||||||
|
connection_id_type connection_id_;
|
||||||
|
Status rc_;
|
||||||
|
WaitPost wp_;
|
||||||
|
};
|
||||||
|
/// \brief Request to cache a single TensorRow
|
||||||
|
class CacheRowRequest : public BaseRequest {
|
||||||
|
public:
|
||||||
|
friend class CacheServer;
|
||||||
|
explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie)
|
||||||
|
: BaseRequest(connection_id, RequestType::kCacheRow), row_id_from_server_(-1), cookie_(cookie) {}
|
||||||
|
~CacheRowRequest() = default;
|
||||||
|
|
||||||
|
/// \brief Serialize a TensorRow for streaming to the cache server
|
||||||
|
/// \param row TensorRow
|
||||||
|
/// \return Status object
|
||||||
|
Status SerializeCacheRowRequest(const TensorRow &row);
|
||||||
|
/// \brief Return the row id assigned to this row for non-mappable dataset
|
||||||
|
/// \return row id of the cached row
|
||||||
|
row_id_type GetRowIdAfterCache() { return row_id_from_server_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb_;
|
||||||
|
row_id_type row_id_from_server_;
|
||||||
|
std::vector<const void *> buffers_;
|
||||||
|
std::string cookie_;
|
||||||
|
|
||||||
|
/// \brief Private function to serialize one TensorRow
|
||||||
|
/// \param row TensorRow
|
||||||
|
/// \return Status object
|
||||||
|
Status SerializeTensorRowHeader(const TensorRow &row);
|
||||||
|
/// \brief Private function to serialize one Tensor
|
||||||
|
/// \param ts_ptr Tensor
|
||||||
|
/// \return Status object
|
||||||
|
Status SerializeOneTensorMeta(const std::shared_ptr<Tensor> &ts_ptr, flatbuffers::Offset<TensorMetaMsg> *out_off);
|
||||||
|
};
|
||||||
|
/// \brief Request to fetch rows in batch
|
||||||
|
class BatchFetchRequest : public BaseRequest {
|
||||||
|
public:
|
||||||
|
friend class CacheServer;
|
||||||
|
friend class CacheService;
|
||||||
|
BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id)
|
||||||
|
: BaseRequest(connection_id, RequestType::kBatchFetchRows), row_id_(row_id) {}
|
||||||
|
Status RestoreRows(TensorTable *out);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<row_id_type> row_id_;
|
||||||
|
MemGuard<uint8_t> mem_;
|
||||||
|
Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out);
|
||||||
|
};
|
||||||
|
/// \brief Request to create a cache for the current connection
|
||||||
|
class CreationCacheRequest : public BaseRequest {
|
||||||
|
public:
|
||||||
|
friend class CacheServer;
|
||||||
|
/// \brief Constructor
|
||||||
|
/// \param connection_id
|
||||||
|
/// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited
|
||||||
|
/// \param flag Attributes of the cache.
|
||||||
|
explicit CreationCacheRequest(connection_id_type connection_id, uint64_t cache_mem_sz,
|
||||||
|
CreateCacheFlag flag = CreateCacheFlag::kNone)
|
||||||
|
: BaseRequest(connection_id, RequestType::kCreateCache), cache_mem_sz(cache_mem_sz), flag_(flag) {}
|
||||||
|
|
||||||
|
std::string cookie() const { return cookie_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
uint64_t cache_mem_sz;
|
||||||
|
CreateCacheFlag flag_;
|
||||||
|
std::string cookie_;
|
||||||
|
};
|
||||||
|
/// \brief Request to purge a cache.
|
||||||
|
class PurgeCacheRequest : public BaseRequest {
|
||||||
|
public:
|
||||||
|
friend class CacheServer;
|
||||||
|
explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kPurgeCache) {}
|
||||||
|
};
|
||||||
|
/// \brief Request to destroy a cache
|
||||||
|
class DestroyCacheRequest : public BaseRequest {
|
||||||
|
public:
|
||||||
|
friend class CacheServer;
|
||||||
|
explicit DestroyCacheRequest(connection_id_type connection_id)
|
||||||
|
: BaseRequest(connection_id, RequestType::kDestroyCache) {}
|
||||||
|
};
|
||||||
|
/// \brief Obtain the statistics of the current connection
|
||||||
|
class GetStatRequest : public BaseRequest {
|
||||||
|
public:
|
||||||
|
friend class CacheServer;
|
||||||
|
friend class CacheService;
|
||||||
|
explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kGetStat) {}
|
||||||
|
row_id_type GetMinRowId() const {
|
||||||
|
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
|
||||||
|
return msg->min_row_id();
|
||||||
|
}
|
||||||
|
row_id_type GetMaxRowId() const {
|
||||||
|
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
|
||||||
|
return msg->max_row_id();
|
||||||
|
}
|
||||||
|
int64_t GetNumMemCached() const {
|
||||||
|
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
|
||||||
|
return msg->num_mem_cached();
|
||||||
|
}
|
||||||
|
int64_t GetNumDiskCached() const {
|
||||||
|
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
|
||||||
|
return msg->num_disk_cached();
|
||||||
|
}
|
||||||
|
uint8_t GetState() const {
|
||||||
|
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
|
||||||
|
return msg->state();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
MemGuard<uint8_t> mem_;
|
||||||
|
};
|
||||||
|
/// \brief Request to cache a schema
|
||||||
|
class CacheSchemaRequest : public BaseRequest {
|
||||||
|
public:
|
||||||
|
friend class CacheServer;
|
||||||
|
explicit CacheSchemaRequest(connection_id_type connection_id)
|
||||||
|
: BaseRequest(connection_id, RequestType::kCacheSchema), buf_(nullptr), len_of_buf_(0) {}
|
||||||
|
~CacheSchemaRequest() = default;
|
||||||
|
|
||||||
|
Status SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map);
|
||||||
|
const void *GetBuffer() const { return buf_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb_;
|
||||||
|
const void *buf_;
|
||||||
|
int64_t len_of_buf_;
|
||||||
|
};
|
||||||
|
/// \brief Request to fetch a schema
|
||||||
|
class FetchSchemaRequest : public BaseRequest {
|
||||||
|
public:
|
||||||
|
friend class CacheServer;
|
||||||
|
explicit FetchSchemaRequest(connection_id_type connection_id)
|
||||||
|
: BaseRequest(connection_id, RequestType::kFetchSchema) {}
|
||||||
|
~FetchSchemaRequest() = default;
|
||||||
|
|
||||||
|
std::unordered_map<std::string, int32_t> GetColumnMap();
|
||||||
|
|
||||||
|
private:
|
||||||
|
MemGuard<uint8_t> mem_;
|
||||||
|
std::unordered_map<std::string, int32_t> column_name_id_map_;
|
||||||
|
};
|
||||||
|
/// \brief Request to change a cache from build phase to read phase. Applies to non-mappable cache only.
|
||||||
|
class BuildPhaseDoneRequest : public BaseRequest {
|
||||||
|
public:
|
||||||
|
friend class CacheServer;
|
||||||
|
BuildPhaseDoneRequest(connection_id_type connection_id, const std::string &cookie)
|
||||||
|
: BaseRequest(connection_id, RequestType::kBuildPhaseDone), cookie_(cookie) {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string cookie_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // DATASET_ENGINE_CACHE_SERVICE_H_
|
|
@ -0,0 +1,252 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "dataset/engine/cache/cache_server.h"
|
||||||
|
#include "dataset/engine/cache/cache_service.h"
|
||||||
|
#include "dataset/engine/cache/cache_request.h"
|
||||||
|
#include "dataset/util/bit.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
Status CacheServer::DoServiceStart() {
|
||||||
|
if (!top_.empty()) {
|
||||||
|
Path spill(top_);
|
||||||
|
RETURN_IF_NOT_OK(spill.CreateDirectories());
|
||||||
|
MS_LOG(INFO) << "CacheServer will use disk folder: " << top_;
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(vg_.ServiceStart());
|
||||||
|
cache_q_ = std::make_shared<Queue<BaseRequest *>>(1024);
|
||||||
|
RETURN_IF_NOT_OK(cache_q_->Register(&vg_));
|
||||||
|
auto f = std::bind(&CacheServer::ServerRequest, this);
|
||||||
|
// Spawn a a few threads to serve the request.
|
||||||
|
for (auto i = 0; i < num_workers_; ++i) {
|
||||||
|
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache server", f));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CacheServer::DoServiceStop() {
|
||||||
|
Status rc;
|
||||||
|
Status rc2;
|
||||||
|
// First stop all the threads.
|
||||||
|
RETURN_IF_NOT_OK(vg_.ServiceStop());
|
||||||
|
// Clean up all the caches if any.
|
||||||
|
UniqueLock lck(&rwLock_);
|
||||||
|
auto it = all_caches_.begin();
|
||||||
|
while (it != all_caches_.end()) {
|
||||||
|
auto cs = std::move(it->second);
|
||||||
|
rc2 = cs->ServiceStop();
|
||||||
|
if (rc2.IsError()) {
|
||||||
|
rc = rc2;
|
||||||
|
}
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
|
CacheService *CacheServer::GetService(connection_id_type id) const {
|
||||||
|
SharedLock lck(&rwLock_);
|
||||||
|
auto it = all_caches_.find(id);
|
||||||
|
if (it != all_caches_.end()) {
|
||||||
|
return it->second.get();
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CacheServer::CreateService(connection_id_type connection_id, uint64_t cache_mem_sz,
|
||||||
|
BaseRequest::CreateCacheFlag flag, std::string *out_cookie) {
|
||||||
|
// We can't do spilling unless this server is setup with a spill path in the first place
|
||||||
|
bool spill = (flag & BaseRequest::CreateCacheFlag::kSpillToDisk) == BaseRequest::CreateCacheFlag::kSpillToDisk;
|
||||||
|
bool generate_id =
|
||||||
|
(flag & BaseRequest::CreateCacheFlag::kGenerateRowId) == BaseRequest::CreateCacheFlag::kGenerateRowId;
|
||||||
|
if (spill && top_.empty()) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Server is not set up with spill support.");
|
||||||
|
}
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(out_cookie);
|
||||||
|
*out_cookie = "";
|
||||||
|
// Before creating the cache, first check if this is a request for a shared usage of an existing cache
|
||||||
|
// If two CreateService come in with identical connection_id, we need to serialize the create.
|
||||||
|
// The first create will be successful and be given a special cookie.
|
||||||
|
UniqueLock lck(&rwLock_);
|
||||||
|
auto end = all_caches_.end();
|
||||||
|
auto it = all_caches_.find(connection_id);
|
||||||
|
if (it == end) {
|
||||||
|
std::unique_ptr<CacheService> cs;
|
||||||
|
try {
|
||||||
|
cs = std::make_unique<CacheService>(cache_mem_sz, spill ? top_ : "", generate_id);
|
||||||
|
RETURN_IF_NOT_OK(cs->ServiceStart());
|
||||||
|
*out_cookie = cs->cookie();
|
||||||
|
all_caches_.emplace(connection_id, std::move(cs));
|
||||||
|
} catch (const std::bad_alloc &e) {
|
||||||
|
return Status(StatusCode::kOutOfMemory);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service";
|
||||||
|
// We can return OK but we will return a duplicate key so user can act accordingly to either ignore it
|
||||||
|
// treat it as OK.
|
||||||
|
return Status(StatusCode::kDuplicateKey);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This is the main loop the cache server thread(s) are running.
|
||||||
|
/// Each thread will pop a request and save the result in the same request.
|
||||||
|
/// The sender will wait on the wait post in the request. Once the request
|
||||||
|
/// is fulfilled, the server thread will do a post signalling the request is
|
||||||
|
/// is processed.
|
||||||
|
/// \return
|
||||||
|
Status CacheServer::ServerRequest() {
|
||||||
|
TaskManager::FindMe()->Post();
|
||||||
|
// Loop forever until we are interrupted.
|
||||||
|
while (true) {
|
||||||
|
BaseRequest *base_rq = nullptr;
|
||||||
|
RETURN_IF_NOT_OK(cache_q_->PopFront(&base_rq));
|
||||||
|
auto cs = GetService(base_rq->connection_id_);
|
||||||
|
// Except for creating a new session, we expect cs is not null.
|
||||||
|
switch (base_rq->type_) {
|
||||||
|
case BaseRequest::RequestType::kCacheRow: {
|
||||||
|
if (cs == nullptr) {
|
||||||
|
std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found";
|
||||||
|
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||||
|
} else {
|
||||||
|
auto *rq = reinterpret_cast<CacheRowRequest *>(base_rq);
|
||||||
|
// Only if the cookie matches, we can accept insert into this cache that has a build phase
|
||||||
|
if (!cs->HasBuildPhase() || rq->cookie_ == cs->cookie()) {
|
||||||
|
rq->rc_ = cs->CacheRow(rq->buffers_, &rq->row_id_from_server_);
|
||||||
|
} else {
|
||||||
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case BaseRequest::RequestType::kBatchFetchRows: {
|
||||||
|
if (cs == nullptr) {
|
||||||
|
std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found";
|
||||||
|
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||||
|
} else {
|
||||||
|
auto *rq = reinterpret_cast<BatchFetchRequest *>(base_rq);
|
||||||
|
rq->rc_ = cs->BatchFetch(rq->row_id_, &rq->mem_);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case BaseRequest::RequestType::kCreateCache: {
|
||||||
|
// If the cache is already created we still need to run the creation so that we do sanity checks on the
|
||||||
|
// client id and return the cache id back to the user.
|
||||||
|
auto *rq = reinterpret_cast<CreationCacheRequest *>(base_rq);
|
||||||
|
rq->rc_ = CreateService(rq->connection_id_, rq->cache_mem_sz, rq->flag_, &rq->cookie_);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case BaseRequest::RequestType::kPurgeCache: {
|
||||||
|
if (cs != nullptr) {
|
||||||
|
base_rq->rc_ = cs->Purge();
|
||||||
|
} else {
|
||||||
|
// it is already purged. Ignore it.
|
||||||
|
base_rq->rc_ = Status::OK();
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case BaseRequest::RequestType::kDestroyCache: {
|
||||||
|
if (cs != nullptr) {
|
||||||
|
// We need a strong lock to protect the map.
|
||||||
|
connection_id_type id = base_rq->connection_id_;
|
||||||
|
UniqueLock lck(&rwLock_);
|
||||||
|
// std::map will invoke the constructor of CacheService. So we don't need to do anything here.
|
||||||
|
auto n = all_caches_.erase(id);
|
||||||
|
if (n == 0) {
|
||||||
|
// It has been destroyed by another duplicate request.
|
||||||
|
MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service";
|
||||||
|
}
|
||||||
|
base_rq->rc_ = Status::OK();
|
||||||
|
} else {
|
||||||
|
// it is already destroyed. Ignore it.
|
||||||
|
base_rq->rc_ = Status::OK();
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case BaseRequest::RequestType::kGetStat: {
|
||||||
|
if (cs == nullptr) {
|
||||||
|
std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found";
|
||||||
|
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||||
|
} else {
|
||||||
|
auto *rq = reinterpret_cast<GetStatRequest *>(base_rq);
|
||||||
|
CacheService::ServiceStat svc_stat;
|
||||||
|
rq->rc_ = cs->GetStat(&svc_stat);
|
||||||
|
if (rq->rc_.IsOk()) {
|
||||||
|
flatbuffers::FlatBufferBuilder fbb;
|
||||||
|
ServiceStatMsgBuilder bld(fbb);
|
||||||
|
bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached);
|
||||||
|
bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached);
|
||||||
|
bld.add_max_row_id(svc_stat.max_);
|
||||||
|
bld.add_min_row_id(svc_stat.min_);
|
||||||
|
bld.add_state(svc_stat.state_);
|
||||||
|
auto offset = bld.Finish();
|
||||||
|
fbb.Finish(offset);
|
||||||
|
rq->rc_ = rq->mem_.allocate(fbb.GetSize());
|
||||||
|
if (rq->rc_.IsOk()) {
|
||||||
|
WritableSlice dest(rq->mem_.GetMutablePointer(), fbb.GetSize());
|
||||||
|
ReadableSlice src(fbb.GetBufferPointer(), fbb.GetSize());
|
||||||
|
RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, src));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case BaseRequest::RequestType::kCacheSchema: {
|
||||||
|
if (cs == nullptr) {
|
||||||
|
std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found";
|
||||||
|
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||||
|
} else {
|
||||||
|
auto *rq = reinterpret_cast<CacheSchemaRequest *>(base_rq);
|
||||||
|
rq->rc_ = cs->CacheSchema(rq->buf_, rq->len_of_buf_);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case BaseRequest::RequestType::kFetchSchema: {
|
||||||
|
if (cs == nullptr) {
|
||||||
|
std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found";
|
||||||
|
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||||
|
} else {
|
||||||
|
auto *rq = reinterpret_cast<FetchSchemaRequest *>(base_rq);
|
||||||
|
rq->rc_ = cs->FetchSchema(&rq->mem_);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case BaseRequest::RequestType::kBuildPhaseDone: {
|
||||||
|
if (cs == nullptr) {
|
||||||
|
std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found";
|
||||||
|
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||||
|
} else {
|
||||||
|
auto *rq = reinterpret_cast<BuildPhaseDoneRequest *>(base_rq);
|
||||||
|
// We can only allow to switch phase is the cookie match.
|
||||||
|
if (rq->cookie_ == cs->cookie()) {
|
||||||
|
rq->rc_ = cs->BuildPhaseDone();
|
||||||
|
} else {
|
||||||
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Unknown request type");
|
||||||
|
}
|
||||||
|
// Notify it is done, and move on to the next request.
|
||||||
|
base_rq->wp_.Set();
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers)
|
||||||
|
: top_(spill_path), num_workers_(num_workers) {}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,98 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef DATASET_ENGINE_CACHE_SERVER_H_
|
||||||
|
#define DATASET_ENGINE_CACHE_SERVER_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <atomic>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
#include "dataset/engine/cache/cache_service.h"
|
||||||
|
#include "dataset/core/tensor.h"
|
||||||
|
#include "dataset/util/arena.h"
|
||||||
|
#include "dataset/util/cache_pool.h"
|
||||||
|
#include "dataset/util/lock.h"
|
||||||
|
#include "dataset/util/service.h"
|
||||||
|
#include "dataset/util/services.h"
|
||||||
|
#include "dataset/util/system_pool.h"
|
||||||
|
#include "dataset/util/queue.h"
|
||||||
|
#include "dataset/util/task_manager.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
class BaseRequest;
|
||||||
|
/// \brief A server which provides CacheService services.
|
||||||
|
class CacheServer : public Service {
|
||||||
|
public:
|
||||||
|
friend class Services;
|
||||||
|
using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>;
|
||||||
|
|
||||||
|
CacheServer(const CacheServer &) = delete;
|
||||||
|
CacheServer &operator=(const CacheServer &) = delete;
|
||||||
|
CacheServer(CacheServer &&) = delete;
|
||||||
|
CacheServer &operator=(CacheServer &) = delete;
|
||||||
|
static CacheServer &GetInstance() noexcept { return Services::getCacheServer(); }
|
||||||
|
Status DoServiceStart() override;
|
||||||
|
Status DoServiceStop() override;
|
||||||
|
~CacheServer() { (void)ServiceStop(); }
|
||||||
|
|
||||||
|
/// \brief For the current demonstration, a cache client contacts cache server using a Queue.
|
||||||
|
/// \param rq
|
||||||
|
/// \return Status object
|
||||||
|
Status PushRequest(BaseRequest *rq) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(rq);
|
||||||
|
RETURN_IF_NOT_OK(cache_q_->Add(rq));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
mutable RWLock rwLock_;
|
||||||
|
std::string top_;
|
||||||
|
cache_index all_caches_;
|
||||||
|
std::shared_ptr<Queue<BaseRequest *>> cache_q_;
|
||||||
|
TaskGroup vg_;
|
||||||
|
int32_t num_workers_;
|
||||||
|
|
||||||
|
/// \brief Constructor
|
||||||
|
/// \param spill_path Top directory for spilling buffers to.
|
||||||
|
/// \param num_workers Number of threads for handling requests.
|
||||||
|
explicit CacheServer(const std::string &spill_path, int32_t num_workers = 3);
|
||||||
|
|
||||||
|
/// \brief Locate a cache service from connection id.
|
||||||
|
/// \return Pointer to cache service. Null if not found
|
||||||
|
CacheService *GetService(connection_id_type id) const;
|
||||||
|
|
||||||
|
/// \brief Create a cache service. We allow multiple clients to create the same cache service.
|
||||||
|
/// Subsequent duplicate requests are ignored. The first cache client to create the service will be given
|
||||||
|
/// a special unique cookie.
|
||||||
|
/// \param[in] connection_id This is from a Cache client.
|
||||||
|
/// \param[in] cache_mem_sz
|
||||||
|
/// \param[in] flag
|
||||||
|
/// \param[out] out_cookie Only the first cache client will be given a special cookie to identify the creator
|
||||||
|
/// \return Status object
|
||||||
|
Status CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, BaseRequest::CreateCacheFlag flag,
|
||||||
|
std::string *out_cookie);
|
||||||
|
|
||||||
|
/// \brief Entry point for all server threads.
|
||||||
|
Status ServerRequest();
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // DATASET_CORE_CACHE_TENSOR_H_
|
|
@ -0,0 +1,265 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "dataset/engine/cache/cache_service.h"
|
||||||
|
#include "dataset/util/slice.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
CacheService::CacheService(uint64_t mem_sz, const std::string &root, bool generate_id)
|
||||||
|
: root_(root),
|
||||||
|
cache_mem_sz_(mem_sz),
|
||||||
|
cp_(nullptr),
|
||||||
|
map_(nullptr),
|
||||||
|
next_id_(0),
|
||||||
|
generate_id_(generate_id),
|
||||||
|
schema_key_(-1),
|
||||||
|
st_(generate_id ? State::kBuildPhase : State::kNone) {}
|
||||||
|
CacheService::~CacheService() { (void)ServiceStop(); }
|
||||||
|
bool CacheService::UseArena() {
|
||||||
|
// If fixed size, use Arena instead of the pool from global context.
|
||||||
|
return (cache_mem_sz_ > 0);
|
||||||
|
}
|
||||||
|
Status CacheService::DoServiceStart() {
|
||||||
|
std::shared_ptr<MemoryPool> mp_;
|
||||||
|
if (UseArena()) {
|
||||||
|
// Create a fixed size arena based on the parameter.
|
||||||
|
std::shared_ptr<Arena> arena;
|
||||||
|
RETURN_IF_NOT_OK(Arena::CreateArena(&arena, cache_mem_sz_));
|
||||||
|
mp_ = std::move(arena);
|
||||||
|
} else {
|
||||||
|
// Unlimited size. Simply use a system pool. Another choice is CircularPool.
|
||||||
|
mp_ = std::make_shared<SystemPool>();
|
||||||
|
}
|
||||||
|
// Put together a CachePool for backing up the Tensor
|
||||||
|
cp_ = std::make_shared<CachePool>(CachePool::value_allocator(mp_), root_);
|
||||||
|
RETURN_IF_NOT_OK(cp_->ServiceStart());
|
||||||
|
// Set up the B+ tree as well. But use the system pool instead.
|
||||||
|
map_ = std::make_shared<row_map>();
|
||||||
|
// Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name.
|
||||||
|
cookie_ = cp_->MyName();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheService::DoServiceStop() {
|
||||||
|
if (cp_ != nullptr) {
|
||||||
|
RETURN_IF_NOT_OK(cp_->ServiceStop());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated) {
|
||||||
|
SharedLock rw(&rw_lock_);
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(row_id_generated);
|
||||||
|
if (st_ == State::kFetchPhase) {
|
||||||
|
// For this kind of cache service, once we are done with the build phase into fetch phase, we can't
|
||||||
|
// allow other to cache more rows.
|
||||||
|
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
// The first buffer is a flatbuffer which describes the rest of the buffers follow
|
||||||
|
auto fb = buf.front();
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(fb);
|
||||||
|
auto msg = GetTensorRowHeaderMsg(fb);
|
||||||
|
// If the server side is designed to ignore incoming row id, we generate row id.
|
||||||
|
if (generate_id_) {
|
||||||
|
*row_id_generated = GetNextRowId();
|
||||||
|
// Some debug information on how many rows we have generated so far.
|
||||||
|
if ((*row_id_generated) % 1000 == 0) {
|
||||||
|
MS_LOG(DEBUG) << "Number of rows cached: " << *row_id_generated;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (msg->row_id() < 0) {
|
||||||
|
std::string errMsg = "Expect positive row id: " + std::to_string(msg->row_id());
|
||||||
|
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||||
|
}
|
||||||
|
*row_id_generated = msg->row_id();
|
||||||
|
}
|
||||||
|
auto size_of_this = msg->size_of_this();
|
||||||
|
auto column_hdr = msg->column();
|
||||||
|
// Number of tensor buffer should match the number of columns plus one.
|
||||||
|
if (buf.size() != column_hdr->size() + 1) {
|
||||||
|
std::string errMsg = "Column count does not match. Expect " + std::to_string(column_hdr->size() + 1) +
|
||||||
|
" but get " + std::to_string(buf.size());
|
||||||
|
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||||
|
}
|
||||||
|
// Next we store in either memory or on disk. Low level code will consolidate everything in one piece.
|
||||||
|
std::vector<ReadableSlice> all_data;
|
||||||
|
all_data.reserve(column_hdr->size() + 1);
|
||||||
|
all_data.emplace_back(fb, size_of_this);
|
||||||
|
for (auto i = 0; i < column_hdr->size(); ++i) {
|
||||||
|
all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i));
|
||||||
|
}
|
||||||
|
// Now we cache the flat buffer.
|
||||||
|
CachePool::key_type key;
|
||||||
|
RETURN_IF_NOT_OK(cp_->Insert(all_data, &key));
|
||||||
|
Status rc = map_->DoInsert(*row_id_generated, key);
|
||||||
|
if (rc == Status(StatusCode::kDuplicateKey)) {
|
||||||
|
MS_LOG(DEBUG) << "Ignoring duplicate key";
|
||||||
|
} else {
|
||||||
|
RETURN_IF_NOT_OK(rc);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
RETURN_STATUS_UNEXPECTED(e.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::ostream &operator<<(std::ostream &out, const CacheService &cs) {
|
||||||
|
// Then show any custom derived-internal stuff
|
||||||
|
out << "\nCache memory size: " << cs.cache_mem_sz_;
|
||||||
|
out << "\nSpill path: ";
|
||||||
|
if (cs.root_.empty()) {
|
||||||
|
out << "None";
|
||||||
|
} else {
|
||||||
|
out << cs.GetSpillPath();
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
Path CacheService::GetSpillPath() const { return cp_->GetSpillPath(); }
|
||||||
|
Status CacheService::Purge() {
|
||||||
|
// First we must lock exclusively. No one else can cache/restore anything.
|
||||||
|
UniqueLock rw(&rw_lock_);
|
||||||
|
RETURN_IF_NOT_OK(cp_->ServiceStop());
|
||||||
|
auto new_map = std::make_shared<row_map>();
|
||||||
|
map_.reset();
|
||||||
|
map_ = std::move(new_map);
|
||||||
|
next_id_ = 0;
|
||||||
|
RETURN_IF_NOT_OK(cp_->ServiceStart());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheService::GetStat(CacheService::ServiceStat *out) {
|
||||||
|
SharedLock rw(&rw_lock_);
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(out);
|
||||||
|
if (st_ == State::kNone || st_ == State::kFetchPhase) {
|
||||||
|
out->stat_ = cp_->GetStat();
|
||||||
|
out->state_ = static_cast<ServiceStat::state_type>(st_);
|
||||||
|
auto it = map_->begin();
|
||||||
|
if (it != map_->end()) {
|
||||||
|
out->min_ = it.key();
|
||||||
|
auto end_it = map_->end();
|
||||||
|
--end_it;
|
||||||
|
out->max_ = end_it.key();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out->state_ = static_cast<ServiceStat::state_type>(st_);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheService::BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint8_t> *out) const {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(out);
|
||||||
|
SharedLock rw(&rw_lock_);
|
||||||
|
if (st_ == State::kBuildPhase) {
|
||||||
|
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
|
||||||
|
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
|
||||||
|
}
|
||||||
|
const auto num_elements = v.size();
|
||||||
|
int64_t mem_sz = (num_elements + 1) * sizeof(int64_t);
|
||||||
|
int64_t data_offset = mem_sz;
|
||||||
|
std::vector<int64_t> sz_v;
|
||||||
|
std::vector<CachePool::key_type> keys;
|
||||||
|
sz_v.reserve(num_elements);
|
||||||
|
keys.reserve(num_elements);
|
||||||
|
for (auto row_id : v) {
|
||||||
|
auto r = map_->Search(row_id);
|
||||||
|
if (r.second) {
|
||||||
|
auto &it = r.first;
|
||||||
|
CachePool::key_type key = it.value();
|
||||||
|
auto sz = cp_->GetSize(key);
|
||||||
|
if (sz == 0) {
|
||||||
|
std::string errMsg = "Key not found: ";
|
||||||
|
errMsg += std::to_string(key);
|
||||||
|
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||||
|
}
|
||||||
|
keys.push_back(key);
|
||||||
|
sz_v.push_back(sz);
|
||||||
|
mem_sz += sz;
|
||||||
|
} else {
|
||||||
|
keys.push_back(-1);
|
||||||
|
sz_v.push_back(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MemGuard<uint8_t> mem;
|
||||||
|
RETURN_IF_NOT_OK(mem.allocate(mem_sz));
|
||||||
|
auto *offset_array = reinterpret_cast<int64_t *>(mem.GetMutablePointer());
|
||||||
|
offset_array[0] = data_offset;
|
||||||
|
WritableSlice all(mem.GetMutablePointer(), mem.GetSizeInBytes());
|
||||||
|
for (auto i = 0; i < num_elements; ++i) {
|
||||||
|
auto sz = sz_v.at(i);
|
||||||
|
offset_array[i + 1] = offset_array[i] + sz;
|
||||||
|
if (sz > 0) {
|
||||||
|
WritableSlice row_data(all, offset_array[i], sz);
|
||||||
|
auto key = keys.at(i);
|
||||||
|
size_t bytesRead = 0;
|
||||||
|
RETURN_IF_NOT_OK(cp_->Read(key, &row_data, &bytesRead));
|
||||||
|
if (bytesRead != sz) {
|
||||||
|
MS_LOG(ERROR) << "Unexpected length. Read " << bytesRead << ". Expected " << sz << "."
|
||||||
|
<< " Internal key: " << key << "\n";
|
||||||
|
RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*out = std::move(mem);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheService::CacheSchema(const void *buf, int64_t len) {
|
||||||
|
SharedLock rw(&rw_lock_);
|
||||||
|
if (st_ == State::kFetchPhase) {
|
||||||
|
// For this kind of cache service, once we are done with the build phase into fetch phase, we can't
|
||||||
|
// allow other to cache more rows.
|
||||||
|
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
|
||||||
|
}
|
||||||
|
// This is a special request and we need to remember where we store it.
|
||||||
|
// In case we are calling the same function from multiple threads, only
|
||||||
|
// the first one is considered. Rest is ignored.
|
||||||
|
CachePool::key_type cur_key = schema_key_;
|
||||||
|
CachePool::key_type key;
|
||||||
|
if (cur_key < 0) {
|
||||||
|
RETURN_IF_NOT_OK(cp_->Insert({ReadableSlice(buf, len)}, &key));
|
||||||
|
auto result = std::atomic_compare_exchange_strong(&schema_key_, &cur_key, key);
|
||||||
|
MS_LOG(DEBUG) << "Caching Schema. Result = " << result;
|
||||||
|
} else {
|
||||||
|
MS_LOG(DEBUG) << "Caching Schema already done";
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheService::FetchSchema(MemGuard<uint8_t> *out) const {
|
||||||
|
SharedLock rw(&rw_lock_);
|
||||||
|
if (st_ == State::kBuildPhase) {
|
||||||
|
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
|
||||||
|
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
|
||||||
|
}
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(out);
|
||||||
|
MemGuard<uint8_t> mem;
|
||||||
|
if (schema_key_ >= 0) {
|
||||||
|
auto len = cp_->GetSize(schema_key_);
|
||||||
|
RETURN_IF_NOT_OK(mem.allocate(len));
|
||||||
|
auto slice = WritableSlice(mem.GetMutablePointer(), len);
|
||||||
|
RETURN_IF_NOT_OK(cp_->Read(schema_key_, &slice));
|
||||||
|
*out = std::move(mem);
|
||||||
|
} else {
|
||||||
|
return Status(StatusCode::kFileNotExist, __LINE__, __FILE__, "No schema has been cached");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheService::BuildPhaseDone() {
|
||||||
|
if (HasBuildPhase()) {
|
||||||
|
// Exclusive lock to switch phase
|
||||||
|
UniqueLock rw(&rw_lock_);
|
||||||
|
st_ = State::kFetchPhase;
|
||||||
|
return Status::OK();
|
||||||
|
} else {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,143 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef DATASET_ENGINE_CACHE_SERVICE_H_
|
||||||
|
#define DATASET_ENGINE_CACHE_SERVICE_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <atomic>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <type_traits>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "./de_tensor_generated.h"
|
||||||
|
#include "dataset/core/global_context.h"
|
||||||
|
#include "dataset/core/tensor.h"
|
||||||
|
#include "dataset/engine/cache/cache_request.h"
|
||||||
|
#include "dataset/util/arena.h"
|
||||||
|
#include "dataset/util/btree.h"
|
||||||
|
#include "dataset/util/cache_pool.h"
|
||||||
|
#include "dataset/util/service.h"
|
||||||
|
#include "dataset/util/services.h"
|
||||||
|
#include "dataset/util/system_pool.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
struct CacheStat;
|
||||||
|
/// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is
|
||||||
|
/// created to support spilling
|
||||||
|
class CacheService : public Service {
|
||||||
|
public:
|
||||||
|
friend class CacheServer;
|
||||||
|
using row_map = BPlusTree<row_id_type, CachePool::key_type>;
|
||||||
|
|
||||||
|
enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase };
|
||||||
|
|
||||||
|
/// \brief Constructor
|
||||||
|
/// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited
|
||||||
|
/// \param root Spill path. Empty string means no spilling
|
||||||
|
/// \param generate_id If the cache service should generate row id for buffer that is cached.
|
||||||
|
/// For non-mappable dataset, this should be set to true.
|
||||||
|
CacheService(uint64_t mem_sz, const std::string &root, bool generate_id);
|
||||||
|
~CacheService();
|
||||||
|
|
||||||
|
/// \brief For fixed size memory, we will create an Arena.
|
||||||
|
/// \return false if unlimited memory.
|
||||||
|
bool UseArena();
|
||||||
|
|
||||||
|
Status DoServiceStart() override;
|
||||||
|
Status DoServiceStop() override;
|
||||||
|
|
||||||
|
/// \brief Main function to cache a row which is in form a series of buffers.
|
||||||
|
/// The first buffer is a Google flatbuffer which describes the rest of the buffers followed.
|
||||||
|
/// \param[in] buf Vector of buffer
|
||||||
|
/// \param[out] row_id_generated The row id assigned to this row if any
|
||||||
|
/// \return Status object
|
||||||
|
Status CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated);
|
||||||
|
/// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded
|
||||||
|
/// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row.
|
||||||
|
/// \param[in] v A vector of row id.
|
||||||
|
/// \param[out] out A contiguous memory buffer that holds the requested rows.
|
||||||
|
/// \return Status object
|
||||||
|
Status BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint8_t> *out) const;
|
||||||
|
|
||||||
|
/// \brief Getter function
|
||||||
|
/// \return Spilling path
|
||||||
|
Path GetSpillPath() const;
|
||||||
|
/// \brief A structure returned from the cache server for statistics request.
|
||||||
|
class ServiceStat {
|
||||||
|
public:
|
||||||
|
using state_type = std::underlying_type<State>::type;
|
||||||
|
ServiceStat() : min_(0), max_(0), state_(0) {}
|
||||||
|
CachePool::CacheStat stat_{};
|
||||||
|
row_id_type min_;
|
||||||
|
row_id_type max_;
|
||||||
|
state_type state_;
|
||||||
|
};
|
||||||
|
/// \brief Statistics for the current service
|
||||||
|
/// \param[in/out] A pointer to a pre-allocated ServiceStat structure
|
||||||
|
/// \return Status Object
|
||||||
|
Status GetStat(ServiceStat *);
|
||||||
|
/// \brief Cache schema
|
||||||
|
/// \param buf A Google Flatbuffer that contains the schema
|
||||||
|
/// \param len size of the buffer
|
||||||
|
/// \return Status object
|
||||||
|
Status CacheSchema(const void *buf, int64_t len);
|
||||||
|
/// \brief Fetch schema
|
||||||
|
/// \param out A contiguous memory that contains the serialized form of schema.
|
||||||
|
/// \return Status object
|
||||||
|
Status FetchSchema(MemGuard<uint8_t> *out) const;
|
||||||
|
/// \brief Purge the content of a cache
|
||||||
|
/// \return Status object
|
||||||
|
Status Purge();
|
||||||
|
/// \brief Overload the << operator to print a cache service
|
||||||
|
/// \param out std::ostream
|
||||||
|
/// \param cs A cache service
|
||||||
|
/// \return std::ostream
|
||||||
|
friend std::ostream &operator<<(std::ostream &out, const CacheService &cs);
|
||||||
|
/// \brief Every cache service has a cookie. If the cookie of a CacheClient matches this cookie, this CacheClient
|
||||||
|
/// is the creator
|
||||||
|
/// \return Cookie
|
||||||
|
std::string cookie() const { return cookie_; }
|
||||||
|
/// \brief If this cache service generates row id for buffer cached, it is divided into two phases, a build phase and
|
||||||
|
/// a read phase.
|
||||||
|
/// \return True if has two phases.
|
||||||
|
bool HasBuildPhase() const { return generate_id_; }
|
||||||
|
/// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call.
|
||||||
|
/// \return Status object
|
||||||
|
Status BuildPhaseDone();
|
||||||
|
|
||||||
|
private:
|
||||||
|
mutable RWLock rw_lock_;
|
||||||
|
std::string root_;
|
||||||
|
uint64_t cache_mem_sz_;
|
||||||
|
std::shared_ptr<CachePool> cp_;
|
||||||
|
std::shared_ptr<row_map> map_;
|
||||||
|
std::atomic<row_id_type> next_id_;
|
||||||
|
bool generate_id_;
|
||||||
|
std::atomic<CachePool::key_type> schema_key_;
|
||||||
|
std::string cookie_;
|
||||||
|
State st_;
|
||||||
|
|
||||||
|
/// \brief Private function to generate a row id
|
||||||
|
/// \return Row id assigned.
|
||||||
|
row_id_type GetNextRowId() { return next_id_.fetch_add(1); }
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // DATASET_ENGINE_CACHE_SERVICE_H_
|
|
@ -0,0 +1,81 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
namespace mindspore.dataset;
|
||||||
|
|
||||||
|
/// Type of a Tensor
|
||||||
|
enum TensorType : byte {
|
||||||
|
DE_UNKNOWN = 0,
|
||||||
|
DE_BOOL = 1,
|
||||||
|
DE_INT8 = 2,
|
||||||
|
DE_UINT8 = 3,
|
||||||
|
DE_INT16 = 4,
|
||||||
|
DE_UINT16 = 5,
|
||||||
|
DE_INT32 = 6,
|
||||||
|
DE_UINT32 = 7,
|
||||||
|
DE_INT64 = 8,
|
||||||
|
DE_UINT64 = 9,
|
||||||
|
DE_FLOAT16 = 10,
|
||||||
|
DE_FLOAT32 = 11,
|
||||||
|
DE_FLOAT64 = 12,
|
||||||
|
DE_STRING = 13
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The meta information of a Tensor
|
||||||
|
/// \note Only the type and shape are considered meta information. Tensor data is excluded.
|
||||||
|
table TensorMetaMsg {
|
||||||
|
dims:[int64] (required);
|
||||||
|
type:TensorType;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This is the first buffer that is sent to a Cache server when a TensorRow is serialized.
|
||||||
|
/// \param row_id is the row id of the TensorRow.
|
||||||
|
/// \param column The meta information of each Tensor in the row
|
||||||
|
/// \param size of this serialized buffer
|
||||||
|
/// \param size of each tensor data buffer that follows
|
||||||
|
table TensorRowHeaderMsg {
|
||||||
|
row_id:int64;
|
||||||
|
column:[TensorMetaMsg] (required);
|
||||||
|
size_of_this:int64;
|
||||||
|
data_sz:[int64] (required);
|
||||||
|
}
|
||||||
|
|
||||||
|
root_type TensorRowHeaderMsg;
|
||||||
|
|
||||||
|
/// A row of row id's
|
||||||
|
table TensorRowIds {
|
||||||
|
row_id:[int64] (required);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Statistics returned from each cache service
|
||||||
|
/// \note It must match CacheService::ServiceStat
|
||||||
|
table ServiceStatMsg {
|
||||||
|
num_mem_cached:int64;
|
||||||
|
num_disk_cached:int64;
|
||||||
|
min_row_id:int64;
|
||||||
|
max_row_id:int64;
|
||||||
|
state:int8;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Column description of each column in a schema
|
||||||
|
table ColumnNameMsg {
|
||||||
|
name:string;
|
||||||
|
id:int32;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Serialized form of a schema
|
||||||
|
table SchemaMsg {
|
||||||
|
column:[ColumnNameMsg];
|
||||||
|
}
|
|
@ -24,10 +24,8 @@ namespace dataset {
|
||||||
// Description: This is the main constructor that is used for making a buffer
|
// Description: This is the main constructor that is used for making a buffer
|
||||||
DataBuffer::DataBuffer(int32_t id, BufferFlags flags) : buffer_id_(id), tensor_table_(nullptr), buffer_flags_(flags) {}
|
DataBuffer::DataBuffer(int32_t id, BufferFlags flags) : buffer_id_(id), tensor_table_(nullptr), buffer_flags_(flags) {}
|
||||||
|
|
||||||
// Name: print()
|
// A method for debug printing of the buffer
|
||||||
// Description: A function that prints info about the DataBuffer (base class version)
|
void DataBuffer::Print(std::ostream &out, bool show_all) const {
|
||||||
void DataBuffer::Print(std::ostream &out, // In: The output stream to print to
|
|
||||||
bool show_all) const { // In: T/F if it should show everything
|
|
||||||
out << "bufferId: " << buffer_id_ << "\nflags: " << std::hex << buffer_flags_ << std::dec << "\n";
|
out << "bufferId: " << buffer_id_ << "\nflags: " << std::hex << buffer_flags_ << std::dec << "\n";
|
||||||
|
|
||||||
// If the column counts are set then it means that data has been set into
|
// If the column counts are set then it means that data has been set into
|
||||||
|
@ -46,11 +44,6 @@ void DataBuffer::Print(std::ostream &out, // In: The output stream to print
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataBuffer::Load() {
|
|
||||||
std::string err_msg = "Base class load called, but it does not have an implementation!";
|
|
||||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove me!! Callers should fetch rows via pop
|
// Remove me!! Callers should fetch rows via pop
|
||||||
Status DataBuffer::GetTensor(std::shared_ptr<Tensor> *ptr, int32_t row_id, int32_t col_id) const {
|
Status DataBuffer::GetTensor(std::shared_ptr<Tensor> *ptr, int32_t row_id, int32_t col_id) const {
|
||||||
if (row_id < tensor_table_->size() && col_id < tensor_table_->at(row_id).size()) {
|
if (row_id < tensor_table_->size() && col_id < tensor_table_->at(row_id).size()) {
|
||||||
|
@ -92,8 +85,5 @@ Status DataBuffer::SliceOff(int64_t number_of_rows) {
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Destructor
|
|
||||||
DataBuffer::~DataBuffer() {}
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -29,11 +29,9 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
// The DataBuffer class is a base class that will represent the data for n values based
|
/// \brief The DataBuffer class is a container of tensor data and is the unit of transmission between
|
||||||
// on a unique row id for each row of data.
|
/// connectors of dataset operators. Inside the buffer, tensors are organized into a table-like format
|
||||||
// There can be different types of DataBuffers to abstract over how the data is stored
|
/// where n TensorRows may consist of m tensors (columns).
|
||||||
// in memory and acquired from storage.
|
|
||||||
// Each buffer holds a range of consecutive row id's.
|
|
||||||
class DataBuffer {
|
class DataBuffer {
|
||||||
public:
|
public:
|
||||||
// Buffer flags
|
// Buffer flags
|
||||||
|
@ -47,13 +45,13 @@ class DataBuffer {
|
||||||
// Description: This is the main constructor that is used for making a buffer
|
// Description: This is the main constructor that is used for making a buffer
|
||||||
DataBuffer(int32_t id, BufferFlags flags);
|
DataBuffer(int32_t id, BufferFlags flags);
|
||||||
|
|
||||||
// Destructor
|
/// \brief default destructor
|
||||||
virtual ~DataBuffer();
|
~DataBuffer() = default;
|
||||||
|
|
||||||
// Name: print()
|
/// \brief A method for debug printing of the buffer
|
||||||
// Description: A function that prints info about the DataBuffer (base class version)
|
/// \param[inout] out The stream to write to
|
||||||
virtual void Print(std::ostream &out, // In: The output stream to print to
|
/// \param[in] show_all A boolean to toggle between details and summary printing
|
||||||
bool show_all) const; // In: T/F if it should show everything
|
void Print(std::ostream &out, bool show_all) const;
|
||||||
|
|
||||||
// Provide stream operator for displaying it
|
// Provide stream operator for displaying it
|
||||||
friend std::ostream &operator<<(std::ostream &out, const DataBuffer &cb) {
|
friend std::ostream &operator<<(std::ostream &out, const DataBuffer &cb) {
|
||||||
|
@ -61,10 +59,6 @@ class DataBuffer {
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Name: load()
|
|
||||||
// Description: populates the DataBuffer with data based on it's id
|
|
||||||
virtual Status Load();
|
|
||||||
|
|
||||||
// Convenience getter functions for flag checking
|
// Convenience getter functions for flag checking
|
||||||
bool eof() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagEOF)); }
|
bool eof() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagEOF)); }
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,11 @@ set(DATASET_ENGINE_DATASETOPS_SRC_FILES
|
||||||
take_op.cc
|
take_op.cc
|
||||||
shuffle_op.cc
|
shuffle_op.cc
|
||||||
zip_op.cc
|
zip_op.cc
|
||||||
concat_op.cc
|
concat_op.cc
|
||||||
|
cache_base_op.cc
|
||||||
|
cache_lookup_op.cc
|
||||||
|
cache_op.cc
|
||||||
|
cache_merge_op.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
if (ENABLE_PYTHON)
|
if (ENABLE_PYTHON)
|
||||||
|
|
|
@ -0,0 +1,185 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "dataset/engine/datasetops/cache_base_op.h"
|
||||||
|
#include <iomanip>
|
||||||
|
#include <iostream>
|
||||||
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
// A print method typically used for debugging
|
||||||
|
void CacheBase::Print(std::ostream &out, bool show_all) const {
|
||||||
|
// Always show the id and name as first line regardless if this summary or detailed print
|
||||||
|
out << "(" << std::setw(2) << operator_id_ << ") <" << Name() << ">:";
|
||||||
|
if (!show_all) {
|
||||||
|
// Call the super class for displaying any common 1-liner info
|
||||||
|
ParallelOp::Print(out, show_all);
|
||||||
|
out << "\n";
|
||||||
|
} else {
|
||||||
|
// Call the super class for displaying any common detailed info
|
||||||
|
ParallelOp::Print(out, show_all);
|
||||||
|
// Then show any custom derived-internal stuff
|
||||||
|
out << "\nCache client:\n" << *cache_client_ << "\n\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Overrides base class reset method. When an operator does a reset, it cleans up any state
|
||||||
|
// info from it's previous execution and then initializes itself so that it can be executed
|
||||||
|
// again.
|
||||||
|
Status CacheBase::Reset() {
|
||||||
|
if (sampler_ != nullptr) {
|
||||||
|
RETURN_IF_NOT_OK(sampler_->ResetSampler());
|
||||||
|
}
|
||||||
|
// Wake up the workers to get them going again in a new epoch
|
||||||
|
MS_LOG(DEBUG) << Name() << " resetting.";
|
||||||
|
epoch_sync_.Set();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
|
||||||
|
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
|
||||||
|
: ParallelOp(num_workers, op_connector_size, sampler),
|
||||||
|
cache_client_(cache_client),
|
||||||
|
rows_per_buffer_(rows_per_buf),
|
||||||
|
// We can cause deadlock if this internal Connector size is too small.
|
||||||
|
keys_miss_(num_workers_, 1, 1024) {
|
||||||
|
io_block_queues_.Init(num_workers, op_connector_size);
|
||||||
|
}
|
||||||
|
// Common function to fetch samples from the sampler and send them using the io_block_queues to
|
||||||
|
// the parallel workers
|
||||||
|
Status CacheBase::FetchSamplesToWorkers() {
|
||||||
|
int64_t buf_cnt = 0;
|
||||||
|
int64_t wait_cnt = 0;
|
||||||
|
do {
|
||||||
|
epoch_sync_.Clear();
|
||||||
|
std::vector<row_id_type> keys;
|
||||||
|
int64_t row_cnt = 0;
|
||||||
|
keys.reserve(rows_per_buffer_);
|
||||||
|
std::unique_ptr<DataBuffer> sampler_buffer;
|
||||||
|
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||||
|
while (!sampler_buffer->eoe()) {
|
||||||
|
TensorRow sample_row;
|
||||||
|
RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row));
|
||||||
|
std::shared_ptr<Tensor> sample_ids = sample_row[0];
|
||||||
|
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) {
|
||||||
|
keys.push_back(*itr);
|
||||||
|
++row_cnt;
|
||||||
|
if (row_cnt % rows_per_buffer_ == 0) {
|
||||||
|
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
|
||||||
|
RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
|
||||||
|
keys.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||||
|
}
|
||||||
|
if (!keys.empty()) {
|
||||||
|
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
|
||||||
|
RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
|
||||||
|
}
|
||||||
|
// send the eoe
|
||||||
|
RETURN_IF_NOT_OK(
|
||||||
|
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
|
||||||
|
// If repeat but the not last repeat, wait for reset.
|
||||||
|
if (BitTest(op_ctrl_flags_, kDeOpRepeated) && !BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
|
||||||
|
MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << ++wait_cnt << " Buffer sent " << buf_cnt;
|
||||||
|
RETURN_IF_NOT_OK(epoch_sync_.Wait());
|
||||||
|
} else {
|
||||||
|
// We can break out from the loop.
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} while (true);
|
||||||
|
// Flow the eof before exit
|
||||||
|
RETURN_IF_NOT_OK(
|
||||||
|
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof)));
|
||||||
|
// Ask all the workers to quit.
|
||||||
|
for (int32_t i = 0; i < num_workers_; i++) {
|
||||||
|
RETURN_IF_NOT_OK(
|
||||||
|
io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheBase::FetchFromCache(int32_t worker_id) {
|
||||||
|
int64_t buffer_id = worker_id;
|
||||||
|
std::unique_ptr<IOBlock> blk;
|
||||||
|
do {
|
||||||
|
RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk));
|
||||||
|
if (blk->eof()) {
|
||||||
|
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)));
|
||||||
|
} else if (blk->eoe()) {
|
||||||
|
if (AllowCacheMiss()) {
|
||||||
|
// This code path is for CacheLookupOp acting as a sampler. If we get a eoe from
|
||||||
|
// a sampler, send a eoe to physical leaf op as well.
|
||||||
|
std::vector<row_id_type> eoe;
|
||||||
|
eoe.push_back(eoe_row_id);
|
||||||
|
RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, eoe));
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)));
|
||||||
|
} else {
|
||||||
|
std::vector<int64_t> keys;
|
||||||
|
RETURN_IF_NOT_OK(blk->GetKeys(&keys));
|
||||||
|
if (keys.empty()) {
|
||||||
|
// empty key is a quit signal for workers
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::unique_ptr<DataBuffer> db = std::make_unique<DataBuffer>(buffer_id, DataBuffer::kDeBFlagNone);
|
||||||
|
std::unique_ptr<TensorQTable> que = std::make_unique<TensorQTable>();
|
||||||
|
TensorTable ttbl;
|
||||||
|
RETURN_IF_NOT_OK(cache_client_->GetRows(keys, &ttbl));
|
||||||
|
auto row_it = ttbl.begin();
|
||||||
|
std::vector<row_id_type> cache_miss;
|
||||||
|
cache_miss.reserve(keys.size());
|
||||||
|
for (auto row_id : keys) {
|
||||||
|
auto &row = *row_it;
|
||||||
|
if (row.empty()) {
|
||||||
|
if (AllowCacheMiss()) {
|
||||||
|
cache_miss.push_back(row_id);
|
||||||
|
} else {
|
||||||
|
std::string errMsg = "Row id " + std::to_string(row_id) + " not found.";
|
||||||
|
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
que->push_back(std::move(row));
|
||||||
|
++row_it;
|
||||||
|
}
|
||||||
|
db->set_tensor_table(std::move(que));
|
||||||
|
if (AllowCacheMiss()) {
|
||||||
|
// Because of the way connector works, we push unconditionally even cache_miss can be empty.
|
||||||
|
RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, cache_miss));
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db)));
|
||||||
|
buffer_id += num_workers_;
|
||||||
|
}
|
||||||
|
} while (true);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheBase::RegisterResources() {
|
||||||
|
RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks()));
|
||||||
|
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
CacheBase::~CacheBase() {}
|
||||||
|
Status CacheBase::UpdateColumnMapFromCache() {
|
||||||
|
Status rc;
|
||||||
|
// Get the schema from the server. It may not be there yet. So tolerate the error.
|
||||||
|
if (column_name_id_map_.empty()) {
|
||||||
|
rc = cache_client_->FetchSchema(&column_name_id_map_);
|
||||||
|
if (rc == Status(StatusCode::kFileNotExist)) {
|
||||||
|
MS_LOG(DEBUG) << "Schema not in the server yet.";
|
||||||
|
rc = Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,108 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
|
||||||
|
#define DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include "dataset/engine/cache/cache_client.h"
|
||||||
|
#include "dataset/engine/cache/cache_service.h"
|
||||||
|
#include "dataset/engine/datasetops/parallel_op.h"
|
||||||
|
#include "dataset/engine/datasetops/repeat_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/io_block.h"
|
||||||
|
#include "dataset/engine/datasetops/source/sampler/sampler.h"
|
||||||
|
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||||
|
#include "dataset/util/queue.h"
|
||||||
|
#include "dataset/util/wait_post.h"
|
||||||
|
#include "dataset/engine/datasetops/cache_base_op.h"
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
/// \brief This is the base class for CacheOp and CacheLookupOp which share many similarities.
|
||||||
|
/// \see CacheOp
|
||||||
|
/// \see CacheLookupOp
|
||||||
|
class CacheBase : public ParallelOp {
|
||||||
|
public:
|
||||||
|
/// \brief Base class constructor
|
||||||
|
/// \param num_workers Number of parallel workers
|
||||||
|
/// \param op_connector_size Connector size
|
||||||
|
/// \param rows_per_buf Number of rows per buffer
|
||||||
|
/// \param cache_client CacheClient for communication to the CacheServer
|
||||||
|
/// \param sampler Sampler which is mandatory
|
||||||
|
CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
|
||||||
|
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler);
|
||||||
|
/// \brief Destructor
|
||||||
|
~CacheBase();
|
||||||
|
|
||||||
|
constexpr static int eoe_row_id = -1;
|
||||||
|
|
||||||
|
/// \brief Overrides base class reset method. When an operator does a reset, it cleans up any state
|
||||||
|
/// info from it's previous execution and then initializes itself so that it can be executed
|
||||||
|
/// again.
|
||||||
|
/// \return Status - The error code return
|
||||||
|
Status Reset() override;
|
||||||
|
|
||||||
|
/// \brief A print method typically used for debugging
|
||||||
|
/// \param out The output stream to write output to
|
||||||
|
/// \param show_all A bool to control if you want to show all info or just a summary
|
||||||
|
void Print(std::ostream &out, bool show_all) const override;
|
||||||
|
|
||||||
|
/// \brief << Stream output operator overload
|
||||||
|
/// \notes This allows you to write the debug print info using stream operators
|
||||||
|
/// \param out reference to the output stream being overloaded
|
||||||
|
/// \param mo reference to the CacheOp to display
|
||||||
|
/// \return the output stream must be returned
|
||||||
|
friend std::ostream &operator<<(std::ostream &out, const CacheBase &mo) {
|
||||||
|
mo.Print(out, false);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Getter for the cache client
|
||||||
|
/// \return shared ptr to the cache client
|
||||||
|
std::shared_ptr<CacheClient> cache_client() { return cache_client_; }
|
||||||
|
/// \brief Setter for the cache client
|
||||||
|
void SetCacheClient(std::shared_ptr<CacheClient> cache_client) { cache_client_ = std::move(cache_client); }
|
||||||
|
/// \brief Derived class must implement this method if a cache miss is treated as error
|
||||||
|
virtual bool AllowCacheMiss() = 0;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::shared_ptr<CacheClient> cache_client_;
|
||||||
|
WaitPost epoch_sync_;
|
||||||
|
int32_t rows_per_buffer_;
|
||||||
|
Connector<std::vector<row_id_type>> keys_miss_;
|
||||||
|
|
||||||
|
/// \brief Common function to register resources for interrupt
|
||||||
|
/// \note Derived should override this function for extra resources to be registered
|
||||||
|
virtual Status RegisterResources();
|
||||||
|
/// \brief This function is called by main thread to send samples to the worker thread.
|
||||||
|
/// \note It is a non-virtual function
|
||||||
|
/// \return Status object
|
||||||
|
Status FetchSamplesToWorkers();
|
||||||
|
/// \brief This function is called by each worker to fetch rows from the cache server for a given set of
|
||||||
|
/// sample row id's
|
||||||
|
/// \return Status object
|
||||||
|
Status FetchFromCache(int32_t worker_id);
|
||||||
|
/// \brief Get the column map from cache server
|
||||||
|
Status UpdateColumnMapFromCache();
|
||||||
|
|
||||||
|
private:
|
||||||
|
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
|
|
@ -0,0 +1,130 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "dataset/engine/datasetops/cache_lookup_op.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
#include "dataset/core/config_manager.h"
|
||||||
|
#include "dataset/core/constants.h"
|
||||||
|
#include "dataset/core/global_context.h"
|
||||||
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
#include "utils/system/crc32c.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
// Builder constructor. Creates the builder object.
|
||||||
|
CacheLookupOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
|
||||||
|
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||||
|
build_num_workers_ = cfg->num_parallel_workers();
|
||||||
|
rows_per_buffer_ = cfg->rows_per_buffer();
|
||||||
|
build_op_connector_size_ = cfg->op_connector_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the required parameters are set by the builder.
|
||||||
|
Status CacheLookupOp::Builder::SanityCheck() const {
|
||||||
|
if (build_cache_client_ == nullptr) {
|
||||||
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheLookupOp requires a CacheClient");
|
||||||
|
}
|
||||||
|
// Make sure the cache client has a valid session
|
||||||
|
if (!build_cache_client_->session_id()) {
|
||||||
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||||
|
"Cache client for CacheLookupOp is missing session id");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// The builder "build" method creates the final object and does some init on it
|
||||||
|
Status CacheLookupOp::Builder::Build(std::shared_ptr<CacheLookupOp> *ptr) {
|
||||||
|
RETURN_IF_NOT_OK(SanityCheck());
|
||||||
|
*ptr = std::make_shared<CacheLookupOp>(build_num_workers_, build_op_connector_size_, rows_per_buffer_,
|
||||||
|
build_cache_client_, build_sampler_);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheLookupOp::operator()() {
|
||||||
|
if (!sampler_) {
|
||||||
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||||
|
"CacheLookupOp requires a sampler before it can be executed!");
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(RegisterResources());
|
||||||
|
// Kick off the workers
|
||||||
|
RETURN_IF_NOT_OK(
|
||||||
|
tree_->LaunchWorkers(num_workers_, std::bind(&CacheLookupOp::WorkerEntry, this, std::placeholders::_1)));
|
||||||
|
// required task group sync after launching workers
|
||||||
|
TaskManager::FindMe()->Post();
|
||||||
|
// We have to wait until the leaf op has handshake with us.
|
||||||
|
RETURN_IF_NOT_OK(leaf_op_wp_.Wait());
|
||||||
|
RETURN_IF_NOT_OK(FetchSamplesToWorkers());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheLookupOp::WorkerEntry(int32_t worker_id) {
|
||||||
|
TaskManager::FindMe()->Post();
|
||||||
|
RETURN_IF_NOT_OK(FetchFromCache(worker_id));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheLookupOp::ResetSampler() { return Status::OK(); }
|
||||||
|
Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op) {
|
||||||
|
// We act like a sampler and as a dataset op. During handshake with leaf op,
|
||||||
|
// We must wait until the leaf op has indexed everything.
|
||||||
|
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(op));
|
||||||
|
// Now we notify the main thread handshake has finished.
|
||||||
|
leaf_op_wp_.Set();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheLookupOp::InitSampler() { return Sampler::InitSampler(); }
|
||||||
|
void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); }
|
||||||
|
Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||||
|
std::vector<row_id_type> cache_miss;
|
||||||
|
RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss));
|
||||||
|
// Ignore the case we have no cache miss, we can't return empty samples.
|
||||||
|
while (cache_miss.empty()) {
|
||||||
|
RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss));
|
||||||
|
}
|
||||||
|
// Special code for eoe
|
||||||
|
if (cache_miss.at(0) == eoe_row_id) {
|
||||||
|
*out_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||||
|
} else {
|
||||||
|
std::shared_ptr<Tensor> sample_ts;
|
||||||
|
RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ts, cache_miss.size()));
|
||||||
|
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagNone);
|
||||||
|
auto idPtr = sample_ts->begin<int64_t>();
|
||||||
|
for (auto i = 0; i < cache_miss.size(); ++i) {
|
||||||
|
*idPtr = cache_miss.at(i);
|
||||||
|
++idPtr;
|
||||||
|
}
|
||||||
|
TensorRow row;
|
||||||
|
row.push_back(sample_ts);
|
||||||
|
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheLookupOp::RegisterResources() {
|
||||||
|
RETURN_IF_NOT_OK(CacheBase::RegisterResources());
|
||||||
|
RETURN_IF_NOT_OK(leaf_op_wp_.Register(tree_->AllTasks()));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheLookupOp::ComputeColMap() {
|
||||||
|
// We don't know the column map at this point unless we contact the cache server
|
||||||
|
// to fetch the schema but the cache server may not have it at this point either.
|
||||||
|
// So we will just return OK and let MergeOp (our parent) to handle it.
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status CacheLookupOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(shared_from_base<CacheLookupOp>(), modified);
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,122 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_
|
||||||
|
#define DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include "dataset/engine/datasetops/cache_base_op.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
/// \brief provides a memory/disk cache that acts as a save-point within a mappable dataset.
|
||||||
|
/// \note For non-mappable dataset, please see CacheOp
|
||||||
|
/// \see CacheOp
|
||||||
|
class CacheLookupOp : public CacheBase, public Sampler {
|
||||||
|
public:
|
||||||
|
class Builder {
|
||||||
|
public:
|
||||||
|
/// \brief Builder constructor. Creates the builder object.
|
||||||
|
/// \note No default args
|
||||||
|
Builder();
|
||||||
|
|
||||||
|
/// Default destructor
|
||||||
|
~Builder() = default;
|
||||||
|
|
||||||
|
/// Setter method.
|
||||||
|
/// \treturn Builder setter method returns reference to the builder.
|
||||||
|
Builder &SetNumWorkers(int32_t num_workers) {
|
||||||
|
build_num_workers_ = num_workers;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Setter method.
|
||||||
|
/// \return Builder setter method returns reference to the builder.
|
||||||
|
Builder &SetOpConnectorSize(int32_t connector_size) {
|
||||||
|
build_op_connector_size_ = connector_size;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Setter method.
|
||||||
|
/// \return Builder setter method returns reference to the builder.
|
||||||
|
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
|
||||||
|
build_cache_client_ = cache_client;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Setter method.
|
||||||
|
/// \return Builder setter method returns reference to the builder.
|
||||||
|
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
|
||||||
|
build_sampler_ = std::move(sampler);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief The builder "build" method creates the final object and does some init on it.
|
||||||
|
/// \param ptr The shared_ptr to the new CacheLookupOp object
|
||||||
|
/// \return Status
|
||||||
|
Status Build(std::shared_ptr<CacheLookupOp> *ptr);
|
||||||
|
|
||||||
|
private:
|
||||||
|
int32_t build_num_workers_;
|
||||||
|
int32_t rows_per_buffer_;
|
||||||
|
int32_t build_op_connector_size_;
|
||||||
|
std::shared_ptr<CacheClient> build_cache_client_;
|
||||||
|
std::shared_ptr<Sampler> build_sampler_;
|
||||||
|
|
||||||
|
// Check if the required parameters are set by the builder.
|
||||||
|
// \return Status The error code return
|
||||||
|
Status SanityCheck() const;
|
||||||
|
};
|
||||||
|
/// \brief Constructor
|
||||||
|
/// \note It takes the same argument as the base class.
|
||||||
|
/// \see CacheBase
|
||||||
|
CacheLookupOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
|
||||||
|
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
|
||||||
|
: CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), Sampler(*(sampler.get())) {}
|
||||||
|
~CacheLookupOp() = default;
|
||||||
|
// As a parallel op, we override these two functions
|
||||||
|
Status operator()() override;
|
||||||
|
Status WorkerEntry(int32_t worker_id) override;
|
||||||
|
// As a sampler, we override the following functions
|
||||||
|
Status ResetSampler() override;
|
||||||
|
Status HandshakeRandomAccessOp(const RandomAccessOp *op) override;
|
||||||
|
Status InitSampler() override;
|
||||||
|
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||||
|
void Print(std::ostream &out, bool show_all) const override;
|
||||||
|
bool AllowCacheMiss() override { return true; }
|
||||||
|
std::string Name() const override { return "CacheLookupOp"; }
|
||||||
|
|
||||||
|
/// \brief Base-class override for NodePass visitor acceptor
|
||||||
|
/// \param[in] p The node to visit
|
||||||
|
/// \param[out] modified Indicator if the node was modified
|
||||||
|
/// \return Status of the node visit
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status ComputeColMap() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
WaitPost leaf_op_wp_;
|
||||||
|
|
||||||
|
Status RegisterResources() override;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_
|
|
@ -0,0 +1,301 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <functional>
|
||||||
|
#include <iomanip>
|
||||||
|
#include "dataset/core/config_manager.h"
|
||||||
|
#include "dataset/core/constants.h"
|
||||||
|
#include "dataset/core/global_context.h"
|
||||||
|
#include "dataset/engine/datasetops/cache_merge_op.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
#include "dataset/util/task_manager.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
CacheMergeOp::~CacheMergeOp() = default;
|
||||||
|
void CacheMergeOp::Print(std::ostream &out, bool show_all)
|
||||||
|
const { // Always show the id and name as first line regardless if this is summary or detailed print
|
||||||
|
out << "(" << std::setw(2) << operator_id_ << ") <CacheMergeOp>:";
|
||||||
|
if (!show_all) {
|
||||||
|
// Call the super class for displaying any common 1-liner info
|
||||||
|
ParallelOp::Print(out, show_all);
|
||||||
|
// Then show any custom derived-internal 1-liner info for this op
|
||||||
|
out << "\n";
|
||||||
|
} else {
|
||||||
|
// Call the super class for displaying any common detailed info
|
||||||
|
ParallelOp::Print(out, show_all);
|
||||||
|
// Then show any custom derived-internal stuff
|
||||||
|
out << "\n\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
|
||||||
|
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler)
|
||||||
|
: ParallelOp(numWorkers, opConnectorSize, sampler), num_cleaners_(numCleaners), cache_client_(cache_client) {}
|
||||||
|
Status CacheMergeOp::operator()() {
|
||||||
|
// A queue of row id to let cleaner send cache miss rows to the cache server
|
||||||
|
// We don't want a small queue as this will block the parallel op workers.
|
||||||
|
// A row id is 8 byte integer. So bigger size doesn't consume a lot of memory.
|
||||||
|
io_que_ = std::make_unique<Queue<row_id_type>>(512);
|
||||||
|
RETURN_IF_NOT_OK(io_que_->Register(tree_->AllTasks()));
|
||||||
|
RETURN_IF_NOT_OK(
|
||||||
|
tree_->LaunchWorkers(num_workers_, std::bind(&CacheMergeOp::WorkerEntry, this, std::placeholders::_1)));
|
||||||
|
RETURN_IF_NOT_OK(
|
||||||
|
tree_->LaunchWorkers(num_workers_, std::bind(&CacheMergeOp::CacheMissWorkerEntry, this, std::placeholders::_1)));
|
||||||
|
// One dedicated thread to move TensorRow from the pool to the cache server
|
||||||
|
for (auto i = 0; i < num_cleaners_; ++i) {
|
||||||
|
RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Cleaner", std::bind(&CacheMergeOp::Cleaner, this)));
|
||||||
|
}
|
||||||
|
TaskManager::FindMe()->Post();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
// Each parallel worker will pop from the CacheHit stream. If there is a missing TensorRow, we will wait
|
||||||
|
// until it shows up in the pool.
|
||||||
|
Status CacheMergeOp::WorkerEntry(int32_t worker_id) {
|
||||||
|
TaskManager::FindMe()->Post();
|
||||||
|
std::shared_ptr<DatasetOp> cache_hit_stream = child_[kCacheHitChildIdx];
|
||||||
|
std::unique_ptr<DataBuffer> db_ptr;
|
||||||
|
RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id));
|
||||||
|
while (!db_ptr->eof()) {
|
||||||
|
if (db_ptr->eoe()) {
|
||||||
|
RETURN_IF_NOT_OK(EoeReceived(worker_id));
|
||||||
|
db_ptr.reset();
|
||||||
|
RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id));
|
||||||
|
} else {
|
||||||
|
// See if there is any missing row
|
||||||
|
auto tbl = std::make_unique<TensorQTable>();
|
||||||
|
while (db_ptr->NumRows() > 0) {
|
||||||
|
TensorRow row;
|
||||||
|
RETURN_IF_NOT_OK(db_ptr->PopRow(&row));
|
||||||
|
if (row.empty()) {
|
||||||
|
auto row_id = row.getId();
|
||||||
|
TensorRowRequest *rq = nullptr;
|
||||||
|
RETURN_IF_NOT_OK(GetRq(row_id, &rq));
|
||||||
|
// Block until the row shows up in the pool.
|
||||||
|
RETURN_IF_NOT_OK(rq->Wait(&row));
|
||||||
|
}
|
||||||
|
tbl->push_back(std::move(row));
|
||||||
|
}
|
||||||
|
db_ptr->set_tensor_table(std::move(tbl));
|
||||||
|
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr)));
|
||||||
|
RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr)));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
|
||||||
|
TaskManager::FindMe()->Post();
|
||||||
|
// We will simply pop TensorRow from the stream and insert them into the pool and
|
||||||
|
// wake up any worker that is awaiting on the missing TensorRow.
|
||||||
|
// If we see an eoe, ignore it. For eof, we exit.
|
||||||
|
std::shared_ptr<DatasetOp> cache_missing_stream = child_[kCacheMissChildIdx];
|
||||||
|
// Before we start, cache the schema at the server. Pick one of the workers
|
||||||
|
// do it. The schema should have been done at prepare time.
|
||||||
|
if (workerId == 0) {
|
||||||
|
RETURN_IF_NOT_OK(cache_client_->CacheSchema(column_name_id_map()));
|
||||||
|
}
|
||||||
|
std::unique_ptr<DataBuffer> db_ptr;
|
||||||
|
RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId));
|
||||||
|
while (!db_ptr->eof()) {
|
||||||
|
if (db_ptr->eoe()) {
|
||||||
|
// Ignore it.
|
||||||
|
MS_LOG(DEBUG) << "Ignore eoe";
|
||||||
|
} else {
|
||||||
|
while (db_ptr->NumRows() > 0) {
|
||||||
|
TensorRow row;
|
||||||
|
RETURN_IF_NOT_OK(db_ptr->PopRow(&row));
|
||||||
|
row_id_type row_id = row.getId();
|
||||||
|
if (row_id < 0) {
|
||||||
|
std::string errMsg = "Expect positive row id: " + std::to_string(row_id);
|
||||||
|
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||||
|
}
|
||||||
|
TensorRowRequest *rq = nullptr;
|
||||||
|
RETURN_IF_NOT_OK(GetRq(row_id, &rq));
|
||||||
|
rq->WakeUpAny(std::move(row));
|
||||||
|
// Let the cleaner to flush out this row (async) to the cache server.
|
||||||
|
RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheMergeOp::Cleaner() {
|
||||||
|
TaskManager::FindMe()->Post();
|
||||||
|
while (true) {
|
||||||
|
row_id_type row_id;
|
||||||
|
RETURN_IF_NOT_OK(io_que_->PopFront(&row_id));
|
||||||
|
if (row_id < 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
TensorRowRequest *rq = nullptr;
|
||||||
|
RETURN_IF_NOT_OK(GetRq(row_id, &rq));
|
||||||
|
if (rq->GetState() == TensorRowRequest::State::kClean) {
|
||||||
|
// If already flushed, move on to the next one.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
TensorRow row;
|
||||||
|
RETURN_IF_NOT_OK(rq->Release(&row));
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(!row.empty(), "Programming error");
|
||||||
|
Status rc = cache_client_->WriteRow(row);
|
||||||
|
// Bad rc should not bring down the pipeline
|
||||||
|
if (rc.IsError()) {
|
||||||
|
MS_LOG(WARNING) << "Cache not successful." << rc.ToString();
|
||||||
|
}
|
||||||
|
rq->SetState(TensorRowRequest::State::kClean);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowRequest **out) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(out);
|
||||||
|
std::unique_lock<std::mutex> lck(mux_);
|
||||||
|
auto it = cache_miss_map_.find(row_id);
|
||||||
|
if (it != cache_miss_map_.end()) {
|
||||||
|
*out = it->second.GetMutablePointer();
|
||||||
|
} else {
|
||||||
|
// We will create a new one.
|
||||||
|
auto alloc = Services::GetAllocator<TensorRowRequest>();
|
||||||
|
auto r = cache_miss_map_.emplace(row_id, MemGuard<TensorRowRequest, Allocator<TensorRowRequest>>(alloc));
|
||||||
|
if (r.second) {
|
||||||
|
auto &mem = r.first->second;
|
||||||
|
RETURN_IF_NOT_OK(mem.allocate(1, row_id));
|
||||||
|
*out = mem.GetMutablePointer();
|
||||||
|
} else {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Map insert fail.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from super class first before adding our own
|
||||||
|
// specific logic
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 2, "Incorrect number of children");
|
||||||
|
RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction());
|
||||||
|
// Get the computed check sum from all ops in the cache miss class
|
||||||
|
uint32_t cache_crc = DatasetOp::GenerateCRC(child_[kCacheMissChildIdx]);
|
||||||
|
// This is a mappable cache op so the id's need to be generated.
|
||||||
|
// Construct the cache
|
||||||
|
const bool generate_ids = false;
|
||||||
|
Status rc = cache_client_->CreateCache(cache_crc, generate_ids);
|
||||||
|
if (rc.get_code() == StatusCode::kDuplicateKey) {
|
||||||
|
// We are told the cache has been created already.
|
||||||
|
MS_LOG(INFO) << "Cache created already";
|
||||||
|
rc = Status::OK();
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(rc);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheMergeOp::ComputeColMap() {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(child_[kCacheMissChildIdx] != nullptr, "Cache miss stream empty");
|
||||||
|
if (column_name_id_map().empty()) {
|
||||||
|
column_name_id_map_ = child_[kCacheMissChildIdx]->column_name_id_map();
|
||||||
|
}
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map().empty(), "No column map detected");
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheMergeOp::TensorRowRequest::Wait(TensorRow *out) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(out);
|
||||||
|
// Block until the missing row is in the pool.
|
||||||
|
RETURN_IF_NOT_OK(use_count_.P());
|
||||||
|
std::unique_lock<std::mutex> lck(dq_mux_);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(!row_.empty(), "Programming error");
|
||||||
|
*out = std::move(row_.front());
|
||||||
|
row_.pop_front();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
void CacheMergeOp::TensorRowRequest::WakeUpAny(TensorRow &&row) {
|
||||||
|
std::unique_lock<std::mutex> lck(dq_mux_);
|
||||||
|
// Technically number of this row shows up in the cache miss stream is equal to the number
|
||||||
|
// of P() call. However the cleaner wants it too. So we need an extra copy.
|
||||||
|
if (GetState() == State::kEmpty) {
|
||||||
|
// We will do a deep copy
|
||||||
|
for (auto &ts : row) {
|
||||||
|
auto out_ts = std::make_shared<Tensor>(ts->shape(), ts->type(), ts->GetBuffer(), ts->SizeInBytes());
|
||||||
|
cleaner_copy_.push_back(out_ts);
|
||||||
|
}
|
||||||
|
cleaner_copy_.setId(row.getId());
|
||||||
|
// Change the state to dirty
|
||||||
|
SetState(State::kDirty);
|
||||||
|
}
|
||||||
|
row_.push_back(std::move(row));
|
||||||
|
// Bump up the use count by 1. This wake up any parallel worker which is waiting
|
||||||
|
// for this row.
|
||||||
|
use_count_.V();
|
||||||
|
}
|
||||||
|
Status CacheMergeOp::TensorRowRequest::Release(TensorRow *out) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(out);
|
||||||
|
// We are not holding any mutex here because the cleaner isn't really touching the deque row_.
|
||||||
|
// In case we have multiple cleaners and they all see the copy, only one of them will
|
||||||
|
// get it.
|
||||||
|
auto expected = State::kDirty;
|
||||||
|
if (st_.compare_exchange_strong(expected, State::kClean)) {
|
||||||
|
*out = std::move(cleaner_copy_);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
// Builder constructor. Creates the builder object.
|
||||||
|
CacheMergeOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
|
||||||
|
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||||
|
build_num_workers_ = cfg->num_parallel_workers();
|
||||||
|
build_op_connector_size_ = cfg->op_connector_size();
|
||||||
|
build_num_cleaners_ = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the required parameters are set by the builder.
|
||||||
|
Status CacheMergeOp::Builder::SanityCheck() const {
|
||||||
|
if (build_cache_client_ == nullptr) {
|
||||||
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheMergeOp requires a CacheClient");
|
||||||
|
}
|
||||||
|
// Make sure the cache client has a valid session
|
||||||
|
if (!build_cache_client_->session_id()) {
|
||||||
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||||
|
"Cache client for CacheMergeOp is missing session id");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// The builder "build" method creates the final object and does some init on it
|
||||||
|
Status CacheMergeOp::Builder::Build(std::shared_ptr<CacheMergeOp> *ptr) {
|
||||||
|
RETURN_IF_NOT_OK(SanityCheck());
|
||||||
|
*ptr = std::make_shared<CacheMergeOp>(build_num_workers_, build_op_connector_size_, build_num_cleaners_,
|
||||||
|
build_cache_client_, build_sampler_);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-Visitor accept method for NodePass
|
||||||
|
Status CacheMergeOp::PreAccept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call the pre-visitation
|
||||||
|
return p->PreRunOnNode(shared_from_base<CacheMergeOp>(), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status CacheMergeOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(shared_from_base<CacheMergeOp>(), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CacheMergeOp::EoeReceived(int32_t worker_id) {
|
||||||
|
// If we are in a repeat path, send the eoe up.
|
||||||
|
// Otherwise ignore it.
|
||||||
|
if (BitTest(op_ctrl_flags_, kDeOpRepeated)) {
|
||||||
|
return DatasetOp::EoeReceived(worker_id);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,196 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_
|
||||||
|
#define DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <deque>
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <mutex>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include "dataset/core/tensor_row.h"
|
||||||
|
#include "dataset/engine/cache/cache_client.h"
|
||||||
|
#include "dataset/engine/datasetops/parallel_op.h"
|
||||||
|
#include "dataset/engine/dataset_iterator.h"
|
||||||
|
#include "dataset/util/queue.h"
|
||||||
|
#include "dataset/util/semaphore.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
/// \brief Provides method to merge two streams (one from CacheLookup and one from cache miss stream) into one single
|
||||||
|
/// stream
|
||||||
|
class CacheMergeOp : public ParallelOp {
|
||||||
|
public:
|
||||||
|
// Some handshake structures among the main thread, cleaner threads and parallel op threads.
|
||||||
|
class TensorRowRequest {
|
||||||
|
public:
|
||||||
|
enum class State : uint8_t {
|
||||||
|
kEmpty = 0, // No row in the deque
|
||||||
|
kDirty = 1, // Cleaner hasn't flushed it to the cache server yet.
|
||||||
|
kClean = 2 // The row has been flushed already.
|
||||||
|
};
|
||||||
|
explicit TensorRowRequest(row_id_type id) : st_(State::kEmpty), use_count_(0) {}
|
||||||
|
~TensorRowRequest() = default;
|
||||||
|
State GetState() const { return st_; }
|
||||||
|
void SetState(State newState) { st_ = newState; }
|
||||||
|
Status Wait(TensorRow *out);
|
||||||
|
void WakeUpAny(TensorRow &&row);
|
||||||
|
Status Release(TensorRow *out);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::mutex dq_mux_;
|
||||||
|
std::atomic<State> st_;
|
||||||
|
Semaphore use_count_;
|
||||||
|
std::deque<TensorRow> row_;
|
||||||
|
TensorRow cleaner_copy_;
|
||||||
|
};
|
||||||
|
|
||||||
|
constexpr static int kCacheHitChildIdx = 0; // Cache hit stream
|
||||||
|
constexpr static int kCacheMissChildIdx = 1; // Cache miss stream
|
||||||
|
|
||||||
|
/// \brief The nested builder class inside of the CacheMergeOp is used to help manage all of
|
||||||
|
/// the arguments for constructing it. Use the builder by setting each argument
|
||||||
|
/// with the provided set methods, and then finally call the build method to execute
|
||||||
|
/// the actual construction.
|
||||||
|
class Builder {
|
||||||
|
public:
|
||||||
|
/// Builder constructor. Creates the builder object.
|
||||||
|
/// \note No default args
|
||||||
|
Builder();
|
||||||
|
|
||||||
|
/// Default destructor
|
||||||
|
~Builder() = default;
|
||||||
|
|
||||||
|
/// Setter method.
|
||||||
|
/// \return Builder setter method returns reference to the builder.
|
||||||
|
Builder &SetNumWorkers(int32_t num_workers) {
|
||||||
|
build_num_workers_ = num_workers;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Setter method.
|
||||||
|
/// \return Builder setter method returns reference to the builder.
|
||||||
|
Builder &SetOpConnectorSize(int32_t connector_size) {
|
||||||
|
build_op_connector_size_ = connector_size;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Setter method.
|
||||||
|
/// \return Builder setter method returns reference to the builder.
|
||||||
|
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
|
||||||
|
build_cache_client_ = cache_client;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Setter method
|
||||||
|
/// \param sampler
|
||||||
|
/// \return Builder setter method returns reference to the builder.
|
||||||
|
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
|
||||||
|
build_sampler_ = std::move(sampler);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Setter method
|
||||||
|
/// \param num_cleaners
|
||||||
|
/// \return Builder setter method returns reference to the builder.
|
||||||
|
Builder &SetNumCleaner(int32_t num_cleaners) {
|
||||||
|
build_num_cleaners_ = num_cleaners;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The builder "build" method creates the final object and does some init on it.
|
||||||
|
/// \param ptr The shared_ptr to the new CacheMergeOp object
|
||||||
|
/// \return Status
|
||||||
|
Status Build(std::shared_ptr<CacheMergeOp> *ptr);
|
||||||
|
|
||||||
|
private:
|
||||||
|
int32_t build_num_workers_;
|
||||||
|
int32_t build_op_connector_size_;
|
||||||
|
int32_t build_num_cleaners_;
|
||||||
|
std::shared_ptr<CacheClient> build_cache_client_;
|
||||||
|
std::shared_ptr<Sampler> build_sampler_;
|
||||||
|
|
||||||
|
/// Check if the required parameters are set by the builder.
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status SanityCheck() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// \brief Constructor
|
||||||
|
/// \param numWorkers Number of parallel workers as a derived class of ParallelOp
|
||||||
|
/// \param opConnector Size Connector size as a derived class of ParallelOp
|
||||||
|
/// \param numCleaners Number of cleaners to move cache miss rows into the cache server
|
||||||
|
/// \param cache_client CacheClient to commmunicate with the Cache server
|
||||||
|
/// \param sampler as a derived class of ParallelOp
|
||||||
|
CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
|
||||||
|
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler);
|
||||||
|
~CacheMergeOp();
|
||||||
|
void Print(std::ostream &out, bool show_all) const override;
|
||||||
|
friend std::ostream &operator<<(std::ostream &out, const CacheMergeOp &mo) {
|
||||||
|
mo.Print(out, false);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
/// \brief Master thread responsible to spawn all the necessary worker threads for the two streams and
|
||||||
|
/// the threads for the cleaners.
|
||||||
|
/// \return
|
||||||
|
Status operator()() override;
|
||||||
|
/// \brief Entry function for worker thread that fetch rows from CacheLookupOp
|
||||||
|
/// \param workerId
|
||||||
|
/// \return Status object
|
||||||
|
Status WorkerEntry(int32_t workerId) override;
|
||||||
|
Status PrepareNodePostAction() override;
|
||||||
|
/// \brief Entry function for worker thread that fetch rows from the cache miss stream
|
||||||
|
/// \param workerId
|
||||||
|
/// \return Status object
|
||||||
|
Status CacheMissWorkerEntry(int32_t workerId);
|
||||||
|
Status GetRq(row_id_type row_id, TensorRowRequest **);
|
||||||
|
|
||||||
|
/// \brief Base-class override for NodePass pre-visit acceptor
|
||||||
|
/// \param[in] p The node to visit
|
||||||
|
/// \param[out] modified Indicator if the node was modified
|
||||||
|
/// \return Status of the node visit
|
||||||
|
Status PreAccept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Base-class override for NodePass visitor acceptor
|
||||||
|
/// \param[in] p The node to visit
|
||||||
|
/// \param[out] modified Indicator if the node was modified
|
||||||
|
/// \return Status of the node visit
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Base-class override for eoe handling
|
||||||
|
/// \param worker_id
|
||||||
|
/// \return Status object
|
||||||
|
Status EoeReceived(int32_t worker_id) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status ComputeColMap() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::mutex mux_;
|
||||||
|
std::map<row_id_type, MemGuard<TensorRowRequest, Allocator<TensorRowRequest>>> cache_miss_map_;
|
||||||
|
std::unique_ptr<Queue<row_id_type>> io_que_;
|
||||||
|
std::shared_ptr<CacheClient> cache_client_;
|
||||||
|
int32_t num_cleaners_;
|
||||||
|
|
||||||
|
/// \brief These are the entry functions for the cleaner threads. Each cleaner is responsible for
|
||||||
|
/// moving cache miss TensorRow into the CacheServer.
|
||||||
|
/// \return Status object
|
||||||
|
Status Cleaner();
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_
|
|
@ -0,0 +1,219 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "dataset/engine/datasetops/cache_op.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include "dataset/core/config_manager.h"
|
||||||
|
#include "dataset/core/constants.h"
|
||||||
|
#include "dataset/core/global_context.h"
|
||||||
|
#include "dataset/engine/datasetops/repeat_op.h"
|
||||||
|
#include "dataset/engine/data_buffer.h"
|
||||||
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
#include "dataset/util/task_manager.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
// Builder constructor. Creates the builder object.
|
||||||
|
CacheOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
|
||||||
|
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||||
|
build_num_workers_ = cfg->num_parallel_workers();
|
||||||
|
rows_per_buffer_ = cfg->rows_per_buffer();
|
||||||
|
build_op_connector_size_ = cfg->op_connector_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the required parameters are set by the builder.
|
||||||
|
Status CacheOp::Builder::SanityCheck() const {
|
||||||
|
if (build_cache_client_ == nullptr) {
|
||||||
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheOp requires a CacheClient");
|
||||||
|
}
|
||||||
|
// Make sure the cache client has a valid session
|
||||||
|
if (!build_cache_client_->session_id()) {
|
||||||
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cache client for CacheOp is missing session id");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// The builder "build" method creates the final object and does some init on it
|
||||||
|
Status CacheOp::Builder::Build(std::shared_ptr<CacheOp> *ptr) {
|
||||||
|
RETURN_IF_NOT_OK(SanityCheck());
|
||||||
|
*ptr = std::make_shared<CacheOp>(build_num_workers_, build_op_connector_size_, rows_per_buffer_, build_cache_client_,
|
||||||
|
build_sampler_);
|
||||||
|
RETURN_IF_NOT_OK((*ptr)->InitCache());
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Constructor of CacheOp
|
||||||
|
CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
|
||||||
|
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
|
||||||
|
: CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler),
|
||||||
|
num_guys_in_(0),
|
||||||
|
phase_(Phase::kBuildPhase) {}
|
||||||
|
|
||||||
|
// Destructor
|
||||||
|
CacheOp::~CacheOp() = default;
|
||||||
|
|
||||||
|
// Private function for cache setup/init work just after construction
|
||||||
|
Status CacheOp::InitCache() { return Status::OK(); }
|
||||||
|
|
||||||
|
// This class functor will provide the master loop that drives the logic for performing the work
|
||||||
|
Status CacheOp::operator()() {
|
||||||
|
if (!sampler_) {
|
||||||
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||||
|
"CacheOp requires a sampler before it can be executed!");
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(RegisterResources());
|
||||||
|
// Kick off the workers
|
||||||
|
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheOp::WorkerEntry, this, std::placeholders::_1)));
|
||||||
|
// required task group sync after launching workers
|
||||||
|
TaskManager::FindMe()->Post();
|
||||||
|
// Wait for the workers to finish caching the rows.
|
||||||
|
RETURN_IF_NOT_OK(WaitForCachingAllRows());
|
||||||
|
RETURN_IF_NOT_OK(FetchSamplesToWorkers());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheOp::CacheAllRows(int32_t worker_id) {
|
||||||
|
// If the current phase is to fill the cache, do it then.
|
||||||
|
if (phase_ == Phase::kBuildPhase) {
|
||||||
|
// We will take the chance to cache the schema at the server.
|
||||||
|
// Just do it once and pick one worker to do it.
|
||||||
|
if (worker_id == 0) {
|
||||||
|
RETURN_IF_NOT_OK(cache_client_->CacheSchema(column_name_id_map()));
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "CacheOp first epoch SAVE mode started. Worker: " << worker_id;
|
||||||
|
// SAVE mode loop
|
||||||
|
std::unique_ptr<DataBuffer> db_ptr;
|
||||||
|
RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0));
|
||||||
|
while (!db_ptr->eof()) {
|
||||||
|
if (!db_ptr->eoe()) {
|
||||||
|
RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr)));
|
||||||
|
} else {
|
||||||
|
// In a repeat-over-cache scenario, any of the "real" leaf operators below us have been set up
|
||||||
|
// as non-repeating leaf ops. As such, they only do one epoch and then quit. Since we got the
|
||||||
|
// the eoe to indicate the end of the epoch, we should next expect to get the eof.
|
||||||
|
// Drain this eof so that we don't leave it sitting there on a connector that we'll never fetch
|
||||||
|
// from again.
|
||||||
|
RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0));
|
||||||
|
if (!db_ptr->eof()) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Cache op expects to get an eof after eoe from child.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Let the main guy know we are done.
|
||||||
|
auto last_guy_in = num_guys_in_.fetch_add(1);
|
||||||
|
if ((last_guy_in + 1) == num_workers_) {
|
||||||
|
rows_cache_done_.Set();
|
||||||
|
} else {
|
||||||
|
// Let's do a sync up here.
|
||||||
|
RETURN_IF_NOT_OK(rows_cache_done_.Wait());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheOp::WaitForCachingAllRows() {
|
||||||
|
// Wait for the workers to finish caching the rows.
|
||||||
|
RETURN_IF_NOT_OK(rows_cache_done_.Wait());
|
||||||
|
// Move from build phase to fetch phase if we are the one to fill the cache
|
||||||
|
if (phase_ == Phase::kBuildPhase) {
|
||||||
|
RETURN_IF_NOT_OK(cache_client_->BuildPhaseDone());
|
||||||
|
// Move to the next phase
|
||||||
|
phase_ = Phase::kFetchPhase;
|
||||||
|
}
|
||||||
|
// Get statistics from the server, and if we are not the one to create the cache,
|
||||||
|
// wait until the state changed from build phase to fetch base.
|
||||||
|
CacheClient::ServiceStat stat{};
|
||||||
|
bool BuildPhaseDone = true;
|
||||||
|
do {
|
||||||
|
RETURN_IF_NOT_OK(cache_client_->GetStat(&stat));
|
||||||
|
BuildPhaseDone = stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase);
|
||||||
|
if (!BuildPhaseDone) {
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
|
}
|
||||||
|
} while (!BuildPhaseDone);
|
||||||
|
const row_id_type min_key = stat.min_row_id;
|
||||||
|
const row_id_type max_key = stat.max_row_id;
|
||||||
|
num_rows_ = max_key - min_key + 1;
|
||||||
|
MS_LOG(INFO) << "Number of rows cached: " << num_rows_;
|
||||||
|
MS_LOG(INFO) << "Number of rows cached in memory : " << stat.num_mem_cached;
|
||||||
|
MS_LOG(INFO) << "Number of rows spilled to disk : " << stat.num_disk_cached;
|
||||||
|
// Now all rows are cached and we have done a sync point check up. Next phase is
|
||||||
|
// is pick up fetch input from sampler and pass up to the caller.
|
||||||
|
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheOp::WorkerEntry(int32_t worker_id) {
|
||||||
|
TaskManager::FindMe()->Post();
|
||||||
|
RETURN_IF_NOT_OK(CacheAllRows(worker_id));
|
||||||
|
RETURN_IF_NOT_OK(FetchFromCache(worker_id));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CacheOp::RegisterResources() {
|
||||||
|
RETURN_IF_NOT_OK(CacheBase::RegisterResources());
|
||||||
|
RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks()));
|
||||||
|
RETURN_IF_NOT_OK(keys_miss_.Register(tree_->AllTasks()));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Base-class override for setting specific CacheOp configurations. This code will be called
|
||||||
|
// during the execution tree prepare phase BEFORE traversing down to child operators.
|
||||||
|
uint32_t CacheOp::PrepareFlags() const { return ExecutionTree::kDePrepCache; }
|
||||||
|
// Base-class override for special eoe handler.
|
||||||
|
// CacheOp must override this because it shall not perform default handling of eoe. Instead
|
||||||
|
// the CacheOp manages actions related to the end of the epoch.
|
||||||
|
Status CacheOp::EoeReceived(int32_t worker_id) {
|
||||||
|
state_ = OpState::kDeOpIdle;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
// Base-class override for handling cases when an eof is received.
|
||||||
|
Status CacheOp::EofReceived(int32_t worker_id) {
|
||||||
|
// eofReceived is overloaded because we want to manually handle this eof.
|
||||||
|
// Specifically, the default behaviour is to pack it and flow it up to the next connection.
|
||||||
|
// In this case, we want a no-op behaviour so that we can perform correct action.
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-Visitor accept method for NodePass
|
||||||
|
Status CacheOp::PreAccept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call the pre-visitation
|
||||||
|
return p->PreRunOnNode(shared_from_base<CacheOp>(), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status CacheOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(shared_from_base<CacheOp>(), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
// A public wrapper for creating the cache through the client
|
||||||
|
Status CacheOp::CreateCache(uint32_t cache_crc) {
|
||||||
|
// This is a non-mappable cache op so the id's need to be generated.
|
||||||
|
// Construct the cache
|
||||||
|
const bool generate_ids = true;
|
||||||
|
Status rc = cache_client_->CreateCache(cache_crc, generate_ids);
|
||||||
|
if (rc.get_code() == StatusCode::kDuplicateKey) {
|
||||||
|
// We are told the cache has been created already. So we skip the build phase.
|
||||||
|
phase_ = Phase::kFetchPhase;
|
||||||
|
rc = Status::OK();
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(rc);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,168 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_OP_H_
|
||||||
|
#define DATASET_ENGINE_DATASETOPS_CACHE_OP_H_
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <memory>
|
||||||
|
#include "dataset/engine/datasetops/cache_base_op.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
/// \brief CacheOp provides a memory/disk cache that acts as a save-point within a non-mappable dataset.
|
||||||
|
/// \note For mappable dataset, please see CacheLookupOp.
|
||||||
|
/// \see CacheLookupOp
|
||||||
|
class CacheOp : public CacheBase, public RandomAccessOp {
|
||||||
|
public:
|
||||||
|
// This CacheOp is for non-mappable case where it is divided into two phases.
|
||||||
|
// The first phase is we cache all the rows from the child (and let the cache server
|
||||||
|
// assigns row id). No read access in the first phase. Once the cache is fully built,
|
||||||
|
// we switch to second phase and fetch requests from the sampler.
|
||||||
|
enum class Phase : uint8_t { kBuildPhase = 0, kFetchPhase = 1 };
|
||||||
|
|
||||||
|
/// \brief The nested builder class inside of the CacheOp is used to help manage all of
|
||||||
|
/// the arguments for constructing it. Use the builder by setting each argument
|
||||||
|
/// with the provided set methods, and then finally call the build method to execute
|
||||||
|
/// the actual construction.
|
||||||
|
class Builder {
|
||||||
|
public:
|
||||||
|
// Builder constructor. Creates the builder object.
|
||||||
|
// @note No default args
|
||||||
|
// @return This is a constructor.
|
||||||
|
Builder();
|
||||||
|
|
||||||
|
// Default destructor
|
||||||
|
~Builder() = default;
|
||||||
|
|
||||||
|
/// \brief Setter method.
|
||||||
|
/// \return Builder setter method returns reference to the builder.
|
||||||
|
Builder &SetNumWorkers(int32_t num_workers) {
|
||||||
|
build_num_workers_ = num_workers;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Setter method.
|
||||||
|
/// \return Builder setter method returns reference to the builder.
|
||||||
|
Builder &SetOpConnectorSize(int32_t connector_size) {
|
||||||
|
build_op_connector_size_ = connector_size;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Setter method.
|
||||||
|
/// \return Builder setter method returns reference to the builder.
|
||||||
|
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
|
||||||
|
build_cache_client_ = cache_client;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Setter method
|
||||||
|
/// \param rows_per_buffer
|
||||||
|
/// \return Builder setter method returns reference to the builder.
|
||||||
|
Builder &SetRowsPerBuffer(int32_t rows_per_buffer) {
|
||||||
|
rows_per_buffer_ = rows_per_buffer;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Setter method
|
||||||
|
/// \param sampler
|
||||||
|
/// \return Builder setter method returns reference to the builder.
|
||||||
|
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
|
||||||
|
build_sampler_ = std::move(sampler);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief The builder "build" method creates the final object and does some init on it.
|
||||||
|
/// \param ptr The shared_ptr to the new CacheOp object
|
||||||
|
/// \return Status
|
||||||
|
Status Build(std::shared_ptr<CacheOp> *ptr);
|
||||||
|
|
||||||
|
private:
|
||||||
|
int32_t build_num_workers_;
|
||||||
|
int32_t rows_per_buffer_;
|
||||||
|
int32_t build_op_connector_size_;
|
||||||
|
std::shared_ptr<CacheClient> build_cache_client_;
|
||||||
|
std::shared_ptr<Sampler> build_sampler_;
|
||||||
|
|
||||||
|
/// \brief Check if the required parameters are set by the builder.
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status SanityCheck() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// \brief Constructor of CacheOp
|
||||||
|
/// \note The builder class should be used to call it.
|
||||||
|
/// \param num_workers The number of worker threads.
|
||||||
|
/// \param op_connector_size The size of each queue in the connector.
|
||||||
|
CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
|
||||||
|
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler);
|
||||||
|
|
||||||
|
// Destructor
|
||||||
|
~CacheOp();
|
||||||
|
|
||||||
|
/// \brief Base-class override for setting specific CacheOp configurations. This code will be called
|
||||||
|
/// during the execution tree prepare phase BEFORE traversing down to child operators.
|
||||||
|
uint32_t PrepareFlags() const override;
|
||||||
|
/// \brief Base-class override for special eoe handler.
|
||||||
|
/// CacheOp must override this because it shall not perform default handling of eoe. Instead
|
||||||
|
/// the CacheOp manages actions related to the end of the epoch.
|
||||||
|
/// \return Status - The error code return
|
||||||
|
Status EoeReceived(int32_t worker_id) override;
|
||||||
|
/// \brief Base-class override for NodePass pre-visit acceptor
|
||||||
|
/// \param[in] p The node to visit
|
||||||
|
/// \param[out] modified Indicator if the node was modified
|
||||||
|
/// \return Status of the node visit
|
||||||
|
Status PreAccept(NodePass *p, bool *modified) override;
|
||||||
|
/// \brief Base-class override for NodePass visitor acceptor
|
||||||
|
/// \param[in] p The node to visit
|
||||||
|
/// \param[out] modified Indicator if the node was modified
|
||||||
|
/// \return Status of the node visit
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
/// \brief Base-class override for handling cases when an eof is received.
|
||||||
|
/// \param worker_id - The worker id
|
||||||
|
/// \return Status - The error code return
|
||||||
|
Status EofReceived(int32_t worker_id) override;
|
||||||
|
Status operator()() override;
|
||||||
|
Status WorkerEntry(int32_t worker_id) override;
|
||||||
|
/// \brief Base-class override for handling cases if we allow cache miss
|
||||||
|
bool AllowCacheMiss() override { return false; }
|
||||||
|
/// \brief Base-class override for the name of this operator
|
||||||
|
std::string Name() const override { return "CacheOp"; }
|
||||||
|
/// \brief A public wrapper for creating the cache through the client
|
||||||
|
/// \param[in] cache_crc The crc that identifies the cache
|
||||||
|
/// \see cache_pass.cc
|
||||||
|
/// \return Status return code
|
||||||
|
Status CreateCache(uint32_t cache_crc);
|
||||||
|
|
||||||
|
private:
|
||||||
|
WaitPost rows_cache_done_;
|
||||||
|
std::atomic<int64_t> num_guys_in_;
|
||||||
|
Phase phase_;
|
||||||
|
/// \brief The main thread will wait until all the rows are cached and will start the handshake with the sampler.
|
||||||
|
/// \return Status object
|
||||||
|
Status WaitForCachingAllRows();
|
||||||
|
/// \brief For non-mappable dataset, there is a build phase where we cache all the rows.
|
||||||
|
/// \return Status object
|
||||||
|
Status CacheAllRows(int32_t worker_id);
|
||||||
|
Status RegisterResources() override;
|
||||||
|
/// \brief Private function for cache setup/init work just after construction
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status InitCache();
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // DATASET_ENGINE_DATASETOPS_CACHE_OP_H_
|
|
@ -61,46 +61,39 @@ void ConcatOp::Print(std::ostream &out, bool show_all) const {
|
||||||
Status ConcatOp::operator()() {
|
Status ConcatOp::operator()() {
|
||||||
// The children_num_ parameter needs to be put here
|
// The children_num_ parameter needs to be put here
|
||||||
children_num_ = static_cast<int32_t>(child_.size());
|
children_num_ = static_cast<int32_t>(child_.size());
|
||||||
|
|
||||||
TaskManager::FindMe()->Post();
|
TaskManager::FindMe()->Post();
|
||||||
std::unique_ptr<DataBuffer> buf;
|
std::unique_ptr<DataBuffer> buf;
|
||||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf));
|
|
||||||
|
|
||||||
int eof_count = 0;
|
int eof_count = 0;
|
||||||
while (eof_count != children_num_) {
|
while (eof_count == 0) {
|
||||||
for (int i = 0; i < children_num_; i++) {
|
for (int i = 0; i < children_num_; i++) {
|
||||||
// 1. Throw the eof buffer when meet it
|
// 1. Read the first buffer
|
||||||
if (buf->eof() || buf->eoe()) {
|
RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf));
|
||||||
RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf));
|
if (buf->eof()) {
|
||||||
|
eof_count++;
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
// 2. Do verification as for column name, column data type and rank of column data
|
// 2. Do verification as for column name, column data type and rank of column data
|
||||||
RETURN_IF_NOT_OK(Verify(i, buf));
|
if (!buf->eoe()) {
|
||||||
|
RETURN_IF_NOT_OK(Verify(i, buf));
|
||||||
|
}
|
||||||
// 3. Put the data into output_connector
|
// 3. Put the data into output_connector
|
||||||
while (!buf->eoe() && !buf->eof()) {
|
while (!buf->eoe() && !buf->eof()) {
|
||||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf)));
|
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf)));
|
||||||
RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf));
|
RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
// 4. Throw the eoe buffer when meet it
|
// 4. Add eoe buffer after get buffer from all child
|
||||||
if (buf->eoe() && (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat))) {
|
if (eof_count == 0) {
|
||||||
RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf));
|
auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||||
}
|
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
|
||||||
// 5. Add eoe buffer after get buffer from all child
|
|
||||||
if (i == (children_num_ - 1)) {
|
|
||||||
auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
|
||||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
|
|
||||||
}
|
|
||||||
if (buf->eof()) {
|
|
||||||
eof_count++;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 6. Add eof buffer in the end manually
|
CHECK_FAIL_RETURN_UNEXPECTED(eof_count == children_num_,
|
||||||
|
"Something went wrong, eof count does not match the number of children.");
|
||||||
|
// 5. Add eof buffer in the end manually
|
||||||
MS_LOG(DEBUG) << "Add the eof buffer manualy in the end.";
|
MS_LOG(DEBUG) << "Add the eof buffer manualy in the end.";
|
||||||
auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
|
auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
|
||||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
|
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -126,12 +119,6 @@ Status ConcatOp::Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConcatOp::PrepareNodePostAction() {
|
|
||||||
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
|
|
||||||
tree_->AddToEOEOpStack(shared_from_this());
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// We need to overwrite the super class ComputeColMap here because the number of children is more than 1.
|
// We need to overwrite the super class ComputeColMap here because the number of children is more than 1.
|
||||||
Status ConcatOp::ComputeColMap() {
|
Status ConcatOp::ComputeColMap() {
|
||||||
if (column_name_id_map_.empty()) {
|
if (column_name_id_map_.empty()) {
|
||||||
|
|
|
@ -75,12 +75,6 @@ class ConcatOp : public PipelineOp {
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status operator()() override;
|
Status operator()() override;
|
||||||
|
|
||||||
// During tree prepare phase, operators may have specific post-operations to perform depending on
|
|
||||||
// their role.
|
|
||||||
// @notes Derived versions of this function should always call it's superclass version first
|
|
||||||
// before providing their own implementations.
|
|
||||||
Status PrepareNodePostAction() override;
|
|
||||||
|
|
||||||
// Op name getter
|
// Op name getter
|
||||||
// @return Name of the current Op
|
// @return Name of the current Op
|
||||||
std::string Name() const override { return "ConcatOp"; }
|
std::string Name() const override { return "ConcatOp"; }
|
||||||
|
|
|
@ -153,16 +153,38 @@ Status DatasetOp::Remove() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Finally, clear "this" op's parent and child pointers since we have just
|
||||||
|
// disconnected it from the tree and invalidate it's fields.
|
||||||
|
child_.clear();
|
||||||
|
parent_.clear();
|
||||||
|
operator_id_ = kInvalidOperatorId;
|
||||||
|
tree_ = nullptr;
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Getter function to get a shared pointer to our childAdds a operator to become our child.
|
// Getter function to get a shared pointer to our child
|
||||||
std::shared_ptr<DatasetOp> DatasetOp::child(int32_t child_index) const {
|
std::shared_ptr<DatasetOp> DatasetOp::child(int32_t child_index) const {
|
||||||
|
std::shared_ptr<DatasetOp> return_op = nullptr;
|
||||||
|
if (child_.empty()) {
|
||||||
|
return return_op;
|
||||||
|
}
|
||||||
MS_ASSERT(child_index < static_cast<int>(child_.size()));
|
MS_ASSERT(child_index < static_cast<int>(child_.size()));
|
||||||
// Return a shared pointer
|
// Return a shared pointer
|
||||||
return child_[child_index];
|
return child_[child_index];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Getter function to get the parent pointer
|
||||||
|
void DatasetOp::Parent(DatasetOp **parent, int32_t parent_index) const {
|
||||||
|
if (parent_.empty()) {
|
||||||
|
// common case if this is a root node
|
||||||
|
*parent = nullptr;
|
||||||
|
} else {
|
||||||
|
MS_ASSERT(parent_index < static_cast<int>(parent_.size()));
|
||||||
|
*parent = parent_[parent_index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Creates the connector within this operator
|
// Creates the connector within this operator
|
||||||
void DatasetOp::CreateConnector(int32_t num_producers, int32_t num_consumers) {
|
void DatasetOp::CreateConnector(int32_t num_producers, int32_t num_consumers) {
|
||||||
MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ". Producer: " << num_producers
|
MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ". Producer: " << num_producers
|
||||||
|
@ -264,19 +286,11 @@ Status DatasetOp::EofReceived(int32_t worker_id) {
|
||||||
|
|
||||||
// During tree prepare phase, operators may have specific pre-operations to perform depending on
|
// During tree prepare phase, operators may have specific pre-operations to perform depending on
|
||||||
// their role.
|
// their role.
|
||||||
Status DatasetOp::PrepareNodePreAction() {
|
Status DatasetOp::PrepareNodePreAction() { return Status::OK(); }
|
||||||
if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepRepeat)) set_control_flag(kDeOpRepeated);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
// During tree prepare phase, operators may have specific post-operations to perform depending on
|
// During tree prepare phase, operators may have specific post-operations to perform depending on
|
||||||
// their role.
|
// their role.
|
||||||
Status DatasetOp::PrepareNodePostAction() {
|
Status DatasetOp::PrepareNodePostAction() {
|
||||||
// If this op does not have any children and it is in a repeat path of the tree...
|
|
||||||
if (child_.empty() && BitTest(op_ctrl_flags_, kDeOpRepeated)) {
|
|
||||||
// push ourselves onto the eoe operator stack. Later, a repeat/epoch ctrl operator
|
|
||||||
// above us will consume them.
|
|
||||||
tree_->AddToEOEOpStack(shared_from_this());
|
|
||||||
}
|
|
||||||
// Creating Connector object for each op.
|
// Creating Connector object for each op.
|
||||||
// The consumer of the root node is assumed to be one thread.
|
// The consumer of the root node is assumed to be one thread.
|
||||||
// If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion.
|
// If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion.
|
||||||
|
@ -346,34 +360,13 @@ Status DatasetOp::Accept(NodePass *p, bool *modified) {
|
||||||
return p->RunOnNode(shared_from_this(), modified);
|
return p->RunOnNode(shared_from_this(), modified);
|
||||||
}
|
}
|
||||||
|
|
||||||
// A helper function with some common code that leaf nodes can use during
|
// Getter for the sampler, and it also removes the sampler from the op
|
||||||
// prepare phase for checking if they need to assign a sampler to the cache.
|
Status DatasetOp::FetchRemoveSampler(std::shared_ptr<Sampler> *sampler) {
|
||||||
Status DatasetOp::SaveSamplerForCache(bool random_access_op) {
|
*sampler = sampler_; // It's okay if it sampler_ points to nullptr
|
||||||
// If we are a descendant under a cache op and we have a sampler, then save this sampler
|
sampler_.reset(); // clear our member-copy of this pointer. We no longer have this sampler
|
||||||
// to a stack so that the cache can pick it up during it's processing above us.
|
|
||||||
if (sampler_) {
|
|
||||||
if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) {
|
|
||||||
// use move semantic to set our sampler_ to null after the move. This is okay because a sampler is
|
|
||||||
// useless to a random data op. It was only being used as a temporary holding until the cache can
|
|
||||||
// be created
|
|
||||||
tree_->AddToSamplerStack(sampler_);
|
|
||||||
MS_LOG(INFO) << "Preparing a leaf op: passing sampler up the tree for Cache handling.";
|
|
||||||
} else if (!random_access_op) {
|
|
||||||
// A sampler exists, but we are not in a caching tree and we are not a random access mappable leaf.
|
|
||||||
// This is an error because that type of leaf does not use sampling unless there's a cache to hook it into.
|
|
||||||
RETURN_STATUS_UNEXPECTED(
|
|
||||||
"Non-mappable leaf op has a sampler, but it only supports sampling if there is a cache after it in the tree");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!random_access_op) {
|
|
||||||
// Since we don't truly need the sampler for this non-mappable dataset and it's been saved for the cache
|
|
||||||
// we can remove it now from the base.
|
|
||||||
sampler_.reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) {
|
uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
op->tree_->Print(ss, op);
|
op->tree_->Print(ss, op);
|
||||||
|
|
|
@ -45,10 +45,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
||||||
public:
|
public:
|
||||||
static constexpr int32_t kInvalidOperatorId = -1;
|
static constexpr int32_t kInvalidOperatorId = -1;
|
||||||
|
|
||||||
// Flags that control operator runtime behaviours
|
// Operator control flags
|
||||||
enum OpControlFlags {
|
enum OpControlFlags {
|
||||||
kDeOpNone = 0,
|
kDeOpNone = 0,
|
||||||
kDeOpRepeated = 1, // Operator is a leaf node in a repeat path
|
kDeOpRepeated = 1, // Operator is a node in a repeat path
|
||||||
kDeOpLastRepeat = 1 << 1 // We are in the last repeat loop
|
kDeOpLastRepeat = 1 << 1 // We are in the last repeat loop
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -71,17 +71,23 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
||||||
/// \param child - shared pointer to the child to remove.
|
/// \param child - shared pointer to the child to remove.
|
||||||
Status RemoveChild(std::shared_ptr<DatasetOp> child);
|
Status RemoveChild(std::shared_ptr<DatasetOp> child);
|
||||||
|
|
||||||
/// \brief Removes this node from the tree and connects it's parent/child together.
|
/// \brief Removes this node from the tree and connects it's parent/child together
|
||||||
/// \return Status eerror code returned
|
/// \return Status eerror code returned
|
||||||
Status Remove();
|
Status Remove();
|
||||||
|
|
||||||
/// \brief Getter function to get a shared pointer to our child
|
/// \brief Getter function to get a shared pointer to our child
|
||||||
/// \param child_index - An operator can have n children. Indicates choose which child to return.
|
/// \param[in] child_index An operator can have n children. Indicates which child to return.
|
||||||
|
/// \return The shared pointer to the child. If there are no children, it returns null regardless of the given index
|
||||||
std::shared_ptr<DatasetOp> child(int32_t child_index) const;
|
std::shared_ptr<DatasetOp> child(int32_t child_index) const;
|
||||||
|
|
||||||
/// \brief Inserts a operator as the parent current op.
|
/// \brief Getter function to get the pointer to our parent
|
||||||
/// Inserted op will become the sole parent of the current op.
|
/// If there are no parents, it returns null regardless of the given index
|
||||||
/// The existing parent of the current op will be transferred to the inserted op.
|
/// \param[in] parent_index An operator can have n parents. Indicates which parent to return.
|
||||||
|
void Parent(DatasetOp **parent, int32_t parent_index) const;
|
||||||
|
|
||||||
|
// Inserts a operator as the parent current op.
|
||||||
|
// Inserted op will become the sole parent of the current op.
|
||||||
|
// The existing parent of the current op will be transferred to the inserted op.
|
||||||
Status InsertAsParent(std::shared_ptr<DatasetOp> to_add);
|
Status InsertAsParent(std::shared_ptr<DatasetOp> to_add);
|
||||||
|
|
||||||
/// \brief Creates the connector within this operator
|
/// \brief Creates the connector within this operator
|
||||||
|
@ -161,16 +167,6 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
||||||
/// \return Status - The error code return
|
/// \return Status - The error code return
|
||||||
virtual Status Reset();
|
virtual Status Reset();
|
||||||
|
|
||||||
/// \brief This calls the reset function on this subtree in pre-order
|
|
||||||
/// \return Status - The error code return
|
|
||||||
virtual Status ResetSubtree() {
|
|
||||||
RETURN_IF_NOT_OK(Reset());
|
|
||||||
for (const auto &c : child_) {
|
|
||||||
RETURN_IF_NOT_OK(c->ResetSubtree());
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief During tree prepare phase, operators may have specific pre-operations to perform depending on
|
/// \brief During tree prepare phase, operators may have specific pre-operations to perform depending on
|
||||||
/// their role.
|
/// their role.
|
||||||
/// \notes Derived versions of this function should always call it's superclass version first
|
/// \notes Derived versions of this function should always call it's superclass version first
|
||||||
|
@ -296,7 +292,12 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
||||||
/// \return Shared pointer to the sampler (may return nullptr)
|
/// \return Shared pointer to the sampler (may return nullptr)
|
||||||
std::shared_ptr<Sampler> sampler() { return sampler_; }
|
std::shared_ptr<Sampler> sampler() { return sampler_; }
|
||||||
|
|
||||||
/// Computes a CRC value for the operator
|
/// \brief Getter for the sampler, and it also removes the sampler from the op
|
||||||
|
/// \param[out] sampler A pointer to the output sampler that was removed
|
||||||
|
/// \return Status error code
|
||||||
|
Status FetchRemoveSampler(std::shared_ptr<Sampler> *sampler);
|
||||||
|
|
||||||
|
// Computes a CRC value for the operator
|
||||||
static uint32_t GenerateCRC(const std::shared_ptr<DatasetOp> &op);
|
static uint32_t GenerateCRC(const std::shared_ptr<DatasetOp> &op);
|
||||||
|
|
||||||
/// \brief A helper templated function for casting "this" pointer to shared_ptr<derived>
|
/// \brief A helper templated function for casting "this" pointer to shared_ptr<derived>
|
||||||
|
@ -307,17 +308,24 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
||||||
return std::static_pointer_cast<Derived>(shared_from_this());
|
return std::static_pointer_cast<Derived>(shared_from_this());
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
/// \brief Setter for the sampler. Allows you to overwrite a previous sampler with a new one.
|
||||||
/// Adds a parent operator to this operator
|
void SetSampler(std::shared_ptr<Sampler> sampler) { sampler_ = sampler; }
|
||||||
/// \notes External callers do not have access to this function.
|
|
||||||
/// \param parent - The parent node to add
|
|
||||||
void AddParent(DatasetOp *parent);
|
|
||||||
|
|
||||||
/// Removes a parent operator from this operator
|
/// \brief Checks if this is a leaf node (0 children)
|
||||||
/// \notes External callers do not have access to this function.
|
/// \return boolean returns true if it's a leaf
|
||||||
/// \param parent - The parent node to remove
|
bool IsLeaf() { return (child_.empty()); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/// \brief Removes a parent operator from this operator
|
||||||
|
/// \notes External callers do not have access to this function
|
||||||
|
/// \param[in] parent The parent node to remove
|
||||||
void RemoveParent(const DatasetOp *parent);
|
void RemoveParent(const DatasetOp *parent);
|
||||||
|
|
||||||
|
/// \brief Adds a parent operator to this operator
|
||||||
|
/// \notes External callers do not have access to this function
|
||||||
|
/// \param[in] parent The parent node to add
|
||||||
|
void AddParent(DatasetOp *parent);
|
||||||
|
|
||||||
/// Compute the current op's column map using its child's column map.
|
/// Compute the current op's column map using its child's column map.
|
||||||
/// Get called during the tree post-prepare phase in PrepareNodePostAction.
|
/// Get called during the tree post-prepare phase in PrepareNodePostAction.
|
||||||
/// This base implementation just inherits the map from child 0, and can only be used if the number of children is 1.
|
/// This base implementation just inherits the map from child 0, and can only be used if the number of children is 1.
|
||||||
|
@ -325,12 +333,6 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
||||||
/// \return - Status
|
/// \return - Status
|
||||||
virtual Status ComputeColMap();
|
virtual Status ComputeColMap();
|
||||||
|
|
||||||
/// A helper function with some common code that leaf nodes can use during
|
|
||||||
/// pre/pare phase for checking if they need to assign a sampler to the cache.
|
|
||||||
/// \param random_access_op - indicate if this is a mappable random access leaf or not
|
|
||||||
/// \return - Status
|
|
||||||
Status SaveSamplerForCache(bool random_access_op);
|
|
||||||
|
|
||||||
std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes
|
std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes
|
||||||
std::vector<DatasetOp *> parent_; // Parent nodes. No ownership
|
std::vector<DatasetOp *> parent_; // Parent nodes. No ownership
|
||||||
std::shared_ptr<Sampler> sampler_; // Some leaf ops might have a sampler
|
std::shared_ptr<Sampler> sampler_; // Some leaf ops might have a sampler
|
||||||
|
|
|
@ -77,26 +77,6 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Base-class override for executing specific RepeatOp configurations. This code will be called
|
|
||||||
// during the execution tree prepare phase when it is visiting this operator.
|
|
||||||
Status RepeatOp::PrepareNodePostAction() {
|
|
||||||
// Run any common code from super class first before adding our own specific logic
|
|
||||||
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
|
|
||||||
std::shared_ptr<DatasetOp> leaf_op = tree_->PopFromEOEOpStack();
|
|
||||||
while (leaf_op != nullptr) {
|
|
||||||
// Track the leaf operators that are under this repeat op.
|
|
||||||
eoe_ops_.push_back(leaf_op);
|
|
||||||
leaf_op = tree_->PopFromEOEOpStack();
|
|
||||||
}
|
|
||||||
// Push ourselves to the stack in case one of our ascendants is repeat too.
|
|
||||||
tree_->AddToEOEOpStack(shared_from_this());
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Base-class override for setting specific RepeatOp configurations. This code will be called
|
|
||||||
// during the execution tree prepare phase BEFORE traversing down to child operators.
|
|
||||||
uint32_t RepeatOp::PrepareFlags() const { return ExecutionTree::kDePrepRepeat; }
|
|
||||||
|
|
||||||
// This function returns the buffer that is at the top of our output connector. The caller is
|
// This function returns the buffer that is at the top of our output connector. The caller is
|
||||||
// typically our parent node, when the parent is asking us to provide the next buffer of data.
|
// typically our parent node, when the parent is asking us to provide the next buffer of data.
|
||||||
// Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get
|
// Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get
|
||||||
|
@ -130,7 +110,8 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo
|
||||||
// Base-class override for handling cases when an eoe is received.
|
// Base-class override for handling cases when an eoe is received.
|
||||||
Status RepeatOp::EoeReceived(int32_t worker_id) {
|
Status RepeatOp::EoeReceived(int32_t worker_id) {
|
||||||
repeat_count_++;
|
repeat_count_++;
|
||||||
MS_LOG(DEBUG) << "Repeat operator end of epoch message received. Repeat count is now: " << repeat_count_ << ".";
|
MS_LOG(DEBUG) << "Repeat operator (" << operator_id_
|
||||||
|
<< ") end of epoch message received. Repeat count is now: " << repeat_count_ << ".";
|
||||||
bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated);
|
bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated);
|
||||||
bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat);
|
bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat);
|
||||||
// If we've reached the requested repeat count, then flag the eoe nodes
|
// If we've reached the requested repeat count, then flag the eoe nodes
|
||||||
|
@ -149,8 +130,12 @@ Status RepeatOp::EoeReceived(int32_t worker_id) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// base-class ResetSubtree
|
// Invoke a reset against the eoe nodes only.
|
||||||
return (DatasetOp::ResetSubtree());
|
for (auto &eoe_op : eoe_ops_) {
|
||||||
|
RETURN_IF_NOT_OK(eoe_op->Reset());
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Class functor operator () override.
|
// Class functor operator () override.
|
||||||
|
@ -178,6 +163,18 @@ int32_t RepeatOp::num_consumers() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Drive reset actions if needed
|
||||||
|
Status RepeatOp::Reset() {
|
||||||
|
// If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op.
|
||||||
|
// In that case, we now have to bounce the reset down to our own eoe ops.
|
||||||
|
MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ << ") reset.";
|
||||||
|
for (auto &eoe_op : eoe_ops_) {
|
||||||
|
RETURN_IF_NOT_OK(eoe_op->Reset());
|
||||||
|
}
|
||||||
|
state_ = OpState::kDeOpRunning;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
int32_t RepeatOp::num_producers() const {
|
int32_t RepeatOp::num_producers() const {
|
||||||
if (child_.empty() || child_[0] == nullptr) {
|
if (child_.empty() || child_[0] == nullptr) {
|
||||||
MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0.";
|
MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0.";
|
||||||
|
@ -187,6 +184,12 @@ int32_t RepeatOp::num_producers() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pre-Visitor accept method for NodePass
|
||||||
|
Status RepeatOp::PreAccept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call the pre-visitation
|
||||||
|
return p->PreRunOnNode(shared_from_base<RepeatOp>(), modified);
|
||||||
|
}
|
||||||
|
|
||||||
// Visitor accept method for NodePass
|
// Visitor accept method for NodePass
|
||||||
Status RepeatOp::Accept(NodePass *p, bool *modified) {
|
Status RepeatOp::Accept(NodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "dataset/engine/datasetops/pipeline_op.h"
|
#include "dataset/engine/datasetops/pipeline_op.h"
|
||||||
|
|
||||||
|
@ -82,14 +83,6 @@ class RepeatOp : public PipelineOp {
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status operator()() override;
|
Status operator()() override;
|
||||||
|
|
||||||
// Base-class override for setting specific RepeatOp configurations. This code will be called
|
|
||||||
// during the execution tree prepare phase BEFORE traversing down to child operators.
|
|
||||||
uint32_t PrepareFlags() const override;
|
|
||||||
|
|
||||||
// Base-class override for executing specific RepeatOp configurations. This code will be called
|
|
||||||
// during the execution tree post-prepare phase when it is visiting this operator.
|
|
||||||
Status PrepareNodePostAction() override;
|
|
||||||
|
|
||||||
// This function returns the buffer that is at the top of our output connector. The caller is
|
// This function returns the buffer that is at the top of our output connector. The caller is
|
||||||
// typically our parent node, when the parent is asking us to provide the next buffer of data.
|
// typically our parent node, when the parent is asking us to provide the next buffer of data.
|
||||||
// Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get
|
// Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get
|
||||||
|
@ -110,6 +103,10 @@ class RepeatOp : public PipelineOp {
|
||||||
// @param worker_id - The worker id
|
// @param worker_id - The worker id
|
||||||
Status EofReceived(int32_t worker_id) override;
|
Status EofReceived(int32_t worker_id) override;
|
||||||
|
|
||||||
|
/// \brief reset Op
|
||||||
|
/// \@return Status - The error code return
|
||||||
|
Status Reset() override;
|
||||||
|
|
||||||
// Base-class override. Return the number of workers in the first parent.
|
// Base-class override. Return the number of workers in the first parent.
|
||||||
// @param workerId - The worker id
|
// @param workerId - The worker id
|
||||||
int32_t num_consumers() const override;
|
int32_t num_consumers() const override;
|
||||||
|
@ -118,16 +115,26 @@ class RepeatOp : public PipelineOp {
|
||||||
// @param workerId - The worker id
|
// @param workerId - The worker id
|
||||||
int32_t num_producers() const override;
|
int32_t num_producers() const override;
|
||||||
|
|
||||||
// Base-class override for NodePass visitor acceptor.
|
/// \brief Base-class override for NodePass pre-visit acceptor
|
||||||
// @param p - Pointer to the NodePass to be accepted.
|
/// \param[in] p The node to visit
|
||||||
// @param modified - Whether this node visit modified the pipeline.
|
/// \param[out] modified Indicator if the node was modified
|
||||||
// @return - Status of the node visit.
|
/// \return Status of the node visit
|
||||||
|
Status PreAccept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Base-class override for NodePass visitor acceptor
|
||||||
|
/// \param[in] p The node to visit
|
||||||
|
/// \param[out] modified Indicator if the node was modified
|
||||||
|
/// \return Status of the node visit
|
||||||
Status Accept(NodePass *p, bool *modified) override;
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
// Op name getter
|
// Op name getter
|
||||||
// @return Name of the current Op
|
// @return Name of the current Op
|
||||||
std::string Name() const override { return "RepeatOp"; }
|
std::string Name() const override { return "RepeatOp"; }
|
||||||
|
|
||||||
|
/// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
|
||||||
|
/// \param[in] eoe_op The input leaf/eoe operator to add to the list
|
||||||
|
void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32_t max_repeats_; // The number of repeats that the user requested
|
int32_t max_repeats_; // The number of repeats that the user requested
|
||||||
int32_t repeat_count_; // A counter for the current number of executed repeats
|
int32_t repeat_count_; // A counter for the current number of executed repeats
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||||
#include "dataset/engine/data_schema.h"
|
#include "dataset/engine/data_schema.h"
|
||||||
#include "dataset/engine/execution_tree.h"
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
#include "dataset/kernels/image/image_utils.h"
|
#include "dataset/kernels/image/image_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -408,6 +409,12 @@ Status CelebAOp::Reset() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status CelebAOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(shared_from_base<CelebAOp>(), modified);
|
||||||
|
}
|
||||||
|
|
||||||
Status CelebAOp::ComputeColMap() {
|
Status CelebAOp::ComputeColMap() {
|
||||||
// Set the column name map (base class field)
|
// Set the column name map (base class field)
|
||||||
if (column_name_id_map_.empty()) {
|
if (column_name_id_map_.empty()) {
|
||||||
|
|
|
@ -169,6 +169,12 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer);
|
Status AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer);
|
||||||
|
|
||||||
|
/// \brief Base-class override for NodePass visitor acceptor
|
||||||
|
/// \param[in] p Pointer to the NodePass to be accepted
|
||||||
|
/// \param[out] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status of the node visit
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
// Op name getter
|
// Op name getter
|
||||||
// @return Name of the current Op
|
// @return Name of the current Op
|
||||||
std::string Name() const { return "CelebAOp"; }
|
std::string Name() const { return "CelebAOp"; }
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||||
#include "dataset/engine/db_connector.h"
|
#include "dataset/engine/db_connector.h"
|
||||||
#include "dataset/engine/execution_tree.h"
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
@ -450,6 +451,12 @@ Status CifarOp::CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status CifarOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(shared_from_base<CifarOp>(), modified);
|
||||||
|
}
|
||||||
|
|
||||||
Status CifarOp::ComputeColMap() {
|
Status CifarOp::ComputeColMap() {
|
||||||
// set the column name map (base class field)
|
// set the column name map (base class field)
|
||||||
if (column_name_id_map_.empty()) {
|
if (column_name_id_map_.empty()) {
|
||||||
|
|
|
@ -155,6 +155,12 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
|
||||||
// @return
|
// @return
|
||||||
static Status CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count);
|
static Status CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count);
|
||||||
|
|
||||||
|
/// \brief Base-class override for NodePass visitor acceptor
|
||||||
|
/// \param[in] p Pointer to the NodePass to be accepted
|
||||||
|
/// \param[out] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status of the node visit
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
// Op name getter
|
// Op name getter
|
||||||
// @return Name of the current Op
|
// @return Name of the current Op
|
||||||
std::string Name() const override { return "CifarOp"; }
|
std::string Name() const override { return "CifarOp"; }
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||||
#include "dataset/engine/db_connector.h"
|
#include "dataset/engine/db_connector.h"
|
||||||
#include "dataset/engine/execution_tree.h"
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
@ -624,6 +625,12 @@ Status CocoOp::GetClassIndexing(const std::string &dir, const std::string &file,
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status CocoOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(shared_from_base<CocoOp>(), modified);
|
||||||
|
}
|
||||||
|
|
||||||
Status CocoOp::ComputeColMap() {
|
Status CocoOp::ComputeColMap() {
|
||||||
// Set the column name map (base class field)
|
// Set the column name map (base class field)
|
||||||
if (column_name_id_map_.empty()) {
|
if (column_name_id_map_.empty()) {
|
||||||
|
|
|
@ -200,6 +200,12 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
|
||||||
static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode,
|
static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode,
|
||||||
std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing);
|
std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing);
|
||||||
|
|
||||||
|
/// \brief Base-class override for NodePass visitor acceptor
|
||||||
|
/// \param[in] p Pointer to the NodePass to be accepted
|
||||||
|
/// \param[out] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status of the node visit
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Initialize Sampler, calls sampler->Init() within
|
// Initialize Sampler, calls sampler->Init() within
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||||
#include "dataset/engine/db_connector.h"
|
#include "dataset/engine/db_connector.h"
|
||||||
#include "dataset/engine/execution_tree.h"
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
@ -416,6 +417,12 @@ Status ManifestOp::GetClassIndexing(const std::string &file, const py::dict &dic
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status ManifestOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(shared_from_base<ManifestOp>(), modified);
|
||||||
|
}
|
||||||
|
|
||||||
Status ManifestOp::ComputeColMap() {
|
Status ManifestOp::ComputeColMap() {
|
||||||
// Set the column name map (base class field)
|
// Set the column name map (base class field)
|
||||||
if (column_name_id_map_.empty()) {
|
if (column_name_id_map_.empty()) {
|
||||||
|
|
|
@ -172,6 +172,12 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
|
||||||
static Status GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage,
|
static Status GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage,
|
||||||
std::map<std::string, int32_t> *output_class_indexing);
|
std::map<std::string, int32_t> *output_class_indexing);
|
||||||
|
|
||||||
|
/// \brief Base-class override for NodePass visitor acceptor
|
||||||
|
/// \param[in] p Pointer to the NodePass to be accepted
|
||||||
|
/// \param[out] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status of the node visit
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
// Op name getter
|
// Op name getter
|
||||||
// @return Name of the current Op
|
// @return Name of the current Op
|
||||||
std::string Name() const override { return "ManifestOp"; }
|
std::string Name() const override { return "ManifestOp"; }
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||||
#include "dataset/engine/db_connector.h"
|
#include "dataset/engine/db_connector.h"
|
||||||
#include "dataset/engine/execution_tree.h"
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
@ -428,6 +429,12 @@ Status MnistOp::CountTotalRows(const std::string &dir, int64_t *count) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status MnistOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(shared_from_base<MnistOp>(), modified);
|
||||||
|
}
|
||||||
|
|
||||||
Status MnistOp::ComputeColMap() {
|
Status MnistOp::ComputeColMap() {
|
||||||
// set the column name map (base class field)
|
// set the column name map (base class field)
|
||||||
if (column_name_id_map_.empty()) {
|
if (column_name_id_map_.empty()) {
|
||||||
|
|
|
@ -152,6 +152,12 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
|
||||||
// @return
|
// @return
|
||||||
static Status CountTotalRows(const std::string &dir, int64_t *count);
|
static Status CountTotalRows(const std::string &dir, int64_t *count);
|
||||||
|
|
||||||
|
/// \brief Base-class override for NodePass visitor acceptor
|
||||||
|
/// \param[in] p Pointer to the NodePass to be accepted
|
||||||
|
/// \param[out] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status of the node visit
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
// Op name getter
|
// Op name getter
|
||||||
// @return Name of the current Op
|
// @return Name of the current Op
|
||||||
std::string Name() const override { return "MnistOp"; }
|
std::string Name() const override { return "MnistOp"; }
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "dataset/util/random.h"
|
#include "dataset/util/random.h"
|
||||||
#include "dataset/util/wait_post.h"
|
#include "dataset/util/wait_post.h"
|
||||||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
@ -406,6 +407,12 @@ Status RandomDataOp::Reset() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status RandomDataOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(shared_from_base<RandomDataOp>(), modified);
|
||||||
|
}
|
||||||
|
|
||||||
Status RandomDataOp::ComputeColMap() {
|
Status RandomDataOp::ComputeColMap() {
|
||||||
// Extract the column name mapping from the schema and save it in the class.
|
// Extract the column name mapping from the schema and save it in the class.
|
||||||
if (column_name_id_map_.empty()) {
|
if (column_name_id_map_.empty()) {
|
||||||
|
@ -415,15 +422,5 @@ Status RandomDataOp::ComputeColMap() {
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// During tree prepare phase, operators may have specific post-operations to perform depending on
|
|
||||||
// their role.
|
|
||||||
Status RandomDataOp::PrepareNodePostAction() {
|
|
||||||
// Run common code from super class before adding RandomDataOp specific handling
|
|
||||||
RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction());
|
|
||||||
// Specific handling for this op, we need to do cache op work to assign the sampler to the cache.
|
|
||||||
RETURN_IF_NOT_OK(DatasetOp::SaveSamplerForCache(false));
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -203,12 +203,6 @@ class RandomDataOp : public ParallelOp {
|
||||||
// @return Name of the current Op
|
// @return Name of the current Op
|
||||||
std::string Name() const override { return "RandomDataOp"; }
|
std::string Name() const override { return "RandomDataOp"; }
|
||||||
|
|
||||||
// During tree prepare phase, operators may have specific post-operations to perform depending on
|
|
||||||
// their role.
|
|
||||||
// @notes Derived versions of this function should always call it's superclass version first
|
|
||||||
// before providing their own implementations.
|
|
||||||
Status PrepareNodePostAction() override;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/**
|
/**
|
||||||
* The entry point code for when workers are launched
|
* The entry point code for when workers are launched
|
||||||
|
@ -266,6 +260,12 @@ class RandomDataOp : public ParallelOp {
|
||||||
return ++buffer_id_;
|
return ++buffer_id_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Base-class override for NodePass visitor acceptor.
|
||||||
|
// @param p - Pointer to the NodePass to be accepted.
|
||||||
|
// @param modified - Whether this node visit modified the pipeline.
|
||||||
|
// @return - Status of the node visit.
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
// Private function for computing the assignment of the column name map.
|
// Private function for computing the assignment of the column name map.
|
||||||
// @return - Status
|
// @return - Status
|
||||||
Status ComputeColMap() override;
|
Status ComputeColMap() override;
|
||||||
|
|
|
@ -1019,31 +1019,28 @@ Status TFReaderOp::ComputeColMap() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing
|
||||||
|
// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so
|
||||||
|
// that this tf reader will produce the full set of data into the cache.
|
||||||
|
void TFReaderOp::MakeSimpleProducer() {
|
||||||
|
device_id_ = 0;
|
||||||
|
num_devices_ = 1;
|
||||||
|
total_rows_ = 0;
|
||||||
|
shuffle_files_ = false;
|
||||||
|
equal_rows_per_shard_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
// During tree prepare phase, operators may have specific post-operations to perform depending on
|
// During tree prepare phase, operators may have specific post-operations to perform depending on
|
||||||
// their role.
|
// their role.
|
||||||
Status TFReaderOp::PrepareNodePostAction() {
|
Status TFReaderOp::PrepareNodePostAction() {
|
||||||
// Run common code from super class before adding TFReaderOp specific handling
|
// Run common code from super class before adding TFReaderOp specific handling
|
||||||
RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction());
|
RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction());
|
||||||
|
|
||||||
// Specific handling for this op, we need to do cache op work so assign the sampler to the cache
|
|
||||||
// TF is a special case because it can support file-based sharding/shuffling, or, if there
|
|
||||||
// is a cache, then it can also do row-based sampler using the sampler on the cache.
|
|
||||||
// Thus, pass true for random access op flag when saving the sampler. This is a special case,
|
|
||||||
// since usually a non-mappable dataset would pass false here.
|
|
||||||
RETURN_IF_NOT_OK(DatasetOp::SaveSamplerForCache(true));
|
|
||||||
|
|
||||||
// Now that the sampler has been saved for the cache, we need to adjust the TFReaderOp to turn it into
|
// Now that the sampler has been saved for the cache, we need to adjust the TFReaderOp to turn it into
|
||||||
// a simpler producer of all data (no shuffling or sharding or anything)
|
// a simpler producer of all data (no shuffling or sharding or anything)
|
||||||
if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) {
|
if (!BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) {
|
||||||
device_id_ = 0;
|
|
||||||
num_devices_ = 1;
|
|
||||||
total_rows_ = 0;
|
|
||||||
shuffle_files_ = false;
|
|
||||||
equal_rows_per_shard_ = false;
|
|
||||||
sampler_.reset(); // Normally SaveSampler code did this for us, but we passed in true above (See comment)
|
|
||||||
} else {
|
|
||||||
// This sanity check had been delayed until now in the prepare loop.
|
// This sanity check had been delayed until now in the prepare loop.
|
||||||
// If we are not in a cache path, then we can validate the the file-based sharding config.
|
// If we are not in a cache path, then we can validate the file-based sharding config.
|
||||||
// If we are in a cache path, there is no file-based sharding so the check is not correct in that
|
// If we are in a cache path, there is no file-based sharding so the check is not correct in that
|
||||||
// situation.
|
// situation.
|
||||||
if (!equal_rows_per_shard_ && dataset_files_list_.size() < static_cast<uint32_t>(num_devices_)) {
|
if (!equal_rows_per_shard_ && dataset_files_list_.size() < static_cast<uint32_t>(num_devices_)) {
|
||||||
|
|
|
@ -246,6 +246,11 @@ class TFReaderOp : public ParallelOp {
|
||||||
// @return Vector of the input file names
|
// @return Vector of the input file names
|
||||||
std::vector<std::string> FileNames() { return dataset_files_list_; }
|
std::vector<std::string> FileNames() { return dataset_files_list_; }
|
||||||
|
|
||||||
|
/// \Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing
|
||||||
|
/// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so
|
||||||
|
/// that this tf reader will produce the full set of data into the cache.
|
||||||
|
void MakeSimpleProducer();
|
||||||
|
|
||||||
// During tree prepare phase, operators may have specific post-operations to perform depending on
|
// During tree prepare phase, operators may have specific post-operations to perform depending on
|
||||||
// their role.
|
// their role.
|
||||||
// @notes Derived versions of this function should always call it's superclass version first
|
// @notes Derived versions of this function should always call it's superclass version first
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||||
#include "dataset/engine/db_connector.h"
|
#include "dataset/engine/db_connector.h"
|
||||||
#include "dataset/engine/execution_tree.h"
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
using tinyxml2::XMLDocument;
|
using tinyxml2::XMLDocument;
|
||||||
using tinyxml2::XMLElement;
|
using tinyxml2::XMLElement;
|
||||||
|
@ -449,6 +450,11 @@ Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_t
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
// Visitor accept method for NodePass
|
||||||
|
Status VOCOp::Accept(NodePass *p, bool *modified) {
|
||||||
|
// Downcast shared pointer then call visitor
|
||||||
|
return p->RunOnNode(shared_from_base<VOCOp>(), modified);
|
||||||
|
}
|
||||||
|
|
||||||
Status VOCOp::ComputeColMap() {
|
Status VOCOp::ComputeColMap() {
|
||||||
// Set the column name map (base class field)
|
// Set the column name map (base class field)
|
||||||
|
|
|
@ -205,6 +205,12 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
||||||
static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode,
|
static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode,
|
||||||
const py::dict &dict, std::map<std::string, int32_t> *output_class_indexing);
|
const py::dict &dict, std::map<std::string, int32_t> *output_class_indexing);
|
||||||
|
|
||||||
|
/// \brief Base-class override for NodePass visitor acceptor
|
||||||
|
/// \param[in] p Pointer to the NodePass to be accepted
|
||||||
|
/// \param[out] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status of the node visit
|
||||||
|
Status Accept(NodePass *p, bool *modified) override;
|
||||||
|
|
||||||
// Op name getter
|
// Op name getter
|
||||||
// @return Name of the current Op
|
// @return Name of the current Op
|
||||||
std::string Name() const override { return "VOCOp"; }
|
std::string Name() const override { return "VOCOp"; }
|
||||||
|
|
|
@ -127,12 +127,6 @@ Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<D
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TakeOp::PrepareNodePostAction() {
|
|
||||||
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
|
|
||||||
tree_->AddToEOEOpStack(shared_from_this());
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Visitor accept method for NodePass
|
// Visitor accept method for NodePass
|
||||||
Status TakeOp::Accept(NodePass *p, bool *modified) {
|
Status TakeOp::Accept(NodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
|
|
|
@ -78,12 +78,6 @@ class TakeOp : public PipelineOp {
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status operator()() override;
|
Status operator()() override;
|
||||||
|
|
||||||
// During tree prepare phase, operators may have specific post-operations to perform depending on
|
|
||||||
// their role.
|
|
||||||
// @notes Derived versions of this function should always call it's superclass version first
|
|
||||||
// before providing their own implementations.
|
|
||||||
Status PrepareNodePostAction() override;
|
|
||||||
|
|
||||||
// Base-class override for NodePass visitor acceptor.
|
// Base-class override for NodePass visitor acceptor.
|
||||||
// @param p - Pointer to the NodePass to be accepted.
|
// @param p - Pointer to the NodePass to be accepted.
|
||||||
// @param modified - Whether this node visit modified the pipeline.
|
// @param modified - Whether this node visit modified the pipeline.
|
||||||
|
|
|
@ -21,6 +21,8 @@
|
||||||
#include "dataset/util/task_manager.h"
|
#include "dataset/util/task_manager.h"
|
||||||
#include "dataset/engine/opt/pass.h"
|
#include "dataset/engine/opt/pass.h"
|
||||||
#include "dataset/engine/opt/pre/removal_pass.h"
|
#include "dataset/engine/opt/pre/removal_pass.h"
|
||||||
|
#include "dataset/engine/opt/pre/cache_transform_pass.h"
|
||||||
|
#include "dataset/engine/opt/post/repeat_pass.h"
|
||||||
#include "dataset/engine/perf/profiling.h"
|
#include "dataset/engine/perf/profiling.h"
|
||||||
#include "dataset/engine/perf/monitor.h"
|
#include "dataset/engine/perf/monitor.h"
|
||||||
|
|
||||||
|
@ -215,18 +217,33 @@ Status ExecutionTree::PrepareTreePreAction() {
|
||||||
bool modified = false;
|
bool modified = false;
|
||||||
std::vector<std::unique_ptr<Pass>> pre_actions;
|
std::vector<std::unique_ptr<Pass>> pre_actions;
|
||||||
// Construct pre actions
|
// Construct pre actions
|
||||||
MS_LOG(INFO) << "Running pre pass";
|
MS_LOG(INFO) << "Running pre pass loops.";
|
||||||
pre_actions.push_back(std::make_unique<RemovalPass>(RemovalPass()));
|
pre_actions.push_back(std::make_unique<RemovalPass>());
|
||||||
|
pre_actions.push_back(std::make_unique<CacheTransformPass>());
|
||||||
// Apply pre action passes
|
// Apply pre action passes
|
||||||
for (auto &pass : pre_actions) {
|
for (auto &pass : pre_actions) {
|
||||||
RETURN_IF_NOT_OK(pass->Run(this, &modified));
|
RETURN_IF_NOT_OK(pass->Run(this, &modified));
|
||||||
}
|
}
|
||||||
|
MS_LOG(INFO) << "Pre passes complete.";
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ExecutionTree::PrepareTreePostAction() {
|
Status ExecutionTree::PrepareTreePostAction() {
|
||||||
// The tree is ready to be prepared.
|
// The tree is ready to be prepared.
|
||||||
tree_state_ = kDeTStatePrepare;
|
tree_state_ = kDeTStatePrepare;
|
||||||
|
|
||||||
|
bool modified = false;
|
||||||
|
std::vector<std::unique_ptr<Pass>> post_actions;
|
||||||
|
// Construct pre actions
|
||||||
|
MS_LOG(INFO) << "Running post pass loops.";
|
||||||
|
post_actions.push_back(std::make_unique<RepeatPass>());
|
||||||
|
|
||||||
|
// Apply post action passes
|
||||||
|
for (auto &pass : post_actions) {
|
||||||
|
RETURN_IF_NOT_OK(pass->Run(this, &modified));
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Post passes complete.";
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -280,31 +297,5 @@ Status ExecutionTree::PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op)
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adds an operator to the eoe operator stack during prepare phase.
|
|
||||||
void ExecutionTree::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { eoe_stack_.push(dataset_op); }
|
|
||||||
|
|
||||||
// Pops an operator from the eoe operator stack during prepare phase.
|
|
||||||
std::shared_ptr<DatasetOp> ExecutionTree::PopFromEOEOpStack() {
|
|
||||||
std::shared_ptr<DatasetOp> top_op = nullptr;
|
|
||||||
if (!eoe_stack_.empty()) {
|
|
||||||
top_op = eoe_stack_.top();
|
|
||||||
eoe_stack_.pop();
|
|
||||||
}
|
|
||||||
return top_op;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Adds a sampler to the sampler stack during prepare phase.
|
|
||||||
void ExecutionTree::AddToSamplerStack(std::shared_ptr<Sampler> sampler) { sampler_stack_.push(sampler); }
|
|
||||||
|
|
||||||
// Pops an operator from the sampler stack during prepare phase.
|
|
||||||
std::shared_ptr<Sampler> ExecutionTree::PopFromSamplerStack() {
|
|
||||||
std::shared_ptr<Sampler> top_sampler = nullptr;
|
|
||||||
if (!sampler_stack_.empty()) {
|
|
||||||
top_sampler = sampler_stack_.top();
|
|
||||||
sampler_stack_.pop();
|
|
||||||
}
|
|
||||||
return top_sampler;
|
|
||||||
}
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -200,24 +200,6 @@ class ExecutionTree {
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op);
|
Status PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op);
|
||||||
|
|
||||||
/// Adds an operator to the eoe operator stack during prepare phase.
|
|
||||||
/// \param op - The dataset op to work add to eoe stack
|
|
||||||
/// \return Status - The error code return
|
|
||||||
void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op);
|
|
||||||
|
|
||||||
/// Pops an operator from the eoe operator stack during prepare phase.
|
|
||||||
/// \return shared_ptr to the popped operator
|
|
||||||
std::shared_ptr<DatasetOp> PopFromEOEOpStack();
|
|
||||||
|
|
||||||
/// Adds a sampler to the sampler stack during prepare phase.
|
|
||||||
/// \param samplerop - The dataset op to work add to eoe stack
|
|
||||||
/// \return Status - The error code return
|
|
||||||
void AddToSamplerStack(std::shared_ptr<Sampler> sampler);
|
|
||||||
|
|
||||||
/// Pops an operator from the sampler stack during prepare phase.
|
|
||||||
/// \return shared_ptr to the popped operator
|
|
||||||
std::shared_ptr<Sampler> PopFromSamplerStack();
|
|
||||||
|
|
||||||
// Return the pointer to the TaskGroup
|
// Return the pointer to the TaskGroup
|
||||||
// @return raw pointer to the TaskGroup
|
// @return raw pointer to the TaskGroup
|
||||||
TaskGroup *AllTasks() const { return tg_.get(); }
|
TaskGroup *AllTasks() const { return tg_.get(); }
|
||||||
|
@ -248,8 +230,6 @@ class ExecutionTree {
|
||||||
TreeState tree_state_; // Tracking the current tree state
|
TreeState tree_state_; // Tracking the current tree state
|
||||||
std::unique_ptr<Monitor> perf_monitor_; // Performance Monitor
|
std::unique_ptr<Monitor> perf_monitor_; // Performance Monitor
|
||||||
std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager
|
std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager
|
||||||
std::stack<std::shared_ptr<DatasetOp>> eoe_stack_; // A stack used during prepare phase
|
|
||||||
std::stack<std::shared_ptr<Sampler>> sampler_stack_; // A stack used during prepare phase
|
|
||||||
};
|
};
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -2,6 +2,9 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
|
||||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||||
add_library(engine-opt OBJECT
|
add_library(engine-opt OBJECT
|
||||||
pass.cc
|
pass.cc
|
||||||
|
post/repeat_pass.cc
|
||||||
|
pre/cache_pass.cc
|
||||||
|
pre/cache_transform_pass.cc
|
||||||
pre/removal_nodes.cc
|
pre/removal_nodes.cc
|
||||||
pre/removal_pass.cc
|
pre/removal_pass.cc
|
||||||
util/printer_pass.cc
|
util/printer_pass.cc
|
||||||
|
|
|
@ -16,6 +16,9 @@
|
||||||
|
|
||||||
#include "dataset/engine/opt/pass.h"
|
#include "dataset/engine/opt/pass.h"
|
||||||
#include "dataset/engine/datasetops/batch_op.h"
|
#include "dataset/engine/datasetops/batch_op.h"
|
||||||
|
#include "dataset/engine/datasetops/cache_op.h"
|
||||||
|
#include "dataset/engine/datasetops/cache_merge_op.h"
|
||||||
|
#include "dataset/engine/datasetops/cache_lookup_op.h"
|
||||||
#include "dataset/engine/datasetops/dataset_op.h"
|
#include "dataset/engine/datasetops/dataset_op.h"
|
||||||
#include "dataset/engine/datasetops/device_queue_op.h"
|
#include "dataset/engine/datasetops/device_queue_op.h"
|
||||||
#include "dataset/engine/datasetops/map_op.h"
|
#include "dataset/engine/datasetops/map_op.h"
|
||||||
|
@ -24,8 +27,15 @@
|
||||||
#include "dataset/engine/datasetops/repeat_op.h"
|
#include "dataset/engine/datasetops/repeat_op.h"
|
||||||
#include "dataset/engine/datasetops/skip_op.h"
|
#include "dataset/engine/datasetops/skip_op.h"
|
||||||
#include "dataset/engine/datasetops/shuffle_op.h"
|
#include "dataset/engine/datasetops/shuffle_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/celeba_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/cifar_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/coco_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/manifest_op.h"
|
||||||
#include "dataset/engine/datasetops/source/mindrecord_op.h"
|
#include "dataset/engine/datasetops/source/mindrecord_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/mnist_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/random_data_op.h"
|
||||||
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/voc_op.h"
|
||||||
#ifdef ENABLE_PYTHON
|
#ifdef ENABLE_PYTHON
|
||||||
#include "dataset/engine/datasetops/filter_op.h"
|
#include "dataset/engine/datasetops/filter_op.h"
|
||||||
#include "dataset/engine/datasetops/source/generator_op.h"
|
#include "dataset/engine/datasetops/source/generator_op.h"
|
||||||
|
@ -145,6 +155,11 @@ Status NodePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
Status NodePass::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) {
|
Status NodePass::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) {
|
||||||
// Fallback to base class visitor by default
|
// Fallback to base class visitor by default
|
||||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
@ -164,5 +179,70 @@ Status NodePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified)
|
||||||
// Fallback to base class visitor by default
|
// Fallback to base class visitor by default
|
||||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status NodePass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
|
||||||
|
// Fallback to base class visitor by default
|
||||||
|
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -47,6 +47,10 @@ class FilterOp;
|
||||||
class GeneratorOp;
|
class GeneratorOp;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
class RandomDataOp;
|
||||||
|
|
||||||
|
class RepeatOp;
|
||||||
|
|
||||||
class TakeOp;
|
class TakeOp;
|
||||||
|
|
||||||
class ZipOp;
|
class ZipOp;
|
||||||
|
@ -55,6 +59,24 @@ class DeviceQueueOp;
|
||||||
|
|
||||||
class ImageFolderOp;
|
class ImageFolderOp;
|
||||||
|
|
||||||
|
class CacheOp;
|
||||||
|
|
||||||
|
class MnistOp;
|
||||||
|
|
||||||
|
class ManifestOp;
|
||||||
|
|
||||||
|
class CifarOp;
|
||||||
|
|
||||||
|
class VOCOp;
|
||||||
|
|
||||||
|
class CocoOp;
|
||||||
|
|
||||||
|
class CelebAOp;
|
||||||
|
|
||||||
|
class CacheMergeOp;
|
||||||
|
|
||||||
|
class CacheLookupOp;
|
||||||
|
|
||||||
// The base class Pass is the basic unit of tree transformation.
|
// The base class Pass is the basic unit of tree transformation.
|
||||||
// The actual implementation of the passes will be derived from here.
|
// The actual implementation of the passes will be derived from here.
|
||||||
class Pass : public std::enable_shared_from_this<Pass> {
|
class Pass : public std::enable_shared_from_this<Pass> {
|
||||||
|
@ -138,14 +160,42 @@ class NodePass : public Pass {
|
||||||
virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified);
|
virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified);
|
||||||
|
|
||||||
virtual Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified);
|
virtual Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified);
|
||||||
|
|
||||||
virtual Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified);
|
virtual Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified);
|
||||||
|
|
||||||
virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified);
|
virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
|
||||||
|
|
||||||
virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified);
|
virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified);
|
||||||
|
|
||||||
|
virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Helper function to perform DFS visit
|
// Helper function to perform DFS visit
|
||||||
Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified);
|
Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified);
|
||||||
|
|
|
@ -0,0 +1,161 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include "dataset/engine/opt/post/repeat_pass.h"
|
||||||
|
#include "dataset/engine/datasetops/repeat_op.h"
|
||||||
|
#include "dataset/engine/datasetops/cache_op.h"
|
||||||
|
#include "dataset/engine/datasetops/cache_lookup_op.h"
|
||||||
|
#include "dataset/engine/datasetops/cache_merge_op.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
RepeatPass::RepeatPass() : is_repeated_(false), nested_repeats_(0), is_merge_(false), cache_lookup_(nullptr) {}
|
||||||
|
|
||||||
|
// Identifies the subtree below this node as being in a repeated path of the tree.
|
||||||
|
Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
||||||
|
// If we are already repeated, then this is a nested repeat.
|
||||||
|
if (is_repeated_) {
|
||||||
|
nested_repeats_++;
|
||||||
|
}
|
||||||
|
is_repeated_ = true;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Identifies the subtree below this node as being in a cache merge path
|
||||||
|
Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
|
||||||
|
// Turn on the flag that we're under a merge op
|
||||||
|
is_merge_ = true;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hooks up any identified eoe nodes under this repeat.
|
||||||
|
Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
||||||
|
// Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking
|
||||||
|
std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack();
|
||||||
|
while (leaf_op != nullptr) {
|
||||||
|
node->AddToEoeList(leaf_op);
|
||||||
|
leaf_op = PopFromEOEOpStack();
|
||||||
|
}
|
||||||
|
|
||||||
|
// We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up
|
||||||
|
// and add it to the list of eoe/leaf ops for the repeat, removing it from the save area.
|
||||||
|
if (is_merge_ && cache_lookup_) {
|
||||||
|
cache_lookup_->set_control_flag(DatasetOp::kDeOpRepeated);
|
||||||
|
node->AddToEoeList(std::move(cache_lookup_));
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we are a nested repeat, then we add ourself to the repeat stack for the next one above us.
|
||||||
|
// A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree.
|
||||||
|
if (nested_repeats_ > 0) {
|
||||||
|
node->set_control_flag(DatasetOp::kDeOpRepeated);
|
||||||
|
AddToEOEOpStack(node);
|
||||||
|
nested_repeats_--;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we are not nested, or we were the top-most repeat, now we clear the flag
|
||||||
|
if (nested_repeats_ == 0) {
|
||||||
|
is_repeated_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// CacheOp removes previous leaf ops and replaces them with itself
|
||||||
|
Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||||
|
if (is_repeated_) {
|
||||||
|
node->set_control_flag(DatasetOp::kDeOpRepeated);
|
||||||
|
// if we are a cache within a repeat path of the tree, then there will be
|
||||||
|
// eoe-generating ops in the eoe op stack in the tree. They are flagged as such so that the
|
||||||
|
// repeat or epoch ctrl operators can work with them for repeat activity during runtime.
|
||||||
|
// However, since a cache is present:
|
||||||
|
// - unflag those ops as being repeated ops
|
||||||
|
// - remove them from the eoe op stack so that repeat op above in the tree won't know about them
|
||||||
|
// - add ourself (the cache op), as an eoe op
|
||||||
|
// We do this so that those old leafs become 1-time use (up to eoe), never repeated. Instead
|
||||||
|
// the repeating behaviours shall be invoked against the cache op.
|
||||||
|
std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack();
|
||||||
|
while (leaf_op != nullptr) {
|
||||||
|
leaf_op->ClearControlFlag(DatasetOp::kDeOpLastRepeat);
|
||||||
|
leaf_op->ClearControlFlag(DatasetOp::kDeOpRepeated);
|
||||||
|
leaf_op = PopFromEOEOpStack();
|
||||||
|
}
|
||||||
|
AddToEOEOpStack(std::static_pointer_cast<DatasetOp>(node));
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
|
||||||
|
// for use with a controlling repeat above it.
|
||||||
|
Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) {
|
||||||
|
// If we are in a repeat path, then set our repeated flag
|
||||||
|
if (is_repeated_) {
|
||||||
|
node->set_control_flag(DatasetOp::kDeOpRepeated);
|
||||||
|
|
||||||
|
// if we are a leaf node then save ourself in a stack for the repeat operator above us
|
||||||
|
if (node->IsLeaf()) {
|
||||||
|
AddToEOEOpStack(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Turns off the tracking for operations under merge op
|
||||||
|
Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
|
||||||
|
// Setting the flag is needed since we didn't call the base class DatasetOp version
|
||||||
|
if (is_repeated_) node->set_control_flag(DatasetOp::kDeOpRepeated);
|
||||||
|
is_merge_ = false;
|
||||||
|
cache_lookup_.reset(); // If a repeat op did not consume this then it's no longer needed
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Saves the lookup up in case it needs to be referenced by a repeat
|
||||||
|
Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) {
|
||||||
|
if (!node->IsLeaf()) {
|
||||||
|
// By definition, the CacheLookup must be a leaf op. Make that clear here.
|
||||||
|
RETURN_STATUS_UNEXPECTED("CacheLookupOp must be a leaf node!");
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we are in a repeat path already, then there must be a repeat above the merge op
|
||||||
|
// In this case, we naturally are a repeating leaf op so add the required setup for leafs under repeat here.
|
||||||
|
if (is_repeated_) {
|
||||||
|
node->set_control_flag(DatasetOp::kDeOpRepeated);
|
||||||
|
AddToEOEOpStack(node);
|
||||||
|
} else {
|
||||||
|
// save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we
|
||||||
|
// may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself
|
||||||
|
// into the pass so that the decision can be made during the processing of the cache miss leg of the merge.
|
||||||
|
cache_lookup_ = std::static_pointer_cast<DatasetOp>(node);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adds an operator to the eoe operator stack save area
|
||||||
|
void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { eoe_stack_.push(dataset_op); }
|
||||||
|
|
||||||
|
// Pops an operator from the eoe operator stack save area
|
||||||
|
std::shared_ptr<DatasetOp> RepeatPass::PopFromEOEOpStack() {
|
||||||
|
std::shared_ptr<DatasetOp> top_op = nullptr;
|
||||||
|
if (!eoe_stack_.empty()) {
|
||||||
|
top_op = eoe_stack_.top();
|
||||||
|
eoe_stack_.pop();
|
||||||
|
}
|
||||||
|
return top_op;
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,98 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_
|
||||||
|
#define DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <stack>
|
||||||
|
#include <utility>
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
/// \class RepeatPass repeat_pass.h
|
||||||
|
/// \brief This is a NodePass who's job is to perform setup actions for RepeatOps. A RepeatOp needs to have references
|
||||||
|
/// to the eoe-producing (typically leaf) nodes underneath it.
|
||||||
|
class RepeatPass : public NodePass {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor
|
||||||
|
RepeatPass();
|
||||||
|
|
||||||
|
/// \brief Identifies the subtree below this node as being in a repeated path of the tree.
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Identifies the subtree below this node as being in a cache merge path
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Hooks up any identified eoe nodes under this repeat.
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief CacheOp removes previous leaf ops and replaces them with itself
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Turns of the tracking for operations under merge op
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Saves the lookup up in case it needs to be referenced by a repeat
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
|
||||||
|
/// for use with a controlling repeat above it.
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// \brief Adds an operator to the eoe operator stack save area
|
||||||
|
/// \param op - The dataset op to work add to eoe stack
|
||||||
|
/// \return Status - The error code return
|
||||||
|
void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op);
|
||||||
|
|
||||||
|
/// \brief Pops an operator from the eoe operator stack save area
|
||||||
|
/// \return shared_ptr to the popped operator
|
||||||
|
std::shared_ptr<DatasetOp> PopFromEOEOpStack();
|
||||||
|
|
||||||
|
bool is_repeated_; // T/F if we are processing under a repeat
|
||||||
|
bool is_merge_; // T/F if we are processing under a cache merge op
|
||||||
|
int32_t nested_repeats_; // A counter for nested repeats
|
||||||
|
std::stack<std::shared_ptr<DatasetOp>> eoe_stack_; // A save area for leaf/eoe ops
|
||||||
|
std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_
|
|
@ -0,0 +1,181 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include "dataset/engine/opt/pre/cache_pass.h"
|
||||||
|
#include "dataset/engine/opt/pre/cache_transform_pass.h"
|
||||||
|
#include "dataset/engine/datasetops/cache_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/celeba_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/generator_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/manifest_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/mnist_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/voc_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/cifar_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/coco_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/image_folder_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/random_data_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/mindrecord_op.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
// Constructor
|
||||||
|
CachePass::CachePass(CacheTransformPass *transform_pass)
|
||||||
|
: transform_pass_(transform_pass), is_caching_(false), leaf_op_(nullptr) {}
|
||||||
|
|
||||||
|
// Identifies the subtree below this node as a cached descendant tree.
|
||||||
|
Status CachePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
|
||||||
|
if (is_caching_) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!");
|
||||||
|
}
|
||||||
|
is_caching_ = true;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache
|
||||||
|
// transformation
|
||||||
|
Status CachePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
is_caching_ = false; // We a no longer in a cache subtree. clear the flag.
|
||||||
|
if (leaf_op_) {
|
||||||
|
MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache.";
|
||||||
|
// Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op,
|
||||||
|
// using base class pointers.
|
||||||
|
transform_pass_->AddMappableCacheOperators(std::move(leaf_op_), node);
|
||||||
|
} else {
|
||||||
|
// If there was no leaf_op set, then this is a non-mappable scenario.
|
||||||
|
|
||||||
|
if (sampler_) {
|
||||||
|
// Grab the sampler that was saved from the leaf and plug it into the cache op
|
||||||
|
node->SetSampler(std::move(sampler_));
|
||||||
|
MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf.";
|
||||||
|
} else {
|
||||||
|
// We're a cache op but no sampler was saved from leaf, so create a default sampler
|
||||||
|
int64_t num_samples = 0;
|
||||||
|
int64_t start_index = 0;
|
||||||
|
sampler_ = std::make_shared<SequentialSampler>(num_samples, start_index);
|
||||||
|
node->SetSampler(std::move(sampler_));
|
||||||
|
MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op.";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache
|
||||||
|
uint32_t cache_crc = DatasetOp::GenerateCRC(node);
|
||||||
|
RETURN_IF_NOT_OK(node->CreateCache(cache_crc));
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Common code for mappable leaf setup.
|
||||||
|
Status CachePass::MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) {
|
||||||
|
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
|
||||||
|
if (is_caching_ && leaf_op_) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we are a leaf in the caching path, then save this leaf.
|
||||||
|
if (is_caching_) {
|
||||||
|
MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected";
|
||||||
|
leaf_op_ = std::move(leaf_op);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Common code for non mappable leaf setup.
|
||||||
|
Status CachePass::NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) {
|
||||||
|
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
|
||||||
|
if (is_caching_ && leaf_op_) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf
|
||||||
|
// as save it for use by cache op in ascendant tree.
|
||||||
|
if (is_caching_) {
|
||||||
|
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_));
|
||||||
|
MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected";
|
||||||
|
} else {
|
||||||
|
// If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can
|
||||||
|
// remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based)
|
||||||
|
std::shared_ptr<Sampler> sampler_from_leaf;
|
||||||
|
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform leaf node cache tranform identifications
|
||||||
|
Status CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) {
|
||||||
|
if (is_caching_) {
|
||||||
|
// If we are a TF Reader in a caching tree, then change our config so that it becomes a basic
|
||||||
|
// TF reader that parses all files. Selection of data will come from the sampler on the cache instead.
|
||||||
|
node->MakeSimpleProducer();
|
||||||
|
}
|
||||||
|
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform leaf node cache tranform identifications
|
||||||
|
Status CachePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) {
|
||||||
|
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform leaf node cache tranform identifications
|
||||||
|
Status CachePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) {
|
||||||
|
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform leaf node cache tranform identifications
|
||||||
|
Status CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) {
|
||||||
|
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform leaf node cache tranform identifications
|
||||||
|
Status CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
|
||||||
|
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform leaf node cache tranform identifications
|
||||||
|
Status CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
|
||||||
|
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform leaf node cache tranform identifications
|
||||||
|
Status CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) {
|
||||||
|
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform leaf node cache tranform identifications
|
||||||
|
Status CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
|
||||||
|
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform leaf node cache tranform identifications
|
||||||
|
Status CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) {
|
||||||
|
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform leaf node cache tranform identifications
|
||||||
|
Status CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) {
|
||||||
|
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform leaf node cache tranform identifications
|
||||||
|
Status CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
|
||||||
|
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,138 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_
|
||||||
|
#define DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
class CacheTransformPass;
|
||||||
|
|
||||||
|
/// \class CachePass cache_pass.h
|
||||||
|
/// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache
|
||||||
|
/// transformation. It works in conjunction with the CacheTransformPass
|
||||||
|
class CachePass : public NodePass {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor
|
||||||
|
/// \param[in] transform_pass Raw pointer back to controlling tree pass
|
||||||
|
explicit CachePass(CacheTransformPass *transform_pass);
|
||||||
|
|
||||||
|
/// \brief Identifies the subtree below this node as a cached descendant tree.
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache
|
||||||
|
/// transformation
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Perform leaf node cache tranform identifications
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Perform leaf node cache tranform identifications
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Perform leaf node cache tranform identifications
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Perform leaf node cache tranform identifications
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Perform leaf node cache tranform identifications
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Perform leaf node cache tranform identifications
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Perform leaf node cache tranform identifications
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Perform leaf node cache tranform identifications
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Perform leaf node cache tranform identifications
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Perform leaf node cache tranform identifications
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Perform leaf node cache tranform identifications
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// \brief Common code for mappable leaf setup.
|
||||||
|
/// \param[in] node The leaf node performing setup work.
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);
|
||||||
|
|
||||||
|
/// \brief Common code for non-mappable leaf setup.
|
||||||
|
/// \param[in] node The leaf node performing setup work.
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);
|
||||||
|
|
||||||
|
bool is_caching_;
|
||||||
|
std::shared_ptr<DatasetOp> leaf_op_;
|
||||||
|
std::shared_ptr<Sampler> sampler_;
|
||||||
|
CacheTransformPass *transform_pass_; // Back pointer to the owning transform pass
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_
|
|
@ -0,0 +1,108 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include "dataset/engine/opt/pre/cache_pass.h"
|
||||||
|
#include "dataset/engine/opt/pre/cache_transform_pass.h"
|
||||||
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
#include "dataset/engine/cache/cache_client.h"
|
||||||
|
#include "dataset/engine/datasetops/cache_lookup_op.h"
|
||||||
|
#include "dataset/engine/datasetops/cache_merge_op.h"
|
||||||
|
#include "dataset/engine/datasetops/cache_op.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
// constructor
|
||||||
|
CacheTransformPass::CacheTransformPass() {}
|
||||||
|
|
||||||
|
// Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations
|
||||||
|
Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *modified) {
|
||||||
|
MS_LOG(INFO) << "Pre pass: Cache transform pass started.";
|
||||||
|
// Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will
|
||||||
|
// use to execute a transform.
|
||||||
|
std::unique_ptr<Pass> cache_pass = std::make_unique<CachePass>(this);
|
||||||
|
RETURN_IF_NOT_OK(cache_pass->Run(tree, modified));
|
||||||
|
|
||||||
|
// Then, execute the transform for each pair
|
||||||
|
for (auto cache_pair : cache_pairs_) {
|
||||||
|
MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform.";
|
||||||
|
ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client());
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Pre pass: Cache transform pass complete.";
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to execute the cache transformation.
|
||||||
|
Status CacheTransformPass::ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op,
|
||||||
|
std::shared_ptr<DatasetOp> cache_op,
|
||||||
|
std::shared_ptr<CacheClient> cache_client) {
|
||||||
|
// Get local pointers the child/parent of the cache op. It's possible that the parent is null if the cache was
|
||||||
|
// the root node. It is also possible that cache_child == leaf_op
|
||||||
|
std::shared_ptr<DatasetOp> cache_child = cache_op->child(0);
|
||||||
|
DatasetOp *cache_parent = nullptr;
|
||||||
|
cache_op->Parent(&cache_parent, 0); // fetch the cache op's parent
|
||||||
|
|
||||||
|
// Extract the sampler from the leaf. We will overwrite this sampler with the lookup op later.
|
||||||
|
std::shared_ptr<Sampler> leaf_sampler = leaf_op->sampler();
|
||||||
|
|
||||||
|
// Construct the merge op with defaults
|
||||||
|
std::shared_ptr<CacheMergeOp> merge_op;
|
||||||
|
CacheMergeOp::Builder merge_builder;
|
||||||
|
RETURN_IF_NOT_OK(merge_builder.SetClient(cache_client).Build(&merge_op));
|
||||||
|
RETURN_IF_NOT_OK(tree->AssociateNode(merge_op));
|
||||||
|
|
||||||
|
// Construct the cache lookup op with defaults
|
||||||
|
std::shared_ptr<CacheLookupOp> cache_lookup_op;
|
||||||
|
CacheLookupOp::Builder lookup_builder;
|
||||||
|
RETURN_IF_NOT_OK(lookup_builder.SetClient(cache_client).SetSampler(std::move(leaf_sampler)).Build(&cache_lookup_op));
|
||||||
|
RETURN_IF_NOT_OK(tree->AssociateNode(cache_lookup_op));
|
||||||
|
|
||||||
|
// Overwrite the old sampler in this leaf op to become the lookup op
|
||||||
|
leaf_op->SetSampler(cache_lookup_op);
|
||||||
|
|
||||||
|
// If the cache had a parent, then go into that parent to remove the cache from it's child list and then
|
||||||
|
// replace it with the merge op.
|
||||||
|
if (cache_parent != nullptr) {
|
||||||
|
RETURN_IF_NOT_OK(cache_parent->RemoveChild(cache_op));
|
||||||
|
RETURN_IF_NOT_OK(cache_parent->AddChild(merge_op));
|
||||||
|
} else {
|
||||||
|
// If we didn't have a parent, then the merge op is the root node
|
||||||
|
RETURN_IF_NOT_OK(tree->AssignRoot(merge_op));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the cache op to no longer be a parent over it's child. This will fully disconnect the old cache op.
|
||||||
|
// We maintain a local pointer to the old child though.
|
||||||
|
RETURN_IF_NOT_OK(cache_op->RemoveChild(cache_child));
|
||||||
|
|
||||||
|
// Connect the merge op
|
||||||
|
RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_lookup_op)));
|
||||||
|
RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_child)));
|
||||||
|
|
||||||
|
// At this point, the cache op has already had it's children and parents taken away. Calling remove
|
||||||
|
// on it at this point will not do any node hookups, and instead set internal fields to invalid.
|
||||||
|
RETURN_IF_NOT_OK(cache_op->Remove());
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assigns the leaf and cache operators that are involved in a cache transformation
|
||||||
|
void CacheTransformPass::AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op,
|
||||||
|
std::shared_ptr<CacheOp> cache_op) {
|
||||||
|
cache_pairs_.push_back(std::make_pair(leaf_op, cache_op));
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,79 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_
|
||||||
|
#define DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include "dataset/engine/opt/pass.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
class DatasetOp;
|
||||||
|
|
||||||
|
class CacheClient;
|
||||||
|
|
||||||
|
/// \class CacheTransformPass cache_transform_pass.h
|
||||||
|
/// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching
|
||||||
|
/// operations
|
||||||
|
class CacheTransformPass : public TreePass {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor
|
||||||
|
CacheTransformPass();
|
||||||
|
|
||||||
|
/// \brief Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations
|
||||||
|
/// \param[inout] tree The tree to operate on.
|
||||||
|
/// \param[inout] Indicate of the tree was modified.
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Assigns the leaf and cache operators that are involved in a cache transformation
|
||||||
|
/// \param[in] leaf_op The leaf operator involved in the cache transform
|
||||||
|
/// \param[in] cache_op The cache operator involved in the cache transform
|
||||||
|
void AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, std::shared_ptr<CacheOp> cache_op);
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// \brief Helper function to execute the cache transformation.
|
||||||
|
///
|
||||||
|
/// Input:
|
||||||
|
/// Sampler
|
||||||
|
/// |
|
||||||
|
/// LeafOp --> OtherOps --> CacheOp
|
||||||
|
///
|
||||||
|
/// Transformed:
|
||||||
|
/// Sampler --> CacheLookupOp ---------------->
|
||||||
|
/// | |
|
||||||
|
/// | MergeOp
|
||||||
|
/// | |
|
||||||
|
/// LeafOp --> OtherOps -->
|
||||||
|
///
|
||||||
|
/// \param[in] leaf_op The leaf node in the transform
|
||||||
|
/// \param[in] cache_op The cache op in the transform (will get removed)
|
||||||
|
/// \param[in] cache_client The cache client
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op,
|
||||||
|
std::shared_ptr<DatasetOp> cache_op, std::shared_ptr<CacheClient> cache_client);
|
||||||
|
|
||||||
|
// The two operators that work together to establish the cache transform
|
||||||
|
std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_
|
|
@ -24,12 +24,28 @@ namespace dataset {
|
||||||
|
|
||||||
RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {}
|
RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {}
|
||||||
|
|
||||||
|
// Identifies the subtree below this node as a cached descendant tree.
|
||||||
|
Status RemovalNodes::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree.";
|
||||||
|
is_caching_ = true;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resets the tracking of the cache within the tree
|
||||||
|
Status RemovalNodes::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||||
|
*modified = false;
|
||||||
|
MS_LOG(INFO) << "Removal pass: cache descendant tree complete.";
|
||||||
|
is_caching_ = false;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// Perform ShuffleOp removal check.
|
// Perform ShuffleOp removal check.
|
||||||
Status RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
|
Status RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
|
||||||
*modified = false;
|
*modified = false;
|
||||||
// If we are in a cache descendant tree, then this shuffle op needs to be removed
|
// If we are in a cache descendant tree, then this shuffle op needs to be removed
|
||||||
if (is_caching_) {
|
if (is_caching_) {
|
||||||
MS_LOG(DEBUG) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)";
|
MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)";
|
||||||
if (removal_pass_) {
|
if (removal_pass_) {
|
||||||
removal_pass_->AddToRemovalList(std::static_pointer_cast<DatasetOp>(node));
|
removal_pass_->AddToRemovalList(std::static_pointer_cast<DatasetOp>(node));
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -34,6 +34,18 @@ class RemovalNodes : public NodePass {
|
||||||
/// \param[in] removal_pass Raw pointer back to controlling tree pass
|
/// \param[in] removal_pass Raw pointer back to controlling tree pass
|
||||||
explicit RemovalNodes(RemovalPass *removal_pass);
|
explicit RemovalNodes(RemovalPass *removal_pass);
|
||||||
|
|
||||||
|
/// \brief Identifies the subtree below this node as a cached descendant tree.
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||||
|
|
||||||
|
/// \brief Resets the tracking of the cache within the tree
|
||||||
|
/// \param[in] node The node being visited
|
||||||
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
/// \return Status The error code return
|
||||||
|
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||||
|
|
||||||
/// \brief Perform ShuffleOp removal check
|
/// \brief Perform ShuffleOp removal check
|
||||||
/// \param[in] node The node being visited
|
/// \param[in] node The node being visited
|
||||||
/// \param[inout] modified Indicator if the node was changed at all
|
/// \param[inout] modified Indicator if the node was changed at all
|
||||||
|
|
|
@ -28,6 +28,7 @@ RemovalPass::RemovalPass() {}
|
||||||
|
|
||||||
// Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
|
// Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
|
||||||
Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) {
|
Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) {
|
||||||
|
MS_LOG(INFO) << "Pre pass: removal pass started.";
|
||||||
// Create the removal node pass which can identify which nodes need to be removed.
|
// Create the removal node pass which can identify which nodes need to be removed.
|
||||||
std::unique_ptr<Pass> removal_nodes = std::make_unique<RemovalNodes>(this);
|
std::unique_ptr<Pass> removal_nodes = std::make_unique<RemovalNodes>(this);
|
||||||
RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified));
|
RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified));
|
||||||
|
@ -36,6 +37,7 @@ Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) {
|
||||||
for (auto node : removal_nodes_) {
|
for (auto node : removal_nodes_) {
|
||||||
node->Remove();
|
node->Remove();
|
||||||
}
|
}
|
||||||
|
MS_LOG(INFO) << "Pre pass: removal pass complete.";
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -87,8 +87,9 @@ class Allocator {
|
||||||
std::shared_ptr<MemoryPool> pool_;
|
std::shared_ptr<MemoryPool> pool_;
|
||||||
};
|
};
|
||||||
/// \brief It is a wrapper of unique_ptr with a custom allocator and acts like std::lock_guard such that the memory will
|
/// \brief It is a wrapper of unique_ptr with a custom allocator and acts like std::lock_guard such that the memory will
|
||||||
/// be released when the object goes out of scope \tparam T The type of object to be allocated \tparam C Allocator.
|
/// be released when the object goes out of scope
|
||||||
/// Default to std::allocator
|
/// \tparam T The type of object to be allocated
|
||||||
|
/// \tparam C Allocator. Default to std::allocator
|
||||||
template <typename T, typename C = std::allocator<T>>
|
template <typename T, typename C = std::allocator<T>>
|
||||||
class MemGuard {
|
class MemGuard {
|
||||||
public:
|
public:
|
||||||
|
@ -168,7 +169,7 @@ class MemGuard {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
allocator alloc_;
|
allocator alloc_;
|
||||||
std::unique_ptr<T[], std::function<void(T *)>> ptr_;
|
std::unique_ptr<T[]> ptr_;
|
||||||
size_t n_;
|
size_t n_;
|
||||||
};
|
};
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
|
|
|
@ -98,11 +98,6 @@ Status CachePool::Insert(const std::vector<ReadableSlice> &buf, CachePool::key_t
|
||||||
} catch (std::bad_alloc &e) {
|
} catch (std::bad_alloc &e) {
|
||||||
if (sm_ != nullptr) {
|
if (sm_ != nullptr) {
|
||||||
RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf));
|
RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf));
|
||||||
// We have an assumption 0 is not a valid key from the design of AutoIndexObj.
|
|
||||||
// Make sure it is not 0.
|
|
||||||
if (bl.storage_key == 0) {
|
|
||||||
RETURN_STATUS_UNEXPECTED("Key 0 is returned which is unexpected");
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
|
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,11 +22,11 @@
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#endif
|
#endif
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
|
#include "dataset/engine/cache/cache_server.h"
|
||||||
#include "dataset/util/circular_pool.h"
|
#include "dataset/util/circular_pool.h"
|
||||||
#include "dataset/util/random.h"
|
#include "dataset/util/random.h"
|
||||||
#include "dataset/util/task_manager.h"
|
#include "dataset/util/task_manager.h"
|
||||||
|
|
||||||
#define SLOT_TASK_MGR 0
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
std::unique_ptr<Services> Services::instance_ = nullptr;
|
std::unique_ptr<Services> Services::instance_ = nullptr;
|
||||||
|
@ -61,15 +61,25 @@ std::string Services::GetUniqueID() {
|
||||||
|
|
||||||
TaskManager &Services::getTaskMgrInstance() {
|
TaskManager &Services::getTaskMgrInstance() {
|
||||||
Services &sm = GetInstance();
|
Services &sm = GetInstance();
|
||||||
return *(static_cast<TaskManager *>(sm.sa_[SLOT_TASK_MGR]));
|
return *(static_cast<TaskManager *>(sm.sa_[kSlotTaskMgr_]));
|
||||||
|
}
|
||||||
|
|
||||||
|
CacheServer &Services::getCacheServer() {
|
||||||
|
Services &sm = GetInstance();
|
||||||
|
return *(static_cast<CacheServer *>(sm.sa_[kSlotCacheMgr_]));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Services::CreateAllInstances() {
|
Status Services::CreateAllInstances() {
|
||||||
// In order, TaskMgr, BufferMgr
|
// In order, TaskMgr, BufferMgr
|
||||||
Status rc;
|
Status rc;
|
||||||
sa_[SLOT_TASK_MGR] = new (&rc, pool_) TaskManager();
|
sa_[kSlotTaskMgr_] = new (&rc, pool_) TaskManager();
|
||||||
RETURN_IF_NOT_OK(rc);
|
RETURN_IF_NOT_OK(rc);
|
||||||
rc = sa_[SLOT_TASK_MGR]->ServiceStart();
|
rc = sa_[kSlotTaskMgr_]->ServiceStart();
|
||||||
|
RETURN_IF_NOT_OK(rc);
|
||||||
|
// TODO(jesse) : Get the parameters from config file. Right now spill to /tmp and spawn 3 workers
|
||||||
|
sa_[kSlotCacheMgr_] = new (&rc, pool_) CacheServer("/tmp", 3);
|
||||||
|
RETURN_IF_NOT_OK(rc);
|
||||||
|
rc = sa_[kSlotCacheMgr_]->ServiceStart();
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,8 +93,14 @@ Services::Services() : pool_(nullptr), sa_{nullptr} {
|
||||||
Services::~Services() noexcept {
|
Services::~Services() noexcept {
|
||||||
try {
|
try {
|
||||||
// In reverse order
|
// In reverse order
|
||||||
TaskManager *tm = static_cast<TaskManager *>(sa_[SLOT_TASK_MGR]);
|
CacheServer *cs = static_cast<CacheServer *>(sa_[kSlotCacheMgr_]);
|
||||||
if (tm) {
|
if (cs != nullptr) {
|
||||||
|
(void)cs->ServiceStop();
|
||||||
|
cs->~CacheServer();
|
||||||
|
pool_->Deallocate(cs);
|
||||||
|
}
|
||||||
|
TaskManager *tm = static_cast<TaskManager *>(sa_[kSlotTaskMgr_]);
|
||||||
|
if (tm != nullptr) {
|
||||||
(void)tm->ServiceStop();
|
(void)tm->ServiceStop();
|
||||||
tm->~TaskManager();
|
tm->~TaskManager();
|
||||||
pool_->Deallocate(tm);
|
pool_->Deallocate(tm);
|
||||||
|
|
|
@ -27,7 +27,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
class TaskManager;
|
class TaskManager;
|
||||||
|
class CacheServer;
|
||||||
class Services {
|
class Services {
|
||||||
public:
|
public:
|
||||||
static Status CreateInstance() {
|
static Status CreateInstance() {
|
||||||
|
@ -61,6 +61,8 @@ class Services {
|
||||||
|
|
||||||
static TaskManager &getTaskMgrInstance();
|
static TaskManager &getTaskMgrInstance();
|
||||||
|
|
||||||
|
static CacheServer &getCacheServer();
|
||||||
|
|
||||||
std::shared_ptr<MemoryPool> GetServiceMemPool() { return pool_; }
|
std::shared_ptr<MemoryPool> GetServiceMemPool() { return pool_; }
|
||||||
|
|
||||||
#if !defined(_WIN32) && !defined(_WIN64)
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
@ -87,7 +89,9 @@ class Services {
|
||||||
// We use pointers here instead of unique_ptr because we
|
// We use pointers here instead of unique_ptr because we
|
||||||
// want to have ultimate control on the order of
|
// want to have ultimate control on the order of
|
||||||
// construction and destruction.
|
// construction and destruction.
|
||||||
static constexpr int kNumServices_ = 1;
|
static constexpr int kSlotTaskMgr_ = 0;
|
||||||
|
static constexpr int kSlotCacheMgr_ = 1;
|
||||||
|
static constexpr int kNumServices_ = 2;
|
||||||
Service *sa_[kNumServices_];
|
Service *sa_[kNumServices_];
|
||||||
|
|
||||||
Services();
|
Services();
|
||||||
|
|
|
@ -24,6 +24,7 @@ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset
|
||||||
TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset
|
TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset
|
||||||
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
|
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
|
||||||
WeightedRandomSampler, Sampler
|
WeightedRandomSampler, Sampler
|
||||||
|
from .engine.cache_client import DatasetCache
|
||||||
from .engine.serializer_deserializer import serialize, deserialize, show
|
from .engine.serializer_deserializer import serialize, deserialize, show
|
||||||
from .engine.graphdata import GraphData
|
from .engine.graphdata import GraphData
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Cache client
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
from mindspore._c_dataengine import CacheClient
|
||||||
|
|
||||||
|
class DatasetCache:
|
||||||
|
"""
|
||||||
|
A client to interface with tensor caching service
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, session_id=None, size=None, spilling=False):
|
||||||
|
if session_id is None:
|
||||||
|
raise RuntimeError("Session generation is not implemented yet. session id required")
|
||||||
|
self.size = size if size is not None else 0
|
||||||
|
if size < 0:
|
||||||
|
raise ValueError("cache size should be 0 or positive integer value but got: size={}".format(size))
|
||||||
|
if not isinstance(spilling, bool):
|
||||||
|
raise ValueError(
|
||||||
|
"spilling argument for cache should be a boolean value but got: spilling={}".format(spilling))
|
||||||
|
self.session_id = session_id
|
||||||
|
self.spilling = spilling
|
||||||
|
self.cache_client = CacheClient(session_id, size, spilling)
|
||||||
|
|
||||||
|
def __deepcopy__(self, memodict):
|
||||||
|
if id(self) in memodict:
|
||||||
|
return memodict[id(self)]
|
||||||
|
cls = self.__class__
|
||||||
|
new_cache = cls.__new__(cls)
|
||||||
|
memodict[id(self)] = new_cache
|
||||||
|
new_cache.session_id = copy.deepcopy(self.session_id, memodict)
|
||||||
|
new_cache.spilling = copy.deepcopy(self.spilling, memodict)
|
||||||
|
new_cache.size = copy.deepcopy(self.size, memodict)
|
||||||
|
new_cache.cache_client = self.cache_client
|
||||||
|
return new_cache
|
|
@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
||||||
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
||||||
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
|
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
|
||||||
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
|
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
|
||||||
check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32
|
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32
|
||||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -386,7 +386,7 @@ class Dataset:
|
||||||
|
|
||||||
@check_map
|
@check_map
|
||||||
def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None,
|
def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None,
|
||||||
num_parallel_workers=None, python_multiprocessing=False):
|
num_parallel_workers=None, python_multiprocessing=False, cache=None):
|
||||||
"""
|
"""
|
||||||
Apply each operation in operations to this dataset.
|
Apply each operation in operations to this dataset.
|
||||||
|
|
||||||
|
@ -427,6 +427,7 @@ class Dataset:
|
||||||
parallel (default=None, the value from the config will be used).
|
parallel (default=None, the value from the config will be used).
|
||||||
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
|
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
|
||||||
option could be beneficial if the python operation is computational heavy (default=False).
|
option could be beneficial if the python operation is computational heavy (default=False).
|
||||||
|
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
MapDataset, dataset after mapping operation.
|
MapDataset, dataset after mapping operation.
|
||||||
|
@ -541,7 +542,7 @@ class Dataset:
|
||||||
>>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order)
|
>>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order)
|
||||||
"""
|
"""
|
||||||
return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers,
|
return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers,
|
||||||
python_multiprocessing)
|
python_multiprocessing, cache)
|
||||||
|
|
||||||
@check_filter
|
@check_filter
|
||||||
def filter(self, predicate, input_columns=None, num_parallel_workers=1):
|
def filter(self, predicate, input_columns=None, num_parallel_workers=1):
|
||||||
|
@ -1868,13 +1869,14 @@ class MapDataset(DatasetOp):
|
||||||
in parallel (default=None).
|
in parallel (default=None).
|
||||||
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
|
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
|
||||||
option could be beneficial if the python operation is computational heavy (default=False).
|
option could be beneficial if the python operation is computational heavy (default=False).
|
||||||
|
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified.
|
ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None,
|
def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None,
|
||||||
num_parallel_workers=None, python_multiprocessing=False):
|
num_parallel_workers=None, python_multiprocessing=False, cache=None):
|
||||||
super().__init__(num_parallel_workers)
|
super().__init__(num_parallel_workers)
|
||||||
self.children.append(input_dataset)
|
self.children.append(input_dataset)
|
||||||
if input_columns is not None and not isinstance(input_columns, list):
|
if input_columns is not None and not isinstance(input_columns, list):
|
||||||
|
@ -1886,6 +1888,7 @@ class MapDataset(DatasetOp):
|
||||||
if output_columns is not None and not isinstance(output_columns, list):
|
if output_columns is not None and not isinstance(output_columns, list):
|
||||||
output_columns = [output_columns]
|
output_columns = [output_columns]
|
||||||
self.output_columns = output_columns
|
self.output_columns = output_columns
|
||||||
|
self.cache = cache
|
||||||
self.columns_order = columns_order
|
self.columns_order = columns_order
|
||||||
|
|
||||||
if self.input_columns and self.output_columns \
|
if self.input_columns and self.output_columns \
|
||||||
|
@ -1904,6 +1907,7 @@ class MapDataset(DatasetOp):
|
||||||
args["operations"] = self.operations
|
args["operations"] = self.operations
|
||||||
args["output_columns"] = self.output_columns
|
args["output_columns"] = self.output_columns
|
||||||
args["columns_order"] = self.columns_order
|
args["columns_order"] = self.columns_order
|
||||||
|
args["cache"] = self.cache.cache_client if self.cache is not None else None
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def get_dataset_size(self):
|
def get_dataset_size(self):
|
||||||
|
@ -1929,6 +1933,7 @@ class MapDataset(DatasetOp):
|
||||||
new_op.parent = copy.deepcopy(self.parent, memodict)
|
new_op.parent = copy.deepcopy(self.parent, memodict)
|
||||||
new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict)
|
new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict)
|
||||||
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict)
|
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict)
|
||||||
|
new_op.cache = copy.deepcopy(self.cache, memodict)
|
||||||
new_op.operations = self.operations
|
new_op.operations = self.operations
|
||||||
return new_op
|
return new_op
|
||||||
|
|
||||||
|
@ -2346,7 +2351,7 @@ class RangeDataset(MappableDataset):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, non_mappable=False):
|
||||||
"""
|
"""
|
||||||
Create sampler based on user input.
|
Create sampler based on user input.
|
||||||
|
|
||||||
|
@ -2356,7 +2361,11 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
||||||
shuffle (bool): Shuffle.
|
shuffle (bool): Shuffle.
|
||||||
num_shards (int): Number of shard for sharding.
|
num_shards (int): Number of shard for sharding.
|
||||||
shard_id (int): Shard ID.
|
shard_id (int): Shard ID.
|
||||||
|
non_mappable (bool, optional): Indicate if caller is non-mappable dataset for special handling (default=False).
|
||||||
"""
|
"""
|
||||||
|
if non_mappable is True and all(arg is None for arg in [num_samples, shuffle, num_shards, shard_id, input_sampler]):
|
||||||
|
return None
|
||||||
|
|
||||||
if input_sampler is not None:
|
if input_sampler is not None:
|
||||||
# If the user provided a sampler, then it doesn't matter what the other args are because
|
# If the user provided a sampler, then it doesn't matter what the other args are because
|
||||||
# we are being asked specifically to use the given sampler.
|
# we are being asked specifically to use the given sampler.
|
||||||
|
@ -2369,7 +2378,7 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
||||||
if (isinstance(input_sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
|
if (isinstance(input_sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
|
||||||
samplers.RandomSampler, samplers.SubsetRandomSampler,
|
samplers.RandomSampler, samplers.SubsetRandomSampler,
|
||||||
samplers.WeightedRandomSampler, samplers.Sampler)) and
|
samplers.WeightedRandomSampler, samplers.Sampler)) and
|
||||||
(num_shards is not None or shard_id is not None or shuffle is not None or num_samples is not None)):
|
(any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'
|
'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'
|
||||||
' shard_id: {}, shuffle: {})'.format(num_samples, num_shards, shard_id, shuffle))
|
' shard_id: {}, shuffle: {})'.format(num_samples, num_shards, shard_id, shuffle))
|
||||||
|
@ -2458,6 +2467,7 @@ class ImageFolderDatasetV2(MappableDataset):
|
||||||
into (default=None).
|
into (default=None).
|
||||||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||||
argument should be specified only when num_shards is also specified.
|
argument should be specified only when num_shards is also specified.
|
||||||
|
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If sampler and shuffle are specified at the same time.
|
RuntimeError: If sampler and shuffle are specified at the same time.
|
||||||
|
@ -2482,7 +2492,7 @@ class ImageFolderDatasetV2(MappableDataset):
|
||||||
@check_imagefolderdatasetv2
|
@check_imagefolderdatasetv2
|
||||||
def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
|
def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
|
||||||
shuffle=None, sampler=None, extensions=None, class_indexing=None,
|
shuffle=None, sampler=None, extensions=None, class_indexing=None,
|
||||||
decode=False, num_shards=None, shard_id=None):
|
decode=False, num_shards=None, shard_id=None, cache=None):
|
||||||
super().__init__(num_parallel_workers)
|
super().__init__(num_parallel_workers)
|
||||||
|
|
||||||
self.dataset_dir = dataset_dir
|
self.dataset_dir = dataset_dir
|
||||||
|
@ -2494,6 +2504,7 @@ class ImageFolderDatasetV2(MappableDataset):
|
||||||
self.decode = decode
|
self.decode = decode
|
||||||
self.num_shards = num_shards
|
self.num_shards = num_shards
|
||||||
self.shard_id = shard_id
|
self.shard_id = shard_id
|
||||||
|
self.cache = cache
|
||||||
|
|
||||||
def get_args(self):
|
def get_args(self):
|
||||||
args = super().get_args()
|
args = super().get_args()
|
||||||
|
@ -2506,6 +2517,7 @@ class ImageFolderDatasetV2(MappableDataset):
|
||||||
args["decode"] = self.decode
|
args["decode"] = self.decode
|
||||||
args["num_shards"] = self.num_shards
|
args["num_shards"] = self.num_shards
|
||||||
args["shard_id"] = self.shard_id
|
args["shard_id"] = self.shard_id
|
||||||
|
args["cache"] = self.cache.cache_client if self.cache is not None else None
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def get_dataset_size(self):
|
def get_dataset_size(self):
|
||||||
|
@ -3251,6 +3263,7 @@ class TFRecordDataset(SourceDataset):
|
||||||
argument should be specified only when num_shards is also specified.
|
argument should be specified only when num_shards is also specified.
|
||||||
shard_equal_rows (bool): Get equal rows for all shards(default=False). If shard_equal_rows is false, number
|
shard_equal_rows (bool): Get equal rows for all shards(default=False). If shard_equal_rows is false, number
|
||||||
of rows of each shard may be not equal.
|
of rows of each shard may be not equal.
|
||||||
|
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used)
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore.dataset as ds
|
>>> import mindspore.dataset as ds
|
||||||
>>> import mindspore.common.dtype as mstype
|
>>> import mindspore.common.dtype as mstype
|
||||||
|
@ -3268,7 +3281,7 @@ class TFRecordDataset(SourceDataset):
|
||||||
|
|
||||||
@check_tfrecorddataset
|
@check_tfrecorddataset
|
||||||
def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None,
|
def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None,
|
||||||
shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False):
|
shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False, cache=None):
|
||||||
super().__init__(num_parallel_workers)
|
super().__init__(num_parallel_workers)
|
||||||
self.dataset_files = self._find_files(dataset_files)
|
self.dataset_files = self._find_files(dataset_files)
|
||||||
self.dataset_files.sort()
|
self.dataset_files.sort()
|
||||||
|
@ -3280,6 +3293,7 @@ class TFRecordDataset(SourceDataset):
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
self.columns_list = columns_list
|
self.columns_list = columns_list
|
||||||
self.num_samples = num_samples
|
self.num_samples = num_samples
|
||||||
|
self.cache = cache
|
||||||
if schema_obj is not None and num_samples is None:
|
if schema_obj is not None and num_samples is None:
|
||||||
self.num_samples = schema_obj.num_rows
|
self.num_samples = schema_obj.num_rows
|
||||||
|
|
||||||
|
@ -3295,6 +3309,14 @@ class TFRecordDataset(SourceDataset):
|
||||||
else:
|
else:
|
||||||
self.shuffle_level = shuffle
|
self.shuffle_level = shuffle
|
||||||
self.shuffle_files = True
|
self.shuffle_files = True
|
||||||
|
|
||||||
|
# The TF record dataset does not directly support a sampler. It has provided sampling arguments
|
||||||
|
# (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in
|
||||||
|
# the pipeline contains a cache. If there is no cache above it, then this sampler is not used.
|
||||||
|
sampler_shuffle = self.shuffle_files
|
||||||
|
sampler = None
|
||||||
|
self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id,
|
||||||
|
non_mappable=True)
|
||||||
self.shard_equal_rows = shard_equal_rows
|
self.shard_equal_rows = shard_equal_rows
|
||||||
|
|
||||||
def get_args(self):
|
def get_args(self):
|
||||||
|
@ -3318,6 +3340,8 @@ class TFRecordDataset(SourceDataset):
|
||||||
args["num_shards"] = self.num_shards
|
args["num_shards"] = self.num_shards
|
||||||
args["shard_id"] = self.shard_id
|
args["shard_id"] = self.shard_id
|
||||||
args["shard_equal_rows"] = self.shard_equal_rows
|
args["shard_equal_rows"] = self.shard_equal_rows
|
||||||
|
args["cache"] = self.cache.cache_client if self.cache is not None else None
|
||||||
|
args["sampler"] = self.sampler
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def get_dataset_size(self, estimate=False):
|
def get_dataset_size(self, estimate=False):
|
||||||
|
@ -3803,43 +3827,61 @@ class RandomDataset(SourceDataset):
|
||||||
A source dataset that generates random data.
|
A source dataset that generates random data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_samples (int): number of samples to generate.
|
total_rows (int): number of rows for the dataset to generate (default=None, number of rows is random)
|
||||||
schema (str or Schema, optional): Path to the json schema file or schema object (default=None).
|
schema (str or Schema, optional): Path to the json schema file or schema object (default=None).
|
||||||
If the schema is not provided, the random dataset generates a random schema.
|
If the schema is not provided, the random dataset generates a random schema.
|
||||||
columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
|
columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
|
||||||
|
num_samples (int): number of samples to draw from the total. (default=None, which means all rows)
|
||||||
num_parallel_workers (int, optional): number of workers to read the data
|
num_parallel_workers (int, optional): number of workers to read the data
|
||||||
(default=None, number set in the config).
|
(default=None, number set in the config).
|
||||||
|
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used)
|
||||||
|
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
|
||||||
|
(default=None, expected order behavior shown in the table).
|
||||||
|
num_shards (int, optional): Number of shards that the dataset should be divided
|
||||||
|
into (default=None).
|
||||||
|
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||||
|
argument should be specified only when num_shards is also specified.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None):
|
@check_random_dataset
|
||||||
|
def __init__(self, total_rows=None, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None,
|
||||||
|
cache=None, shuffle=None, num_shards=None, shard_id=None):
|
||||||
super().__init__(num_parallel_workers)
|
super().__init__(num_parallel_workers)
|
||||||
schema_obj = None
|
schema_obj = None
|
||||||
if (schema is not None) and (not isinstance(schema, Schema)):
|
if (schema is not None) and (not isinstance(schema, Schema)):
|
||||||
schema_obj = Schema(schema) # read the schema file and convert to schema object to validate it
|
schema_obj = Schema(schema) # read the schema file and convert to schema object to validate it
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
self.columns_list = columns_list
|
self.columns_list = columns_list
|
||||||
if schema_obj is not None and num_samples is None:
|
sampler = None
|
||||||
self.num_samples = schema_obj.num_rows
|
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id, non_mappable=True)
|
||||||
elif num_samples is None:
|
self.num_samples = num_samples
|
||||||
self.num_samples = 0
|
self.cache = cache
|
||||||
|
if schema_obj is not None and total_rows is None:
|
||||||
|
self.total_rows = schema_obj.num_rows
|
||||||
|
elif total_rows is None:
|
||||||
|
self.total_rows = 0
|
||||||
else:
|
else:
|
||||||
self.num_samples = num_samples
|
self.total_rows = total_rows
|
||||||
|
self.num_shards = num_shards
|
||||||
|
self.shard_id = shard_id
|
||||||
|
self.shuffle_level = shuffle
|
||||||
|
|
||||||
def get_args(self):
|
def get_args(self):
|
||||||
args = super().get_args()
|
args = super().get_args()
|
||||||
if self.schema is not None:
|
if self.schema is not None:
|
||||||
if isinstance(self.schema, Schema):
|
if isinstance(self.schema, Schema):
|
||||||
self.schema.datasetType = 'Random'
|
self.schema.datasetType = 'Random'
|
||||||
if self.num_samples is not None:
|
if self.total_rows is not None:
|
||||||
self.schema.num_rows = self.num_samples
|
self.schema.num_rows = self.total_rows
|
||||||
args["schema_json_string"] = self.schema.to_json()
|
args["schema_json_string"] = self.schema.to_json()
|
||||||
else:
|
else:
|
||||||
args["schema_file_path"] = self.schema
|
args["schema_file_path"] = self.schema
|
||||||
args["schema"] = self.schema
|
args["schema"] = self.schema
|
||||||
if self.columns_list is not None:
|
args["columns_list"] = self.columns_list
|
||||||
args["columns_list"] = self.columns_list
|
args["num_samples"] = self.num_samples
|
||||||
if self.num_samples is not None:
|
args["total_rows"] = self.total_rows
|
||||||
args["num_samples"] = self.num_samples
|
args["cache"] = self.cache.cache_client if self.cache is not None else None
|
||||||
|
args["sampler"] = self.sampler
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def get_dataset_size(self):
|
def get_dataset_size(self):
|
||||||
|
@ -3849,18 +3891,29 @@ class RandomDataset(SourceDataset):
|
||||||
Return:
|
Return:
|
||||||
Number, number of batches.
|
Number, number of batches.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
num_rows = CifarOp.get_num_rows(self.dataset_dir, True)
|
||||||
|
|
||||||
|
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||||
rows_from_sampler = self._get_sampler_dataset_size()
|
rows_from_sampler = self._get_sampler_dataset_size()
|
||||||
|
|
||||||
if rows_from_sampler is None:
|
if rows_from_sampler is None:
|
||||||
return self.num_samples
|
return rows_per_shard
|
||||||
|
|
||||||
return min(rows_from_sampler, self.num_samples)
|
return min(rows_from_sampler, rows_per_shard)
|
||||||
|
|
||||||
def is_shuffled(self):
|
def is_shuffled(self):
|
||||||
return True
|
if self.shuffle_level is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return self.shuffle_level or self.sampler.is_shuffled()
|
||||||
|
|
||||||
def is_sharded(self):
|
def is_sharded(self):
|
||||||
return False
|
if self.num_shards is not None:
|
||||||
|
return self.num_shards > 1
|
||||||
|
|
||||||
|
return self.sampler.is_sharded()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Schema:
|
class Schema:
|
||||||
|
|
|
@ -173,7 +173,9 @@ def traverse(node):
|
||||||
# num_samples, shard_id, num_shards, shuffle
|
# num_samples, shard_id, num_shards, shuffle
|
||||||
# These arguments get moved into the sampler itself, so they are no longer needed to
|
# These arguments get moved into the sampler itself, so they are no longer needed to
|
||||||
# be set at the dataset level.
|
# be set at the dataset level.
|
||||||
if 'sampler' in node_args.keys():
|
# TF Record is a special case because it uses both the dataset and sampler arguments
|
||||||
|
# which is not decided until later during tree preparation phase.
|
||||||
|
if node_repr['op_type'] != 'TFRecordDataset' and 'sampler' in node_args.keys():
|
||||||
if 'num_samples' in node_repr.keys():
|
if 'num_samples' in node_repr.keys():
|
||||||
node_repr['num_samples'] = None
|
node_repr['num_samples'] = None
|
||||||
if 'shuffle' in node_repr.keys():
|
if 'shuffle' in node_repr.keys():
|
||||||
|
|
|
@ -29,10 +29,11 @@ from ..core.validator_helpers import parse_user_args, type_check, type_check_lis
|
||||||
|
|
||||||
from . import datasets
|
from . import datasets
|
||||||
from . import samplers
|
from . import samplers
|
||||||
|
from . import cache_client
|
||||||
|
|
||||||
|
|
||||||
def check_imagefolderdatasetv2(method):
|
def check_imagefolderdatasetv2(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original Dataset(ImageFolderDatasetV2)."""
|
"""A wrapper that wraps a parameter checker to the original Dataset(ImageFolderDatasetV2)."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -58,7 +59,7 @@ def check_imagefolderdatasetv2(method):
|
||||||
|
|
||||||
|
|
||||||
def check_mnist_cifar_dataset(method):
|
def check_mnist_cifar_dataset(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset)."""
|
"""A wrapper that wraps a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset)."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -81,7 +82,7 @@ def check_mnist_cifar_dataset(method):
|
||||||
|
|
||||||
|
|
||||||
def check_manifestdataset(method):
|
def check_manifestdataset(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset)."""
|
"""A wrapper that wraps a parameter checker to the original Dataset(ManifestDataset)."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -108,7 +109,7 @@ def check_manifestdataset(method):
|
||||||
|
|
||||||
|
|
||||||
def check_tfrecorddataset(method):
|
def check_tfrecorddataset(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original Dataset(TFRecordDataset)."""
|
"""A wrapper that wraps a parameter checker to the original Dataset(TFRecordDataset)."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -134,7 +135,7 @@ def check_tfrecorddataset(method):
|
||||||
|
|
||||||
|
|
||||||
def check_vocdataset(method):
|
def check_vocdataset(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original Dataset(VOCDataset)."""
|
"""A wrapper that wraps a parameter checker to the original Dataset(VOCDataset)."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -175,7 +176,7 @@ def check_vocdataset(method):
|
||||||
|
|
||||||
|
|
||||||
def check_cocodataset(method):
|
def check_cocodataset(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original Dataset(CocoDataset)."""
|
"""A wrapper that wraps a parameter checker to the original Dataset(CocoDataset)."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -211,7 +212,7 @@ def check_cocodataset(method):
|
||||||
|
|
||||||
|
|
||||||
def check_celebadataset(method):
|
def check_celebadataset(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original Dataset(CelebADataset)."""
|
"""A wrapper that wraps a parameter checker to the original Dataset(CelebADataset)."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -247,7 +248,7 @@ def check_celebadataset(method):
|
||||||
|
|
||||||
|
|
||||||
def check_minddataset(method):
|
def check_minddataset(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original Dataset(MindDataset)."""
|
"""A wrapper that wraps a parameter checker to the original Dataset(MindDataset)."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -279,7 +280,7 @@ def check_minddataset(method):
|
||||||
|
|
||||||
|
|
||||||
def check_generatordataset(method):
|
def check_generatordataset(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original Dataset(GeneratorDataset)."""
|
"""A wrapper that wraps a parameter checker to the original Dataset(GeneratorDataset)."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -344,6 +345,27 @@ def check_generatordataset(method):
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
def check_random_dataset(method):
|
||||||
|
"""A wrapper that wraps a parameter checker to the original Dataset(RandomDataset)."""
|
||||||
|
|
||||||
|
@wraps(method)
|
||||||
|
def new_method(self, *args, **kwargs):
|
||||||
|
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||||
|
|
||||||
|
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id', 'total_rows']
|
||||||
|
nreq_param_bool = ['shuffle']
|
||||||
|
nreq_param_list = ['columns_list']
|
||||||
|
|
||||||
|
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||||
|
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||||
|
validate_dataset_param_value(nreq_param_list, param_dict, list)
|
||||||
|
|
||||||
|
check_sampler_shuffle_shard_options(param_dict)
|
||||||
|
|
||||||
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
def check_pad_info(key, val):
|
def check_pad_info(key, val):
|
||||||
"""check the key and value pair of pad_info in batch"""
|
"""check the key and value pair of pad_info in batch"""
|
||||||
|
@ -506,7 +528,7 @@ def check_map(method):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
[input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing], _ = \
|
[input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing, cache], _ = \
|
||||||
parse_user_args(method, *args, **kwargs)
|
parse_user_args(method, *args, **kwargs)
|
||||||
|
|
||||||
nreq_param_columns = ['input_columns', 'output_columns']
|
nreq_param_columns = ['input_columns', 'output_columns']
|
||||||
|
@ -516,6 +538,8 @@ def check_map(method):
|
||||||
if num_parallel_workers is not None:
|
if num_parallel_workers is not None:
|
||||||
check_num_parallel_workers(num_parallel_workers)
|
check_num_parallel_workers(num_parallel_workers)
|
||||||
type_check(python_multiprocessing, (bool,), "python_multiprocessing")
|
type_check(python_multiprocessing, (bool,), "python_multiprocessing")
|
||||||
|
if cache is not None:
|
||||||
|
type_check(cache, (cache_client.DatasetCache,), "cache")
|
||||||
|
|
||||||
for param_name, param in zip(nreq_param_columns, [input_columns, output_columns]):
|
for param_name, param in zip(nreq_param_columns, [input_columns, output_columns]):
|
||||||
if param is not None:
|
if param is not None:
|
||||||
|
@ -720,7 +744,7 @@ def check_add_column(method):
|
||||||
|
|
||||||
|
|
||||||
def check_cluedataset(method):
|
def check_cluedataset(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original Dataset(CLUEDataset)."""
|
"""A wrapper that wraps a parameter checker to the original Dataset(CLUEDataset)."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -750,7 +774,7 @@ def check_cluedataset(method):
|
||||||
|
|
||||||
|
|
||||||
def check_textfiledataset(method):
|
def check_textfiledataset(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset)."""
|
"""A wrapper that wraps a parameter checker to the original Dataset(TextFileDataset)."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -823,7 +847,7 @@ def check_gnn_graphdata(method):
|
||||||
|
|
||||||
|
|
||||||
def check_gnn_get_all_nodes(method):
|
def check_gnn_get_all_nodes(method):
|
||||||
"""A wrapper that wrap a parameter checker to the GNN `get_all_nodes` function."""
|
"""A wrapper that wraps a parameter checker to the GNN `get_all_nodes` function."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -836,7 +860,7 @@ def check_gnn_get_all_nodes(method):
|
||||||
|
|
||||||
|
|
||||||
def check_gnn_get_all_edges(method):
|
def check_gnn_get_all_edges(method):
|
||||||
"""A wrapper that wrap a parameter checker to the GNN `get_all_edges` function."""
|
"""A wrapper that wraps a parameter checker to the GNN `get_all_edges` function."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -849,7 +873,7 @@ def check_gnn_get_all_edges(method):
|
||||||
|
|
||||||
|
|
||||||
def check_gnn_get_nodes_from_edges(method):
|
def check_gnn_get_nodes_from_edges(method):
|
||||||
"""A wrapper that wrap a parameter checker to the GNN `get_nodes_from_edges` function."""
|
"""A wrapper that wraps a parameter checker to the GNN `get_nodes_from_edges` function."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -862,7 +886,7 @@ def check_gnn_get_nodes_from_edges(method):
|
||||||
|
|
||||||
|
|
||||||
def check_gnn_get_all_neighbors(method):
|
def check_gnn_get_all_neighbors(method):
|
||||||
"""A wrapper that wrap a parameter checker to the GNN `get_all_neighbors` function."""
|
"""A wrapper that wraps a parameter checker to the GNN `get_all_neighbors` function."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -877,7 +901,7 @@ def check_gnn_get_all_neighbors(method):
|
||||||
|
|
||||||
|
|
||||||
def check_gnn_get_sampled_neighbors(method):
|
def check_gnn_get_sampled_neighbors(method):
|
||||||
"""A wrapper that wrap a parameter checker to the GNN `get_sampled_neighbors` function."""
|
"""A wrapper that wraps a parameter checker to the GNN `get_sampled_neighbors` function."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -905,7 +929,7 @@ def check_gnn_get_sampled_neighbors(method):
|
||||||
|
|
||||||
|
|
||||||
def check_gnn_get_neg_sampled_neighbors(method):
|
def check_gnn_get_neg_sampled_neighbors(method):
|
||||||
"""A wrapper that wrap a parameter checker to the GNN `get_neg_sampled_neighbors` function."""
|
"""A wrapper that wraps a parameter checker to the GNN `get_neg_sampled_neighbors` function."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -921,7 +945,7 @@ def check_gnn_get_neg_sampled_neighbors(method):
|
||||||
|
|
||||||
|
|
||||||
def check_gnn_random_walk(method):
|
def check_gnn_random_walk(method):
|
||||||
"""A wrapper that wrap a parameter checker to the GNN `random_walk` function."""
|
"""A wrapper that wraps a parameter checker to the GNN `random_walk` function."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -968,7 +992,7 @@ def check_aligned_list(param, param_name, member_type):
|
||||||
|
|
||||||
|
|
||||||
def check_gnn_get_node_feature(method):
|
def check_gnn_get_node_feature(method):
|
||||||
"""A wrapper that wrap a parameter checker to the GNN `get_node_feature` function."""
|
"""A wrapper that wraps a parameter checker to the GNN `get_node_feature` function."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -1012,7 +1036,7 @@ def check_gnn_get_edge_feature(method):
|
||||||
|
|
||||||
|
|
||||||
def check_numpyslicesdataset(method):
|
def check_numpyslicesdataset(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original Dataset(NumpySlicesDataset)."""
|
"""A wrapper that wraps a parameter checker to the original Dataset(NumpySlicesDataset)."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
|
|
@ -39,7 +39,7 @@ def check_unique_list_of_words(words, arg_name):
|
||||||
|
|
||||||
|
|
||||||
def check_lookup(method):
|
def check_lookup(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original function."""
|
"""A wrapper that wraps a parameter checker to the original function."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -56,7 +56,7 @@ def check_lookup(method):
|
||||||
|
|
||||||
|
|
||||||
def check_from_file(method):
|
def check_from_file(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original function."""
|
"""A wrapper that wraps a parameter checker to the original function."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -74,7 +74,7 @@ def check_from_file(method):
|
||||||
|
|
||||||
|
|
||||||
def check_from_list(method):
|
def check_from_list(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original function."""
|
"""A wrapper that wraps a parameter checker to the original function."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -97,7 +97,7 @@ def check_from_list(method):
|
||||||
|
|
||||||
|
|
||||||
def check_from_dict(method):
|
def check_from_dict(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original function."""
|
"""A wrapper that wraps a parameter checker to the original function."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -285,7 +285,7 @@ def check_bert_tokenizer(method):
|
||||||
|
|
||||||
|
|
||||||
def check_from_dataset(method):
|
def check_from_dataset(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original function."""
|
"""A wrapper that wraps a parameter checker to the original function."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -328,7 +328,7 @@ def check_from_dataset(method):
|
||||||
|
|
||||||
|
|
||||||
def check_ngram(method):
|
def check_ngram(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original function."""
|
"""A wrapper that wraps a parameter checker to the original function."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
|
|
@ -114,7 +114,7 @@ def check_erasing_value(value):
|
||||||
|
|
||||||
|
|
||||||
def check_crop(method):
|
def check_crop(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original function(crop operation)."""
|
"""A wrapper that wraps a parameter checker to the original function(crop operation)."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -127,7 +127,7 @@ def check_crop(method):
|
||||||
|
|
||||||
|
|
||||||
def check_resize_interpolation(method):
|
def check_resize_interpolation(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original function(resize interpolation operation)."""
|
"""A wrapper that wraps a parameter checker to the original function(resize interpolation operation)."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -142,7 +142,7 @@ def check_resize_interpolation(method):
|
||||||
|
|
||||||
|
|
||||||
def check_resize(method):
|
def check_resize(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original function(resize operation)."""
|
"""A wrapper that wraps a parameter checker to the original function(resize operation)."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -155,7 +155,7 @@ def check_resize(method):
|
||||||
|
|
||||||
|
|
||||||
def check_random_resize_crop(method):
|
def check_random_resize_crop(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original function(random resize crop operation)."""
|
"""A wrapper that wraps a parameter checker to the original function(random resize crop operation)."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -178,7 +178,7 @@ def check_random_resize_crop(method):
|
||||||
|
|
||||||
|
|
||||||
def check_prob(method):
|
def check_prob(method):
|
||||||
"""A wrapper that wrap a parameter checker(check the probability) to the original function."""
|
"""A wrapper that wraps a parameter checker(check the probability) to the original function."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -192,7 +192,7 @@ def check_prob(method):
|
||||||
|
|
||||||
|
|
||||||
def check_normalize_c(method):
|
def check_normalize_c(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original function(normalize operation written in C++)."""
|
"""A wrapper that wraps a parameter checker to the original function(normalize operation written in C++)."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
@ -205,7 +205,7 @@ def check_normalize_c(method):
|
||||||
|
|
||||||
|
|
||||||
def check_normalize_py(method):
|
def check_normalize_py(method):
|
||||||
"""A wrapper that wrap a parameter checker to the original function(normalize operation written in Python)."""
|
"""A wrapper that wraps a parameter checker to the original function(normalize operation written in Python)."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
|
|
|
@ -738,7 +738,7 @@ TEST_F(MindDataTestPipeline, TestProjectMap) {
|
||||||
EXPECT_TRUE(ds != nullptr);
|
EXPECT_TRUE(ds != nullptr);
|
||||||
|
|
||||||
// Create a Project operation on ds
|
// Create a Project operation on ds
|
||||||
std::vector<std::string> column_project = {"label"};
|
std::vector<std::string> column_project = {"image"};
|
||||||
ds = ds->Project(column_project);
|
ds = ds->Project(column_project);
|
||||||
EXPECT_TRUE(ds != nullptr);
|
EXPECT_TRUE(ds != nullptr);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,579 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <string>
|
||||||
|
#include "dataset/core/client.h"
|
||||||
|
#include "dataset/engine/cache/cache_client.h"
|
||||||
|
#include "dataset/engine/execution_tree.h"
|
||||||
|
#include "dataset/engine/datasetops/cache_op.h"
|
||||||
|
#include "dataset/engine/datasetops/cache_lookup_op.h"
|
||||||
|
#include "dataset/engine/datasetops/cache_merge_op.h"
|
||||||
|
#include "dataset/engine/datasetops/source/image_folder_op.h"
|
||||||
|
#include "common/common.h"
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
#include "dataset/util/storage_container.h" // lint !e322
|
||||||
|
#include "dataset/engine/datasetops/source/random_data_op.h"
|
||||||
|
#include "dataset/engine/data_schema.h"
|
||||||
|
|
||||||
|
using namespace mindspore::dataset;
|
||||||
|
using mindspore::LogStream;
|
||||||
|
using mindspore::dataset::CacheClient;
|
||||||
|
using mindspore::dataset::TaskGroup;
|
||||||
|
using mindspore::ExceptionType::NoExceptionType;
|
||||||
|
using mindspore::MsLogLevel::INFO;
|
||||||
|
|
||||||
|
class MindDataTestCacheOp : public UT::DatasetOpTesting {
|
||||||
|
public:
|
||||||
|
void SetUp() override {
|
||||||
|
DatasetOpTesting::SetUp();
|
||||||
|
GlobalInit();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(MindDataTestCacheOp, TestCacheServer) {
|
||||||
|
Status rc;
|
||||||
|
CacheClient myClient(1, 0, true); // use arbitrary session of 1, size of 0, spilling is true
|
||||||
|
// cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated.
|
||||||
|
rc = myClient.CreateCache(1, true);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
std::cout << myClient << std::endl;
|
||||||
|
|
||||||
|
// Create a schema using the C api's
|
||||||
|
int32_t rank = 0; // not used
|
||||||
|
std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
|
||||||
|
// 2 columns. First column is an "image" 640,480,3
|
||||||
|
TensorShape c1Shape({640, 480, 3});
|
||||||
|
ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
|
||||||
|
rank, // not used
|
||||||
|
&c1Shape);
|
||||||
|
// Column 2 will just be a scalar label number
|
||||||
|
TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
|
||||||
|
ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
|
||||||
|
|
||||||
|
testSchema->AddColumn(c1);
|
||||||
|
testSchema->AddColumn(c2);
|
||||||
|
|
||||||
|
std::unordered_map<std::string, int32_t> map;
|
||||||
|
rc = testSchema->GetColumnNameMap(&map);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
// Test the CacheSchema api
|
||||||
|
rc = myClient.CacheSchema(map);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
// Create a tensor, take a snapshot and restore it back, and compare.
|
||||||
|
std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2, 3}), DataType(DataType::DE_UINT64));
|
||||||
|
t->SetItemAt<uint64_t>({0, 0}, 1);
|
||||||
|
t->SetItemAt<uint64_t>({0, 1}, 2);
|
||||||
|
t->SetItemAt<uint64_t>({0, 2}, 3);
|
||||||
|
t->SetItemAt<uint64_t>({1, 0}, 4);
|
||||||
|
t->SetItemAt<uint64_t>({1, 1}, 5);
|
||||||
|
t->SetItemAt<uint64_t>({1, 2}, 6);
|
||||||
|
std::cout << *t << std::endl;
|
||||||
|
TensorTable tbl;
|
||||||
|
TensorRow row;
|
||||||
|
row.push_back(t);
|
||||||
|
int64_t row_id;
|
||||||
|
rc = myClient.WriteRow(row, &row_id);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
// Switch off build phase.
|
||||||
|
rc = myClient.BuildPhaseDone();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
// Now restore from cache.
|
||||||
|
row.clear();
|
||||||
|
rc = myClient.GetRows({row_id}, &tbl);
|
||||||
|
row = tbl.front();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
auto r = row.front();
|
||||||
|
std::cout << *r << std::endl;
|
||||||
|
// Compare
|
||||||
|
bool cmp = (*t == *r);
|
||||||
|
EXPECT_TRUE(cmp);
|
||||||
|
|
||||||
|
// Get back the schema and verify
|
||||||
|
std::unordered_map<std::string, int32_t> map_out;
|
||||||
|
rc = myClient.FetchSchema(&map_out);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
cmp = (map_out == map);
|
||||||
|
EXPECT_TRUE(cmp);
|
||||||
|
|
||||||
|
// Test Purge and Destroy
|
||||||
|
rc = myClient.PurgeCache();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myClient.DestroyCache();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) {
|
||||||
|
// Clear the rc of the master thread if any
|
||||||
|
(void)TaskManager::GetMasterThreadRc();
|
||||||
|
TaskGroup vg;
|
||||||
|
Status rc;
|
||||||
|
CacheClient myClient(1, 1, true); // use arbitrary session of 1, size 1, spilling is true
|
||||||
|
// cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated.
|
||||||
|
rc = myClient.CreateCache(1, true);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
std::cout << myClient << std::endl;
|
||||||
|
std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2, 3}), DataType(DataType::DE_UINT64));
|
||||||
|
t->SetItemAt<uint64_t>({0, 0}, 1);
|
||||||
|
t->SetItemAt<uint64_t>({0, 1}, 2);
|
||||||
|
t->SetItemAt<uint64_t>({0, 2}, 3);
|
||||||
|
t->SetItemAt<uint64_t>({1, 0}, 4);
|
||||||
|
t->SetItemAt<uint64_t>({1, 1}, 5);
|
||||||
|
t->SetItemAt<uint64_t>({1, 2}, 6);
|
||||||
|
TensorTable tbl;
|
||||||
|
TensorRow row;
|
||||||
|
row.push_back(t);
|
||||||
|
// Cache tensor row t 5000 times using 10 threads.
|
||||||
|
for (auto k = 0; k < 10; ++k) {
|
||||||
|
Status vg_rc = vg.CreateAsyncTask("Test agent", [&myClient, &row]() -> Status {
|
||||||
|
TaskManager::FindMe()->Post();
|
||||||
|
for (auto i = 0; i < 500; i++) {
|
||||||
|
RETURN_IF_NOT_OK(myClient.WriteRow(row));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
EXPECT_TRUE(vg_rc.IsOk());
|
||||||
|
}
|
||||||
|
ASSERT_TRUE(vg.join_all().IsOk());
|
||||||
|
ASSERT_TRUE(vg.GetTaskErrorIfAny().IsOk());
|
||||||
|
rc = myClient.BuildPhaseDone();
|
||||||
|
ASSERT_TRUE(rc.IsOk());
|
||||||
|
// Get statistics from the server.
|
||||||
|
CacheClient::ServiceStat stat{};
|
||||||
|
rc = myClient.GetStat(&stat);
|
||||||
|
ASSERT_TRUE(rc.IsOk());
|
||||||
|
std::cout << stat.min_row_id << ":" << stat.max_row_id << ":" << stat.num_mem_cached << ":" << stat.num_disk_cached
|
||||||
|
<< "\n";
|
||||||
|
// Expect there are 5000 rows there.
|
||||||
|
EXPECT_EQ(5000, stat.max_row_id - stat.min_row_id + 1);
|
||||||
|
// Get them all back using row id and compare with tensor t.
|
||||||
|
for (auto i = stat.min_row_id; i <= stat.max_row_id; ++i) {
|
||||||
|
tbl.clear();
|
||||||
|
row.clear();
|
||||||
|
rc = myClient.GetRows({i}, &tbl);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
row = tbl.front();
|
||||||
|
auto r = row.front();
|
||||||
|
bool cmp = (*t == *r);
|
||||||
|
EXPECT_TRUE(cmp);
|
||||||
|
}
|
||||||
|
rc = myClient.DestroyCache();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simple test with a repeated cache op over random data producer
|
||||||
|
//
|
||||||
|
// RepeatOp
|
||||||
|
// |
|
||||||
|
// CacheOp
|
||||||
|
// |
|
||||||
|
// RandomDataOp
|
||||||
|
//
|
||||||
|
TEST_F(MindDataTestCacheOp, TestRandomDataCache1) {
|
||||||
|
Status rc;
|
||||||
|
int32_t rank = 0; // not used
|
||||||
|
MS_LOG(INFO) << "UT test TestRandomDataCache1";
|
||||||
|
// Start with an empty execution tree
|
||||||
|
auto myTree = std::make_shared<ExecutionTree>();
|
||||||
|
|
||||||
|
// Create a schema using the C api's
|
||||||
|
std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
|
||||||
|
|
||||||
|
// 2 columns. First column is an "image" 640,480,3
|
||||||
|
TensorShape c1Shape({640, 480, 3});
|
||||||
|
ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
|
||||||
|
rank, // not used
|
||||||
|
&c1Shape);
|
||||||
|
|
||||||
|
// Column 2 will just be a scalar label number
|
||||||
|
TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
|
||||||
|
ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
|
||||||
|
|
||||||
|
testSchema->AddColumn(c1);
|
||||||
|
testSchema->AddColumn(c2);
|
||||||
|
|
||||||
|
// RandomDataOp
|
||||||
|
std::shared_ptr<RandomDataOp> myRandomDataOp;
|
||||||
|
rc = RandomDataOp::Builder()
|
||||||
|
.SetRowsPerBuffer(4)
|
||||||
|
.SetNumWorkers(4)
|
||||||
|
.SetDataSchema(std::move(testSchema))
|
||||||
|
.SetTotalRows(50) // 50 samples for now
|
||||||
|
.Build(&myRandomDataOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myTree->AssociateNode(myRandomDataOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
// CacheOp
|
||||||
|
// size of 0, spilling is true
|
||||||
|
std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 0, true);
|
||||||
|
std::shared_ptr<CacheOp> myCacheOp;
|
||||||
|
|
||||||
|
int64_t num_samples = 0;
|
||||||
|
int64_t start_index = 0;
|
||||||
|
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
|
||||||
|
rc = CacheOp::Builder()
|
||||||
|
.SetNumWorkers(5)
|
||||||
|
.SetClient(myClient)
|
||||||
|
.SetRowsPerBuffer(4)
|
||||||
|
.SetSampler(std::move(seq_sampler))
|
||||||
|
.Build(&myCacheOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myTree->AssociateNode(myCacheOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
// RepeatOp
|
||||||
|
uint32_t numRepeats = 4;
|
||||||
|
std::shared_ptr<RepeatOp> myRepeatOp;
|
||||||
|
rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myTree->AssociateNode(myRepeatOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
// Assign tree relations and root
|
||||||
|
rc = myRepeatOp->AddChild(myCacheOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myCacheOp->AddChild(myRandomDataOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myTree->AssignRoot(myRepeatOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Launching tree and begin iteration";
|
||||||
|
rc = myTree->Prepare();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
// quick check to see what tree looks like
|
||||||
|
std::ostringstream ss;
|
||||||
|
ss << *myTree; // some funny const error if I try to write directly to ms log stream
|
||||||
|
MS_LOG(INFO) << "Here's the tree:\n" << ss.str();
|
||||||
|
|
||||||
|
std::cout << *myClient << std::endl;
|
||||||
|
|
||||||
|
rc = myTree->Launch();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
// Start the loop of reading tensors from our pipeline
|
||||||
|
DatasetIterator dI(myTree);
|
||||||
|
TensorRow tensorList;
|
||||||
|
rc = dI.FetchNextTensorRow(&tensorList);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
int rowCount = 0;
|
||||||
|
while (!tensorList.empty()) {
|
||||||
|
// Don't display these rows, just count them
|
||||||
|
MS_LOG(INFO) << "Row fetched #: " << rowCount;
|
||||||
|
rc = dI.FetchNextTensorRow(&tensorList);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rowCount++;
|
||||||
|
}
|
||||||
|
ASSERT_EQ(rowCount, 200);
|
||||||
|
rc = myClient->DestroyCache();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
}
|
||||||
|
|
||||||
|
//// Simple test with a repeated cache op over random data producer.
|
||||||
|
//// This one will exceed memory and require a spill.
|
||||||
|
////
|
||||||
|
//// RepeatOp
|
||||||
|
//// |
|
||||||
|
//// CacheOp
|
||||||
|
//// |
|
||||||
|
//// RandomDataOp
|
||||||
|
////
|
||||||
|
TEST_F(MindDataTestCacheOp, TestRandomDataCacheSpill) {
|
||||||
|
Status rc;
|
||||||
|
int32_t rank = 0; // not used
|
||||||
|
MS_LOG(INFO) << "UT test TestRandomDataCacheSpill";
|
||||||
|
// Start with an empty execution tree
|
||||||
|
auto myTree = std::make_shared<ExecutionTree>();
|
||||||
|
|
||||||
|
// Create a schema using the C api's
|
||||||
|
std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
|
||||||
|
|
||||||
|
// 2 columns. First column is an "image" 640,480,3
|
||||||
|
TensorShape c1Shape({640, 480, 3});
|
||||||
|
ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
|
||||||
|
rank, // not used
|
||||||
|
&c1Shape);
|
||||||
|
|
||||||
|
// Column 2 will just be a scalar label number
|
||||||
|
TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
|
||||||
|
ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
|
||||||
|
|
||||||
|
testSchema->AddColumn(c1);
|
||||||
|
testSchema->AddColumn(c2);
|
||||||
|
|
||||||
|
// RandomDataOp
|
||||||
|
std::shared_ptr<RandomDataOp> myRandomDataOp;
|
||||||
|
rc = RandomDataOp::Builder()
|
||||||
|
.SetRowsPerBuffer(2)
|
||||||
|
.SetNumWorkers(4)
|
||||||
|
.SetDataSchema(std::move(testSchema))
|
||||||
|
.SetTotalRows(10)
|
||||||
|
.Build(&myRandomDataOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myTree->AssociateNode(myRandomDataOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
// CacheOp
|
||||||
|
int64_t num_samples = 0;
|
||||||
|
int64_t start_index = 0;
|
||||||
|
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
|
||||||
|
std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 4, true);
|
||||||
|
std::shared_ptr<CacheOp> myCacheOp;
|
||||||
|
rc = CacheOp::Builder()
|
||||||
|
.SetNumWorkers(4)
|
||||||
|
.SetClient(myClient)
|
||||||
|
.SetRowsPerBuffer(3)
|
||||||
|
.SetSampler(std::move(seq_sampler))
|
||||||
|
.Build(&myCacheOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myTree->AssociateNode(myCacheOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
// RepeatOp
|
||||||
|
uint32_t numRepeats = 4;
|
||||||
|
std::shared_ptr<RepeatOp> myRepeatOp;
|
||||||
|
rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myTree->AssociateNode(myRepeatOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
// Assign tree relations and root
|
||||||
|
rc = myRepeatOp->AddChild(myCacheOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myCacheOp->AddChild(myRandomDataOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myTree->AssignRoot(myRepeatOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Launching tree and begin iteration";
|
||||||
|
rc = myTree->Prepare();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
std::cout << *myClient << std::endl;
|
||||||
|
|
||||||
|
rc = myTree->Launch();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
// Start the loop of reading tensors from our pipeline
|
||||||
|
DatasetIterator dI(myTree);
|
||||||
|
TensorRow tensorList;
|
||||||
|
rc = dI.FetchNextTensorRow(&tensorList);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
int rowCount = 0;
|
||||||
|
while (!tensorList.empty()) {
|
||||||
|
// Don't display these rows, just count them
|
||||||
|
MS_LOG(INFO) << "Row fetched #: " << rowCount;
|
||||||
|
rc = dI.FetchNextTensorRow(&tensorList);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rowCount++;
|
||||||
|
}
|
||||||
|
ASSERT_EQ(rowCount, 40);
|
||||||
|
rc = myClient->DestroyCache();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) {
|
||||||
|
Status rc;
|
||||||
|
int64_t num_samples = 0;
|
||||||
|
int64_t start_index = 0;
|
||||||
|
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
|
||||||
|
|
||||||
|
std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 0, true);
|
||||||
|
|
||||||
|
std::shared_ptr<CacheMergeOp> myMergeOp;
|
||||||
|
rc = CacheMergeOp::Builder().SetNumWorkers(3).SetOpConnectorSize(3).SetNumCleaner(2).SetClient(myClient).Build(
|
||||||
|
&myMergeOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
std::shared_ptr<CacheLookupOp> myLookupOp;
|
||||||
|
rc = CacheLookupOp::Builder()
|
||||||
|
.SetNumWorkers(3)
|
||||||
|
.SetOpConnectorSize(3)
|
||||||
|
.SetClient(myClient)
|
||||||
|
.SetSampler(seq_sampler)
|
||||||
|
.Build(&myLookupOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
std::shared_ptr<ImageFolderOp> so;
|
||||||
|
ImageFolderOp::Builder builder;
|
||||||
|
builder.SetSampler(myLookupOp)
|
||||||
|
.SetOpConnectorSize(3)
|
||||||
|
.SetNumWorkers(3)
|
||||||
|
.SetRowsPerBuffer(2)
|
||||||
|
.SetExtensions({".jpg", ".JPEG"})
|
||||||
|
.SetRecursive(true)
|
||||||
|
.SetImageFolderDir(datasets_root_path_ + "/testPK/data");
|
||||||
|
rc = builder.Build(&so);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
// RepeatOp
|
||||||
|
uint32_t numRepeats = 4;
|
||||||
|
std::shared_ptr<RepeatOp> myRepeatOp;
|
||||||
|
rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
auto myTree = std::make_shared<ExecutionTree>();
|
||||||
|
rc = myTree->AssociateNode(so);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myTree->AssociateNode(myLookupOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myTree->AssociateNode(myMergeOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myTree->AssociateNode(myRepeatOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myTree->AssignRoot(myRepeatOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
rc = myRepeatOp->AddChild(myMergeOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myMergeOp->AddChild(myLookupOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myMergeOp->AddChild(so);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
rc = myTree->Prepare();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myTree->Launch();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
// Start the loop of reading tensors from our pipeline
|
||||||
|
DatasetIterator dI(myTree);
|
||||||
|
TensorRow tensorList;
|
||||||
|
rc = dI.FetchNextTensorRow(&tensorList);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
int rowCount = 0;
|
||||||
|
while (!tensorList.empty()) {
|
||||||
|
rc = dI.FetchNextTensorRow(&tensorList);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
if (rc.IsError()) {
|
||||||
|
std::cout << rc << std::endl;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
rowCount++;
|
||||||
|
}
|
||||||
|
ASSERT_EQ(rowCount, 176);
|
||||||
|
std::cout << "Row count : " << rowCount << std::endl;
|
||||||
|
rc = myClient->DestroyCache();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
}
|
||||||
|
|
||||||
|
//// Simple test with a repeated cache op over random data producer.
|
||||||
|
//// The difference in this one is that you do not add the sampler to the cache op directly.
|
||||||
|
//// Instead, the sampler is added as part of the leaf op construction. Then, the prepare
|
||||||
|
//// phase will pull this up from the leaf and into the cache.
|
||||||
|
//// It removes the sampler from the leaf op, which doesn't make sense there anyway for
|
||||||
|
//// the RandomDataOp which doesn't support sampling without a cache.
|
||||||
|
////
|
||||||
|
//// RepeatOp
|
||||||
|
//// |
|
||||||
|
//// CacheOp
|
||||||
|
//// |
|
||||||
|
//// RandomDataOp
|
||||||
|
////
|
||||||
|
TEST_F(MindDataTestCacheOp, TestCacheInheritSampler) {
|
||||||
|
Status rc;
|
||||||
|
int32_t rank = 0; // not used
|
||||||
|
MS_LOG(INFO) << "UT test TestCacheInheritSampler";
|
||||||
|
|
||||||
|
int64_t num_samples = 0;
|
||||||
|
int64_t start_index = 0;
|
||||||
|
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
|
||||||
|
|
||||||
|
// Start with an empty execution tree
|
||||||
|
auto myTree = std::make_shared<ExecutionTree>();
|
||||||
|
|
||||||
|
// Create a schema using the C api's
|
||||||
|
std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
|
||||||
|
|
||||||
|
// 2 columns. First column is an "image" 640,480,3
|
||||||
|
TensorShape c1Shape({640, 480, 3});
|
||||||
|
ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
|
||||||
|
rank, // not used
|
||||||
|
&c1Shape);
|
||||||
|
|
||||||
|
// Column 2 will just be a scalar label number
|
||||||
|
TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
|
||||||
|
ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
|
||||||
|
|
||||||
|
testSchema->AddColumn(c1);
|
||||||
|
testSchema->AddColumn(c2);
|
||||||
|
|
||||||
|
// RandomDataOp
|
||||||
|
std::shared_ptr<RandomDataOp> myRandomDataOp;
|
||||||
|
rc = RandomDataOp::Builder()
|
||||||
|
.SetRowsPerBuffer(2)
|
||||||
|
.SetNumWorkers(4)
|
||||||
|
.SetDataSchema(std::move(testSchema))
|
||||||
|
.SetTotalRows(10)
|
||||||
|
.SetSampler(std::move(seq_sampler))
|
||||||
|
.Build(&myRandomDataOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myTree->AssociateNode(myRandomDataOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
// CacheOp
|
||||||
|
std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 4, true);
|
||||||
|
std::shared_ptr<CacheOp> myCacheOp;
|
||||||
|
rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myTree->AssociateNode(myCacheOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
// RepeatOp
|
||||||
|
uint32_t numRepeats = 4;
|
||||||
|
std::shared_ptr<RepeatOp> myRepeatOp;
|
||||||
|
rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myTree->AssociateNode(myRepeatOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
// Assign tree relations and root
|
||||||
|
rc = myRepeatOp->AddChild(myCacheOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myCacheOp->AddChild(myRandomDataOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = myTree->AssignRoot(myRepeatOp);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Launching tree and begin iteration";
|
||||||
|
rc = myTree->Prepare();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
std::cout << *myClient << std::endl;
|
||||||
|
|
||||||
|
rc = myTree->Launch();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
// Start the loop of reading tensors from our pipeline
|
||||||
|
DatasetIterator dI(myTree);
|
||||||
|
TensorRow tensorList;
|
||||||
|
rc = dI.FetchNextTensorRow(&tensorList);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
int rowCount = 0;
|
||||||
|
while (!tensorList.empty()) {
|
||||||
|
// Don't display these rows, just count them
|
||||||
|
MS_LOG(INFO) << "Row fetched #: " << rowCount;
|
||||||
|
rc = dI.FetchNextTensorRow(&tensorList);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rowCount++;
|
||||||
|
}
|
||||||
|
ASSERT_EQ(rowCount, 40);
|
||||||
|
rc = myClient->DestroyCache();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
}
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,157 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
Testing cache operator with mappable datasets
|
||||||
|
"""
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.dataset.transforms.vision.c_transforms as c_vision
|
||||||
|
from mindspore import log as logger
|
||||||
|
from util import save_and_check_md5
|
||||||
|
|
||||||
|
DATA_DIR = "../data/dataset/testImageNetData/train/"
|
||||||
|
|
||||||
|
GENERATE_GOLDEN = False
|
||||||
|
|
||||||
|
def test_cache_map_basic1():
|
||||||
|
"""
|
||||||
|
Test mappable leaf with cache op right over the leaf
|
||||||
|
|
||||||
|
Repeat
|
||||||
|
|
|
||||||
|
Map(decode)
|
||||||
|
|
|
||||||
|
Cache
|
||||||
|
|
|
||||||
|
ImageFolder
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info("Test cache map basic 1")
|
||||||
|
|
||||||
|
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
|
||||||
|
|
||||||
|
# This DATA_DIR only has 2 images in it
|
||||||
|
ds1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR, cache=some_cache)
|
||||||
|
decode_op = c_vision.Decode()
|
||||||
|
ds1 = ds1.map(input_columns=["image"], operations=decode_op)
|
||||||
|
ds1 = ds1.repeat(4)
|
||||||
|
|
||||||
|
filename = "cache_map_01_result.npz"
|
||||||
|
save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN)
|
||||||
|
|
||||||
|
logger.info("test_cache_map_basic1 Ended.\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_map_basic2():
|
||||||
|
"""
|
||||||
|
Test mappable leaf with the cache op later in the tree above the map(decode)
|
||||||
|
|
||||||
|
Repeat
|
||||||
|
|
|
||||||
|
Cache
|
||||||
|
|
|
||||||
|
Map(decode)
|
||||||
|
|
|
||||||
|
ImageFolder
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info("Test cache map basic 2")
|
||||||
|
|
||||||
|
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
|
||||||
|
|
||||||
|
# This DATA_DIR only has 2 images in it
|
||||||
|
ds1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR)
|
||||||
|
decode_op = c_vision.Decode()
|
||||||
|
ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
|
||||||
|
ds1 = ds1.repeat(4)
|
||||||
|
|
||||||
|
filename = "cache_map_02_result.npz"
|
||||||
|
save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN)
|
||||||
|
|
||||||
|
logger.info("test_cache_map_basic2 Ended.\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_map_basic3():
|
||||||
|
"""
|
||||||
|
Test a repeat under mappable cache
|
||||||
|
|
||||||
|
Cache
|
||||||
|
|
|
||||||
|
Map(decode)
|
||||||
|
|
|
||||||
|
Repeat
|
||||||
|
|
|
||||||
|
ImageFolder
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info("Test cache basic 3")
|
||||||
|
|
||||||
|
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
|
||||||
|
|
||||||
|
# This DATA_DIR only has 2 images in it
|
||||||
|
ds1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR)
|
||||||
|
decode_op = c_vision.Decode()
|
||||||
|
ds1 = ds1.repeat(4)
|
||||||
|
ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
|
||||||
|
|
||||||
|
num_iter = 0
|
||||||
|
for _ in ds1.create_dict_iterator():
|
||||||
|
num_iter += 1
|
||||||
|
|
||||||
|
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||||
|
assert num_iter == 8
|
||||||
|
logger.info('test_cache_basic3 Ended.\n')
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_map_failure1():
|
||||||
|
"""
|
||||||
|
Test nested cache (failure)
|
||||||
|
|
||||||
|
Repeat
|
||||||
|
|
|
||||||
|
Cache
|
||||||
|
|
|
||||||
|
Map(decode)
|
||||||
|
|
|
||||||
|
Cache
|
||||||
|
|
|
||||||
|
ImageFolder
|
||||||
|
|
||||||
|
"""
|
||||||
|
logger.info("Test cache failure 1")
|
||||||
|
|
||||||
|
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
|
||||||
|
|
||||||
|
# This DATA_DIR only has 2 images in it
|
||||||
|
ds1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR, cache=some_cache)
|
||||||
|
decode_op = c_vision.Decode()
|
||||||
|
ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
|
||||||
|
ds1 = ds1.repeat(4)
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_iter = 0
|
||||||
|
for _ in ds1.create_dict_iterator():
|
||||||
|
num_iter += 1
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
|
assert "Nested cache operations is not supported!" in str(e)
|
||||||
|
|
||||||
|
assert num_iter == 0
|
||||||
|
logger.info('test_cache_failure1 Ended.\n')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_cache_map_basic1()
|
||||||
|
test_cache_map_basic2()
|
||||||
|
test_cache_map_basic3()
|
||||||
|
test_cache_map_failure1()
|
|
@ -0,0 +1,429 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
Testing cache operator with non-mappable datasets
|
||||||
|
"""
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.dataset.transforms.vision.c_transforms as c_vision
|
||||||
|
from mindspore import log as logger
|
||||||
|
|
||||||
|
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||||
|
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||||
|
|
||||||
|
GENERATE_GOLDEN = False
|
||||||
|
|
||||||
|
def test_cache_nomap_basic1():
|
||||||
|
"""
|
||||||
|
A random dataset (a non mappable dataset) with a cache over it just after the leaf
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info("Test cache nomap basic 1")
|
||||||
|
|
||||||
|
schema = ds.Schema()
|
||||||
|
schema.add_column('image', de_type=mstype.uint8,
|
||||||
|
shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
|
||||||
|
schema.add_column('label', de_type=mstype.uint8, shape=[1])
|
||||||
|
|
||||||
|
# create a cache. arbitrary session_id for now
|
||||||
|
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
|
||||||
|
|
||||||
|
# User-created sampler here
|
||||||
|
ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=4, cache=some_cache)
|
||||||
|
ds1 = ds1.repeat(4)
|
||||||
|
|
||||||
|
num_iter = 0
|
||||||
|
for data in ds1.create_dict_iterator():
|
||||||
|
logger.info("printing the label: {}".format(data["label"]))
|
||||||
|
num_iter += 1
|
||||||
|
|
||||||
|
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||||
|
assert num_iter == 40
|
||||||
|
logger.info("test_cache_nomap_basic1 Ended.\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_nomap_basic2():
|
||||||
|
"""
|
||||||
|
A random dataset (a non mappable dataset) with a cache over it just after the leaf
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info("Test cache nomap basic 2")
|
||||||
|
|
||||||
|
schema = ds.Schema()
|
||||||
|
schema.add_column('image', de_type=mstype.uint8,
|
||||||
|
shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
|
||||||
|
schema.add_column('label', de_type=mstype.uint8, shape=[1])
|
||||||
|
|
||||||
|
# create a cache. arbitrary session_id for now
|
||||||
|
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
|
||||||
|
|
||||||
|
# sampler arg not given directly, however any of these args will auto-generate an appropriate sampler:
|
||||||
|
# num_samples, shuffle, num_shards, shard_id
|
||||||
|
# In this case, the presence of num_samples chooses a sampler.
|
||||||
|
ds1 = ds.RandomDataset(schema=schema, total_rows=20, num_samples=20, num_parallel_workers=4, cache=some_cache)
|
||||||
|
ds1 = ds1.repeat(2)
|
||||||
|
|
||||||
|
num_iter = 0
|
||||||
|
for data in ds1.create_dict_iterator():
|
||||||
|
logger.info("printing the label: {}".format(data["label"]))
|
||||||
|
num_iter += 1
|
||||||
|
|
||||||
|
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||||
|
assert num_iter == 40
|
||||||
|
logger.info("test_cache_nomap_basic2 Ended.\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_nomap_basic3():
|
||||||
|
"""
|
||||||
|
A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
|
||||||
|
|
||||||
|
Repeat
|
||||||
|
|
|
||||||
|
Map(decode)
|
||||||
|
|
|
||||||
|
Cache
|
||||||
|
|
|
||||||
|
TFReader
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info("Test cache nomap basic 3")
|
||||||
|
|
||||||
|
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
|
||||||
|
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
|
||||||
|
decode_op = c_vision.Decode()
|
||||||
|
ds1 = ds1.map(input_columns=["image"], operations=decode_op)
|
||||||
|
ds1 = ds1.repeat(4)
|
||||||
|
|
||||||
|
num_iter = 0
|
||||||
|
for _ in ds1.create_dict_iterator():
|
||||||
|
num_iter += 1
|
||||||
|
|
||||||
|
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||||
|
assert num_iter == 12
|
||||||
|
logger.info("test_cache_nomap_basic3 Ended.\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_nomap_basic4():
|
||||||
|
"""
|
||||||
|
A TF reader dataset (a non mappable dataset) with a map decode and cache after it
|
||||||
|
Since a global shuffle is used for the tf reader, it will inject a shuffle op over the tf.
|
||||||
|
But, if there's a cache later, that shuffle becomes invalid and should be removed.
|
||||||
|
|
||||||
|
Repeat
|
||||||
|
|
|
||||||
|
Cache
|
||||||
|
|
|
||||||
|
Map(decode)
|
||||||
|
|
|
||||||
|
TFReader
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info("Test cache nomap basic 4")
|
||||||
|
|
||||||
|
# This dataset has 3 records in it only
|
||||||
|
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
|
||||||
|
# With shuffle not being set, TF defaults to a "global" shuffle when there is no cache
|
||||||
|
# in the picture. This causes a shuffle-injection over the TF. For clarify, this test will
|
||||||
|
# explicitly give the global option, even though it's the default in python.
|
||||||
|
# But, when caching is added in the ascendent tree above TF, we do global shuffling
|
||||||
|
# through the sampler over the cache, not by the shuffle op. In that case, tree prepare
|
||||||
|
# will remove the shuffle op that got injected by the initial tree creation.
|
||||||
|
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL)
|
||||||
|
decode_op = c_vision.Decode()
|
||||||
|
|
||||||
|
ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
|
||||||
|
ds1 = ds1.repeat(4)
|
||||||
|
|
||||||
|
num_iter = 0
|
||||||
|
for _ in ds1.create_dict_iterator():
|
||||||
|
num_iter += 1
|
||||||
|
|
||||||
|
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||||
|
assert num_iter == 12
|
||||||
|
logger.info("test_cache_nomap_basic4 Ended.\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_nomap_basic5():
|
||||||
|
"""
|
||||||
|
A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
|
||||||
|
Same as test 3, but this one does not have shuffle arg, causing tf to default to global
|
||||||
|
shuffle which attempts to inject a shuffle operator. However, since there is a cache
|
||||||
|
we do not need global shuffle, so the shuffle will not be built. It ends up being
|
||||||
|
identical to test basic 3, however we arrive at the same tree in different codepaths
|
||||||
|
(if there was no cache, then the shuffle IS built)
|
||||||
|
|
||||||
|
Repeat
|
||||||
|
|
|
||||||
|
Map(decode)
|
||||||
|
|
|
||||||
|
Cache
|
||||||
|
|
|
||||||
|
TFReader
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info("Test cache nomap basic 5")
|
||||||
|
|
||||||
|
# This dataset has 3 records in it only
|
||||||
|
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
|
||||||
|
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], cache=some_cache)
|
||||||
|
decode_op = c_vision.Decode()
|
||||||
|
ds1 = ds1.map(input_columns=["image"], operations=decode_op)
|
||||||
|
ds1 = ds1.repeat(4)
|
||||||
|
|
||||||
|
num_iter = 0
|
||||||
|
for _ in ds1.create_dict_iterator():
|
||||||
|
num_iter += 1
|
||||||
|
|
||||||
|
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||||
|
assert num_iter == 12
|
||||||
|
logger.info("test_cache_nomap_basic5 Ended.\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_nomap_basic6():
|
||||||
|
"""
|
||||||
|
A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
|
||||||
|
In this one, the tf dataset will be given sharding configuration, however since a cache is
|
||||||
|
used, the tree prepare should undo the sharding configuration and instead, a distributed
|
||||||
|
sampler will be chosen with the same shard config.
|
||||||
|
|
||||||
|
Repeat
|
||||||
|
|
|
||||||
|
Map(decode)
|
||||||
|
|
|
||||||
|
Cache
|
||||||
|
|
|
||||||
|
TFReader
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info("Test cache nomap basic 6")
|
||||||
|
|
||||||
|
# This dataset has 3 records in it only
|
||||||
|
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
|
||||||
|
|
||||||
|
# With only 3 records shard into 3, we expect only 1 record returned for this shard
|
||||||
|
# However, the sharding will be done by the sampler, not by the tf record leaf node
|
||||||
|
# In this case, it is a row-based sharding, not the file-based sharding that would happen if
|
||||||
|
# there was not any cache.
|
||||||
|
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_shards=3, shard_id=1, cache=some_cache)
|
||||||
|
decode_op = c_vision.Decode()
|
||||||
|
ds1 = ds1.map(input_columns=["image"], operations=decode_op)
|
||||||
|
ds1 = ds1.repeat(4)
|
||||||
|
|
||||||
|
num_iter = 0
|
||||||
|
for _ in ds1.create_dict_iterator():
|
||||||
|
num_iter += 1
|
||||||
|
|
||||||
|
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||||
|
assert num_iter == 4
|
||||||
|
logger.info("test_cache_nomap_basic6 Ended.\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_nomap_basic7():
|
||||||
|
"""
|
||||||
|
A TF reader dataset (a non mappable dataset) that uses global shuffle, and is cached followed by
|
||||||
|
map.
|
||||||
|
In this one, the tf dataset with global shuffle might want to inject a shuffle op over top of the
|
||||||
|
tf reader, but since a cache is given, it will choose not to.
|
||||||
|
|
||||||
|
Repeat
|
||||||
|
|
|
||||||
|
Map(decode)
|
||||||
|
|
|
||||||
|
cache
|
||||||
|
|
|
||||||
|
TFReader
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info("Test cache nomap basic 7")
|
||||||
|
|
||||||
|
# This dataset has 3 records in it only
|
||||||
|
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
|
||||||
|
|
||||||
|
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL, cache=some_cache)
|
||||||
|
decode_op = c_vision.Decode()
|
||||||
|
ds1 = ds1.map(input_columns=["image"], operations=decode_op)
|
||||||
|
ds1 = ds1.repeat(4)
|
||||||
|
|
||||||
|
num_iter = 0
|
||||||
|
for _ in ds1.create_dict_iterator():
|
||||||
|
num_iter += 1
|
||||||
|
|
||||||
|
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||||
|
assert num_iter == 12
|
||||||
|
logger.info("test_cache_nomap_basic7 Ended.\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_nomap_allowed_share1():
|
||||||
|
"""
|
||||||
|
It is allowed to share the cache between the following two trees:
|
||||||
|
|
||||||
|
Repeat Shuffle
|
||||||
|
| |
|
||||||
|
Cache Cache
|
||||||
|
| |
|
||||||
|
TFReader TFReader
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info("Test cache nomap allowed share 1")
|
||||||
|
|
||||||
|
ds.config.set_seed(1)
|
||||||
|
# This dataset has 3 records in it only
|
||||||
|
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
|
||||||
|
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
|
||||||
|
ds1 = ds1.repeat(4)
|
||||||
|
|
||||||
|
ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
|
||||||
|
ds2 = ds2.shuffle(buffer_size=2)
|
||||||
|
|
||||||
|
num_iter = 0
|
||||||
|
for _ in ds1.create_dict_iterator():
|
||||||
|
num_iter += 1
|
||||||
|
assert num_iter == 12
|
||||||
|
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||||
|
|
||||||
|
num_iter = 0
|
||||||
|
for _ in ds2.create_dict_iterator():
|
||||||
|
num_iter += 1
|
||||||
|
assert num_iter == 3
|
||||||
|
logger.info("test_cache_nomap_allowed_share1 Ended.\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_nomap_allowed_share2():
|
||||||
|
"""
|
||||||
|
It is allowed to share the cache between the following two trees (with map decode):
|
||||||
|
|
||||||
|
Repeat Shuffle
|
||||||
|
| |
|
||||||
|
Cache Cache
|
||||||
|
| |
|
||||||
|
Map(decode) Map(decode)
|
||||||
|
| |
|
||||||
|
TFReader TFReader
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info("Test cache nomap allowed share 2")
|
||||||
|
|
||||||
|
ds.config.set_seed(1)
|
||||||
|
# This dataset has 3 records in it only
|
||||||
|
some_cache = ds.DatasetCache(session_id=2, size=0, spilling=True)
|
||||||
|
decode_op = c_vision.Decode()
|
||||||
|
|
||||||
|
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||||
|
ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
|
||||||
|
ds1 = ds1.repeat(4)
|
||||||
|
|
||||||
|
ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||||
|
ds2 = ds2.map(input_columns=["image"], operations=decode_op, cache=some_cache)
|
||||||
|
ds2 = ds2.shuffle(buffer_size=2)
|
||||||
|
|
||||||
|
num_iter = 0
|
||||||
|
for _ in ds1.create_dict_iterator():
|
||||||
|
num_iter += 1
|
||||||
|
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||||
|
assert num_iter == 12
|
||||||
|
|
||||||
|
num_iter = 0
|
||||||
|
for _ in ds2.create_dict_iterator():
|
||||||
|
num_iter += 1
|
||||||
|
assert num_iter == 3
|
||||||
|
logger.info("test_cache_nomap_allowed_share2 Ended.\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_nomap_allowed_share3():
|
||||||
|
"""
|
||||||
|
It is allowed to share the cache between the following two trees (different shard ids):
|
||||||
|
|
||||||
|
Repeat Repeat
|
||||||
|
| |
|
||||||
|
Cache Cache
|
||||||
|
| |
|
||||||
|
TFReader(shard_id = 0) TFReader(shard_id = 1)
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info("Test cache nomap allowed share 3")
|
||||||
|
|
||||||
|
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
|
||||||
|
|
||||||
|
tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data"]
|
||||||
|
ds1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=0, num_samples=3, shuffle=False, cache=some_cache)
|
||||||
|
ds1 = ds1.repeat(4)
|
||||||
|
|
||||||
|
ds2 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=1, num_samples=3, shuffle=False, cache=some_cache)
|
||||||
|
ds2 = ds2.repeat(4)
|
||||||
|
|
||||||
|
num_iter = 0
|
||||||
|
for _ in ds1.create_dict_iterator():
|
||||||
|
num_iter += 1
|
||||||
|
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||||
|
assert num_iter == 12
|
||||||
|
|
||||||
|
num_iter = 0
|
||||||
|
for _ in ds2.create_dict_iterator():
|
||||||
|
num_iter += 1
|
||||||
|
assert num_iter == 12
|
||||||
|
logger.info("test_cache_nomap_allowed_share3 Ended.\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_nomap_disallowed_share1():
|
||||||
|
"""
|
||||||
|
It is not allowed to share the cache between the following two trees:
|
||||||
|
|
||||||
|
Cache Cache
|
||||||
|
| |
|
||||||
|
Map(decode) Map(rescale)
|
||||||
|
| |
|
||||||
|
TFReader TFReader
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info("Test cache nomap disallowed share1")
|
||||||
|
|
||||||
|
# This dataset has 3 records in it only
|
||||||
|
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
|
||||||
|
decode_op = c_vision.Decode()
|
||||||
|
rescale_op = c_vision.Rescale(1.0 / 255.0, -1.0)
|
||||||
|
|
||||||
|
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||||
|
ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
|
||||||
|
|
||||||
|
ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||||
|
ds2 = ds2.map(input_columns=["image"], operations=rescale_op, cache=some_cache)
|
||||||
|
|
||||||
|
num_iter = 0
|
||||||
|
for _ in ds1.create_dict_iterator():
|
||||||
|
num_iter += 1
|
||||||
|
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||||
|
assert num_iter == 3
|
||||||
|
|
||||||
|
try:
|
||||||
|
sum([1 for _ in ds2])
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||||
|
assert "Attempt to re-use a cache for a different tree!" in str(e)
|
||||||
|
|
||||||
|
logger.info("test_cache_nomap_disallowed_share1 Ended.\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_cache_nomap_basic1()
|
||||||
|
test_cache_nomap_basic2()
|
||||||
|
test_cache_nomap_basic3()
|
||||||
|
test_cache_nomap_basic4()
|
||||||
|
test_cache_nomap_basic5()
|
||||||
|
test_cache_nomap_basic6()
|
||||||
|
test_cache_nomap_basic7()
|
||||||
|
test_cache_nomap_allowed_share1()
|
||||||
|
test_cache_nomap_allowed_share2()
|
||||||
|
test_cache_nomap_allowed_share3()
|
||||||
|
test_cache_nomap_disallowed_share1()
|
|
@ -16,17 +16,16 @@ import mindspore.common.dtype as mstype
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
|
|
||||||
|
|
||||||
# just a basic test with parallel random data op
|
# just a basic test with parallel random data op
|
||||||
def test_randomdataset_basic1():
|
def test_randomdataset_basic1():
|
||||||
logger.info("Test randomdataset basic")
|
logger.info("Test randomdataset basic 1")
|
||||||
|
|
||||||
schema = ds.Schema()
|
schema = ds.Schema()
|
||||||
schema.add_column('image', de_type=mstype.uint8, shape=[2])
|
schema.add_column('image', de_type=mstype.uint8, shape=[2])
|
||||||
schema.add_column('label', de_type=mstype.uint8, shape=[1])
|
schema.add_column('label', de_type=mstype.uint8, shape=[1])
|
||||||
|
|
||||||
# apply dataset operations
|
# apply dataset operations
|
||||||
ds1 = ds.RandomDataset(schema=schema, num_samples=50, num_parallel_workers=4)
|
ds1 = ds.RandomDataset(schema=schema, total_rows=50, num_parallel_workers=4)
|
||||||
ds1 = ds1.repeat(4)
|
ds1 = ds1.repeat(4)
|
||||||
|
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
|
@ -36,8 +35,9 @@ def test_randomdataset_basic1():
|
||||||
logger.info("{} label: {}".format(num_iter, data["label"]))
|
logger.info("{} label: {}".format(num_iter, data["label"]))
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
|
|
||||||
logger.info("Number of data in ds1: ", num_iter)
|
logger.info("Number of data in ds1: {}".format(num_iter))
|
||||||
assert num_iter == 200
|
assert num_iter == 200
|
||||||
|
logger.info("Test randomdataset basic 1 complete")
|
||||||
|
|
||||||
|
|
||||||
# Another simple test
|
# Another simple test
|
||||||
|
@ -49,10 +49,8 @@ def test_randomdataset_basic2():
|
||||||
shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
|
shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
|
||||||
schema.add_column('label', de_type=mstype.uint8, shape=[1])
|
schema.add_column('label', de_type=mstype.uint8, shape=[1])
|
||||||
|
|
||||||
# Make up about 10 samples
|
# Make up 10 rows
|
||||||
ds1 = ds.RandomDataset(schema=schema, num_samples=10, num_parallel_workers=1)
|
ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=1)
|
||||||
|
|
||||||
# cache size allows for about 4 images since each image just a bit less than 1MB, after that we will have to spill
|
|
||||||
ds1 = ds1.repeat(4)
|
ds1 = ds1.repeat(4)
|
||||||
|
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
|
@ -62,11 +60,31 @@ def test_randomdataset_basic2():
|
||||||
logger.info("printing the label: {}".format(data["label"]))
|
logger.info("printing the label: {}".format(data["label"]))
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
|
|
||||||
logger.info("Number of data in ds1: ", num_iter)
|
logger.info("Number of data in ds1: {}".format(num_iter))
|
||||||
assert num_iter == 40
|
assert num_iter == 40
|
||||||
|
logger.info("Test randomdataset basic 2 complete")
|
||||||
|
|
||||||
|
|
||||||
|
# Another simple test
|
||||||
|
def test_randomdataset_basic3():
|
||||||
|
logger.info("Test randomdataset basic 3")
|
||||||
|
|
||||||
|
# Make up 10 samples, but here even the schema is randomly created
|
||||||
|
# The columns are named like this "c0", "c1", "c2" etc
|
||||||
|
# But, we will use a tuple iterator instead of dict iterator so the column names
|
||||||
|
# are not needed to iterate
|
||||||
|
ds1 = ds.RandomDataset(total_rows=10, num_parallel_workers=1)
|
||||||
|
ds1 = ds1.repeat(2)
|
||||||
|
|
||||||
|
num_iter = 0
|
||||||
|
for _ in ds1.create_tuple_iterator():
|
||||||
|
num_iter += 1
|
||||||
|
|
||||||
|
logger.info("Number of data in ds1: {}".format(num_iter))
|
||||||
|
assert num_iter == 20
|
||||||
|
logger.info("Test randomdataset basic 3 Complete")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_randomdataset_basic1()
|
test_randomdataset_basic1()
|
||||||
test_randomdataset_basic2()
|
test_randomdataset_basic2()
|
||||||
logger.info('test_randomdataset_basic Ended.\n')
|
test_randomdataset_basic3()
|
||||||
|
|
Loading…
Reference in New Issue