!2891 CacheOp phase 1

Merge pull request !2891 from Jamie/CacheOp_dev
This commit is contained in:
mindspore-ci-bot 2020-07-13 23:56:57 +08:00 committed by Gitee
commit eadcb341e1
82 changed files with 5868 additions and 374 deletions

View File

@ -47,6 +47,8 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/dataset/include)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default")
ms_build_flatbuffers("engine/cache/de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${CMAKE_BINARY_DIR})
################## Include sub-modules ###############################
add_subdirectory(util)
add_subdirectory(core)
@ -55,7 +57,7 @@ add_subdirectory(engine)
add_subdirectory(api)
add_subdirectory(text)
######################################################################
add_dependencies(core utils)
add_dependencies(utils core)
add_dependencies(kernels-image core)
add_dependencies(kernels-data core)
add_dependencies(kernels core)
@ -89,6 +91,8 @@ set(submodules
$<TARGET_OBJECTS:engine-perf>
$<TARGET_OBJECTS:engine-datasetops>
$<TARGET_OBJECTS:engine-opt>
$<TARGET_OBJECTS:engine-cache-client>
$<TARGET_OBJECTS:engine-cache-server>
$<TARGET_OBJECTS:engine>
$<TARGET_OBJECTS:text>
$<TARGET_OBJECTS:text-kernels>
@ -106,6 +110,8 @@ else ()
add_library(_c_dataengine SHARED ${submodules})
endif ()
add_dependencies(_c_dataengine generated_engine_files)
set_target_properties(_c_dataengine PROPERTIES
PREFIX "${PYTHON_MODULE_PREFIX}"
SUFFIX "${PYTHON_MODULE_EXTENSION}"

View File

@ -21,8 +21,10 @@
#include "common/utils.h"
#include "dataset/core/tensor.h"
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/bucket_batch_by_length_op.h"
#include "dataset/engine/datasetops/cache_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "dataset/engine/datasetops/source/celeba_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
@ -34,6 +36,7 @@
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/kernels/py_func_op.h"
#include "dataset/util/random.h"
#include "dataset/util/status.h"
@ -441,6 +444,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
MapOp::Builder map_builder;
std::vector<std::shared_ptr<TensorOp>> tensor_op_list;
std::vector<std::string> project_columns;
std::shared_ptr<CacheClient> cache_client = nullptr;
int num_workers = 0;
if (args["operations"].is_none()) RETURN_STATUS_UNEXPECTED("Error: 'operations' is not set. \n");
@ -456,7 +461,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
} else if (key == "columns_order") {
project_columns = ToStringVector(value);
} else if (key == "num_parallel_workers") {
(void)map_builder.SetNumWorkers(ToInt(value));
num_workers = ToInt(value);
(void)map_builder.SetNumWorkers(num_workers);
} else if (key == "prefetch_size") {
(void)map_builder.SetOpConnectorSize(ToInt(value));
} else if (key == "operations") {
@ -477,6 +483,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
}
if (tensor_op_list.empty()) RETURN_STATUS_UNEXPECTED("Error: tensor_op is invalid or not set.");
(void)map_builder.SetTensorFuncs(std::move(tensor_op_list));
} else if (key == "cache") {
cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else {
RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key);
}
@ -499,6 +507,15 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
*bottom = map_op;
}
// Additionally, add a cache if required. This will go over top of the project op if one
// was created, otherwise it goes over top of the map op
if (cache_client) {
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, *top, &cache_op));
*top = cache_op;
*bottom = map_op;
}
return Status::OK();
}
@ -809,6 +826,9 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
std::shared_ptr<DatasetOp> *bottom) {
// Required arguments
std::vector<std::string> files_list;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
int num_workers = 0;
std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>();
if (!args["dataset_files"].is_none()) {
files_list = ToStringVector(args["dataset_files"]);
@ -828,7 +848,8 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
num_workers = ToInt(value);
(void)builder->SetNumWorkers(num_workers);
} else if (key == "columns_list") {
columns_to_load = ToStringVector(value);
(void)builder->SetColumnsToLoad(columns_to_load);
@ -848,6 +869,11 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
(void)builder->SetDeviceId(ToInt(value));
} else if (key == "shard_equal_rows") {
(void)builder->SetShardEqualRows(ToBool(value));
} else if (key == "cache") {
cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
}
}
}
@ -860,12 +886,27 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
}
(void)builder->SetDataSchema(std::move(schema));
}
// If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed
// because TFReaderOp is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we save the sampler here in a leaf node that does not use sampling.
if (sampler) {
(void)builder->SetSampler(std::move(sampler));
} else if (cache_client) {
int64_t num_samples = 0;
int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
(void)builder->SetSampler(std::move(sampler));
}
std::shared_ptr<TFReaderOp> tf_op;
RETURN_IF_NOT_OK(builder->Build(&tf_op));
RETURN_IF_NOT_OK(tree_->AssociateNode(tf_op));
*top = tf_op;
if (shuffle_required) {
if (!cache_client && shuffle_required) {
const boolean estimate = true;
const int64_t workers = 8;
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
@ -882,6 +923,15 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
*bottom = tf_op;
}
// Add a cache op over this op if required and update the output subtree (top/bottom)
if (cache_client) {
// Note, it is not allowed to have both shuffle and cache
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, tf_op, &cache_op));
*top = cache_op;
*bottom = tf_op;
}
return Status::OK();
}
@ -906,6 +956,8 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data
std::string err_msg = "Error: No dataset path specified";
RETURN_STATUS_UNEXPECTED(err_msg);
}
int num_workers = 0;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<ImageFolderOp::Builder> builder = std::make_shared<ImageFolderOp::Builder>();
(void)builder->SetImageFolderDir(ToString(args["dataset_dir"]));
@ -915,7 +967,8 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
num_workers = ToInt(value);
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
@ -926,12 +979,27 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data
(void)builder->SetClassIndex(ToStringMap(value));
} else if (key == "decode") {
(void)builder->SetDecode(ToBool(value));
} else if (key == "cache") {
cache_client = value.cast<std::shared_ptr<CacheClient>>();
}
}
}
std::shared_ptr<ImageFolderOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*top = op;
std::shared_ptr<ImageFolderOp> if_op;
RETURN_IF_NOT_OK(builder->Build(&if_op));
RETURN_IF_NOT_OK(tree_->AssociateNode(if_op));
*top = if_op;
// Additionally, add a cache if required.
// Note that this cache op is only acting as a place holder for the caching position
// within the tree. Later, a pre-pass will execute a tree transform to set up the actual
// caching logic in the tree.
if (cache_client) {
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, if_op, &cache_op));
*top = cache_op;
*bottom = if_op;
}
return Status::OK();
}
@ -1130,9 +1198,12 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
std::shared_ptr<DatasetOp> *bottom) {
// Required arguments
RandomDataOp::Builder builder;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
int num_workers = 0;
if (args["num_samples"].is_none()) {
std::string err_msg = "Error: num_samples is a required argument";
if (args["total_rows"].is_none()) {
std::string err_msg = "Error: total_rows is a required argument";
RETURN_STATUS_UNEXPECTED(err_msg);
}
std::vector<std::string> columns_to_load;
@ -1141,16 +1212,23 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (key == "num_parallel_workers") {
(void)builder.SetNumWorkers(ToInt(value));
} else if (key == "schema_file_path" || key == "schema_json_string") {
schema_exists = true;
} else if (key == "columns_list") {
columns_to_load = ToStringVector(value);
} else if (key == "num_samples") {
// This is not sampling here. The random data op needs to know how much data to
// generate. It does not currently support sampling.
(void)builder.SetTotalRows(ToInt(value));
if (!value.is_none()) {
if (key == "num_parallel_workers") {
num_workers = ToInt(value);
(void)builder.SetNumWorkers(num_workers);
} else if (key == "schema_file_path" || key == "schema_json_string") {
schema_exists = true;
} else if (key == "columns_list") {
columns_to_load = ToStringVector(value);
} else if (key == "total_rows") {
// This is not sampling here. The random data op needs to know how much data to generate.
(void)builder.SetTotalRows(ToInt(value));
} else if (key == "cache") {
cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
}
}
}
if (schema_exists) {
@ -1162,9 +1240,34 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
}
(void)builder.SetDataSchema(std::move(schema));
}
std::shared_ptr<RandomDataOp> op;
RETURN_IF_NOT_OK(builder.Build(&op));
*top = op;
// If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed
// because RandomDataOp is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we save the sampler here in a leaf node that does not use sampling.
if (sampler) {
(void)builder.SetSampler(std::move(sampler));
} else if (cache_client) {
int64_t num_samples = 0;
int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
(void)builder.SetSampler(std::move(sampler));
}
std::shared_ptr<RandomDataOp> random_op = nullptr;
RETURN_IF_NOT_OK(builder.Build(&random_op));
RETURN_IF_NOT_OK(tree_->AssociateNode(random_op));
*top = random_op;
// Add a cache op over this op if required and update the output subtree (top/bottom)
if (cache_client) {
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, random_op, &cache_op));
*top = cache_op;
*bottom = random_op;
}
return Status::OK();
}
@ -1425,6 +1528,31 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp>
return Status::OK();
}
// Helper function to inject the cache operator over top of the current operation being built.
Status DEPipeline::AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num_workers,
std::shared_ptr<DatasetOp> input_op, std::shared_ptr<DatasetOp> *cache_op) {
std::shared_ptr<CacheOp> new_cache_op = nullptr;
CacheOp::Builder cache_builder;
// use the same number of workers as the leaf. We need some optimization here, the user does not
// give the cache op number of workers directly.
if (num_workers != 0) {
(void)cache_builder.SetNumWorkers(num_workers);
}
(void)cache_builder.SetClient(cache_client);
RETURN_IF_NOT_OK(cache_builder.Build(&new_cache_op));
RETURN_IF_NOT_OK(tree_->AssociateNode(new_cache_op));
RETURN_IF_NOT_OK(new_cache_op->AddChild(input_op));
// We have now created:
//
// CacheOp
// |
// input_op
//
*cache_op = new_cache_op;
return Status::OK();
}
// Helper function to inject a shuffle operator over top of the current operation being built.
Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr<DatasetOp> input_op,
std::shared_ptr<DatasetOp> *shuffle_op) {

View File

@ -35,6 +35,8 @@ namespace mindspore {
namespace dataset {
using DsOpPtr = std::shared_ptr<DatasetOp>;
class CacheClient;
// enum for the dataset operator names
enum OpName {
kShuffle,
@ -181,6 +183,16 @@ class DEPipeline {
static Status ParsePadInfo(py::handle value, PadInfo *pad_info);
/// \brief Helper function to inject a cache operator over top of the current operation being built.
/// \param[in] cache_client The client to use for caching
/// \param[in] num_workers The number of workers to use in the cache op
/// \param[in] input_op The operator to build the cache on top of
/// \param[out] cache_op The top node of the created subtree (subtree contains two nodes). In this case it will be
/// the cache operator
/// \return Status return code
Status AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num_workers, std::shared_ptr<DatasetOp> input_op,
std::shared_ptr<DatasetOp> *cache_op);
/// \brief Helper function to inject a shuffle operator over top of the current operation being built.
/// \param[in] shuffle_size The size to use in the shuffle buffer
/// \param[in] input_op The operator to build shuffle on top of

View File

@ -35,6 +35,7 @@
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/gnn/graph.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/kernels/data/concatenate_op.h"
@ -768,6 +769,11 @@ void bindInfoObjects(py::module *m) {
.def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num);
}
void bindCacheClient(py::module *m) {
(void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient")
.def(py::init<uint32_t, uint64_t, bool>());
}
void bindVocabObjects(py::module *m) {
(void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab")
.def(py::init<>())
@ -939,6 +945,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
bindSamplerOps(&m);
bindDatasetOps(&m);
bindInfoObjects(&m);
bindCacheClient(&m);
bindVocabObjects(&m);
bindGraphData(&m);
bindDependIcuTokenizerOps(&m);

View File

@ -2,6 +2,7 @@ add_subdirectory(datasetops)
add_subdirectory(opt)
add_subdirectory(gnn)
add_subdirectory(perf)
add_subdirectory(cache)
if (ENABLE_TDTQUE)
add_subdirectory(tdt)
endif ()
@ -17,7 +18,9 @@ add_library(engine OBJECT
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})
if (ENABLE_TDTQUE)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf)
else()
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf
engine-cache-client engine-cache-server)
else ()
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf
engine-cache-client engine-cache-server)
endif ()

View File

@ -0,0 +1,8 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine-cache-client OBJECT
cache_client.cc
cache_request.cc)
add_library(engine-cache-server OBJECT
cache_service.cc
cache_server.cc)

View File

@ -0,0 +1,208 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iomanip>
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/cache/cache_request.h"
#include "dataset/util/bit.h"
namespace mindspore {
namespace dataset {
// Constructor
CacheClient::CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill)
: server_connection_id_(0), session_id_(session_id), cache_crc_(0), cache_mem_sz_(cache_mem_sz), spill_(spill) {}
// print method for display cache details
void CacheClient::Print(std::ostream &out) const {
out << " Session id: " << session_id_ << "\n Cache crc: " << cache_crc_
<< "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << cache_mem_sz_
<< "\n Spilling: " << std::boolalpha << spill_;
}
Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const {
CacheRowRequest rq(server_connection_id_, cookie());
RETURN_IF_NOT_OK(rq.SerializeCacheRowRequest(row));
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
if (row_id_from_server != nullptr) {
*row_id_from_server = rq.GetRowIdAfterCache();
}
return Status::OK();
}
Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const {
std::unique_ptr<DataBuffer> db_ptr = std::move(in);
auto num_rows = db_ptr->NumRows();
std::vector<TensorRow> all_rows;
if (num_rows > 0) {
all_rows.reserve(num_rows);
// Break down the DataBuffer into TensorRow. We will send the requests async
// and then do a final wait.
MemGuard<CacheRowRequest> rq_arr;
RETURN_IF_NOT_OK(rq_arr.allocate(num_rows, server_connection_id_, cookie()));
CacheServer &cs = CacheServer::GetInstance();
for (auto i = 0; i < num_rows; ++i) {
TensorRow row;
auto rq = rq_arr[i];
RETURN_IF_NOT_OK(db_ptr->PopRow(&row));
RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(row));
RETURN_IF_NOT_OK(cs.PushRequest(rq));
// We can't let row go out of scope. Otherwise it will free all the tensor memory.
// So park it in the vector. When this function go out of scope, its memory
// will be freed.
all_rows.push_back(std::move(row));
}
// Now we wait for the requests to be done.
for (auto i = 0; i < num_rows; ++i) {
auto rq = rq_arr[i];
RETURN_IF_NOT_OK(rq->Wait());
}
}
return Status::OK();
}
Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const {
RETURN_UNEXPECTED_IF_NULL(out);
BatchFetchRequest rq(server_connection_id_, row_id);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
RETURN_IF_NOT_OK(rq.RestoreRows(out));
return Status::OK();
}
Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
UniqueLock lck(&mux_);
// To create a cache, we identify ourself at the client by:
// - the shared session id
// - a crc for the tree nodes from the cache downward
// Pack these 2 into a single 64 bit request id
//
// Consider this example:
// tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch
// tree2: cifar10 --> map(rotate) --> cache (session id = 1, crc = 456) --> batch
// These are different trees in a single session, but the user wants to share the cache.
// This is not allowed because the data of these caches are different.
//
// Consider this example:
// tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch
// tree2: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> map(rotate) --> batch
// These are different trees in the same session, but the cached data is the same, so it is okay
// to allow the sharing of this cache between these pipelines.
// The CRC is computed by the tree prepare phase and passed to this function when creating the cache.
// If we already have a server_connection_id_, then it means this same cache client has already been used
// to create a cache and some other tree is trying to use the same cache.
// That is allowed, however the crc better match!
if (server_connection_id_) {
if (cache_crc_ != tree_crc) {
RETURN_STATUS_UNEXPECTED("Attempt to re-use a cache for a different tree!");
}
// Check the state of the server. For non-mappable case where there is a build phase and a fetch phase, we should
// skip the build phase.
lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock.
CacheClient::ServiceStat stat{};
RETURN_IF_NOT_OK(GetStat(&stat));
if (stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase)) {
return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase");
}
} else {
cache_crc_ = tree_crc; // It's really a new cache we're creating so save our crc in the client
// Combine the session and crc. This will form our client cache identifier.
connection_id_type connection_identification = (static_cast<uint64_t>(session_id_) << 32) | cache_crc_;
// Now execute the cache create request using this identifier and other configs
BaseRequest::CreateCacheFlag createFlag = BaseRequest::CreateCacheFlag::kNone;
if (spill_) {
createFlag |= BaseRequest::CreateCacheFlag::kSpillToDisk;
}
if (generate_id) {
createFlag |= BaseRequest::CreateCacheFlag::kGenerateRowId;
}
CreationCacheRequest rq(connection_identification, cache_mem_sz_, createFlag);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
Status rc = rq.Wait();
if (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey) {
server_connection_id_ = rq.GetServerConnectionId();
if (rc.IsOk()) {
// The 1st guy creating the cache will get a cookie back.
// But this object may be shared among pipelines and we don't want
// overwrite it.
cookie_ = rq.cookie();
}
}
// We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the
// CacheOp to bypass the build phase.
return rc;
}
return Status::OK();
}
Status CacheClient::PurgeCache() {
UniqueLock lck(&mux_);
PurgeCacheRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
return rq.Wait();
}
Status CacheClient::DestroyCache() {
UniqueLock lck(&mux_);
DestroyCacheRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
return rq.Wait();
}
Status CacheClient::GetStat(ServiceStat *stat) {
SharedLock lck(&mux_);
RETURN_UNEXPECTED_IF_NULL(stat);
GetStatRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
stat->num_disk_cached = rq.GetNumDiskCached();
stat->num_mem_cached = rq.GetNumMemCached();
stat->min_row_id = rq.GetMinRowId();
stat->max_row_id = rq.GetMaxRowId();
stat->cache_service_state = rq.GetState();
return Status::OK();
}
Status CacheClient::CacheSchema(const std::unordered_map<std::string, int32_t> &map) {
SharedLock lck(&mux_);
CacheSchemaRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(rq.SerializeCacheSchemaRequest(map));
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
return Status::OK();
}
Status CacheClient::FetchSchema(std::unordered_map<std::string, int32_t> *map) {
SharedLock lck(&mux_);
RETURN_UNEXPECTED_IF_NULL(map);
FetchSchemaRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
*map = rq.GetColumnMap();
return Status::OK();
}
Status CacheClient::BuildPhaseDone() const {
SharedLock lck(&mux_);
BuildPhaseDoneRequest rq(server_connection_id_, cookie());
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,141 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_CACHE_CLIENT_H_
#define DATASET_ENGINE_CACHE_CLIENT_H_
#include <iostream>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "./de_tensor_generated.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/cache/cache_server.h"
#include "dataset/util/lock.h"
namespace mindspore {
namespace dataset {
/// \brief A CacheClient is a bridge between a DatasetOp and a CacheServer. All communications are through
/// a CacheClient. Typical tasks including like creating a cache service, cache a data buffer, restore a previously
/// rows, etc.
class CacheClient {
public:
/// \brief Constructor
/// \param session_id A user assigned session id for the current pipeline
/// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited
/// \param spill Spill to disk if out of memory
CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill);
/// \brief Destructor
~CacheClient() = default;
/// \brief Getter function for returning the current session id
/// \return session id
uint64_t session_id() const { return session_id_; }
/// \brief Send a TensorRow to the cache server
/// \param[in] row
/// \param[out] row_id_from_server Optional. The row id assigned by the server for non-mappable dataset
/// \return return code
Status WriteRow(const TensorRow &row, row_id_type *row_id_from_server = nullptr) const;
/// \brief Send a DataBuffer to the cache server
/// \param in Unique pointer of the DataBuffer to be cached
/// \return return code
Status WriteBuffer(std::unique_ptr<DataBuffer> &&in) const;
/// \brief Fetch a list of rows from the cache server. An empty TensorRow will be returned if there is
/// any cache miss
/// \param row_id A vector of row id's
/// \param out A TensorTable of TensorRows.
/// \return return code
Status GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const;
/// \brief Create a cache.
/// \param tree_crc A crc that was generated during tree prepare phase
/// \param generate_id Let the cache service generate row id
/// \return Status object
Status CreateCache(uint32_t tree_crc, bool generate_id);
/// \brief Purge a cache. Cache can be reused after reset.
/// \return Status object
Status PurgeCache();
/// \brief Destroy a cache. Like Purge but the cache is deleted and can't be reused.
/// \return Status object
Status DestroyCache();
/// \brief Get the statistics from a cache.
/// \param[in/out] Pointer to a pre-allocated ServiceStat object
/// \return Status object
struct ServiceStat {
int64_t num_mem_cached;
int64_t num_disk_cached;
row_id_type min_row_id;
row_id_type max_row_id;
int8_t cache_service_state;
};
Status GetStat(ServiceStat *);
/// \brief Cache the schema at the cache server
/// \param map The unordered map of the schema
/// \return Status object
Status CacheSchema(const std::unordered_map<std::string, int32_t> &map);
/// \brief Fetch the schema from the cache server
/// \param map Pointer to pre-allocated map object
/// \return Status object.
Status FetchSchema(std::unordered_map<std::string, int32_t> *map);
/// \brief Change the state from build phase to read phase. Applicable to non-mappable dataset only. Only the cache
/// client that holds cookie can be allowed to make this request
/// \return Status object
Status BuildPhaseDone() const;
/// \brief A print method typically used for debugging
/// \param out The output stream to write output to
void Print(std::ostream &out) const;
/// \brief Stream output operator overload
/// \return the output stream must be returned
friend std::ostream &operator<<(std::ostream &out, const CacheClient &cc) {
cc.Print(out);
return out;
}
/// \brief Every cache server has a cookie which uniquely identifies the CacheClient that creates it.
/// \return Cookie
std::string cookie() const { return cookie_; }
private:
mutable RWLock mux_;
uint64_t cache_mem_sz_;
bool spill_;
// The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow
// sharing of the cache.
uint32_t session_id_;
uint32_t cache_crc_;
// The server_connection_id_ is the actual id we use for operations after the cache is built
connection_id_type server_connection_id_;
// Some magic cookie returned from the cache server.
std::string cookie_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_CACHE_CLIENT_H_

View File

@ -0,0 +1,223 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/cache/cache_request.h"
namespace mindspore {
namespace dataset {
Status CacheRowRequest::SerializeCacheRowRequest(const TensorRow &row) {
buffers_.reserve(row.size() + 1);
RETURN_IF_NOT_OK(SerializeTensorRowHeader(row));
buffers_.push_back(fbb_->GetBufferPointer());
for (const auto &ts : row) {
buffers_.push_back(ts->GetBuffer());
}
return Status::OK();
}
Status CacheRowRequest::SerializeTensorRowHeader(const TensorRow &row) {
try {
fbb_ = std::make_shared<flatbuffers::FlatBufferBuilder>();
std::vector<flatbuffers::Offset<TensorMetaMsg>> v;
std::vector<int64_t> tensor_sz;
v.reserve(row.size());
tensor_sz.reserve(row.size());
// We will go through each column in the row.
for (const std::shared_ptr<Tensor> &ts_ptr : row) {
flatbuffers::Offset<TensorMetaMsg> ts_off;
RETURN_IF_NOT_OK(SerializeOneTensorMeta(ts_ptr, &ts_off));
v.push_back(ts_off);
tensor_sz.push_back(ts_ptr->SizeInBytes());
}
auto column_off = fbb_->CreateVector(v);
auto data_sz_off = fbb_->CreateVector(tensor_sz);
TensorRowHeaderMsgBuilder row_builder(*fbb_);
row_builder.add_column(column_off);
row_builder.add_data_sz(data_sz_off);
// Pass the row_id even if it may not be known.
row_builder.add_row_id(row.getId());
row_builder.add_size_of_this(-1); // fill in later after we call Finish.
auto out = row_builder.Finish();
fbb_->Finish(out);
// Now go back to fill in size_of_this in the flat buffer.
auto msg = GetMutableTensorRowHeaderMsg(fbb_->GetBufferPointer());
auto success = msg->mutate_size_of_this(fbb_->GetSize());
if (!success) {
RETURN_STATUS_UNEXPECTED("Unable to set size_of_this");
}
return Status::OK();
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}
Status CacheRowRequest::SerializeOneTensorMeta(const std::shared_ptr<Tensor> &ts_ptr,
flatbuffers::Offset<TensorMetaMsg> *out_off) {
RETURN_UNEXPECTED_IF_NULL(out_off);
const Tensor *ts = ts_ptr.get();
auto shape_off = fbb_->CreateVector(ts->shape().AsVector());
const auto ptr = ts->GetBuffer();
if (ptr == nullptr) {
RETURN_STATUS_UNEXPECTED("Tensor buffer is null");
}
auto src = ts->type().value();
TensorType dest;
#define CASE(t) \
case DataType::t: \
dest = TensorType::TensorType_##t; \
break
// Map the type to fill in the flat buffer.
switch (src) {
CASE(DE_BOOL);
CASE(DE_INT8);
CASE(DE_UINT8);
CASE(DE_INT16);
CASE(DE_UINT16);
CASE(DE_INT32);
CASE(DE_UINT32);
CASE(DE_INT64);
CASE(DE_UINT64);
CASE(DE_FLOAT16);
CASE(DE_FLOAT32);
CASE(DE_FLOAT64);
CASE(DE_STRING);
default:
MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts;
RETURN_STATUS_UNEXPECTED("Unknown type");
}
#undef CASE
TensorMetaMsgBuilder ts_builder(*fbb_);
ts_builder.add_dims(shape_off);
ts_builder.add_type(dest);
auto ts_off = ts_builder.Finish();
*out_off = ts_off;
return Status::OK();
}
Status BatchFetchRequest::RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data,
std::shared_ptr<Tensor> *out) {
RETURN_UNEXPECTED_IF_NULL(col_ts);
auto shape_in = col_ts->dims();
auto type_in = col_ts->type();
std::vector<dsize_t> v;
v.reserve(shape_in->size());
v.assign(shape_in->begin(), shape_in->end());
TensorShape shape(v);
DataType::Type dest = DataType::DE_UNKNOWN;
#define CASE(t) \
case TensorType_##t: \
dest = DataType::Type::t; \
break
switch (type_in) {
CASE(DE_BOOL);
CASE(DE_INT8);
CASE(DE_UINT8);
CASE(DE_INT16);
CASE(DE_UINT16);
CASE(DE_INT32);
CASE(DE_UINT32);
CASE(DE_INT64);
CASE(DE_UINT64);
CASE(DE_FLOAT16);
CASE(DE_FLOAT32);
CASE(DE_FLOAT64);
CASE(DE_STRING);
}
#undef CASE
DataType type(dest);
std::shared_ptr<Tensor> ts =
std::make_shared<Tensor>(shape, type, static_cast<const unsigned char *>(data.GetPointer()), data.GetSize());
// Next we restore the real data which can be embedded or stored separately.
if (ts->SizeInBytes() != data.GetSize()) {
MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n"
<< "Dumping tensor\n"
<< *ts << "\n";
RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
}
*out = std::move(ts);
return Status::OK();
}
Status BatchFetchRequest::RestoreRows(TensorTable *out) {
RETURN_UNEXPECTED_IF_NULL(out);
auto num_elements = row_id_.size();
auto *offset_array = reinterpret_cast<const int64_t *>(mem_.GetPointer());
TensorTable tbl;
tbl.reserve(num_elements);
ReadableSlice all(mem_.GetPointer(), mem_.GetSizeInBytes());
for (auto i = 0; i < num_elements; ++i) {
auto len = offset_array[i + 1] - offset_array[i];
TensorRow row;
row.setId(row_id_.at(i));
if (len > 0) {
ReadableSlice row_data(all, offset_array[i], len);
// Next we de-serialize flat buffer to get back each column
auto msg = GetTensorRowHeaderMsg(row_data.GetPointer());
auto msg_sz = msg->size_of_this();
// Start of the tensor data
auto ts_offset = msg_sz;
row.reserve(msg->column()->size());
for (auto k = 0; k < msg->column()->size(); ++k) {
auto col_ts = msg->column()->Get(k);
std::shared_ptr<Tensor> ts;
ReadableSlice data(row_data, ts_offset, msg->data_sz()->Get(k));
RETURN_IF_NOT_OK(RestoreOneTensor(col_ts, data, &ts));
row.push_back(ts);
ts_offset += data.GetSize();
}
}
tbl.push_back(std::move(row));
}
*out = std::move(tbl);
return Status::OK();
}
Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map) {
try {
fbb_ = std::make_shared<flatbuffers::FlatBufferBuilder>();
std::vector<flatbuffers::Offset<ColumnNameMsg>> v;
v.reserve(map.size());
for (auto &column : map) {
auto c = CreateColumnNameMsg(*fbb_, fbb_->CreateString(column.first), column.second);
v.push_back(c);
}
auto v_off = fbb_->CreateVector(v);
auto final_off = CreateSchemaMsg(*fbb_, v_off);
fbb_->Finish(final_off);
buf_ = fbb_->GetBufferPointer();
len_of_buf_ = fbb_->GetSize();
return Status::OK();
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}
std::unordered_map<std::string, int32_t> FetchSchemaRequest::GetColumnMap() {
if (column_name_id_map_.empty()) {
auto *map_msg = flatbuffers::GetRoot<SchemaMsg>(mem_.GetPointer());
auto v = map_msg->column();
for (auto i = 0; i < v->size(); ++i) {
auto col = map_msg->column()->Get(i);
column_name_id_map_.emplace(col->name()->str(), col->id());
}
}
return column_name_id_map_;
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,225 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_CACHE_REQ_H_
#define DATASET_ENGINE_CACHE_REQ_H_
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "./de_tensor_generated.h"
#include "dataset/core/tensor_row.h"
#include "dataset/util/slice.h"
#include "dataset/util/wait_post.h"
namespace mindspore {
namespace dataset {
/// \brief CacheClient communicates with CacheServer using Requests.
class BaseRequest {
public:
// Request types
enum class RequestType : int16_t {
kCacheRow = 0,
kBatchFetchRows = 1,
kCreateCache = 2,
kPurgeCache = 3,
kDestroyCache = 4,
kGetStat = 5,
kCacheSchema = 6,
kFetchSchema = 7,
kBuildPhaseDone = 8,
// Add new request before it.
kRequestUnknown = 32767
};
// For kCreateCache
enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L };
friend class CacheServer;
/// \brief Base class of a cache server request
/// \param connection_id A combination of session id and crc that uniquely identifies a connection.
/// \param type Type of the request
explicit BaseRequest(connection_id_type connection_id, RequestType type)
: type_(type), connection_id_(connection_id) {}
virtual ~BaseRequest() = default;
/// \brief Wait for the completion of a request
/// \return Status returned from the cache server
Status Wait() {
RETURN_IF_NOT_OK(wp_.Wait());
return rc_;
}
/// \brief Getter function of the current connection id
/// \return Connection id
connection_id_type GetServerConnectionId() const { return connection_id_; }
private:
RequestType type_;
connection_id_type connection_id_;
Status rc_;
WaitPost wp_;
};
/// \brief Request to cache a single TensorRow
class CacheRowRequest : public BaseRequest {
public:
friend class CacheServer;
explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie)
: BaseRequest(connection_id, RequestType::kCacheRow), row_id_from_server_(-1), cookie_(cookie) {}
~CacheRowRequest() = default;
/// \brief Serialize a TensorRow for streaming to the cache server
/// \param row TensorRow
/// \return Status object
Status SerializeCacheRowRequest(const TensorRow &row);
/// \brief Return the row id assigned to this row for non-mappable dataset
/// \return row id of the cached row
row_id_type GetRowIdAfterCache() { return row_id_from_server_; }
private:
std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb_;
row_id_type row_id_from_server_;
std::vector<const void *> buffers_;
std::string cookie_;
/// \brief Private function to serialize one TensorRow
/// \param row TensorRow
/// \return Status object
Status SerializeTensorRowHeader(const TensorRow &row);
/// \brief Private function to serialize one Tensor
/// \param ts_ptr Tensor
/// \return Status object
Status SerializeOneTensorMeta(const std::shared_ptr<Tensor> &ts_ptr, flatbuffers::Offset<TensorMetaMsg> *out_off);
};
/// \brief Request to fetch rows in batch
class BatchFetchRequest : public BaseRequest {
public:
friend class CacheServer;
friend class CacheService;
BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id)
: BaseRequest(connection_id, RequestType::kBatchFetchRows), row_id_(row_id) {}
Status RestoreRows(TensorTable *out);
private:
std::vector<row_id_type> row_id_;
MemGuard<uint8_t> mem_;
Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out);
};
/// \brief Request to create a cache for the current connection
class CreationCacheRequest : public BaseRequest {
public:
friend class CacheServer;
/// \brief Constructor
/// \param connection_id
/// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited
/// \param flag Attributes of the cache.
explicit CreationCacheRequest(connection_id_type connection_id, uint64_t cache_mem_sz,
CreateCacheFlag flag = CreateCacheFlag::kNone)
: BaseRequest(connection_id, RequestType::kCreateCache), cache_mem_sz(cache_mem_sz), flag_(flag) {}
std::string cookie() const { return cookie_; }
private:
uint64_t cache_mem_sz;
CreateCacheFlag flag_;
std::string cookie_;
};
/// \brief Request to purge a cache.
class PurgeCacheRequest : public BaseRequest {
public:
friend class CacheServer;
explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kPurgeCache) {}
};
/// \brief Request to destroy a cache
class DestroyCacheRequest : public BaseRequest {
public:
friend class CacheServer;
explicit DestroyCacheRequest(connection_id_type connection_id)
: BaseRequest(connection_id, RequestType::kDestroyCache) {}
};
/// \brief Obtain the statistics of the current connection
class GetStatRequest : public BaseRequest {
public:
friend class CacheServer;
friend class CacheService;
explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kGetStat) {}
row_id_type GetMinRowId() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->min_row_id();
}
row_id_type GetMaxRowId() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->max_row_id();
}
int64_t GetNumMemCached() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->num_mem_cached();
}
int64_t GetNumDiskCached() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->num_disk_cached();
}
uint8_t GetState() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->state();
}
private:
MemGuard<uint8_t> mem_;
};
/// \brief Request to cache a schema
class CacheSchemaRequest : public BaseRequest {
public:
friend class CacheServer;
explicit CacheSchemaRequest(connection_id_type connection_id)
: BaseRequest(connection_id, RequestType::kCacheSchema), buf_(nullptr), len_of_buf_(0) {}
~CacheSchemaRequest() = default;
Status SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map);
const void *GetBuffer() const { return buf_; }
private:
std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb_;
const void *buf_;
int64_t len_of_buf_;
};
/// \brief Request to fetch a schema
class FetchSchemaRequest : public BaseRequest {
public:
friend class CacheServer;
explicit FetchSchemaRequest(connection_id_type connection_id)
: BaseRequest(connection_id, RequestType::kFetchSchema) {}
~FetchSchemaRequest() = default;
std::unordered_map<std::string, int32_t> GetColumnMap();
private:
MemGuard<uint8_t> mem_;
std::unordered_map<std::string, int32_t> column_name_id_map_;
};
/// \brief Request to change a cache from build phase to read phase. Applies to non-mappable cache only.
class BuildPhaseDoneRequest : public BaseRequest {
public:
friend class CacheServer;
BuildPhaseDoneRequest(connection_id_type connection_id, const std::string &cookie)
: BaseRequest(connection_id, RequestType::kBuildPhaseDone), cookie_(cookie) {}
private:
std::string cookie_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_CACHE_SERVICE_H_

View File

@ -0,0 +1,252 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/cache/cache_server.h"
#include "dataset/engine/cache/cache_service.h"
#include "dataset/engine/cache/cache_request.h"
#include "dataset/util/bit.h"
namespace mindspore {
namespace dataset {
Status CacheServer::DoServiceStart() {
if (!top_.empty()) {
Path spill(top_);
RETURN_IF_NOT_OK(spill.CreateDirectories());
MS_LOG(INFO) << "CacheServer will use disk folder: " << top_;
}
RETURN_IF_NOT_OK(vg_.ServiceStart());
cache_q_ = std::make_shared<Queue<BaseRequest *>>(1024);
RETURN_IF_NOT_OK(cache_q_->Register(&vg_));
auto f = std::bind(&CacheServer::ServerRequest, this);
// Spawn a a few threads to serve the request.
for (auto i = 0; i < num_workers_; ++i) {
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache server", f));
}
return Status::OK();
}
Status CacheServer::DoServiceStop() {
Status rc;
Status rc2;
// First stop all the threads.
RETURN_IF_NOT_OK(vg_.ServiceStop());
// Clean up all the caches if any.
UniqueLock lck(&rwLock_);
auto it = all_caches_.begin();
while (it != all_caches_.end()) {
auto cs = std::move(it->second);
rc2 = cs->ServiceStop();
if (rc2.IsError()) {
rc = rc2;
}
++it;
}
return rc;
}
CacheService *CacheServer::GetService(connection_id_type id) const {
SharedLock lck(&rwLock_);
auto it = all_caches_.find(id);
if (it != all_caches_.end()) {
return it->second.get();
}
return nullptr;
}
Status CacheServer::CreateService(connection_id_type connection_id, uint64_t cache_mem_sz,
BaseRequest::CreateCacheFlag flag, std::string *out_cookie) {
// We can't do spilling unless this server is setup with a spill path in the first place
bool spill = (flag & BaseRequest::CreateCacheFlag::kSpillToDisk) == BaseRequest::CreateCacheFlag::kSpillToDisk;
bool generate_id =
(flag & BaseRequest::CreateCacheFlag::kGenerateRowId) == BaseRequest::CreateCacheFlag::kGenerateRowId;
if (spill && top_.empty()) {
RETURN_STATUS_UNEXPECTED("Server is not set up with spill support.");
}
RETURN_UNEXPECTED_IF_NULL(out_cookie);
*out_cookie = "";
// Before creating the cache, first check if this is a request for a shared usage of an existing cache
// If two CreateService come in with identical connection_id, we need to serialize the create.
// The first create will be successful and be given a special cookie.
UniqueLock lck(&rwLock_);
auto end = all_caches_.end();
auto it = all_caches_.find(connection_id);
if (it == end) {
std::unique_ptr<CacheService> cs;
try {
cs = std::make_unique<CacheService>(cache_mem_sz, spill ? top_ : "", generate_id);
RETURN_IF_NOT_OK(cs->ServiceStart());
*out_cookie = cs->cookie();
all_caches_.emplace(connection_id, std::move(cs));
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory);
}
} else {
MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service";
// We can return OK but we will return a duplicate key so user can act accordingly to either ignore it
// treat it as OK.
return Status(StatusCode::kDuplicateKey);
}
return Status::OK();
}
/// This is the main loop the cache server thread(s) are running.
/// Each thread will pop a request and save the result in the same request.
/// The sender will wait on the wait post in the request. Once the request
/// is fulfilled, the server thread will do a post signalling the request is
/// is processed.
/// \return
Status CacheServer::ServerRequest() {
TaskManager::FindMe()->Post();
// Loop forever until we are interrupted.
while (true) {
BaseRequest *base_rq = nullptr;
RETURN_IF_NOT_OK(cache_q_->PopFront(&base_rq));
auto cs = GetService(base_rq->connection_id_);
// Except for creating a new session, we expect cs is not null.
switch (base_rq->type_) {
case BaseRequest::RequestType::kCacheRow: {
if (cs == nullptr) {
std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found";
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto *rq = reinterpret_cast<CacheRowRequest *>(base_rq);
// Only if the cookie matches, we can accept insert into this cache that has a build phase
if (!cs->HasBuildPhase() || rq->cookie_ == cs->cookie()) {
rq->rc_ = cs->CacheRow(rq->buffers_, &rq->row_id_from_server_);
} else {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
}
}
break;
}
case BaseRequest::RequestType::kBatchFetchRows: {
if (cs == nullptr) {
std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found";
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto *rq = reinterpret_cast<BatchFetchRequest *>(base_rq);
rq->rc_ = cs->BatchFetch(rq->row_id_, &rq->mem_);
}
break;
}
case BaseRequest::RequestType::kCreateCache: {
// If the cache is already created we still need to run the creation so that we do sanity checks on the
// client id and return the cache id back to the user.
auto *rq = reinterpret_cast<CreationCacheRequest *>(base_rq);
rq->rc_ = CreateService(rq->connection_id_, rq->cache_mem_sz, rq->flag_, &rq->cookie_);
break;
}
case BaseRequest::RequestType::kPurgeCache: {
if (cs != nullptr) {
base_rq->rc_ = cs->Purge();
} else {
// it is already purged. Ignore it.
base_rq->rc_ = Status::OK();
}
break;
}
case BaseRequest::RequestType::kDestroyCache: {
if (cs != nullptr) {
// We need a strong lock to protect the map.
connection_id_type id = base_rq->connection_id_;
UniqueLock lck(&rwLock_);
// std::map will invoke the constructor of CacheService. So we don't need to do anything here.
auto n = all_caches_.erase(id);
if (n == 0) {
// It has been destroyed by another duplicate request.
MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service";
}
base_rq->rc_ = Status::OK();
} else {
// it is already destroyed. Ignore it.
base_rq->rc_ = Status::OK();
}
break;
}
case BaseRequest::RequestType::kGetStat: {
if (cs == nullptr) {
std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found";
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto *rq = reinterpret_cast<GetStatRequest *>(base_rq);
CacheService::ServiceStat svc_stat;
rq->rc_ = cs->GetStat(&svc_stat);
if (rq->rc_.IsOk()) {
flatbuffers::FlatBufferBuilder fbb;
ServiceStatMsgBuilder bld(fbb);
bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached);
bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached);
bld.add_max_row_id(svc_stat.max_);
bld.add_min_row_id(svc_stat.min_);
bld.add_state(svc_stat.state_);
auto offset = bld.Finish();
fbb.Finish(offset);
rq->rc_ = rq->mem_.allocate(fbb.GetSize());
if (rq->rc_.IsOk()) {
WritableSlice dest(rq->mem_.GetMutablePointer(), fbb.GetSize());
ReadableSlice src(fbb.GetBufferPointer(), fbb.GetSize());
RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, src));
}
}
}
break;
}
case BaseRequest::RequestType::kCacheSchema: {
if (cs == nullptr) {
std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found";
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto *rq = reinterpret_cast<CacheSchemaRequest *>(base_rq);
rq->rc_ = cs->CacheSchema(rq->buf_, rq->len_of_buf_);
}
break;
}
case BaseRequest::RequestType::kFetchSchema: {
if (cs == nullptr) {
std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found";
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto *rq = reinterpret_cast<FetchSchemaRequest *>(base_rq);
rq->rc_ = cs->FetchSchema(&rq->mem_);
}
break;
}
case BaseRequest::RequestType::kBuildPhaseDone: {
if (cs == nullptr) {
std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found";
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto *rq = reinterpret_cast<BuildPhaseDoneRequest *>(base_rq);
// We can only allow to switch phase is the cookie match.
if (rq->cookie_ == cs->cookie()) {
rq->rc_ = cs->BuildPhaseDone();
} else {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
}
}
break;
}
default:
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Unknown request type");
}
// Notify it is done, and move on to the next request.
base_rq->wp_.Set();
}
return Status::OK();
}
CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers)
: top_(spill_path), num_workers_(num_workers) {}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,98 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_CACHE_SERVER_H_
#define DATASET_ENGINE_CACHE_SERVER_H_
#include <algorithm>
#include <atomic>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <map>
#include "dataset/engine/cache/cache_service.h"
#include "dataset/core/tensor.h"
#include "dataset/util/arena.h"
#include "dataset/util/cache_pool.h"
#include "dataset/util/lock.h"
#include "dataset/util/service.h"
#include "dataset/util/services.h"
#include "dataset/util/system_pool.h"
#include "dataset/util/queue.h"
#include "dataset/util/task_manager.h"
namespace mindspore {
namespace dataset {
class BaseRequest;
/// \brief A server which provides CacheService services.
class CacheServer : public Service {
public:
friend class Services;
using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>;
CacheServer(const CacheServer &) = delete;
CacheServer &operator=(const CacheServer &) = delete;
CacheServer(CacheServer &&) = delete;
CacheServer &operator=(CacheServer &) = delete;
static CacheServer &GetInstance() noexcept { return Services::getCacheServer(); }
Status DoServiceStart() override;
Status DoServiceStop() override;
~CacheServer() { (void)ServiceStop(); }
/// \brief For the current demonstration, a cache client contacts cache server using a Queue.
/// \param rq
/// \return Status object
Status PushRequest(BaseRequest *rq) {
RETURN_UNEXPECTED_IF_NULL(rq);
RETURN_IF_NOT_OK(cache_q_->Add(rq));
return Status::OK();
}
private:
mutable RWLock rwLock_;
std::string top_;
cache_index all_caches_;
std::shared_ptr<Queue<BaseRequest *>> cache_q_;
TaskGroup vg_;
int32_t num_workers_;
/// \brief Constructor
/// \param spill_path Top directory for spilling buffers to.
/// \param num_workers Number of threads for handling requests.
explicit CacheServer(const std::string &spill_path, int32_t num_workers = 3);
/// \brief Locate a cache service from connection id.
/// \return Pointer to cache service. Null if not found
CacheService *GetService(connection_id_type id) const;
/// \brief Create a cache service. We allow multiple clients to create the same cache service.
/// Subsequent duplicate requests are ignored. The first cache client to create the service will be given
/// a special unique cookie.
/// \param[in] connection_id This is from a Cache client.
/// \param[in] cache_mem_sz
/// \param[in] flag
/// \param[out] out_cookie Only the first cache client will be given a special cookie to identify the creator
/// \return Status object
Status CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, BaseRequest::CreateCacheFlag flag,
std::string *out_cookie);
/// \brief Entry point for all server threads.
Status ServerRequest();
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_CORE_CACHE_TENSOR_H_

View File

@ -0,0 +1,265 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/cache/cache_service.h"
#include "dataset/util/slice.h"
namespace mindspore {
namespace dataset {
CacheService::CacheService(uint64_t mem_sz, const std::string &root, bool generate_id)
: root_(root),
cache_mem_sz_(mem_sz),
cp_(nullptr),
map_(nullptr),
next_id_(0),
generate_id_(generate_id),
schema_key_(-1),
st_(generate_id ? State::kBuildPhase : State::kNone) {}
CacheService::~CacheService() { (void)ServiceStop(); }
bool CacheService::UseArena() {
// If fixed size, use Arena instead of the pool from global context.
return (cache_mem_sz_ > 0);
}
Status CacheService::DoServiceStart() {
std::shared_ptr<MemoryPool> mp_;
if (UseArena()) {
// Create a fixed size arena based on the parameter.
std::shared_ptr<Arena> arena;
RETURN_IF_NOT_OK(Arena::CreateArena(&arena, cache_mem_sz_));
mp_ = std::move(arena);
} else {
// Unlimited size. Simply use a system pool. Another choice is CircularPool.
mp_ = std::make_shared<SystemPool>();
}
// Put together a CachePool for backing up the Tensor
cp_ = std::make_shared<CachePool>(CachePool::value_allocator(mp_), root_);
RETURN_IF_NOT_OK(cp_->ServiceStart());
// Set up the B+ tree as well. But use the system pool instead.
map_ = std::make_shared<row_map>();
// Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name.
cookie_ = cp_->MyName();
return Status::OK();
}
Status CacheService::DoServiceStop() {
if (cp_ != nullptr) {
RETURN_IF_NOT_OK(cp_->ServiceStop());
}
return Status::OK();
}
Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated) {
SharedLock rw(&rw_lock_);
RETURN_UNEXPECTED_IF_NULL(row_id_generated);
if (st_ == State::kFetchPhase) {
// For this kind of cache service, once we are done with the build phase into fetch phase, we can't
// allow other to cache more rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
try {
// The first buffer is a flatbuffer which describes the rest of the buffers follow
auto fb = buf.front();
RETURN_UNEXPECTED_IF_NULL(fb);
auto msg = GetTensorRowHeaderMsg(fb);
// If the server side is designed to ignore incoming row id, we generate row id.
if (generate_id_) {
*row_id_generated = GetNextRowId();
// Some debug information on how many rows we have generated so far.
if ((*row_id_generated) % 1000 == 0) {
MS_LOG(DEBUG) << "Number of rows cached: " << *row_id_generated;
}
} else {
if (msg->row_id() < 0) {
std::string errMsg = "Expect positive row id: " + std::to_string(msg->row_id());
RETURN_STATUS_UNEXPECTED(errMsg);
}
*row_id_generated = msg->row_id();
}
auto size_of_this = msg->size_of_this();
auto column_hdr = msg->column();
// Number of tensor buffer should match the number of columns plus one.
if (buf.size() != column_hdr->size() + 1) {
std::string errMsg = "Column count does not match. Expect " + std::to_string(column_hdr->size() + 1) +
" but get " + std::to_string(buf.size());
RETURN_STATUS_UNEXPECTED(errMsg);
}
// Next we store in either memory or on disk. Low level code will consolidate everything in one piece.
std::vector<ReadableSlice> all_data;
all_data.reserve(column_hdr->size() + 1);
all_data.emplace_back(fb, size_of_this);
for (auto i = 0; i < column_hdr->size(); ++i) {
all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i));
}
// Now we cache the flat buffer.
CachePool::key_type key;
RETURN_IF_NOT_OK(cp_->Insert(all_data, &key));
Status rc = map_->DoInsert(*row_id_generated, key);
if (rc == Status(StatusCode::kDuplicateKey)) {
MS_LOG(DEBUG) << "Ignoring duplicate key";
} else {
RETURN_IF_NOT_OK(rc);
}
return Status::OK();
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
}
std::ostream &operator<<(std::ostream &out, const CacheService &cs) {
// Then show any custom derived-internal stuff
out << "\nCache memory size: " << cs.cache_mem_sz_;
out << "\nSpill path: ";
if (cs.root_.empty()) {
out << "None";
} else {
out << cs.GetSpillPath();
}
return out;
}
Path CacheService::GetSpillPath() const { return cp_->GetSpillPath(); }
Status CacheService::Purge() {
// First we must lock exclusively. No one else can cache/restore anything.
UniqueLock rw(&rw_lock_);
RETURN_IF_NOT_OK(cp_->ServiceStop());
auto new_map = std::make_shared<row_map>();
map_.reset();
map_ = std::move(new_map);
next_id_ = 0;
RETURN_IF_NOT_OK(cp_->ServiceStart());
return Status::OK();
}
Status CacheService::GetStat(CacheService::ServiceStat *out) {
SharedLock rw(&rw_lock_);
RETURN_UNEXPECTED_IF_NULL(out);
if (st_ == State::kNone || st_ == State::kFetchPhase) {
out->stat_ = cp_->GetStat();
out->state_ = static_cast<ServiceStat::state_type>(st_);
auto it = map_->begin();
if (it != map_->end()) {
out->min_ = it.key();
auto end_it = map_->end();
--end_it;
out->max_ = end_it.key();
}
} else {
out->state_ = static_cast<ServiceStat::state_type>(st_);
}
return Status::OK();
}
Status CacheService::BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint8_t> *out) const {
RETURN_UNEXPECTED_IF_NULL(out);
SharedLock rw(&rw_lock_);
if (st_ == State::kBuildPhase) {
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
const auto num_elements = v.size();
int64_t mem_sz = (num_elements + 1) * sizeof(int64_t);
int64_t data_offset = mem_sz;
std::vector<int64_t> sz_v;
std::vector<CachePool::key_type> keys;
sz_v.reserve(num_elements);
keys.reserve(num_elements);
for (auto row_id : v) {
auto r = map_->Search(row_id);
if (r.second) {
auto &it = r.first;
CachePool::key_type key = it.value();
auto sz = cp_->GetSize(key);
if (sz == 0) {
std::string errMsg = "Key not found: ";
errMsg += std::to_string(key);
RETURN_STATUS_UNEXPECTED(errMsg);
}
keys.push_back(key);
sz_v.push_back(sz);
mem_sz += sz;
} else {
keys.push_back(-1);
sz_v.push_back(0);
}
}
MemGuard<uint8_t> mem;
RETURN_IF_NOT_OK(mem.allocate(mem_sz));
auto *offset_array = reinterpret_cast<int64_t *>(mem.GetMutablePointer());
offset_array[0] = data_offset;
WritableSlice all(mem.GetMutablePointer(), mem.GetSizeInBytes());
for (auto i = 0; i < num_elements; ++i) {
auto sz = sz_v.at(i);
offset_array[i + 1] = offset_array[i] + sz;
if (sz > 0) {
WritableSlice row_data(all, offset_array[i], sz);
auto key = keys.at(i);
size_t bytesRead = 0;
RETURN_IF_NOT_OK(cp_->Read(key, &row_data, &bytesRead));
if (bytesRead != sz) {
MS_LOG(ERROR) << "Unexpected length. Read " << bytesRead << ". Expected " << sz << "."
<< " Internal key: " << key << "\n";
RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
}
}
}
*out = std::move(mem);
return Status::OK();
}
Status CacheService::CacheSchema(const void *buf, int64_t len) {
SharedLock rw(&rw_lock_);
if (st_ == State::kFetchPhase) {
// For this kind of cache service, once we are done with the build phase into fetch phase, we can't
// allow other to cache more rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
// This is a special request and we need to remember where we store it.
// In case we are calling the same function from multiple threads, only
// the first one is considered. Rest is ignored.
CachePool::key_type cur_key = schema_key_;
CachePool::key_type key;
if (cur_key < 0) {
RETURN_IF_NOT_OK(cp_->Insert({ReadableSlice(buf, len)}, &key));
auto result = std::atomic_compare_exchange_strong(&schema_key_, &cur_key, key);
MS_LOG(DEBUG) << "Caching Schema. Result = " << result;
} else {
MS_LOG(DEBUG) << "Caching Schema already done";
}
return Status::OK();
}
Status CacheService::FetchSchema(MemGuard<uint8_t> *out) const {
SharedLock rw(&rw_lock_);
if (st_ == State::kBuildPhase) {
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
RETURN_UNEXPECTED_IF_NULL(out);
MemGuard<uint8_t> mem;
if (schema_key_ >= 0) {
auto len = cp_->GetSize(schema_key_);
RETURN_IF_NOT_OK(mem.allocate(len));
auto slice = WritableSlice(mem.GetMutablePointer(), len);
RETURN_IF_NOT_OK(cp_->Read(schema_key_, &slice));
*out = std::move(mem);
} else {
return Status(StatusCode::kFileNotExist, __LINE__, __FILE__, "No schema has been cached");
}
return Status::OK();
}
Status CacheService::BuildPhaseDone() {
if (HasBuildPhase()) {
// Exclusive lock to switch phase
UniqueLock rw(&rw_lock_);
st_ = State::kFetchPhase;
return Status::OK();
} else {
RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase");
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,143 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_CACHE_SERVICE_H_
#define DATASET_ENGINE_CACHE_SERVICE_H_
#include <algorithm>
#include <atomic>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "./de_tensor_generated.h"
#include "dataset/core/global_context.h"
#include "dataset/core/tensor.h"
#include "dataset/engine/cache/cache_request.h"
#include "dataset/util/arena.h"
#include "dataset/util/btree.h"
#include "dataset/util/cache_pool.h"
#include "dataset/util/service.h"
#include "dataset/util/services.h"
#include "dataset/util/system_pool.h"
namespace mindspore {
namespace dataset {
struct CacheStat;
/// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is
/// created to support spilling
class CacheService : public Service {
public:
friend class CacheServer;
using row_map = BPlusTree<row_id_type, CachePool::key_type>;
enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase };
/// \brief Constructor
/// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited
/// \param root Spill path. Empty string means no spilling
/// \param generate_id If the cache service should generate row id for buffer that is cached.
/// For non-mappable dataset, this should be set to true.
CacheService(uint64_t mem_sz, const std::string &root, bool generate_id);
~CacheService();
/// \brief For fixed size memory, we will create an Arena.
/// \return false if unlimited memory.
bool UseArena();
Status DoServiceStart() override;
Status DoServiceStop() override;
/// \brief Main function to cache a row which is in form a series of buffers.
/// The first buffer is a Google flatbuffer which describes the rest of the buffers followed.
/// \param[in] buf Vector of buffer
/// \param[out] row_id_generated The row id assigned to this row if any
/// \return Status object
Status CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated);
/// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded
/// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row.
/// \param[in] v A vector of row id.
/// \param[out] out A contiguous memory buffer that holds the requested rows.
/// \return Status object
Status BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint8_t> *out) const;
/// \brief Getter function
/// \return Spilling path
Path GetSpillPath() const;
/// \brief A structure returned from the cache server for statistics request.
class ServiceStat {
public:
using state_type = std::underlying_type<State>::type;
ServiceStat() : min_(0), max_(0), state_(0) {}
CachePool::CacheStat stat_{};
row_id_type min_;
row_id_type max_;
state_type state_;
};
/// \brief Statistics for the current service
/// \param[in/out] A pointer to a pre-allocated ServiceStat structure
/// \return Status Object
Status GetStat(ServiceStat *);
/// \brief Cache schema
/// \param buf A Google Flatbuffer that contains the schema
/// \param len size of the buffer
/// \return Status object
Status CacheSchema(const void *buf, int64_t len);
/// \brief Fetch schema
/// \param out A contiguous memory that contains the serialized form of schema.
/// \return Status object
Status FetchSchema(MemGuard<uint8_t> *out) const;
/// \brief Purge the content of a cache
/// \return Status object
Status Purge();
/// \brief Overload the << operator to print a cache service
/// \param out std::ostream
/// \param cs A cache service
/// \return std::ostream
friend std::ostream &operator<<(std::ostream &out, const CacheService &cs);
/// \brief Every cache service has a cookie. If the cookie of a CacheClient matches this cookie, this CacheClient
/// is the creator
/// \return Cookie
std::string cookie() const { return cookie_; }
/// \brief If this cache service generates row id for buffer cached, it is divided into two phases, a build phase and
/// a read phase.
/// \return True if has two phases.
bool HasBuildPhase() const { return generate_id_; }
/// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call.
/// \return Status object
Status BuildPhaseDone();
private:
mutable RWLock rw_lock_;
std::string root_;
uint64_t cache_mem_sz_;
std::shared_ptr<CachePool> cp_;
std::shared_ptr<row_map> map_;
std::atomic<row_id_type> next_id_;
bool generate_id_;
std::atomic<CachePool::key_type> schema_key_;
std::string cookie_;
State st_;
/// \brief Private function to generate a row id
/// \return Row id assigned.
row_id_type GetNextRowId() { return next_id_.fetch_add(1); }
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_CACHE_SERVICE_H_

View File

@ -0,0 +1,81 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
namespace mindspore.dataset;
/// Type of a Tensor
enum TensorType : byte {
DE_UNKNOWN = 0,
DE_BOOL = 1,
DE_INT8 = 2,
DE_UINT8 = 3,
DE_INT16 = 4,
DE_UINT16 = 5,
DE_INT32 = 6,
DE_UINT32 = 7,
DE_INT64 = 8,
DE_UINT64 = 9,
DE_FLOAT16 = 10,
DE_FLOAT32 = 11,
DE_FLOAT64 = 12,
DE_STRING = 13
}
/// The meta information of a Tensor
/// \note Only the type and shape are considered meta information. Tensor data is excluded.
table TensorMetaMsg {
dims:[int64] (required);
type:TensorType;
}
/// This is the first buffer that is sent to a Cache server when a TensorRow is serialized.
/// \param row_id is the row id of the TensorRow.
/// \param column The meta information of each Tensor in the row
/// \param size of this serialized buffer
/// \param size of each tensor data buffer that follows
table TensorRowHeaderMsg {
row_id:int64;
column:[TensorMetaMsg] (required);
size_of_this:int64;
data_sz:[int64] (required);
}
root_type TensorRowHeaderMsg;
/// A row of row id's
table TensorRowIds {
row_id:[int64] (required);
}
/// Statistics returned from each cache service
/// \note It must match CacheService::ServiceStat
table ServiceStatMsg {
num_mem_cached:int64;
num_disk_cached:int64;
min_row_id:int64;
max_row_id:int64;
state:int8;
}
/// Column description of each column in a schema
table ColumnNameMsg {
name:string;
id:int32;
}
/// Serialized form of a schema
table SchemaMsg {
column:[ColumnNameMsg];
}

View File

@ -24,10 +24,8 @@ namespace dataset {
// Description: This is the main constructor that is used for making a buffer
DataBuffer::DataBuffer(int32_t id, BufferFlags flags) : buffer_id_(id), tensor_table_(nullptr), buffer_flags_(flags) {}
// Name: print()
// Description: A function that prints info about the DataBuffer (base class version)
void DataBuffer::Print(std::ostream &out, // In: The output stream to print to
bool show_all) const { // In: T/F if it should show everything
// A method for debug printing of the buffer
void DataBuffer::Print(std::ostream &out, bool show_all) const {
out << "bufferId: " << buffer_id_ << "\nflags: " << std::hex << buffer_flags_ << std::dec << "\n";
// If the column counts are set then it means that data has been set into
@ -46,11 +44,6 @@ void DataBuffer::Print(std::ostream &out, // In: The output stream to print
}
}
Status DataBuffer::Load() {
std::string err_msg = "Base class load called, but it does not have an implementation!";
RETURN_STATUS_UNEXPECTED(err_msg);
}
// Remove me!! Callers should fetch rows via pop
Status DataBuffer::GetTensor(std::shared_ptr<Tensor> *ptr, int32_t row_id, int32_t col_id) const {
if (row_id < tensor_table_->size() && col_id < tensor_table_->at(row_id).size()) {
@ -92,8 +85,5 @@ Status DataBuffer::SliceOff(int64_t number_of_rows) {
return Status::OK();
}
// Destructor
DataBuffer::~DataBuffer() {}
} // namespace dataset
} // namespace mindspore

View File

@ -29,11 +29,9 @@
namespace mindspore {
namespace dataset {
// The DataBuffer class is a base class that will represent the data for n values based
// on a unique row id for each row of data.
// There can be different types of DataBuffers to abstract over how the data is stored
// in memory and acquired from storage.
// Each buffer holds a range of consecutive row id's.
/// \brief The DataBuffer class is a container of tensor data and is the unit of transmission between
/// connectors of dataset operators. Inside the buffer, tensors are organized into a table-like format
/// where n TensorRows may consist of m tensors (columns).
class DataBuffer {
public:
// Buffer flags
@ -47,13 +45,13 @@ class DataBuffer {
// Description: This is the main constructor that is used for making a buffer
DataBuffer(int32_t id, BufferFlags flags);
// Destructor
virtual ~DataBuffer();
/// \brief default destructor
~DataBuffer() = default;
// Name: print()
// Description: A function that prints info about the DataBuffer (base class version)
virtual void Print(std::ostream &out, // In: The output stream to print to
bool show_all) const; // In: T/F if it should show everything
/// \brief A method for debug printing of the buffer
/// \param[inout] out The stream to write to
/// \param[in] show_all A boolean to toggle between details and summary printing
void Print(std::ostream &out, bool show_all) const;
// Provide stream operator for displaying it
friend std::ostream &operator<<(std::ostream &out, const DataBuffer &cb) {
@ -61,10 +59,6 @@ class DataBuffer {
return out;
}
// Name: load()
// Description: populates the DataBuffer with data based on it's id
virtual Status Load();
// Convenience getter functions for flag checking
bool eof() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagEOF)); }

View File

@ -17,7 +17,11 @@ set(DATASET_ENGINE_DATASETOPS_SRC_FILES
take_op.cc
shuffle_op.cc
zip_op.cc
concat_op.cc
concat_op.cc
cache_base_op.cc
cache_lookup_op.cc
cache_op.cc
cache_merge_op.cc
)
if (ENABLE_PYTHON)

View File

@ -0,0 +1,185 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/cache_base_op.h"
#include <iomanip>
#include <iostream>
#include "dataset/engine/execution_tree.h"
namespace mindspore {
namespace dataset {
// A print method typically used for debugging
void CacheBase::Print(std::ostream &out, bool show_all) const {
// Always show the id and name as first line regardless if this summary or detailed print
out << "(" << std::setw(2) << operator_id_ << ") <" << Name() << ">:";
if (!show_all) {
// Call the super class for displaying any common 1-liner info
ParallelOp::Print(out, show_all);
out << "\n";
} else {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nCache client:\n" << *cache_client_ << "\n\n";
}
}
// Overrides base class reset method. When an operator does a reset, it cleans up any state
// info from it's previous execution and then initializes itself so that it can be executed
// again.
Status CacheBase::Reset() {
if (sampler_ != nullptr) {
RETURN_IF_NOT_OK(sampler_->ResetSampler());
}
// Wake up the workers to get them going again in a new epoch
MS_LOG(DEBUG) << Name() << " resetting.";
epoch_sync_.Set();
return Status::OK();
}
CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_workers, op_connector_size, sampler),
cache_client_(cache_client),
rows_per_buffer_(rows_per_buf),
// We can cause deadlock if this internal Connector size is too small.
keys_miss_(num_workers_, 1, 1024) {
io_block_queues_.Init(num_workers, op_connector_size);
}
// Common function to fetch samples from the sampler and send them using the io_block_queues to
// the parallel workers
Status CacheBase::FetchSamplesToWorkers() {
int64_t buf_cnt = 0;
int64_t wait_cnt = 0;
do {
epoch_sync_.Clear();
std::vector<row_id_type> keys;
int64_t row_cnt = 0;
keys.reserve(rows_per_buffer_);
std::unique_ptr<DataBuffer> sampler_buffer;
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
while (!sampler_buffer->eoe()) {
TensorRow sample_row;
RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row));
std::shared_ptr<Tensor> sample_ids = sample_row[0];
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) {
keys.push_back(*itr);
++row_cnt;
if (row_cnt % rows_per_buffer_ == 0) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
keys.clear();
}
}
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
}
if (!keys.empty()) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
}
// send the eoe
RETURN_IF_NOT_OK(
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
// If repeat but the not last repeat, wait for reset.
if (BitTest(op_ctrl_flags_, kDeOpRepeated) && !BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << ++wait_cnt << " Buffer sent " << buf_cnt;
RETURN_IF_NOT_OK(epoch_sync_.Wait());
} else {
// We can break out from the loop.
break;
}
} while (true);
// Flow the eof before exit
RETURN_IF_NOT_OK(
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof)));
// Ask all the workers to quit.
for (int32_t i = 0; i < num_workers_; i++) {
RETURN_IF_NOT_OK(
io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
}
return Status::OK();
}
Status CacheBase::FetchFromCache(int32_t worker_id) {
int64_t buffer_id = worker_id;
std::unique_ptr<IOBlock> blk;
do {
RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk));
if (blk->eof()) {
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)));
} else if (blk->eoe()) {
if (AllowCacheMiss()) {
// This code path is for CacheLookupOp acting as a sampler. If we get a eoe from
// a sampler, send a eoe to physical leaf op as well.
std::vector<row_id_type> eoe;
eoe.push_back(eoe_row_id);
RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, eoe));
}
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)));
} else {
std::vector<int64_t> keys;
RETURN_IF_NOT_OK(blk->GetKeys(&keys));
if (keys.empty()) {
// empty key is a quit signal for workers
break;
}
std::unique_ptr<DataBuffer> db = std::make_unique<DataBuffer>(buffer_id, DataBuffer::kDeBFlagNone);
std::unique_ptr<TensorQTable> que = std::make_unique<TensorQTable>();
TensorTable ttbl;
RETURN_IF_NOT_OK(cache_client_->GetRows(keys, &ttbl));
auto row_it = ttbl.begin();
std::vector<row_id_type> cache_miss;
cache_miss.reserve(keys.size());
for (auto row_id : keys) {
auto &row = *row_it;
if (row.empty()) {
if (AllowCacheMiss()) {
cache_miss.push_back(row_id);
} else {
std::string errMsg = "Row id " + std::to_string(row_id) + " not found.";
RETURN_STATUS_UNEXPECTED(errMsg);
}
}
que->push_back(std::move(row));
++row_it;
}
db->set_tensor_table(std::move(que));
if (AllowCacheMiss()) {
// Because of the way connector works, we push unconditionally even cache_miss can be empty.
RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, cache_miss));
}
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db)));
buffer_id += num_workers_;
}
} while (true);
return Status::OK();
}
Status CacheBase::RegisterResources() {
RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
return Status::OK();
}
CacheBase::~CacheBase() {}
Status CacheBase::UpdateColumnMapFromCache() {
Status rc;
// Get the schema from the server. It may not be there yet. So tolerate the error.
if (column_name_id_map_.empty()) {
rc = cache_client_->FetchSchema(&column_name_id_map_);
if (rc == Status(StatusCode::kFileNotExist)) {
MS_LOG(DEBUG) << "Schema not in the server yet.";
rc = Status::OK();
}
}
return rc;
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,108 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
#define DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/cache/cache_service.h"
#include "dataset/engine/datasetops/parallel_op.h"
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/engine/datasetops/source/sampler/sampler.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/util/queue.h"
#include "dataset/util/wait_post.h"
#include "dataset/engine/datasetops/cache_base_op.h"
namespace mindspore {
namespace dataset {
/// \brief This is the base class for CacheOp and CacheLookupOp which share many similarities.
/// \see CacheOp
/// \see CacheLookupOp
class CacheBase : public ParallelOp {
public:
/// \brief Base class constructor
/// \param num_workers Number of parallel workers
/// \param op_connector_size Connector size
/// \param rows_per_buf Number of rows per buffer
/// \param cache_client CacheClient for communication to the CacheServer
/// \param sampler Sampler which is mandatory
CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler);
/// \brief Destructor
~CacheBase();
constexpr static int eoe_row_id = -1;
/// \brief Overrides base class reset method. When an operator does a reset, it cleans up any state
/// info from it's previous execution and then initializes itself so that it can be executed
/// again.
/// \return Status - The error code return
Status Reset() override;
/// \brief A print method typically used for debugging
/// \param out The output stream to write output to
/// \param show_all A bool to control if you want to show all info or just a summary
void Print(std::ostream &out, bool show_all) const override;
/// \brief << Stream output operator overload
/// \notes This allows you to write the debug print info using stream operators
/// \param out reference to the output stream being overloaded
/// \param mo reference to the CacheOp to display
/// \return the output stream must be returned
friend std::ostream &operator<<(std::ostream &out, const CacheBase &mo) {
mo.Print(out, false);
return out;
}
/// \brief Getter for the cache client
/// \return shared ptr to the cache client
std::shared_ptr<CacheClient> cache_client() { return cache_client_; }
/// \brief Setter for the cache client
void SetCacheClient(std::shared_ptr<CacheClient> cache_client) { cache_client_ = std::move(cache_client); }
/// \brief Derived class must implement this method if a cache miss is treated as error
virtual bool AllowCacheMiss() = 0;
protected:
std::shared_ptr<CacheClient> cache_client_;
WaitPost epoch_sync_;
int32_t rows_per_buffer_;
Connector<std::vector<row_id_type>> keys_miss_;
/// \brief Common function to register resources for interrupt
/// \note Derived should override this function for extra resources to be registered
virtual Status RegisterResources();
/// \brief This function is called by main thread to send samples to the worker thread.
/// \note It is a non-virtual function
/// \return Status object
Status FetchSamplesToWorkers();
/// \brief This function is called by each worker to fetch rows from the cache server for a given set of
/// sample row id's
/// \return Status object
Status FetchFromCache(int32_t worker_id);
/// \brief Get the column map from cache server
Status UpdateColumnMapFromCache();
private:
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_

View File

@ -0,0 +1,130 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/cache_lookup_op.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/core/config_manager.h"
#include "dataset/core/constants.h"
#include "dataset/core/global_context.h"
#include "dataset/engine/execution_tree.h"
#include "utils/log_adapter.h"
#include "utils/system/crc32c.h"
namespace mindspore {
namespace dataset {
// Builder constructor. Creates the builder object.
CacheLookupOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
build_num_workers_ = cfg->num_parallel_workers();
rows_per_buffer_ = cfg->rows_per_buffer();
build_op_connector_size_ = cfg->op_connector_size();
}
// Check if the required parameters are set by the builder.
Status CacheLookupOp::Builder::SanityCheck() const {
if (build_cache_client_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheLookupOp requires a CacheClient");
}
// Make sure the cache client has a valid session
if (!build_cache_client_->session_id()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"Cache client for CacheLookupOp is missing session id");
}
return Status::OK();
}
// The builder "build" method creates the final object and does some init on it
Status CacheLookupOp::Builder::Build(std::shared_ptr<CacheLookupOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<CacheLookupOp>(build_num_workers_, build_op_connector_size_, rows_per_buffer_,
build_cache_client_, build_sampler_);
return Status::OK();
}
Status CacheLookupOp::operator()() {
if (!sampler_) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"CacheLookupOp requires a sampler before it can be executed!");
}
RETURN_IF_NOT_OK(RegisterResources());
// Kick off the workers
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&CacheLookupOp::WorkerEntry, this, std::placeholders::_1)));
// required task group sync after launching workers
TaskManager::FindMe()->Post();
// We have to wait until the leaf op has handshake with us.
RETURN_IF_NOT_OK(leaf_op_wp_.Wait());
RETURN_IF_NOT_OK(FetchSamplesToWorkers());
return Status::OK();
}
Status CacheLookupOp::WorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(FetchFromCache(worker_id));
return Status::OK();
}
Status CacheLookupOp::ResetSampler() { return Status::OK(); }
Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op) {
// We act like a sampler and as a dataset op. During handshake with leaf op,
// We must wait until the leaf op has indexed everything.
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(op));
// Now we notify the main thread handshake has finished.
leaf_op_wp_.Set();
return Status::OK();
}
Status CacheLookupOp::InitSampler() { return Sampler::InitSampler(); }
void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); }
Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
std::vector<row_id_type> cache_miss;
RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss));
// Ignore the case we have no cache miss, we can't return empty samples.
while (cache_miss.empty()) {
RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss));
}
// Special code for eoe
if (cache_miss.at(0) == eoe_row_id) {
*out_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
std::shared_ptr<Tensor> sample_ts;
RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ts, cache_miss.size()));
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagNone);
auto idPtr = sample_ts->begin<int64_t>();
for (auto i = 0; i < cache_miss.size(); ++i) {
*idPtr = cache_miss.at(i);
++idPtr;
}
TensorRow row;
row.push_back(sample_ts);
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
}
return Status::OK();
}
Status CacheLookupOp::RegisterResources() {
RETURN_IF_NOT_OK(CacheBase::RegisterResources());
RETURN_IF_NOT_OK(leaf_op_wp_.Register(tree_->AllTasks()));
return Status::OK();
}
Status CacheLookupOp::ComputeColMap() {
// We don't know the column map at this point unless we contact the cache server
// to fetch the schema but the cache server may not have it at this point either.
// So we will just return OK and let MergeOp (our parent) to handle it.
return Status::OK();
}
// Visitor accept method for NodePass
Status CacheLookupOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<CacheLookupOp>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,122 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_
#define DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_
#include <atomic>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dataset/engine/datasetops/cache_base_op.h"
namespace mindspore {
namespace dataset {
/// \brief provides a memory/disk cache that acts as a save-point within a mappable dataset.
/// \note For non-mappable dataset, please see CacheOp
/// \see CacheOp
class CacheLookupOp : public CacheBase, public Sampler {
public:
class Builder {
public:
/// \brief Builder constructor. Creates the builder object.
/// \note No default args
Builder();
/// Default destructor
~Builder() = default;
/// Setter method.
/// \treturn Builder setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
build_num_workers_ = num_workers;
return *this;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t connector_size) {
build_op_connector_size_ = connector_size;
return *this;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
build_cache_client_ = cache_client;
return *this;
}
/// \brief Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
build_sampler_ = std::move(sampler);
return *this;
}
/// \brief The builder "build" method creates the final object and does some init on it.
/// \param ptr The shared_ptr to the new CacheLookupOp object
/// \return Status
Status Build(std::shared_ptr<CacheLookupOp> *ptr);
private:
int32_t build_num_workers_;
int32_t rows_per_buffer_;
int32_t build_op_connector_size_;
std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<Sampler> build_sampler_;
// Check if the required parameters are set by the builder.
// \return Status The error code return
Status SanityCheck() const;
};
/// \brief Constructor
/// \note It takes the same argument as the base class.
/// \see CacheBase
CacheLookupOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
: CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), Sampler(*(sampler.get())) {}
~CacheLookupOp() = default;
// As a parallel op, we override these two functions
Status operator()() override;
Status WorkerEntry(int32_t worker_id) override;
// As a sampler, we override the following functions
Status ResetSampler() override;
Status HandshakeRandomAccessOp(const RandomAccessOp *op) override;
Status InitSampler() override;
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
void Print(std::ostream &out, bool show_all) const override;
bool AllowCacheMiss() override { return true; }
std::string Name() const override { return "CacheLookupOp"; }
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
protected:
Status ComputeColMap() override;
private:
WaitPost leaf_op_wp_;
Status RegisterResources() override;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_

View File

@ -0,0 +1,301 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <functional>
#include <iomanip>
#include "dataset/core/config_manager.h"
#include "dataset/core/constants.h"
#include "dataset/core/global_context.h"
#include "dataset/engine/datasetops/cache_merge_op.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/util/task_manager.h"
namespace mindspore {
namespace dataset {
CacheMergeOp::~CacheMergeOp() = default;
void CacheMergeOp::Print(std::ostream &out, bool show_all)
const { // Always show the id and name as first line regardless if this is summary or detailed print
out << "(" << std::setw(2) << operator_id_ << ") <CacheMergeOp>:";
if (!show_all) {
// Call the super class for displaying any common 1-liner info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << "\n";
} else {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\n\n";
}
}
CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler)
: ParallelOp(numWorkers, opConnectorSize, sampler), num_cleaners_(numCleaners), cache_client_(cache_client) {}
Status CacheMergeOp::operator()() {
// A queue of row id to let cleaner send cache miss rows to the cache server
// We don't want a small queue as this will block the parallel op workers.
// A row id is 8 byte integer. So bigger size doesn't consume a lot of memory.
io_que_ = std::make_unique<Queue<row_id_type>>(512);
RETURN_IF_NOT_OK(io_que_->Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&CacheMergeOp::WorkerEntry, this, std::placeholders::_1)));
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&CacheMergeOp::CacheMissWorkerEntry, this, std::placeholders::_1)));
// One dedicated thread to move TensorRow from the pool to the cache server
for (auto i = 0; i < num_cleaners_; ++i) {
RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Cleaner", std::bind(&CacheMergeOp::Cleaner, this)));
}
TaskManager::FindMe()->Post();
return Status::OK();
}
// Each parallel worker will pop from the CacheHit stream. If there is a missing TensorRow, we will wait
// until it shows up in the pool.
Status CacheMergeOp::WorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post();
std::shared_ptr<DatasetOp> cache_hit_stream = child_[kCacheHitChildIdx];
std::unique_ptr<DataBuffer> db_ptr;
RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id));
while (!db_ptr->eof()) {
if (db_ptr->eoe()) {
RETURN_IF_NOT_OK(EoeReceived(worker_id));
db_ptr.reset();
RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id));
} else {
// See if there is any missing row
auto tbl = std::make_unique<TensorQTable>();
while (db_ptr->NumRows() > 0) {
TensorRow row;
RETURN_IF_NOT_OK(db_ptr->PopRow(&row));
if (row.empty()) {
auto row_id = row.getId();
TensorRowRequest *rq = nullptr;
RETURN_IF_NOT_OK(GetRq(row_id, &rq));
// Block until the row shows up in the pool.
RETURN_IF_NOT_OK(rq->Wait(&row));
}
tbl->push_back(std::move(row));
}
db_ptr->set_tensor_table(std::move(tbl));
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr)));
RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id));
}
}
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr)));
return Status::OK();
}
Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
TaskManager::FindMe()->Post();
// We will simply pop TensorRow from the stream and insert them into the pool and
// wake up any worker that is awaiting on the missing TensorRow.
// If we see an eoe, ignore it. For eof, we exit.
std::shared_ptr<DatasetOp> cache_missing_stream = child_[kCacheMissChildIdx];
// Before we start, cache the schema at the server. Pick one of the workers
// do it. The schema should have been done at prepare time.
if (workerId == 0) {
RETURN_IF_NOT_OK(cache_client_->CacheSchema(column_name_id_map()));
}
std::unique_ptr<DataBuffer> db_ptr;
RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId));
while (!db_ptr->eof()) {
if (db_ptr->eoe()) {
// Ignore it.
MS_LOG(DEBUG) << "Ignore eoe";
} else {
while (db_ptr->NumRows() > 0) {
TensorRow row;
RETURN_IF_NOT_OK(db_ptr->PopRow(&row));
row_id_type row_id = row.getId();
if (row_id < 0) {
std::string errMsg = "Expect positive row id: " + std::to_string(row_id);
RETURN_STATUS_UNEXPECTED(errMsg);
}
TensorRowRequest *rq = nullptr;
RETURN_IF_NOT_OK(GetRq(row_id, &rq));
rq->WakeUpAny(std::move(row));
// Let the cleaner to flush out this row (async) to the cache server.
RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id));
}
}
RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId));
}
return Status::OK();
}
Status CacheMergeOp::Cleaner() {
TaskManager::FindMe()->Post();
while (true) {
row_id_type row_id;
RETURN_IF_NOT_OK(io_que_->PopFront(&row_id));
if (row_id < 0) {
break;
}
TensorRowRequest *rq = nullptr;
RETURN_IF_NOT_OK(GetRq(row_id, &rq));
if (rq->GetState() == TensorRowRequest::State::kClean) {
// If already flushed, move on to the next one.
continue;
}
TensorRow row;
RETURN_IF_NOT_OK(rq->Release(&row));
CHECK_FAIL_RETURN_UNEXPECTED(!row.empty(), "Programming error");
Status rc = cache_client_->WriteRow(row);
// Bad rc should not bring down the pipeline
if (rc.IsError()) {
MS_LOG(WARNING) << "Cache not successful." << rc.ToString();
}
rq->SetState(TensorRowRequest::State::kClean);
}
return Status::OK();
}
Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowRequest **out) {
RETURN_UNEXPECTED_IF_NULL(out);
std::unique_lock<std::mutex> lck(mux_);
auto it = cache_miss_map_.find(row_id);
if (it != cache_miss_map_.end()) {
*out = it->second.GetMutablePointer();
} else {
// We will create a new one.
auto alloc = Services::GetAllocator<TensorRowRequest>();
auto r = cache_miss_map_.emplace(row_id, MemGuard<TensorRowRequest, Allocator<TensorRowRequest>>(alloc));
if (r.second) {
auto &mem = r.first->second;
RETURN_IF_NOT_OK(mem.allocate(1, row_id));
*out = mem.GetMutablePointer();
} else {
RETURN_STATUS_UNEXPECTED("Map insert fail.");
}
}
return Status::OK();
}
Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from super class first before adding our own
// specific logic
CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 2, "Incorrect number of children");
RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction());
// Get the computed check sum from all ops in the cache miss class
uint32_t cache_crc = DatasetOp::GenerateCRC(child_[kCacheMissChildIdx]);
// This is a mappable cache op so the id's need to be generated.
// Construct the cache
const bool generate_ids = false;
Status rc = cache_client_->CreateCache(cache_crc, generate_ids);
if (rc.get_code() == StatusCode::kDuplicateKey) {
// We are told the cache has been created already.
MS_LOG(INFO) << "Cache created already";
rc = Status::OK();
}
RETURN_IF_NOT_OK(rc);
return Status::OK();
}
Status CacheMergeOp::ComputeColMap() {
CHECK_FAIL_RETURN_UNEXPECTED(child_[kCacheMissChildIdx] != nullptr, "Cache miss stream empty");
if (column_name_id_map().empty()) {
column_name_id_map_ = child_[kCacheMissChildIdx]->column_name_id_map();
}
CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map().empty(), "No column map detected");
return Status::OK();
}
Status CacheMergeOp::TensorRowRequest::Wait(TensorRow *out) {
RETURN_UNEXPECTED_IF_NULL(out);
// Block until the missing row is in the pool.
RETURN_IF_NOT_OK(use_count_.P());
std::unique_lock<std::mutex> lck(dq_mux_);
CHECK_FAIL_RETURN_UNEXPECTED(!row_.empty(), "Programming error");
*out = std::move(row_.front());
row_.pop_front();
return Status::OK();
}
void CacheMergeOp::TensorRowRequest::WakeUpAny(TensorRow &&row) {
std::unique_lock<std::mutex> lck(dq_mux_);
// Technically number of this row shows up in the cache miss stream is equal to the number
// of P() call. However the cleaner wants it too. So we need an extra copy.
if (GetState() == State::kEmpty) {
// We will do a deep copy
for (auto &ts : row) {
auto out_ts = std::make_shared<Tensor>(ts->shape(), ts->type(), ts->GetBuffer(), ts->SizeInBytes());
cleaner_copy_.push_back(out_ts);
}
cleaner_copy_.setId(row.getId());
// Change the state to dirty
SetState(State::kDirty);
}
row_.push_back(std::move(row));
// Bump up the use count by 1. This wake up any parallel worker which is waiting
// for this row.
use_count_.V();
}
Status CacheMergeOp::TensorRowRequest::Release(TensorRow *out) {
RETURN_UNEXPECTED_IF_NULL(out);
// We are not holding any mutex here because the cleaner isn't really touching the deque row_.
// In case we have multiple cleaners and they all see the copy, only one of them will
// get it.
auto expected = State::kDirty;
if (st_.compare_exchange_strong(expected, State::kClean)) {
*out = std::move(cleaner_copy_);
}
return Status::OK();
}
// Builder constructor. Creates the builder object.
CacheMergeOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
build_num_workers_ = cfg->num_parallel_workers();
build_op_connector_size_ = cfg->op_connector_size();
build_num_cleaners_ = 1;
}
// Check if the required parameters are set by the builder.
Status CacheMergeOp::Builder::SanityCheck() const {
if (build_cache_client_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheMergeOp requires a CacheClient");
}
// Make sure the cache client has a valid session
if (!build_cache_client_->session_id()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"Cache client for CacheMergeOp is missing session id");
}
return Status::OK();
}
// The builder "build" method creates the final object and does some init on it
Status CacheMergeOp::Builder::Build(std::shared_ptr<CacheMergeOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<CacheMergeOp>(build_num_workers_, build_op_connector_size_, build_num_cleaners_,
build_cache_client_, build_sampler_);
return Status::OK();
}
// Pre-Visitor accept method for NodePass
Status CacheMergeOp::PreAccept(NodePass *p, bool *modified) {
// Downcast shared pointer then call the pre-visitation
return p->PreRunOnNode(shared_from_base<CacheMergeOp>(), modified);
}
// Visitor accept method for NodePass
Status CacheMergeOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<CacheMergeOp>(), modified);
}
Status CacheMergeOp::EoeReceived(int32_t worker_id) {
// If we are in a repeat path, send the eoe up.
// Otherwise ignore it.
if (BitTest(op_ctrl_flags_, kDeOpRepeated)) {
return DatasetOp::EoeReceived(worker_id);
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,196 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_
#define DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_
#include <atomic>
#include <deque>
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include "dataset/core/tensor_row.h"
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/datasetops/parallel_op.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/util/queue.h"
#include "dataset/util/semaphore.h"
namespace mindspore {
namespace dataset {
/// \brief Provides method to merge two streams (one from CacheLookup and one from cache miss stream) into one single
/// stream
class CacheMergeOp : public ParallelOp {
public:
// Some handshake structures among the main thread, cleaner threads and parallel op threads.
class TensorRowRequest {
public:
enum class State : uint8_t {
kEmpty = 0, // No row in the deque
kDirty = 1, // Cleaner hasn't flushed it to the cache server yet.
kClean = 2 // The row has been flushed already.
};
explicit TensorRowRequest(row_id_type id) : st_(State::kEmpty), use_count_(0) {}
~TensorRowRequest() = default;
State GetState() const { return st_; }
void SetState(State newState) { st_ = newState; }
Status Wait(TensorRow *out);
void WakeUpAny(TensorRow &&row);
Status Release(TensorRow *out);
private:
std::mutex dq_mux_;
std::atomic<State> st_;
Semaphore use_count_;
std::deque<TensorRow> row_;
TensorRow cleaner_copy_;
};
constexpr static int kCacheHitChildIdx = 0; // Cache hit stream
constexpr static int kCacheMissChildIdx = 1; // Cache miss stream
/// \brief The nested builder class inside of the CacheMergeOp is used to help manage all of
/// the arguments for constructing it. Use the builder by setting each argument
/// with the provided set methods, and then finally call the build method to execute
/// the actual construction.
class Builder {
public:
/// Builder constructor. Creates the builder object.
/// \note No default args
Builder();
/// Default destructor
~Builder() = default;
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
build_num_workers_ = num_workers;
return *this;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t connector_size) {
build_op_connector_size_ = connector_size;
return *this;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
build_cache_client_ = cache_client;
return *this;
}
/// \brief Setter method
/// \param sampler
/// \return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
build_sampler_ = std::move(sampler);
return *this;
}
/// \brief Setter method
/// \param num_cleaners
/// \return Builder setter method returns reference to the builder.
Builder &SetNumCleaner(int32_t num_cleaners) {
build_num_cleaners_ = num_cleaners;
return *this;
}
/// The builder "build" method creates the final object and does some init on it.
/// \param ptr The shared_ptr to the new CacheMergeOp object
/// \return Status
Status Build(std::shared_ptr<CacheMergeOp> *ptr);
private:
int32_t build_num_workers_;
int32_t build_op_connector_size_;
int32_t build_num_cleaners_;
std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<Sampler> build_sampler_;
/// Check if the required parameters are set by the builder.
/// \return Status The error code return
Status SanityCheck() const;
};
/// \brief Constructor
/// \param numWorkers Number of parallel workers as a derived class of ParallelOp
/// \param opConnector Size Connector size as a derived class of ParallelOp
/// \param numCleaners Number of cleaners to move cache miss rows into the cache server
/// \param cache_client CacheClient to commmunicate with the Cache server
/// \param sampler as a derived class of ParallelOp
CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler);
~CacheMergeOp();
void Print(std::ostream &out, bool show_all) const override;
friend std::ostream &operator<<(std::ostream &out, const CacheMergeOp &mo) {
mo.Print(out, false);
return out;
}
/// \brief Master thread responsible to spawn all the necessary worker threads for the two streams and
/// the threads for the cleaners.
/// \return
Status operator()() override;
/// \brief Entry function for worker thread that fetch rows from CacheLookupOp
/// \param workerId
/// \return Status object
Status WorkerEntry(int32_t workerId) override;
Status PrepareNodePostAction() override;
/// \brief Entry function for worker thread that fetch rows from the cache miss stream
/// \param workerId
/// \return Status object
Status CacheMissWorkerEntry(int32_t workerId);
Status GetRq(row_id_type row_id, TensorRowRequest **);
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for eoe handling
/// \param worker_id
/// \return Status object
Status EoeReceived(int32_t worker_id) override;
protected:
Status ComputeColMap() override;
private:
std::mutex mux_;
std::map<row_id_type, MemGuard<TensorRowRequest, Allocator<TensorRowRequest>>> cache_miss_map_;
std::unique_ptr<Queue<row_id_type>> io_que_;
std::shared_ptr<CacheClient> cache_client_;
int32_t num_cleaners_;
/// \brief These are the entry functions for the cleaner threads. Each cleaner is responsible for
/// moving cache miss TensorRow into the CacheServer.
/// \return Status object
Status Cleaner();
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_

View File

@ -0,0 +1,219 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/cache_op.h"
#include <memory>
#include <vector>
#include "dataset/core/config_manager.h"
#include "dataset/core/constants.h"
#include "dataset/core/global_context.h"
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/util/task_manager.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
// Builder constructor. Creates the builder object.
CacheOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
build_num_workers_ = cfg->num_parallel_workers();
rows_per_buffer_ = cfg->rows_per_buffer();
build_op_connector_size_ = cfg->op_connector_size();
}
// Check if the required parameters are set by the builder.
Status CacheOp::Builder::SanityCheck() const {
if (build_cache_client_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheOp requires a CacheClient");
}
// Make sure the cache client has a valid session
if (!build_cache_client_->session_id()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cache client for CacheOp is missing session id");
}
return Status::OK();
}
// The builder "build" method creates the final object and does some init on it
Status CacheOp::Builder::Build(std::shared_ptr<CacheOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<CacheOp>(build_num_workers_, build_op_connector_size_, rows_per_buffer_, build_cache_client_,
build_sampler_);
RETURN_IF_NOT_OK((*ptr)->InitCache());
return Status::OK();
}
// Constructor of CacheOp
CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
: CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler),
num_guys_in_(0),
phase_(Phase::kBuildPhase) {}
// Destructor
CacheOp::~CacheOp() = default;
// Private function for cache setup/init work just after construction
Status CacheOp::InitCache() { return Status::OK(); }
// This class functor will provide the master loop that drives the logic for performing the work
Status CacheOp::operator()() {
if (!sampler_) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"CacheOp requires a sampler before it can be executed!");
}
RETURN_IF_NOT_OK(RegisterResources());
// Kick off the workers
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheOp::WorkerEntry, this, std::placeholders::_1)));
// required task group sync after launching workers
TaskManager::FindMe()->Post();
// Wait for the workers to finish caching the rows.
RETURN_IF_NOT_OK(WaitForCachingAllRows());
RETURN_IF_NOT_OK(FetchSamplesToWorkers());
return Status::OK();
}
Status CacheOp::CacheAllRows(int32_t worker_id) {
// If the current phase is to fill the cache, do it then.
if (phase_ == Phase::kBuildPhase) {
// We will take the chance to cache the schema at the server.
// Just do it once and pick one worker to do it.
if (worker_id == 0) {
RETURN_IF_NOT_OK(cache_client_->CacheSchema(column_name_id_map()));
}
MS_LOG(INFO) << "CacheOp first epoch SAVE mode started. Worker: " << worker_id;
// SAVE mode loop
std::unique_ptr<DataBuffer> db_ptr;
RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0));
while (!db_ptr->eof()) {
if (!db_ptr->eoe()) {
RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr)));
} else {
// In a repeat-over-cache scenario, any of the "real" leaf operators below us have been set up
// as non-repeating leaf ops. As such, they only do one epoch and then quit. Since we got the
// the eoe to indicate the end of the epoch, we should next expect to get the eof.
// Drain this eof so that we don't leave it sitting there on a connector that we'll never fetch
// from again.
RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0));
if (!db_ptr->eof()) {
RETURN_STATUS_UNEXPECTED("Cache op expects to get an eof after eoe from child.");
}
}
RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0));
}
}
// Let the main guy know we are done.
auto last_guy_in = num_guys_in_.fetch_add(1);
if ((last_guy_in + 1) == num_workers_) {
rows_cache_done_.Set();
} else {
// Let's do a sync up here.
RETURN_IF_NOT_OK(rows_cache_done_.Wait());
}
return Status::OK();
}
Status CacheOp::WaitForCachingAllRows() {
// Wait for the workers to finish caching the rows.
RETURN_IF_NOT_OK(rows_cache_done_.Wait());
// Move from build phase to fetch phase if we are the one to fill the cache
if (phase_ == Phase::kBuildPhase) {
RETURN_IF_NOT_OK(cache_client_->BuildPhaseDone());
// Move to the next phase
phase_ = Phase::kFetchPhase;
}
// Get statistics from the server, and if we are not the one to create the cache,
// wait until the state changed from build phase to fetch base.
CacheClient::ServiceStat stat{};
bool BuildPhaseDone = true;
do {
RETURN_IF_NOT_OK(cache_client_->GetStat(&stat));
BuildPhaseDone = stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase);
if (!BuildPhaseDone) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
} while (!BuildPhaseDone);
const row_id_type min_key = stat.min_row_id;
const row_id_type max_key = stat.max_row_id;
num_rows_ = max_key - min_key + 1;
MS_LOG(INFO) << "Number of rows cached: " << num_rows_;
MS_LOG(INFO) << "Number of rows cached in memory : " << stat.num_mem_cached;
MS_LOG(INFO) << "Number of rows spilled to disk : " << stat.num_disk_cached;
// Now all rows are cached and we have done a sync point check up. Next phase is
// is pick up fetch input from sampler and pass up to the caller.
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));
return Status::OK();
}
Status CacheOp::WorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(CacheAllRows(worker_id));
RETURN_IF_NOT_OK(FetchFromCache(worker_id));
return Status::OK();
}
Status CacheOp::RegisterResources() {
RETURN_IF_NOT_OK(CacheBase::RegisterResources());
RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(keys_miss_.Register(tree_->AllTasks()));
return Status::OK();
}
// Base-class override for setting specific CacheOp configurations. This code will be called
// during the execution tree prepare phase BEFORE traversing down to child operators.
uint32_t CacheOp::PrepareFlags() const { return ExecutionTree::kDePrepCache; }
// Base-class override for special eoe handler.
// CacheOp must override this because it shall not perform default handling of eoe. Instead
// the CacheOp manages actions related to the end of the epoch.
Status CacheOp::EoeReceived(int32_t worker_id) {
state_ = OpState::kDeOpIdle;
return Status::OK();
}
// Base-class override for handling cases when an eof is received.
Status CacheOp::EofReceived(int32_t worker_id) {
// eofReceived is overloaded because we want to manually handle this eof.
// Specifically, the default behaviour is to pack it and flow it up to the next connection.
// In this case, we want a no-op behaviour so that we can perform correct action.
return Status::OK();
}
// Pre-Visitor accept method for NodePass
Status CacheOp::PreAccept(NodePass *p, bool *modified) {
// Downcast shared pointer then call the pre-visitation
return p->PreRunOnNode(shared_from_base<CacheOp>(), modified);
}
// Visitor accept method for NodePass
Status CacheOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<CacheOp>(), modified);
}
// A public wrapper for creating the cache through the client
Status CacheOp::CreateCache(uint32_t cache_crc) {
// This is a non-mappable cache op so the id's need to be generated.
// Construct the cache
const bool generate_ids = true;
Status rc = cache_client_->CreateCache(cache_crc, generate_ids);
if (rc.get_code() == StatusCode::kDuplicateKey) {
// We are told the cache has been created already. So we skip the build phase.
phase_ = Phase::kFetchPhase;
rc = Status::OK();
}
RETURN_IF_NOT_OK(rc);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,168 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_OP_H_
#define DATASET_ENGINE_DATASETOPS_CACHE_OP_H_
#include <atomic>
#include <string>
#include <utility>
#include <memory>
#include "dataset/engine/datasetops/cache_base_op.h"
namespace mindspore {
namespace dataset {
/// \brief CacheOp provides a memory/disk cache that acts as a save-point within a non-mappable dataset.
/// \note For mappable dataset, please see CacheLookupOp.
/// \see CacheLookupOp
class CacheOp : public CacheBase, public RandomAccessOp {
public:
// This CacheOp is for non-mappable case where it is divided into two phases.
// The first phase is we cache all the rows from the child (and let the cache server
// assigns row id). No read access in the first phase. Once the cache is fully built,
// we switch to second phase and fetch requests from the sampler.
enum class Phase : uint8_t { kBuildPhase = 0, kFetchPhase = 1 };
/// \brief The nested builder class inside of the CacheOp is used to help manage all of
/// the arguments for constructing it. Use the builder by setting each argument
/// with the provided set methods, and then finally call the build method to execute
/// the actual construction.
class Builder {
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @return This is a constructor.
Builder();
// Default destructor
~Builder() = default;
/// \brief Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
build_num_workers_ = num_workers;
return *this;
}
/// \brief Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t connector_size) {
build_op_connector_size_ = connector_size;
return *this;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
build_cache_client_ = cache_client;
return *this;
}
/// \brief Setter method
/// \param rows_per_buffer
/// \return Builder setter method returns reference to the builder.
Builder &SetRowsPerBuffer(int32_t rows_per_buffer) {
rows_per_buffer_ = rows_per_buffer;
return *this;
}
/// \brief Setter method
/// \param sampler
/// \return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
build_sampler_ = std::move(sampler);
return *this;
}
/// \brief The builder "build" method creates the final object and does some init on it.
/// \param ptr The shared_ptr to the new CacheOp object
/// \return Status
Status Build(std::shared_ptr<CacheOp> *ptr);
private:
int32_t build_num_workers_;
int32_t rows_per_buffer_;
int32_t build_op_connector_size_;
std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<Sampler> build_sampler_;
/// \brief Check if the required parameters are set by the builder.
/// \return Status The error code return
Status SanityCheck() const;
};
/// \brief Constructor of CacheOp
/// \note The builder class should be used to call it.
/// \param num_workers The number of worker threads.
/// \param op_connector_size The size of each queue in the connector.
CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler);
// Destructor
~CacheOp();
/// \brief Base-class override for setting specific CacheOp configurations. This code will be called
/// during the execution tree prepare phase BEFORE traversing down to child operators.
uint32_t PrepareFlags() const override;
/// \brief Base-class override for special eoe handler.
/// CacheOp must override this because it shall not perform default handling of eoe. Instead
/// the CacheOp manages actions related to the end of the epoch.
/// \return Status - The error code return
Status EoeReceived(int32_t worker_id) override;
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for handling cases when an eof is received.
/// \param worker_id - The worker id
/// \return Status - The error code return
Status EofReceived(int32_t worker_id) override;
Status operator()() override;
Status WorkerEntry(int32_t worker_id) override;
/// \brief Base-class override for handling cases if we allow cache miss
bool AllowCacheMiss() override { return false; }
/// \brief Base-class override for the name of this operator
std::string Name() const override { return "CacheOp"; }
/// \brief A public wrapper for creating the cache through the client
/// \param[in] cache_crc The crc that identifies the cache
/// \see cache_pass.cc
/// \return Status return code
Status CreateCache(uint32_t cache_crc);
private:
WaitPost rows_cache_done_;
std::atomic<int64_t> num_guys_in_;
Phase phase_;
/// \brief The main thread will wait until all the rows are cached and will start the handshake with the sampler.
/// \return Status object
Status WaitForCachingAllRows();
/// \brief For non-mappable dataset, there is a build phase where we cache all the rows.
/// \return Status object
Status CacheAllRows(int32_t worker_id);
Status RegisterResources() override;
/// \brief Private function for cache setup/init work just after construction
/// \return Status The error code return
Status InitCache();
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_CACHE_OP_H_

View File

@ -61,46 +61,39 @@ void ConcatOp::Print(std::ostream &out, bool show_all) const {
Status ConcatOp::operator()() {
// The children_num_ parameter needs to be put here
children_num_ = static_cast<int32_t>(child_.size());
TaskManager::FindMe()->Post();
std::unique_ptr<DataBuffer> buf;
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf));
int eof_count = 0;
while (eof_count != children_num_) {
while (eof_count == 0) {
for (int i = 0; i < children_num_; i++) {
// 1. Throw the eof buffer when meet it
if (buf->eof() || buf->eoe()) {
RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf));
// 1. Read the first buffer
RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf));
if (buf->eof()) {
eof_count++;
continue;
}
// 2. Do verification as for column name, column data type and rank of column data
RETURN_IF_NOT_OK(Verify(i, buf));
if (!buf->eoe()) {
RETURN_IF_NOT_OK(Verify(i, buf));
}
// 3. Put the data into output_connector
while (!buf->eoe() && !buf->eof()) {
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf)));
RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf));
}
// 4. Throw the eoe buffer when meet it
if (buf->eoe() && (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat))) {
RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf));
}
// 5. Add eoe buffer after get buffer from all child
if (i == (children_num_ - 1)) {
auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
}
if (buf->eof()) {
eof_count++;
}
}
// 4. Add eoe buffer after get buffer from all child
if (eof_count == 0) {
auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
}
}
// 6. Add eof buffer in the end manually
CHECK_FAIL_RETURN_UNEXPECTED(eof_count == children_num_,
"Something went wrong, eof count does not match the number of children.");
// 5. Add eof buffer in the end manually
MS_LOG(DEBUG) << "Add the eof buffer manualy in the end.";
auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
return Status::OK();
}
@ -126,12 +119,6 @@ Status ConcatOp::Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf) {
return Status::OK();
}
Status ConcatOp::PrepareNodePostAction() {
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
tree_->AddToEOEOpStack(shared_from_this());
return Status::OK();
}
// We need to overwrite the super class ComputeColMap here because the number of children is more than 1.
Status ConcatOp::ComputeColMap() {
if (column_name_id_map_.empty()) {

View File

@ -75,12 +75,6 @@ class ConcatOp : public PipelineOp {
// @return Status - The error code return
Status operator()() override;
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
Status PrepareNodePostAction() override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "ConcatOp"; }

View File

@ -153,16 +153,38 @@ Status DatasetOp::Remove() {
}
}
// Finally, clear "this" op's parent and child pointers since we have just
// disconnected it from the tree and invalidate it's fields.
child_.clear();
parent_.clear();
operator_id_ = kInvalidOperatorId;
tree_ = nullptr;
return Status::OK();
}
// Getter function to get a shared pointer to our childAdds a operator to become our child.
// Getter function to get a shared pointer to our child
std::shared_ptr<DatasetOp> DatasetOp::child(int32_t child_index) const {
std::shared_ptr<DatasetOp> return_op = nullptr;
if (child_.empty()) {
return return_op;
}
MS_ASSERT(child_index < static_cast<int>(child_.size()));
// Return a shared pointer
return child_[child_index];
}
// Getter function to get the parent pointer
void DatasetOp::Parent(DatasetOp **parent, int32_t parent_index) const {
if (parent_.empty()) {
// common case if this is a root node
*parent = nullptr;
} else {
MS_ASSERT(parent_index < static_cast<int>(parent_.size()));
*parent = parent_[parent_index];
}
}
// Creates the connector within this operator
void DatasetOp::CreateConnector(int32_t num_producers, int32_t num_consumers) {
MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ". Producer: " << num_producers
@ -264,19 +286,11 @@ Status DatasetOp::EofReceived(int32_t worker_id) {
// During tree prepare phase, operators may have specific pre-operations to perform depending on
// their role.
Status DatasetOp::PrepareNodePreAction() {
if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepRepeat)) set_control_flag(kDeOpRepeated);
return Status::OK();
}
Status DatasetOp::PrepareNodePreAction() { return Status::OK(); }
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
Status DatasetOp::PrepareNodePostAction() {
// If this op does not have any children and it is in a repeat path of the tree...
if (child_.empty() && BitTest(op_ctrl_flags_, kDeOpRepeated)) {
// push ourselves onto the eoe operator stack. Later, a repeat/epoch ctrl operator
// above us will consume them.
tree_->AddToEOEOpStack(shared_from_this());
}
// Creating Connector object for each op.
// The consumer of the root node is assumed to be one thread.
// If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion.
@ -346,34 +360,13 @@ Status DatasetOp::Accept(NodePass *p, bool *modified) {
return p->RunOnNode(shared_from_this(), modified);
}
// A helper function with some common code that leaf nodes can use during
// prepare phase for checking if they need to assign a sampler to the cache.
Status DatasetOp::SaveSamplerForCache(bool random_access_op) {
// If we are a descendant under a cache op and we have a sampler, then save this sampler
// to a stack so that the cache can pick it up during it's processing above us.
if (sampler_) {
if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) {
// use move semantic to set our sampler_ to null after the move. This is okay because a sampler is
// useless to a random data op. It was only being used as a temporary holding until the cache can
// be created
tree_->AddToSamplerStack(sampler_);
MS_LOG(INFO) << "Preparing a leaf op: passing sampler up the tree for Cache handling.";
} else if (!random_access_op) {
// A sampler exists, but we are not in a caching tree and we are not a random access mappable leaf.
// This is an error because that type of leaf does not use sampling unless there's a cache to hook it into.
RETURN_STATUS_UNEXPECTED(
"Non-mappable leaf op has a sampler, but it only supports sampling if there is a cache after it in the tree");
}
}
if (!random_access_op) {
// Since we don't truly need the sampler for this non-mappable dataset and it's been saved for the cache
// we can remove it now from the base.
sampler_.reset();
}
// Getter for the sampler, and it also removes the sampler from the op
Status DatasetOp::FetchRemoveSampler(std::shared_ptr<Sampler> *sampler) {
*sampler = sampler_; // It's okay if it sampler_ points to nullptr
sampler_.reset(); // clear our member-copy of this pointer. We no longer have this sampler
return Status::OK();
}
uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) {
std::stringstream ss;
op->tree_->Print(ss, op);

View File

@ -45,10 +45,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
public:
static constexpr int32_t kInvalidOperatorId = -1;
// Flags that control operator runtime behaviours
// Operator control flags
enum OpControlFlags {
kDeOpNone = 0,
kDeOpRepeated = 1, // Operator is a leaf node in a repeat path
kDeOpRepeated = 1, // Operator is a node in a repeat path
kDeOpLastRepeat = 1 << 1 // We are in the last repeat loop
};
@ -71,17 +71,23 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \param child - shared pointer to the child to remove.
Status RemoveChild(std::shared_ptr<DatasetOp> child);
/// \brief Removes this node from the tree and connects it's parent/child together.
/// \brief Removes this node from the tree and connects it's parent/child together
/// \return Status eerror code returned
Status Remove();
/// \brief Getter function to get a shared pointer to our child
/// \param child_index - An operator can have n children. Indicates choose which child to return.
/// \param[in] child_index An operator can have n children. Indicates which child to return.
/// \return The shared pointer to the child. If there are no children, it returns null regardless of the given index
std::shared_ptr<DatasetOp> child(int32_t child_index) const;
/// \brief Inserts a operator as the parent current op.
/// Inserted op will become the sole parent of the current op.
/// The existing parent of the current op will be transferred to the inserted op.
/// \brief Getter function to get the pointer to our parent
/// If there are no parents, it returns null regardless of the given index
/// \param[in] parent_index An operator can have n parents. Indicates which parent to return.
void Parent(DatasetOp **parent, int32_t parent_index) const;
// Inserts a operator as the parent current op.
// Inserted op will become the sole parent of the current op.
// The existing parent of the current op will be transferred to the inserted op.
Status InsertAsParent(std::shared_ptr<DatasetOp> to_add);
/// \brief Creates the connector within this operator
@ -161,16 +167,6 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return Status - The error code return
virtual Status Reset();
/// \brief This calls the reset function on this subtree in pre-order
/// \return Status - The error code return
virtual Status ResetSubtree() {
RETURN_IF_NOT_OK(Reset());
for (const auto &c : child_) {
RETURN_IF_NOT_OK(c->ResetSubtree());
}
return Status::OK();
}
/// \brief During tree prepare phase, operators may have specific pre-operations to perform depending on
/// their role.
/// \notes Derived versions of this function should always call it's superclass version first
@ -296,7 +292,12 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return Shared pointer to the sampler (may return nullptr)
std::shared_ptr<Sampler> sampler() { return sampler_; }
/// Computes a CRC value for the operator
/// \brief Getter for the sampler, and it also removes the sampler from the op
/// \param[out] sampler A pointer to the output sampler that was removed
/// \return Status error code
Status FetchRemoveSampler(std::shared_ptr<Sampler> *sampler);
// Computes a CRC value for the operator
static uint32_t GenerateCRC(const std::shared_ptr<DatasetOp> &op);
/// \brief A helper templated function for casting "this" pointer to shared_ptr<derived>
@ -307,17 +308,24 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
return std::static_pointer_cast<Derived>(shared_from_this());
}
protected:
/// Adds a parent operator to this operator
/// \notes External callers do not have access to this function.
/// \param parent - The parent node to add
void AddParent(DatasetOp *parent);
/// \brief Setter for the sampler. Allows you to overwrite a previous sampler with a new one.
void SetSampler(std::shared_ptr<Sampler> sampler) { sampler_ = sampler; }
/// Removes a parent operator from this operator
/// \notes External callers do not have access to this function.
/// \param parent - The parent node to remove
/// \brief Checks if this is a leaf node (0 children)
/// \return boolean returns true if it's a leaf
bool IsLeaf() { return (child_.empty()); }
protected:
/// \brief Removes a parent operator from this operator
/// \notes External callers do not have access to this function
/// \param[in] parent The parent node to remove
void RemoveParent(const DatasetOp *parent);
/// \brief Adds a parent operator to this operator
/// \notes External callers do not have access to this function
/// \param[in] parent The parent node to add
void AddParent(DatasetOp *parent);
/// Compute the current op's column map using its child's column map.
/// Get called during the tree post-prepare phase in PrepareNodePostAction.
/// This base implementation just inherits the map from child 0, and can only be used if the number of children is 1.
@ -325,12 +333,6 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return - Status
virtual Status ComputeColMap();
/// A helper function with some common code that leaf nodes can use during
/// pre/pare phase for checking if they need to assign a sampler to the cache.
/// \param random_access_op - indicate if this is a mappable random access leaf or not
/// \return - Status
Status SaveSamplerForCache(bool random_access_op);
std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes
std::vector<DatasetOp *> parent_; // Parent nodes. No ownership
std::shared_ptr<Sampler> sampler_; // Some leaf ops might have a sampler

View File

@ -77,26 +77,6 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const {
}
}
// Base-class override for executing specific RepeatOp configurations. This code will be called
// during the execution tree prepare phase when it is visiting this operator.
Status RepeatOp::PrepareNodePostAction() {
// Run any common code from super class first before adding our own specific logic
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
std::shared_ptr<DatasetOp> leaf_op = tree_->PopFromEOEOpStack();
while (leaf_op != nullptr) {
// Track the leaf operators that are under this repeat op.
eoe_ops_.push_back(leaf_op);
leaf_op = tree_->PopFromEOEOpStack();
}
// Push ourselves to the stack in case one of our ascendants is repeat too.
tree_->AddToEOEOpStack(shared_from_this());
return Status::OK();
}
// Base-class override for setting specific RepeatOp configurations. This code will be called
// during the execution tree prepare phase BEFORE traversing down to child operators.
uint32_t RepeatOp::PrepareFlags() const { return ExecutionTree::kDePrepRepeat; }
// This function returns the buffer that is at the top of our output connector. The caller is
// typically our parent node, when the parent is asking us to provide the next buffer of data.
// Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get
@ -130,7 +110,8 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo
// Base-class override for handling cases when an eoe is received.
Status RepeatOp::EoeReceived(int32_t worker_id) {
repeat_count_++;
MS_LOG(DEBUG) << "Repeat operator end of epoch message received. Repeat count is now: " << repeat_count_ << ".";
MS_LOG(DEBUG) << "Repeat operator (" << operator_id_
<< ") end of epoch message received. Repeat count is now: " << repeat_count_ << ".";
bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated);
bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat);
// If we've reached the requested repeat count, then flag the eoe nodes
@ -149,8 +130,12 @@ Status RepeatOp::EoeReceived(int32_t worker_id) {
return Status::OK();
}
// base-class ResetSubtree
return (DatasetOp::ResetSubtree());
// Invoke a reset against the eoe nodes only.
for (auto &eoe_op : eoe_ops_) {
RETURN_IF_NOT_OK(eoe_op->Reset());
}
return Status::OK();
}
// Class functor operator () override.
@ -178,6 +163,18 @@ int32_t RepeatOp::num_consumers() const {
}
}
// Drive reset actions if needed
Status RepeatOp::Reset() {
// If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op.
// In that case, we now have to bounce the reset down to our own eoe ops.
MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ << ") reset.";
for (auto &eoe_op : eoe_ops_) {
RETURN_IF_NOT_OK(eoe_op->Reset());
}
state_ = OpState::kDeOpRunning;
return Status::OK();
}
int32_t RepeatOp::num_producers() const {
if (child_.empty() || child_[0] == nullptr) {
MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0.";
@ -187,6 +184,12 @@ int32_t RepeatOp::num_producers() const {
}
}
// Pre-Visitor accept method for NodePass
Status RepeatOp::PreAccept(NodePass *p, bool *modified) {
// Downcast shared pointer then call the pre-visitation
return p->PreRunOnNode(shared_from_base<RepeatOp>(), modified);
}
// Visitor accept method for NodePass
Status RepeatOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor

View File

@ -18,6 +18,7 @@
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dataset/engine/datasetops/pipeline_op.h"
@ -82,14 +83,6 @@ class RepeatOp : public PipelineOp {
// @return Status - The error code return
Status operator()() override;
// Base-class override for setting specific RepeatOp configurations. This code will be called
// during the execution tree prepare phase BEFORE traversing down to child operators.
uint32_t PrepareFlags() const override;
// Base-class override for executing specific RepeatOp configurations. This code will be called
// during the execution tree post-prepare phase when it is visiting this operator.
Status PrepareNodePostAction() override;
// This function returns the buffer that is at the top of our output connector. The caller is
// typically our parent node, when the parent is asking us to provide the next buffer of data.
// Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get
@ -110,6 +103,10 @@ class RepeatOp : public PipelineOp {
// @param worker_id - The worker id
Status EofReceived(int32_t worker_id) override;
/// \brief reset Op
/// \@return Status - The error code return
Status Reset() override;
// Base-class override. Return the number of workers in the first parent.
// @param workerId - The worker id
int32_t num_consumers() const override;
@ -118,16 +115,26 @@ class RepeatOp : public PipelineOp {
// @param workerId - The worker id
int32_t num_producers() const override;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "RepeatOp"; }
/// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
/// \param[in] eoe_op The input leaf/eoe operator to add to the list
void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); }
private:
int32_t max_repeats_; // The number of repeats that the user requested
int32_t repeat_count_; // A counter for the current number of executed repeats

View File

@ -22,6 +22,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/data_schema.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/kernels/image/image_utils.h"
namespace mindspore {
@ -408,6 +409,12 @@ Status CelebAOp::Reset() {
return Status::OK();
}
// Visitor accept method for NodePass
Status CelebAOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<CelebAOp>(), modified);
}
Status CelebAOp::ComputeColMap() {
// Set the column name map (base class field)
if (column_name_id_map_.empty()) {

View File

@ -169,6 +169,12 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// @return Status - The error code return
Status AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
/// \param[out] modified Indicator if the node was changed at all
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const { return "CelebAOp"; }

View File

@ -26,6 +26,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
@ -450,6 +451,12 @@ Status CifarOp::CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *
}
}
// Visitor accept method for NodePass
Status CifarOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<CifarOp>(), modified);
}
Status CifarOp::ComputeColMap() {
// set the column name map (base class field)
if (column_name_id_map_.empty()) {

View File

@ -155,6 +155,12 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// @return
static Status CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
/// \param[out] modified Indicator if the node was changed at all
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "CifarOp"; }

View File

@ -24,6 +24,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
@ -624,6 +625,12 @@ Status CocoOp::GetClassIndexing(const std::string &dir, const std::string &file,
return Status::OK();
}
// Visitor accept method for NodePass
Status CocoOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<CocoOp>(), modified);
}
Status CocoOp::ComputeColMap() {
// Set the column name map (base class field)
if (column_name_id_map_.empty()) {

View File

@ -200,6 +200,12 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode,
std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
/// \param[out] modified Indicator if the node was changed at all
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return

View File

@ -26,6 +26,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
@ -416,6 +417,12 @@ Status ManifestOp::GetClassIndexing(const std::string &file, const py::dict &dic
return Status::OK();
}
// Visitor accept method for NodePass
Status ManifestOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<ManifestOp>(), modified);
}
Status ManifestOp::ComputeColMap() {
// Set the column name map (base class field)
if (column_name_id_map_.empty()) {

View File

@ -172,6 +172,12 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
static Status GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage,
std::map<std::string, int32_t> *output_class_indexing);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
/// \param[out] modified Indicator if the node was changed at all
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "ManifestOp"; }

View File

@ -23,6 +23,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
@ -428,6 +429,12 @@ Status MnistOp::CountTotalRows(const std::string &dir, int64_t *count) {
return Status::OK();
}
// Visitor accept method for NodePass
Status MnistOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<MnistOp>(), modified);
}
Status MnistOp::ComputeColMap() {
// set the column name map (base class field)
if (column_name_id_map_.empty()) {

View File

@ -152,6 +152,12 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
// @return
static Status CountTotalRows(const std::string &dir, int64_t *count);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
/// \param[out] modified Indicator if the node was changed at all
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "MnistOp"; }

View File

@ -22,6 +22,7 @@
#include "dataset/util/random.h"
#include "dataset/util/wait_post.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
@ -406,6 +407,12 @@ Status RandomDataOp::Reset() {
return Status::OK();
}
// Visitor accept method for NodePass
Status RandomDataOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<RandomDataOp>(), modified);
}
Status RandomDataOp::ComputeColMap() {
// Extract the column name mapping from the schema and save it in the class.
if (column_name_id_map_.empty()) {
@ -415,15 +422,5 @@ Status RandomDataOp::ComputeColMap() {
}
return Status::OK();
}
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
Status RandomDataOp::PrepareNodePostAction() {
// Run common code from super class before adding RandomDataOp specific handling
RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction());
// Specific handling for this op, we need to do cache op work to assign the sampler to the cache.
RETURN_IF_NOT_OK(DatasetOp::SaveSamplerForCache(false));
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -203,12 +203,6 @@ class RandomDataOp : public ParallelOp {
// @return Name of the current Op
std::string Name() const override { return "RandomDataOp"; }
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
Status PrepareNodePostAction() override;
private:
/**
* The entry point code for when workers are launched
@ -266,6 +260,12 @@ class RandomDataOp : public ParallelOp {
return ++buffer_id_;
}
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
// Private function for computing the assignment of the column name map.
// @return - Status
Status ComputeColMap() override;

View File

@ -1019,31 +1019,28 @@ Status TFReaderOp::ComputeColMap() {
return Status::OK();
}
// Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing
// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so
// that this tf reader will produce the full set of data into the cache.
void TFReaderOp::MakeSimpleProducer() {
device_id_ = 0;
num_devices_ = 1;
total_rows_ = 0;
shuffle_files_ = false;
equal_rows_per_shard_ = false;
}
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
Status TFReaderOp::PrepareNodePostAction() {
// Run common code from super class before adding TFReaderOp specific handling
RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction());
// Specific handling for this op, we need to do cache op work so assign the sampler to the cache
// TF is a special case because it can support file-based sharding/shuffling, or, if there
// is a cache, then it can also do row-based sampler using the sampler on the cache.
// Thus, pass true for random access op flag when saving the sampler. This is a special case,
// since usually a non-mappable dataset would pass false here.
RETURN_IF_NOT_OK(DatasetOp::SaveSamplerForCache(true));
// Now that the sampler has been saved for the cache, we need to adjust the TFReaderOp to turn it into
// a simpler producer of all data (no shuffling or sharding or anything)
if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) {
device_id_ = 0;
num_devices_ = 1;
total_rows_ = 0;
shuffle_files_ = false;
equal_rows_per_shard_ = false;
sampler_.reset(); // Normally SaveSampler code did this for us, but we passed in true above (See comment)
} else {
if (!BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) {
// This sanity check had been delayed until now in the prepare loop.
// If we are not in a cache path, then we can validate the the file-based sharding config.
// If we are not in a cache path, then we can validate the file-based sharding config.
// If we are in a cache path, there is no file-based sharding so the check is not correct in that
// situation.
if (!equal_rows_per_shard_ && dataset_files_list_.size() < static_cast<uint32_t>(num_devices_)) {

View File

@ -246,6 +246,11 @@ class TFReaderOp : public ParallelOp {
// @return Vector of the input file names
std::vector<std::string> FileNames() { return dataset_files_list_; }
/// \Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so
/// that this tf reader will produce the full set of data into the cache.
void MakeSimpleProducer();
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first

View File

@ -25,6 +25,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
using tinyxml2::XMLDocument;
using tinyxml2::XMLElement;
@ -449,6 +450,11 @@ Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_t
return Status::OK();
}
// Visitor accept method for NodePass
Status VOCOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<VOCOp>(), modified);
}
Status VOCOp::ComputeColMap() {
// Set the column name map (base class field)

View File

@ -205,6 +205,12 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, std::map<std::string, int32_t> *output_class_indexing);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
/// \param[out] modified Indicator if the node was changed at all
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "VOCOp"; }

View File

@ -127,12 +127,6 @@ Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<D
return Status::OK();
}
Status TakeOp::PrepareNodePostAction() {
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
tree_->AddToEOEOpStack(shared_from_this());
return Status::OK();
}
// Visitor accept method for NodePass
Status TakeOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor

View File

@ -78,12 +78,6 @@ class TakeOp : public PipelineOp {
// @return Status - The error code return
Status operator()() override;
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
Status PrepareNodePostAction() override;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.

View File

@ -21,6 +21,8 @@
#include "dataset/util/task_manager.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/engine/opt/pre/removal_pass.h"
#include "dataset/engine/opt/pre/cache_transform_pass.h"
#include "dataset/engine/opt/post/repeat_pass.h"
#include "dataset/engine/perf/profiling.h"
#include "dataset/engine/perf/monitor.h"
@ -215,18 +217,33 @@ Status ExecutionTree::PrepareTreePreAction() {
bool modified = false;
std::vector<std::unique_ptr<Pass>> pre_actions;
// Construct pre actions
MS_LOG(INFO) << "Running pre pass";
pre_actions.push_back(std::make_unique<RemovalPass>(RemovalPass()));
MS_LOG(INFO) << "Running pre pass loops.";
pre_actions.push_back(std::make_unique<RemovalPass>());
pre_actions.push_back(std::make_unique<CacheTransformPass>());
// Apply pre action passes
for (auto &pass : pre_actions) {
RETURN_IF_NOT_OK(pass->Run(this, &modified));
}
MS_LOG(INFO) << "Pre passes complete.";
return Status::OK();
}
Status ExecutionTree::PrepareTreePostAction() {
// The tree is ready to be prepared.
tree_state_ = kDeTStatePrepare;
bool modified = false;
std::vector<std::unique_ptr<Pass>> post_actions;
// Construct pre actions
MS_LOG(INFO) << "Running post pass loops.";
post_actions.push_back(std::make_unique<RepeatPass>());
// Apply post action passes
for (auto &pass : post_actions) {
RETURN_IF_NOT_OK(pass->Run(this, &modified));
}
MS_LOG(INFO) << "Post passes complete.";
return Status::OK();
}
@ -280,31 +297,5 @@ Status ExecutionTree::PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op)
return Status::OK();
}
// Adds an operator to the eoe operator stack during prepare phase.
void ExecutionTree::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { eoe_stack_.push(dataset_op); }
// Pops an operator from the eoe operator stack during prepare phase.
std::shared_ptr<DatasetOp> ExecutionTree::PopFromEOEOpStack() {
std::shared_ptr<DatasetOp> top_op = nullptr;
if (!eoe_stack_.empty()) {
top_op = eoe_stack_.top();
eoe_stack_.pop();
}
return top_op;
}
// Adds a sampler to the sampler stack during prepare phase.
void ExecutionTree::AddToSamplerStack(std::shared_ptr<Sampler> sampler) { sampler_stack_.push(sampler); }
// Pops an operator from the sampler stack during prepare phase.
std::shared_ptr<Sampler> ExecutionTree::PopFromSamplerStack() {
std::shared_ptr<Sampler> top_sampler = nullptr;
if (!sampler_stack_.empty()) {
top_sampler = sampler_stack_.top();
sampler_stack_.pop();
}
return top_sampler;
}
} // namespace dataset
} // namespace mindspore

View File

@ -200,24 +200,6 @@ class ExecutionTree {
// @return Status - The error code return
Status PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op);
/// Adds an operator to the eoe operator stack during prepare phase.
/// \param op - The dataset op to work add to eoe stack
/// \return Status - The error code return
void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op);
/// Pops an operator from the eoe operator stack during prepare phase.
/// \return shared_ptr to the popped operator
std::shared_ptr<DatasetOp> PopFromEOEOpStack();
/// Adds a sampler to the sampler stack during prepare phase.
/// \param samplerop - The dataset op to work add to eoe stack
/// \return Status - The error code return
void AddToSamplerStack(std::shared_ptr<Sampler> sampler);
/// Pops an operator from the sampler stack during prepare phase.
/// \return shared_ptr to the popped operator
std::shared_ptr<Sampler> PopFromSamplerStack();
// Return the pointer to the TaskGroup
// @return raw pointer to the TaskGroup
TaskGroup *AllTasks() const { return tg_.get(); }
@ -248,8 +230,6 @@ class ExecutionTree {
TreeState tree_state_; // Tracking the current tree state
std::unique_ptr<Monitor> perf_monitor_; // Performance Monitor
std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager
std::stack<std::shared_ptr<DatasetOp>> eoe_stack_; // A stack used during prepare phase
std::stack<std::shared_ptr<Sampler>> sampler_stack_; // A stack used during prepare phase
};
} // namespace dataset
} // namespace mindspore

View File

@ -2,6 +2,9 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine-opt OBJECT
pass.cc
post/repeat_pass.cc
pre/cache_pass.cc
pre/cache_transform_pass.cc
pre/removal_nodes.cc
pre/removal_pass.cc
util/printer_pass.cc

View File

@ -16,6 +16,9 @@
#include "dataset/engine/opt/pass.h"
#include "dataset/engine/datasetops/batch_op.h"
#include "dataset/engine/datasetops/cache_op.h"
#include "dataset/engine/datasetops/cache_merge_op.h"
#include "dataset/engine/datasetops/cache_lookup_op.h"
#include "dataset/engine/datasetops/dataset_op.h"
#include "dataset/engine/datasetops/device_queue_op.h"
#include "dataset/engine/datasetops/map_op.h"
@ -24,8 +27,15 @@
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/datasetops/skip_op.h"
#include "dataset/engine/datasetops/shuffle_op.h"
#include "dataset/engine/datasetops/source/celeba_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/coco_op.h"
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/mindrecord_op.h"
#include "dataset/engine/datasetops/source/mnist_op.h"
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#ifdef ENABLE_PYTHON
#include "dataset/engine/datasetops/filter_op.h"
#include "dataset/engine/datasetops/source/generator_op.h"
@ -145,6 +155,11 @@ Status NodePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
}
#endif
Status NodePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
@ -164,5 +179,70 @@ Status NodePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified)
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -47,6 +47,10 @@ class FilterOp;
class GeneratorOp;
#endif
class RandomDataOp;
class RepeatOp;
class TakeOp;
class ZipOp;
@ -55,6 +59,24 @@ class DeviceQueueOp;
class ImageFolderOp;
class CacheOp;
class MnistOp;
class ManifestOp;
class CifarOp;
class VOCOp;
class CocoOp;
class CelebAOp;
class CacheMergeOp;
class CacheLookupOp;
// The base class Pass is the basic unit of tree transformation.
// The actual implementation of the passes will be derived from here.
class Pass : public std::enable_shared_from_this<Pass> {
@ -138,14 +160,42 @@ class NodePass : public Pass {
virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified);
#endif
virtual Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
private:
// Helper function to perform DFS visit
Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified);

View File

@ -0,0 +1,161 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "dataset/engine/opt/post/repeat_pass.h"
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/datasetops/cache_op.h"
#include "dataset/engine/datasetops/cache_lookup_op.h"
#include "dataset/engine/datasetops/cache_merge_op.h"
namespace mindspore {
namespace dataset {
RepeatPass::RepeatPass() : is_repeated_(false), nested_repeats_(0), is_merge_(false), cache_lookup_(nullptr) {}
// Identifies the subtree below this node as being in a repeated path of the tree.
Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
// If we are already repeated, then this is a nested repeat.
if (is_repeated_) {
nested_repeats_++;
}
is_repeated_ = true;
return Status::OK();
}
// Identifies the subtree below this node as being in a cache merge path
Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
// Turn on the flag that we're under a merge op
is_merge_ = true;
return Status::OK();
}
// Hooks up any identified eoe nodes under this repeat.
Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
// Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking
std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack();
while (leaf_op != nullptr) {
node->AddToEoeList(leaf_op);
leaf_op = PopFromEOEOpStack();
}
// We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up
// and add it to the list of eoe/leaf ops for the repeat, removing it from the save area.
if (is_merge_ && cache_lookup_) {
cache_lookup_->set_control_flag(DatasetOp::kDeOpRepeated);
node->AddToEoeList(std::move(cache_lookup_));
}
// If we are a nested repeat, then we add ourself to the repeat stack for the next one above us.
// A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree.
if (nested_repeats_ > 0) {
node->set_control_flag(DatasetOp::kDeOpRepeated);
AddToEOEOpStack(node);
nested_repeats_--;
}
// If we are not nested, or we were the top-most repeat, now we clear the flag
if (nested_repeats_ == 0) {
is_repeated_ = false;
}
return Status::OK();
}
// CacheOp removes previous leaf ops and replaces them with itself
Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
if (is_repeated_) {
node->set_control_flag(DatasetOp::kDeOpRepeated);
// if we are a cache within a repeat path of the tree, then there will be
// eoe-generating ops in the eoe op stack in the tree. They are flagged as such so that the
// repeat or epoch ctrl operators can work with them for repeat activity during runtime.
// However, since a cache is present:
// - unflag those ops as being repeated ops
// - remove them from the eoe op stack so that repeat op above in the tree won't know about them
// - add ourself (the cache op), as an eoe op
// We do this so that those old leafs become 1-time use (up to eoe), never repeated. Instead
// the repeating behaviours shall be invoked against the cache op.
std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack();
while (leaf_op != nullptr) {
leaf_op->ClearControlFlag(DatasetOp::kDeOpLastRepeat);
leaf_op->ClearControlFlag(DatasetOp::kDeOpRepeated);
leaf_op = PopFromEOEOpStack();
}
AddToEOEOpStack(std::static_pointer_cast<DatasetOp>(node));
}
return Status::OK();
}
// All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
// for use with a controlling repeat above it.
Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) {
// If we are in a repeat path, then set our repeated flag
if (is_repeated_) {
node->set_control_flag(DatasetOp::kDeOpRepeated);
// if we are a leaf node then save ourself in a stack for the repeat operator above us
if (node->IsLeaf()) {
AddToEOEOpStack(node);
}
}
return Status::OK();
}
// Turns off the tracking for operations under merge op
Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
// Setting the flag is needed since we didn't call the base class DatasetOp version
if (is_repeated_) node->set_control_flag(DatasetOp::kDeOpRepeated);
is_merge_ = false;
cache_lookup_.reset(); // If a repeat op did not consume this then it's no longer needed
return Status::OK();
}
// Saves the lookup up in case it needs to be referenced by a repeat
Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) {
if (!node->IsLeaf()) {
// By definition, the CacheLookup must be a leaf op. Make that clear here.
RETURN_STATUS_UNEXPECTED("CacheLookupOp must be a leaf node!");
}
// If we are in a repeat path already, then there must be a repeat above the merge op
// In this case, we naturally are a repeating leaf op so add the required setup for leafs under repeat here.
if (is_repeated_) {
node->set_control_flag(DatasetOp::kDeOpRepeated);
AddToEOEOpStack(node);
} else {
// save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we
// may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself
// into the pass so that the decision can be made during the processing of the cache miss leg of the merge.
cache_lookup_ = std::static_pointer_cast<DatasetOp>(node);
}
return Status::OK();
}
// Adds an operator to the eoe operator stack save area
void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { eoe_stack_.push(dataset_op); }
// Pops an operator from the eoe operator stack save area
std::shared_ptr<DatasetOp> RepeatPass::PopFromEOEOpStack() {
std::shared_ptr<DatasetOp> top_op = nullptr;
if (!eoe_stack_.empty()) {
top_op = eoe_stack_.top();
eoe_stack_.pop();
}
return top_op;
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,98 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_
#define DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_
#include <memory>
#include <stack>
#include <utility>
#include "dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
/// \class RepeatPass repeat_pass.h
/// \brief This is a NodePass who's job is to perform setup actions for RepeatOps. A RepeatOp needs to have references
/// to the eoe-producing (typically leaf) nodes underneath it.
class RepeatPass : public NodePass {
public:
/// \brief Constructor
RepeatPass();
/// \brief Identifies the subtree below this node as being in a repeated path of the tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
/// \brief Identifies the subtree below this node as being in a cache merge path
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override;
/// \brief Hooks up any identified eoe nodes under this repeat.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
/// \brief CacheOp removes previous leaf ops and replaces them with itself
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Turns of the tracking for operations under merge op
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override;
/// \brief Saves the lookup up in case it needs to be referenced by a repeat
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) override;
/// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
/// for use with a controlling repeat above it.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override;
private:
/// \brief Adds an operator to the eoe operator stack save area
/// \param op - The dataset op to work add to eoe stack
/// \return Status - The error code return
void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op);
/// \brief Pops an operator from the eoe operator stack save area
/// \return shared_ptr to the popped operator
std::shared_ptr<DatasetOp> PopFromEOEOpStack();
bool is_repeated_; // T/F if we are processing under a repeat
bool is_merge_; // T/F if we are processing under a cache merge op
int32_t nested_repeats_; // A counter for nested repeats
std::stack<std::shared_ptr<DatasetOp>> eoe_stack_; // A save area for leaf/eoe ops
std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_

View File

@ -0,0 +1,181 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "dataset/engine/opt/pre/cache_pass.h"
#include "dataset/engine/opt/pre/cache_transform_pass.h"
#include "dataset/engine/datasetops/cache_op.h"
#include "dataset/engine/datasetops/source/celeba_op.h"
#include "dataset/engine/datasetops/source/generator_op.h"
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/mnist_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/coco_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/source/mindrecord_op.h"
namespace mindspore {
namespace dataset {
// Constructor
CachePass::CachePass(CacheTransformPass *transform_pass)
: transform_pass_(transform_pass), is_caching_(false), leaf_op_(nullptr) {}
// Identifies the subtree below this node as a cached descendant tree.
Status CachePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
*modified = false;
MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!");
}
is_caching_ = true;
return Status::OK();
}
// Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache
// transformation
Status CachePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
*modified = false;
is_caching_ = false; // We a no longer in a cache subtree. clear the flag.
if (leaf_op_) {
MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache.";
// Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op,
// using base class pointers.
transform_pass_->AddMappableCacheOperators(std::move(leaf_op_), node);
} else {
// If there was no leaf_op set, then this is a non-mappable scenario.
if (sampler_) {
// Grab the sampler that was saved from the leaf and plug it into the cache op
node->SetSampler(std::move(sampler_));
MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf.";
} else {
// We're a cache op but no sampler was saved from leaf, so create a default sampler
int64_t num_samples = 0;
int64_t start_index = 0;
sampler_ = std::make_shared<SequentialSampler>(num_samples, start_index);
node->SetSampler(std::move(sampler_));
MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op.";
}
// Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache
uint32_t cache_crc = DatasetOp::GenerateCRC(node);
RETURN_IF_NOT_OK(node->CreateCache(cache_crc));
}
return Status::OK();
}
// Common code for mappable leaf setup.
Status CachePass::MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) {
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
if (is_caching_ && leaf_op_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache.");
}
// If we are a leaf in the caching path, then save this leaf.
if (is_caching_) {
MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected";
leaf_op_ = std::move(leaf_op);
}
return Status::OK();
}
// Common code for non mappable leaf setup.
Status CachePass::NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) {
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
if (is_caching_ && leaf_op_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache.");
}
// Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf
// as save it for use by cache op in ascendant tree.
if (is_caching_) {
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_));
MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected";
} else {
// If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can
// remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based)
std::shared_ptr<Sampler> sampler_from_leaf;
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf));
}
return Status::OK();
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) {
if (is_caching_) {
// If we are a TF Reader in a caching tree, then change our config so that it becomes a basic
// TF reader that parses all files. Selection of data will come from the sampler on the cache instead.
node->MakeSimpleProducer();
}
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) {
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,138 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_
#define DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_
#include <memory>
#include <string>
#include <utility>
#include "dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
class CacheTransformPass;
/// \class CachePass cache_pass.h
/// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache
/// transformation. It works in conjunction with the CacheTransformPass
class CachePass : public NodePass {
public:
/// \brief Constructor
/// \param[in] transform_pass Raw pointer back to controlling tree pass
explicit CachePass(CacheTransformPass *transform_pass);
/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache
/// transformation
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override;
private:
/// \brief Common code for mappable leaf setup.
/// \param[in] node The leaf node performing setup work.
/// \return Status The error code return
Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);
/// \brief Common code for non-mappable leaf setup.
/// \param[in] node The leaf node performing setup work.
/// \return Status The error code return
Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);
bool is_caching_;
std::shared_ptr<DatasetOp> leaf_op_;
std::shared_ptr<Sampler> sampler_;
CacheTransformPass *transform_pass_; // Back pointer to the owning transform pass
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_

View File

@ -0,0 +1,108 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "dataset/engine/opt/pre/cache_pass.h"
#include "dataset/engine/opt/pre/cache_transform_pass.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/datasetops/cache_lookup_op.h"
#include "dataset/engine/datasetops/cache_merge_op.h"
#include "dataset/engine/datasetops/cache_op.h"
namespace mindspore {
namespace dataset {
// constructor
CacheTransformPass::CacheTransformPass() {}
// Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations
Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *modified) {
MS_LOG(INFO) << "Pre pass: Cache transform pass started.";
// Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will
// use to execute a transform.
std::unique_ptr<Pass> cache_pass = std::make_unique<CachePass>(this);
RETURN_IF_NOT_OK(cache_pass->Run(tree, modified));
// Then, execute the transform for each pair
for (auto cache_pair : cache_pairs_) {
MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform.";
ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client());
}
MS_LOG(INFO) << "Pre pass: Cache transform pass complete.";
return Status::OK();
}
// Helper function to execute the cache transformation.
Status CacheTransformPass::ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op,
std::shared_ptr<DatasetOp> cache_op,
std::shared_ptr<CacheClient> cache_client) {
// Get local pointers the child/parent of the cache op. It's possible that the parent is null if the cache was
// the root node. It is also possible that cache_child == leaf_op
std::shared_ptr<DatasetOp> cache_child = cache_op->child(0);
DatasetOp *cache_parent = nullptr;
cache_op->Parent(&cache_parent, 0); // fetch the cache op's parent
// Extract the sampler from the leaf. We will overwrite this sampler with the lookup op later.
std::shared_ptr<Sampler> leaf_sampler = leaf_op->sampler();
// Construct the merge op with defaults
std::shared_ptr<CacheMergeOp> merge_op;
CacheMergeOp::Builder merge_builder;
RETURN_IF_NOT_OK(merge_builder.SetClient(cache_client).Build(&merge_op));
RETURN_IF_NOT_OK(tree->AssociateNode(merge_op));
// Construct the cache lookup op with defaults
std::shared_ptr<CacheLookupOp> cache_lookup_op;
CacheLookupOp::Builder lookup_builder;
RETURN_IF_NOT_OK(lookup_builder.SetClient(cache_client).SetSampler(std::move(leaf_sampler)).Build(&cache_lookup_op));
RETURN_IF_NOT_OK(tree->AssociateNode(cache_lookup_op));
// Overwrite the old sampler in this leaf op to become the lookup op
leaf_op->SetSampler(cache_lookup_op);
// If the cache had a parent, then go into that parent to remove the cache from it's child list and then
// replace it with the merge op.
if (cache_parent != nullptr) {
RETURN_IF_NOT_OK(cache_parent->RemoveChild(cache_op));
RETURN_IF_NOT_OK(cache_parent->AddChild(merge_op));
} else {
// If we didn't have a parent, then the merge op is the root node
RETURN_IF_NOT_OK(tree->AssignRoot(merge_op));
}
// Set the cache op to no longer be a parent over it's child. This will fully disconnect the old cache op.
// We maintain a local pointer to the old child though.
RETURN_IF_NOT_OK(cache_op->RemoveChild(cache_child));
// Connect the merge op
RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_lookup_op)));
RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_child)));
// At this point, the cache op has already had it's children and parents taken away. Calling remove
// on it at this point will not do any node hookups, and instead set internal fields to invalid.
RETURN_IF_NOT_OK(cache_op->Remove());
return Status::OK();
}
// Assigns the leaf and cache operators that are involved in a cache transformation
void CacheTransformPass::AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op,
std::shared_ptr<CacheOp> cache_op) {
cache_pairs_.push_back(std::make_pair(leaf_op, cache_op));
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,79 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_
#define DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_
#include <memory>
#include <utility>
#include <vector>
#include "dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
class DatasetOp;
class CacheClient;
/// \class CacheTransformPass cache_transform_pass.h
/// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching
/// operations
class CacheTransformPass : public TreePass {
public:
/// \brief Constructor
CacheTransformPass();
/// \brief Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
/// \brief Assigns the leaf and cache operators that are involved in a cache transformation
/// \param[in] leaf_op The leaf operator involved in the cache transform
/// \param[in] cache_op The cache operator involved in the cache transform
void AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, std::shared_ptr<CacheOp> cache_op);
private:
/// \brief Helper function to execute the cache transformation.
///
/// Input:
/// Sampler
/// |
/// LeafOp --> OtherOps --> CacheOp
///
/// Transformed:
/// Sampler --> CacheLookupOp ---------------->
/// | |
/// | MergeOp
/// | |
/// LeafOp --> OtherOps -->
///
/// \param[in] leaf_op The leaf node in the transform
/// \param[in] cache_op The cache op in the transform (will get removed)
/// \param[in] cache_client The cache client
/// \return Status The error code return
Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op,
std::shared_ptr<DatasetOp> cache_op, std::shared_ptr<CacheClient> cache_client);
// The two operators that work together to establish the cache transform
std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_

View File

@ -24,12 +24,28 @@ namespace dataset {
RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {}
// Identifies the subtree below this node as a cached descendant tree.
Status RemovalNodes::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
*modified = false;
MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree.";
is_caching_ = true;
return Status::OK();
}
// Resets the tracking of the cache within the tree
Status RemovalNodes::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
*modified = false;
MS_LOG(INFO) << "Removal pass: cache descendant tree complete.";
is_caching_ = false;
return Status::OK();
}
// Perform ShuffleOp removal check.
Status RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
*modified = false;
// If we are in a cache descendant tree, then this shuffle op needs to be removed
if (is_caching_) {
MS_LOG(DEBUG) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)";
MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)";
if (removal_pass_) {
removal_pass_->AddToRemovalList(std::static_pointer_cast<DatasetOp>(node));
} else {

View File

@ -34,6 +34,18 @@ class RemovalNodes : public NodePass {
/// \param[in] removal_pass Raw pointer back to controlling tree pass
explicit RemovalNodes(RemovalPass *removal_pass);
/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Resets the tracking of the cache within the tree
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Perform ShuffleOp removal check
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all

View File

@ -28,6 +28,7 @@ RemovalPass::RemovalPass() {}
// Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) {
MS_LOG(INFO) << "Pre pass: removal pass started.";
// Create the removal node pass which can identify which nodes need to be removed.
std::unique_ptr<Pass> removal_nodes = std::make_unique<RemovalNodes>(this);
RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified));
@ -36,6 +37,7 @@ Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) {
for (auto node : removal_nodes_) {
node->Remove();
}
MS_LOG(INFO) << "Pre pass: removal pass complete.";
return Status::OK();
}

View File

@ -87,8 +87,9 @@ class Allocator {
std::shared_ptr<MemoryPool> pool_;
};
/// \brief It is a wrapper of unique_ptr with a custom allocator and acts like std::lock_guard such that the memory will
/// be released when the object goes out of scope \tparam T The type of object to be allocated \tparam C Allocator.
/// Default to std::allocator
/// be released when the object goes out of scope
/// \tparam T The type of object to be allocated
/// \tparam C Allocator. Default to std::allocator
template <typename T, typename C = std::allocator<T>>
class MemGuard {
public:
@ -168,7 +169,7 @@ class MemGuard {
private:
allocator alloc_;
std::unique_ptr<T[], std::function<void(T *)>> ptr_;
std::unique_ptr<T[]> ptr_;
size_t n_;
};
} // namespace dataset

View File

@ -98,11 +98,6 @@ Status CachePool::Insert(const std::vector<ReadableSlice> &buf, CachePool::key_t
} catch (std::bad_alloc &e) {
if (sm_ != nullptr) {
RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf));
// We have an assumption 0 is not a valid key from the design of AutoIndexObj.
// Make sure it is not 0.
if (bl.storage_key == 0) {
RETURN_STATUS_UNEXPECTED("Key 0 is returned which is unexpected");
}
} else {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}

View File

@ -22,11 +22,11 @@
#include <stdlib.h>
#endif
#include <unistd.h>
#include "dataset/engine/cache/cache_server.h"
#include "dataset/util/circular_pool.h"
#include "dataset/util/random.h"
#include "dataset/util/task_manager.h"
#define SLOT_TASK_MGR 0
namespace mindspore {
namespace dataset {
std::unique_ptr<Services> Services::instance_ = nullptr;
@ -61,15 +61,25 @@ std::string Services::GetUniqueID() {
TaskManager &Services::getTaskMgrInstance() {
Services &sm = GetInstance();
return *(static_cast<TaskManager *>(sm.sa_[SLOT_TASK_MGR]));
return *(static_cast<TaskManager *>(sm.sa_[kSlotTaskMgr_]));
}
CacheServer &Services::getCacheServer() {
Services &sm = GetInstance();
return *(static_cast<CacheServer *>(sm.sa_[kSlotCacheMgr_]));
}
Status Services::CreateAllInstances() {
// In order, TaskMgr, BufferMgr
Status rc;
sa_[SLOT_TASK_MGR] = new (&rc, pool_) TaskManager();
sa_[kSlotTaskMgr_] = new (&rc, pool_) TaskManager();
RETURN_IF_NOT_OK(rc);
rc = sa_[SLOT_TASK_MGR]->ServiceStart();
rc = sa_[kSlotTaskMgr_]->ServiceStart();
RETURN_IF_NOT_OK(rc);
// TODO(jesse) : Get the parameters from config file. Right now spill to /tmp and spawn 3 workers
sa_[kSlotCacheMgr_] = new (&rc, pool_) CacheServer("/tmp", 3);
RETURN_IF_NOT_OK(rc);
rc = sa_[kSlotCacheMgr_]->ServiceStart();
return rc;
}
@ -83,8 +93,14 @@ Services::Services() : pool_(nullptr), sa_{nullptr} {
Services::~Services() noexcept {
try {
// In reverse order
TaskManager *tm = static_cast<TaskManager *>(sa_[SLOT_TASK_MGR]);
if (tm) {
CacheServer *cs = static_cast<CacheServer *>(sa_[kSlotCacheMgr_]);
if (cs != nullptr) {
(void)cs->ServiceStop();
cs->~CacheServer();
pool_->Deallocate(cs);
}
TaskManager *tm = static_cast<TaskManager *>(sa_[kSlotTaskMgr_]);
if (tm != nullptr) {
(void)tm->ServiceStop();
tm->~TaskManager();
pool_->Deallocate(tm);

View File

@ -27,7 +27,7 @@
namespace mindspore {
namespace dataset {
class TaskManager;
class CacheServer;
class Services {
public:
static Status CreateInstance() {
@ -61,6 +61,8 @@ class Services {
static TaskManager &getTaskMgrInstance();
static CacheServer &getCacheServer();
std::shared_ptr<MemoryPool> GetServiceMemPool() { return pool_; }
#if !defined(_WIN32) && !defined(_WIN64)
@ -87,7 +89,9 @@ class Services {
// We use pointers here instead of unique_ptr because we
// want to have ultimate control on the order of
// construction and destruction.
static constexpr int kNumServices_ = 1;
static constexpr int kSlotTaskMgr_ = 0;
static constexpr int kSlotCacheMgr_ = 1;
static constexpr int kNumServices_ = 2;
Service *sa_[kNumServices_];
Services();

View File

@ -24,6 +24,7 @@ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset
TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
WeightedRandomSampler, Sampler
from .engine.cache_client import DatasetCache
from .engine.serializer_deserializer import serialize, deserialize, show
from .engine.graphdata import GraphData

View File

@ -0,0 +1,49 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Cache client
"""
import copy
from mindspore._c_dataengine import CacheClient
class DatasetCache:
"""
A client to interface with tensor caching service
"""
def __init__(self, session_id=None, size=None, spilling=False):
if session_id is None:
raise RuntimeError("Session generation is not implemented yet. session id required")
self.size = size if size is not None else 0
if size < 0:
raise ValueError("cache size should be 0 or positive integer value but got: size={}".format(size))
if not isinstance(spilling, bool):
raise ValueError(
"spilling argument for cache should be a boolean value but got: spilling={}".format(spilling))
self.session_id = session_id
self.spilling = spilling
self.cache_client = CacheClient(session_id, size, spilling)
def __deepcopy__(self, memodict):
if id(self) in memodict:
return memodict[id(self)]
cls = self.__class__
new_cache = cls.__new__(cls)
memodict[id(self)] = new_cache
new_cache.session_id = copy.deepcopy(self.session_id, memodict)
new_cache.spilling = copy.deepcopy(self.spilling, memodict)
new_cache.size = copy.deepcopy(self.size, memodict)
new_cache.cache_client = self.cache_client
return new_cache

View File

@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
try:
@ -386,7 +386,7 @@ class Dataset:
@check_map
def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None,
num_parallel_workers=None, python_multiprocessing=False):
num_parallel_workers=None, python_multiprocessing=False, cache=None):
"""
Apply each operation in operations to this dataset.
@ -427,6 +427,7 @@ class Dataset:
parallel (default=None, the value from the config will be used).
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
option could be beneficial if the python operation is computational heavy (default=False).
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used)
Returns:
MapDataset, dataset after mapping operation.
@ -541,7 +542,7 @@ class Dataset:
>>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order)
"""
return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers,
python_multiprocessing)
python_multiprocessing, cache)
@check_filter
def filter(self, predicate, input_columns=None, num_parallel_workers=1):
@ -1868,13 +1869,14 @@ class MapDataset(DatasetOp):
in parallel (default=None).
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
option could be beneficial if the python operation is computational heavy (default=False).
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used)
Raises:
ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified.
"""
def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None,
num_parallel_workers=None, python_multiprocessing=False):
num_parallel_workers=None, python_multiprocessing=False, cache=None):
super().__init__(num_parallel_workers)
self.children.append(input_dataset)
if input_columns is not None and not isinstance(input_columns, list):
@ -1886,6 +1888,7 @@ class MapDataset(DatasetOp):
if output_columns is not None and not isinstance(output_columns, list):
output_columns = [output_columns]
self.output_columns = output_columns
self.cache = cache
self.columns_order = columns_order
if self.input_columns and self.output_columns \
@ -1904,6 +1907,7 @@ class MapDataset(DatasetOp):
args["operations"] = self.operations
args["output_columns"] = self.output_columns
args["columns_order"] = self.columns_order
args["cache"] = self.cache.cache_client if self.cache is not None else None
return args
def get_dataset_size(self):
@ -1929,6 +1933,7 @@ class MapDataset(DatasetOp):
new_op.parent = copy.deepcopy(self.parent, memodict)
new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict)
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict)
new_op.cache = copy.deepcopy(self.cache, memodict)
new_op.operations = self.operations
return new_op
@ -2346,7 +2351,7 @@ class RangeDataset(MappableDataset):
return False
def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, non_mappable=False):
"""
Create sampler based on user input.
@ -2356,7 +2361,11 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
shuffle (bool): Shuffle.
num_shards (int): Number of shard for sharding.
shard_id (int): Shard ID.
non_mappable (bool, optional): Indicate if caller is non-mappable dataset for special handling (default=False).
"""
if non_mappable is True and all(arg is None for arg in [num_samples, shuffle, num_shards, shard_id, input_sampler]):
return None
if input_sampler is not None:
# If the user provided a sampler, then it doesn't matter what the other args are because
# we are being asked specifically to use the given sampler.
@ -2369,7 +2378,7 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
if (isinstance(input_sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
samplers.RandomSampler, samplers.SubsetRandomSampler,
samplers.WeightedRandomSampler, samplers.Sampler)) and
(num_shards is not None or shard_id is not None or shuffle is not None or num_samples is not None)):
(any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))):
raise ValueError(
'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'
' shard_id: {}, shuffle: {})'.format(num_samples, num_shards, shard_id, shuffle))
@ -2458,6 +2467,7 @@ class ImageFolderDatasetV2(MappableDataset):
into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument should be specified only when num_shards is also specified.
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used)
Raises:
RuntimeError: If sampler and shuffle are specified at the same time.
@ -2482,7 +2492,7 @@ class ImageFolderDatasetV2(MappableDataset):
@check_imagefolderdatasetv2
def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
shuffle=None, sampler=None, extensions=None, class_indexing=None,
decode=False, num_shards=None, shard_id=None):
decode=False, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers)
self.dataset_dir = dataset_dir
@ -2494,6 +2504,7 @@ class ImageFolderDatasetV2(MappableDataset):
self.decode = decode
self.num_shards = num_shards
self.shard_id = shard_id
self.cache = cache
def get_args(self):
args = super().get_args()
@ -2506,6 +2517,7 @@ class ImageFolderDatasetV2(MappableDataset):
args["decode"] = self.decode
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
args["cache"] = self.cache.cache_client if self.cache is not None else None
return args
def get_dataset_size(self):
@ -3251,6 +3263,7 @@ class TFRecordDataset(SourceDataset):
argument should be specified only when num_shards is also specified.
shard_equal_rows (bool): Get equal rows for all shards(default=False). If shard_equal_rows is false, number
of rows of each shard may be not equal.
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used)
Examples:
>>> import mindspore.dataset as ds
>>> import mindspore.common.dtype as mstype
@ -3268,7 +3281,7 @@ class TFRecordDataset(SourceDataset):
@check_tfrecorddataset
def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None,
shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False):
shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False, cache=None):
super().__init__(num_parallel_workers)
self.dataset_files = self._find_files(dataset_files)
self.dataset_files.sort()
@ -3280,6 +3293,7 @@ class TFRecordDataset(SourceDataset):
self.schema = schema
self.columns_list = columns_list
self.num_samples = num_samples
self.cache = cache
if schema_obj is not None and num_samples is None:
self.num_samples = schema_obj.num_rows
@ -3295,6 +3309,14 @@ class TFRecordDataset(SourceDataset):
else:
self.shuffle_level = shuffle
self.shuffle_files = True
# The TF record dataset does not directly support a sampler. It has provided sampling arguments
# (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in
# the pipeline contains a cache. If there is no cache above it, then this sampler is not used.
sampler_shuffle = self.shuffle_files
sampler = None
self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id,
non_mappable=True)
self.shard_equal_rows = shard_equal_rows
def get_args(self):
@ -3318,6 +3340,8 @@ class TFRecordDataset(SourceDataset):
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
args["shard_equal_rows"] = self.shard_equal_rows
args["cache"] = self.cache.cache_client if self.cache is not None else None
args["sampler"] = self.sampler
return args
def get_dataset_size(self, estimate=False):
@ -3803,43 +3827,61 @@ class RandomDataset(SourceDataset):
A source dataset that generates random data.
Args:
num_samples (int): number of samples to generate.
total_rows (int): number of rows for the dataset to generate (default=None, number of rows is random)
schema (str or Schema, optional): Path to the json schema file or schema object (default=None).
If the schema is not provided, the random dataset generates a random schema.
columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
num_samples (int): number of samples to draw from the total. (default=None, which means all rows)
num_parallel_workers (int, optional): number of workers to read the data
(default=None, number set in the config).
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used)
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
(default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset should be divided
into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument should be specified only when num_shards is also specified.
"""
def __init__(self, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None):
@check_random_dataset
def __init__(self, total_rows=None, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None,
cache=None, shuffle=None, num_shards=None, shard_id=None):
super().__init__(num_parallel_workers)
schema_obj = None
if (schema is not None) and (not isinstance(schema, Schema)):
schema_obj = Schema(schema) # read the schema file and convert to schema object to validate it
self.schema = schema
self.columns_list = columns_list
if schema_obj is not None and num_samples is None:
self.num_samples = schema_obj.num_rows
elif num_samples is None:
self.num_samples = 0
sampler = None
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id, non_mappable=True)
self.num_samples = num_samples
self.cache = cache
if schema_obj is not None and total_rows is None:
self.total_rows = schema_obj.num_rows
elif total_rows is None:
self.total_rows = 0
else:
self.num_samples = num_samples
self.total_rows = total_rows
self.num_shards = num_shards
self.shard_id = shard_id
self.shuffle_level = shuffle
def get_args(self):
args = super().get_args()
if self.schema is not None:
if isinstance(self.schema, Schema):
self.schema.datasetType = 'Random'
if self.num_samples is not None:
self.schema.num_rows = self.num_samples
if self.total_rows is not None:
self.schema.num_rows = self.total_rows
args["schema_json_string"] = self.schema.to_json()
else:
args["schema_file_path"] = self.schema
args["schema"] = self.schema
if self.columns_list is not None:
args["columns_list"] = self.columns_list
if self.num_samples is not None:
args["num_samples"] = self.num_samples
args["columns_list"] = self.columns_list
args["num_samples"] = self.num_samples
args["total_rows"] = self.total_rows
args["cache"] = self.cache.cache_client if self.cache is not None else None
args["sampler"] = self.sampler
return args
def get_dataset_size(self):
@ -3849,18 +3891,29 @@ class RandomDataset(SourceDataset):
Return:
Number, number of batches.
"""
num_rows = CifarOp.get_num_rows(self.dataset_dir, True)
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return self.num_samples
return rows_per_shard
return min(rows_from_sampler, self.num_samples)
return min(rows_from_sampler, rows_per_shard)
def is_shuffled(self):
return True
if self.shuffle_level is None:
return True
return self.shuffle_level or self.sampler.is_shuffled()
def is_sharded(self):
return False
if self.num_shards is not None:
return self.num_shards > 1
return self.sampler.is_sharded()
class Schema:

View File

@ -173,7 +173,9 @@ def traverse(node):
# num_samples, shard_id, num_shards, shuffle
# These arguments get moved into the sampler itself, so they are no longer needed to
# be set at the dataset level.
if 'sampler' in node_args.keys():
# TF Record is a special case because it uses both the dataset and sampler arguments
# which is not decided until later during tree preparation phase.
if node_repr['op_type'] != 'TFRecordDataset' and 'sampler' in node_args.keys():
if 'num_samples' in node_repr.keys():
node_repr['num_samples'] = None
if 'shuffle' in node_repr.keys():

View File

@ -29,10 +29,11 @@ from ..core.validator_helpers import parse_user_args, type_check, type_check_lis
from . import datasets
from . import samplers
from . import cache_client
def check_imagefolderdatasetv2(method):
"""A wrapper that wrap a parameter checker to the original Dataset(ImageFolderDatasetV2)."""
"""A wrapper that wraps a parameter checker to the original Dataset(ImageFolderDatasetV2)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -58,7 +59,7 @@ def check_imagefolderdatasetv2(method):
def check_mnist_cifar_dataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset)."""
"""A wrapper that wraps a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -81,7 +82,7 @@ def check_mnist_cifar_dataset(method):
def check_manifestdataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset)."""
"""A wrapper that wraps a parameter checker to the original Dataset(ManifestDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -108,7 +109,7 @@ def check_manifestdataset(method):
def check_tfrecorddataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(TFRecordDataset)."""
"""A wrapper that wraps a parameter checker to the original Dataset(TFRecordDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -134,7 +135,7 @@ def check_tfrecorddataset(method):
def check_vocdataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(VOCDataset)."""
"""A wrapper that wraps a parameter checker to the original Dataset(VOCDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -175,7 +176,7 @@ def check_vocdataset(method):
def check_cocodataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(CocoDataset)."""
"""A wrapper that wraps a parameter checker to the original Dataset(CocoDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -211,7 +212,7 @@ def check_cocodataset(method):
def check_celebadataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(CelebADataset)."""
"""A wrapper that wraps a parameter checker to the original Dataset(CelebADataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -247,7 +248,7 @@ def check_celebadataset(method):
def check_minddataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(MindDataset)."""
"""A wrapper that wraps a parameter checker to the original Dataset(MindDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -279,7 +280,7 @@ def check_minddataset(method):
def check_generatordataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(GeneratorDataset)."""
"""A wrapper that wraps a parameter checker to the original Dataset(GeneratorDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -344,6 +345,27 @@ def check_generatordataset(method):
return new_method
def check_random_dataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(RandomDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id', 'total_rows']
nreq_param_bool = ['shuffle']
nreq_param_list = ['columns_list']
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
validate_dataset_param_value(nreq_param_list, param_dict, list)
check_sampler_shuffle_shard_options(param_dict)
return method(self, *args, **kwargs)
return new_method
def check_pad_info(key, val):
"""check the key and value pair of pad_info in batch"""
@ -506,7 +528,7 @@ def check_map(method):
@wraps(method)
def new_method(self, *args, **kwargs):
[input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing], _ = \
[input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing, cache], _ = \
parse_user_args(method, *args, **kwargs)
nreq_param_columns = ['input_columns', 'output_columns']
@ -516,6 +538,8 @@ def check_map(method):
if num_parallel_workers is not None:
check_num_parallel_workers(num_parallel_workers)
type_check(python_multiprocessing, (bool,), "python_multiprocessing")
if cache is not None:
type_check(cache, (cache_client.DatasetCache,), "cache")
for param_name, param in zip(nreq_param_columns, [input_columns, output_columns]):
if param is not None:
@ -720,7 +744,7 @@ def check_add_column(method):
def check_cluedataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(CLUEDataset)."""
"""A wrapper that wraps a parameter checker to the original Dataset(CLUEDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -750,7 +774,7 @@ def check_cluedataset(method):
def check_textfiledataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset)."""
"""A wrapper that wraps a parameter checker to the original Dataset(TextFileDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -823,7 +847,7 @@ def check_gnn_graphdata(method):
def check_gnn_get_all_nodes(method):
"""A wrapper that wrap a parameter checker to the GNN `get_all_nodes` function."""
"""A wrapper that wraps a parameter checker to the GNN `get_all_nodes` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -836,7 +860,7 @@ def check_gnn_get_all_nodes(method):
def check_gnn_get_all_edges(method):
"""A wrapper that wrap a parameter checker to the GNN `get_all_edges` function."""
"""A wrapper that wraps a parameter checker to the GNN `get_all_edges` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -849,7 +873,7 @@ def check_gnn_get_all_edges(method):
def check_gnn_get_nodes_from_edges(method):
"""A wrapper that wrap a parameter checker to the GNN `get_nodes_from_edges` function."""
"""A wrapper that wraps a parameter checker to the GNN `get_nodes_from_edges` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -862,7 +886,7 @@ def check_gnn_get_nodes_from_edges(method):
def check_gnn_get_all_neighbors(method):
"""A wrapper that wrap a parameter checker to the GNN `get_all_neighbors` function."""
"""A wrapper that wraps a parameter checker to the GNN `get_all_neighbors` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -877,7 +901,7 @@ def check_gnn_get_all_neighbors(method):
def check_gnn_get_sampled_neighbors(method):
"""A wrapper that wrap a parameter checker to the GNN `get_sampled_neighbors` function."""
"""A wrapper that wraps a parameter checker to the GNN `get_sampled_neighbors` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -905,7 +929,7 @@ def check_gnn_get_sampled_neighbors(method):
def check_gnn_get_neg_sampled_neighbors(method):
"""A wrapper that wrap a parameter checker to the GNN `get_neg_sampled_neighbors` function."""
"""A wrapper that wraps a parameter checker to the GNN `get_neg_sampled_neighbors` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -921,7 +945,7 @@ def check_gnn_get_neg_sampled_neighbors(method):
def check_gnn_random_walk(method):
"""A wrapper that wrap a parameter checker to the GNN `random_walk` function."""
"""A wrapper that wraps a parameter checker to the GNN `random_walk` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -968,7 +992,7 @@ def check_aligned_list(param, param_name, member_type):
def check_gnn_get_node_feature(method):
"""A wrapper that wrap a parameter checker to the GNN `get_node_feature` function."""
"""A wrapper that wraps a parameter checker to the GNN `get_node_feature` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -1012,7 +1036,7 @@ def check_gnn_get_edge_feature(method):
def check_numpyslicesdataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(NumpySlicesDataset)."""
"""A wrapper that wraps a parameter checker to the original Dataset(NumpySlicesDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):

View File

@ -39,7 +39,7 @@ def check_unique_list_of_words(words, arg_name):
def check_lookup(method):
"""A wrapper that wrap a parameter checker to the original function."""
"""A wrapper that wraps a parameter checker to the original function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -56,7 +56,7 @@ def check_lookup(method):
def check_from_file(method):
"""A wrapper that wrap a parameter checker to the original function."""
"""A wrapper that wraps a parameter checker to the original function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -74,7 +74,7 @@ def check_from_file(method):
def check_from_list(method):
"""A wrapper that wrap a parameter checker to the original function."""
"""A wrapper that wraps a parameter checker to the original function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -97,7 +97,7 @@ def check_from_list(method):
def check_from_dict(method):
"""A wrapper that wrap a parameter checker to the original function."""
"""A wrapper that wraps a parameter checker to the original function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -285,7 +285,7 @@ def check_bert_tokenizer(method):
def check_from_dataset(method):
"""A wrapper that wrap a parameter checker to the original function."""
"""A wrapper that wraps a parameter checker to the original function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -328,7 +328,7 @@ def check_from_dataset(method):
def check_ngram(method):
"""A wrapper that wrap a parameter checker to the original function."""
"""A wrapper that wraps a parameter checker to the original function."""
@wraps(method)
def new_method(self, *args, **kwargs):

View File

@ -114,7 +114,7 @@ def check_erasing_value(value):
def check_crop(method):
"""A wrapper that wrap a parameter checker to the original function(crop operation)."""
"""A wrapper that wraps a parameter checker to the original function(crop operation)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -127,7 +127,7 @@ def check_crop(method):
def check_resize_interpolation(method):
"""A wrapper that wrap a parameter checker to the original function(resize interpolation operation)."""
"""A wrapper that wraps a parameter checker to the original function(resize interpolation operation)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -142,7 +142,7 @@ def check_resize_interpolation(method):
def check_resize(method):
"""A wrapper that wrap a parameter checker to the original function(resize operation)."""
"""A wrapper that wraps a parameter checker to the original function(resize operation)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -155,7 +155,7 @@ def check_resize(method):
def check_random_resize_crop(method):
"""A wrapper that wrap a parameter checker to the original function(random resize crop operation)."""
"""A wrapper that wraps a parameter checker to the original function(random resize crop operation)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -178,7 +178,7 @@ def check_random_resize_crop(method):
def check_prob(method):
"""A wrapper that wrap a parameter checker(check the probability) to the original function."""
"""A wrapper that wraps a parameter checker(check the probability) to the original function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -192,7 +192,7 @@ def check_prob(method):
def check_normalize_c(method):
"""A wrapper that wrap a parameter checker to the original function(normalize operation written in C++)."""
"""A wrapper that wraps a parameter checker to the original function(normalize operation written in C++)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -205,7 +205,7 @@ def check_normalize_c(method):
def check_normalize_py(method):
"""A wrapper that wrap a parameter checker to the original function(normalize operation written in Python)."""
"""A wrapper that wraps a parameter checker to the original function(normalize operation written in Python)."""
@wraps(method)
def new_method(self, *args, **kwargs):

View File

@ -738,7 +738,7 @@ TEST_F(MindDataTestPipeline, TestProjectMap) {
EXPECT_TRUE(ds != nullptr);
// Create a Project operation on ds
std::vector<std::string> column_project = {"label"};
std::vector<std::string> column_project = {"image"};
ds = ds->Project(column_project);
EXPECT_TRUE(ds != nullptr);

View File

@ -0,0 +1,579 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include "dataset/core/client.h"
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/datasetops/cache_op.h"
#include "dataset/engine/datasetops/cache_lookup_op.h"
#include "dataset/engine/datasetops/cache_merge_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "common/common.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
#include "dataset/util/storage_container.h" // lint !e322
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/data_schema.h"
using namespace mindspore::dataset;
using mindspore::LogStream;
using mindspore::dataset::CacheClient;
using mindspore::dataset::TaskGroup;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::MsLogLevel::INFO;
class MindDataTestCacheOp : public UT::DatasetOpTesting {
public:
void SetUp() override {
DatasetOpTesting::SetUp();
GlobalInit();
}
};
TEST_F(MindDataTestCacheOp, TestCacheServer) {
Status rc;
CacheClient myClient(1, 0, true); // use arbitrary session of 1, size of 0, spilling is true
// cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated.
rc = myClient.CreateCache(1, true);
EXPECT_TRUE(rc.IsOk());
std::cout << myClient << std::endl;
// Create a schema using the C api's
int32_t rank = 0; // not used
std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
// 2 columns. First column is an "image" 640,480,3
TensorShape c1Shape({640, 480, 3});
ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
rank, // not used
&c1Shape);
// Column 2 will just be a scalar label number
TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
testSchema->AddColumn(c1);
testSchema->AddColumn(c2);
std::unordered_map<std::string, int32_t> map;
rc = testSchema->GetColumnNameMap(&map);
EXPECT_TRUE(rc.IsOk());
// Test the CacheSchema api
rc = myClient.CacheSchema(map);
EXPECT_TRUE(rc.IsOk());
// Create a tensor, take a snapshot and restore it back, and compare.
std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2, 3}), DataType(DataType::DE_UINT64));
t->SetItemAt<uint64_t>({0, 0}, 1);
t->SetItemAt<uint64_t>({0, 1}, 2);
t->SetItemAt<uint64_t>({0, 2}, 3);
t->SetItemAt<uint64_t>({1, 0}, 4);
t->SetItemAt<uint64_t>({1, 1}, 5);
t->SetItemAt<uint64_t>({1, 2}, 6);
std::cout << *t << std::endl;
TensorTable tbl;
TensorRow row;
row.push_back(t);
int64_t row_id;
rc = myClient.WriteRow(row, &row_id);
EXPECT_TRUE(rc.IsOk());
// Switch off build phase.
rc = myClient.BuildPhaseDone();
EXPECT_TRUE(rc.IsOk());
// Now restore from cache.
row.clear();
rc = myClient.GetRows({row_id}, &tbl);
row = tbl.front();
EXPECT_TRUE(rc.IsOk());
auto r = row.front();
std::cout << *r << std::endl;
// Compare
bool cmp = (*t == *r);
EXPECT_TRUE(cmp);
// Get back the schema and verify
std::unordered_map<std::string, int32_t> map_out;
rc = myClient.FetchSchema(&map_out);
EXPECT_TRUE(rc.IsOk());
cmp = (map_out == map);
EXPECT_TRUE(cmp);
// Test Purge and Destroy
rc = myClient.PurgeCache();
EXPECT_TRUE(rc.IsOk());
rc = myClient.DestroyCache();
EXPECT_TRUE(rc.IsOk());
}
TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) {
// Clear the rc of the master thread if any
(void)TaskManager::GetMasterThreadRc();
TaskGroup vg;
Status rc;
CacheClient myClient(1, 1, true); // use arbitrary session of 1, size 1, spilling is true
// cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated.
rc = myClient.CreateCache(1, true);
EXPECT_TRUE(rc.IsOk());
std::cout << myClient << std::endl;
std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2, 3}), DataType(DataType::DE_UINT64));
t->SetItemAt<uint64_t>({0, 0}, 1);
t->SetItemAt<uint64_t>({0, 1}, 2);
t->SetItemAt<uint64_t>({0, 2}, 3);
t->SetItemAt<uint64_t>({1, 0}, 4);
t->SetItemAt<uint64_t>({1, 1}, 5);
t->SetItemAt<uint64_t>({1, 2}, 6);
TensorTable tbl;
TensorRow row;
row.push_back(t);
// Cache tensor row t 5000 times using 10 threads.
for (auto k = 0; k < 10; ++k) {
Status vg_rc = vg.CreateAsyncTask("Test agent", [&myClient, &row]() -> Status {
TaskManager::FindMe()->Post();
for (auto i = 0; i < 500; i++) {
RETURN_IF_NOT_OK(myClient.WriteRow(row));
}
return Status::OK();
});
EXPECT_TRUE(vg_rc.IsOk());
}
ASSERT_TRUE(vg.join_all().IsOk());
ASSERT_TRUE(vg.GetTaskErrorIfAny().IsOk());
rc = myClient.BuildPhaseDone();
ASSERT_TRUE(rc.IsOk());
// Get statistics from the server.
CacheClient::ServiceStat stat{};
rc = myClient.GetStat(&stat);
ASSERT_TRUE(rc.IsOk());
std::cout << stat.min_row_id << ":" << stat.max_row_id << ":" << stat.num_mem_cached << ":" << stat.num_disk_cached
<< "\n";
// Expect there are 5000 rows there.
EXPECT_EQ(5000, stat.max_row_id - stat.min_row_id + 1);
// Get them all back using row id and compare with tensor t.
for (auto i = stat.min_row_id; i <= stat.max_row_id; ++i) {
tbl.clear();
row.clear();
rc = myClient.GetRows({i}, &tbl);
EXPECT_TRUE(rc.IsOk());
row = tbl.front();
auto r = row.front();
bool cmp = (*t == *r);
EXPECT_TRUE(cmp);
}
rc = myClient.DestroyCache();
EXPECT_TRUE(rc.IsOk());
}
// Simple test with a repeated cache op over random data producer
//
// RepeatOp
// |
// CacheOp
// |
// RandomDataOp
//
TEST_F(MindDataTestCacheOp, TestRandomDataCache1) {
Status rc;
int32_t rank = 0; // not used
MS_LOG(INFO) << "UT test TestRandomDataCache1";
// Start with an empty execution tree
auto myTree = std::make_shared<ExecutionTree>();
// Create a schema using the C api's
std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
// 2 columns. First column is an "image" 640,480,3
TensorShape c1Shape({640, 480, 3});
ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
rank, // not used
&c1Shape);
// Column 2 will just be a scalar label number
TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
testSchema->AddColumn(c1);
testSchema->AddColumn(c2);
// RandomDataOp
std::shared_ptr<RandomDataOp> myRandomDataOp;
rc = RandomDataOp::Builder()
.SetRowsPerBuffer(4)
.SetNumWorkers(4)
.SetDataSchema(std::move(testSchema))
.SetTotalRows(50) // 50 samples for now
.Build(&myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
// CacheOp
// size of 0, spilling is true
std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 0, true);
std::shared_ptr<CacheOp> myCacheOp;
int64_t num_samples = 0;
int64_t start_index = 0;
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
rc = CacheOp::Builder()
.SetNumWorkers(5)
.SetClient(myClient)
.SetRowsPerBuffer(4)
.SetSampler(std::move(seq_sampler))
.Build(&myCacheOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myCacheOp);
EXPECT_TRUE(rc.IsOk());
// RepeatOp
uint32_t numRepeats = 4;
std::shared_ptr<RepeatOp> myRepeatOp;
rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
// Assign tree relations and root
rc = myRepeatOp->AddChild(myCacheOp);
EXPECT_TRUE(rc.IsOk());
rc = myCacheOp->AddChild(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssignRoot(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
MS_LOG(INFO) << "Launching tree and begin iteration";
rc = myTree->Prepare();
EXPECT_TRUE(rc.IsOk());
// quick check to see what tree looks like
std::ostringstream ss;
ss << *myTree; // some funny const error if I try to write directly to ms log stream
MS_LOG(INFO) << "Here's the tree:\n" << ss.str();
std::cout << *myClient << std::endl;
rc = myTree->Launch();
EXPECT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator dI(myTree);
TensorRow tensorList;
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
int rowCount = 0;
while (!tensorList.empty()) {
// Don't display these rows, just count them
MS_LOG(INFO) << "Row fetched #: " << rowCount;
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
rowCount++;
}
ASSERT_EQ(rowCount, 200);
rc = myClient->DestroyCache();
EXPECT_TRUE(rc.IsOk());
}
//// Simple test with a repeated cache op over random data producer.
//// This one will exceed memory and require a spill.
////
//// RepeatOp
//// |
//// CacheOp
//// |
//// RandomDataOp
////
TEST_F(MindDataTestCacheOp, TestRandomDataCacheSpill) {
Status rc;
int32_t rank = 0; // not used
MS_LOG(INFO) << "UT test TestRandomDataCacheSpill";
// Start with an empty execution tree
auto myTree = std::make_shared<ExecutionTree>();
// Create a schema using the C api's
std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
// 2 columns. First column is an "image" 640,480,3
TensorShape c1Shape({640, 480, 3});
ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
rank, // not used
&c1Shape);
// Column 2 will just be a scalar label number
TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
testSchema->AddColumn(c1);
testSchema->AddColumn(c2);
// RandomDataOp
std::shared_ptr<RandomDataOp> myRandomDataOp;
rc = RandomDataOp::Builder()
.SetRowsPerBuffer(2)
.SetNumWorkers(4)
.SetDataSchema(std::move(testSchema))
.SetTotalRows(10)
.Build(&myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
// CacheOp
int64_t num_samples = 0;
int64_t start_index = 0;
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 4, true);
std::shared_ptr<CacheOp> myCacheOp;
rc = CacheOp::Builder()
.SetNumWorkers(4)
.SetClient(myClient)
.SetRowsPerBuffer(3)
.SetSampler(std::move(seq_sampler))
.Build(&myCacheOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myCacheOp);
EXPECT_TRUE(rc.IsOk());
// RepeatOp
uint32_t numRepeats = 4;
std::shared_ptr<RepeatOp> myRepeatOp;
rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
// Assign tree relations and root
rc = myRepeatOp->AddChild(myCacheOp);
EXPECT_TRUE(rc.IsOk());
rc = myCacheOp->AddChild(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssignRoot(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
MS_LOG(INFO) << "Launching tree and begin iteration";
rc = myTree->Prepare();
EXPECT_TRUE(rc.IsOk());
std::cout << *myClient << std::endl;
rc = myTree->Launch();
EXPECT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator dI(myTree);
TensorRow tensorList;
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
int rowCount = 0;
while (!tensorList.empty()) {
// Don't display these rows, just count them
MS_LOG(INFO) << "Row fetched #: " << rowCount;
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
rowCount++;
}
ASSERT_EQ(rowCount, 40);
rc = myClient->DestroyCache();
EXPECT_TRUE(rc.IsOk());
}
TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) {
Status rc;
int64_t num_samples = 0;
int64_t start_index = 0;
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 0, true);
std::shared_ptr<CacheMergeOp> myMergeOp;
rc = CacheMergeOp::Builder().SetNumWorkers(3).SetOpConnectorSize(3).SetNumCleaner(2).SetClient(myClient).Build(
&myMergeOp);
EXPECT_TRUE(rc.IsOk());
std::shared_ptr<CacheLookupOp> myLookupOp;
rc = CacheLookupOp::Builder()
.SetNumWorkers(3)
.SetOpConnectorSize(3)
.SetClient(myClient)
.SetSampler(seq_sampler)
.Build(&myLookupOp);
EXPECT_TRUE(rc.IsOk());
std::shared_ptr<ImageFolderOp> so;
ImageFolderOp::Builder builder;
builder.SetSampler(myLookupOp)
.SetOpConnectorSize(3)
.SetNumWorkers(3)
.SetRowsPerBuffer(2)
.SetExtensions({".jpg", ".JPEG"})
.SetRecursive(true)
.SetImageFolderDir(datasets_root_path_ + "/testPK/data");
rc = builder.Build(&so);
EXPECT_TRUE(rc.IsOk());
// RepeatOp
uint32_t numRepeats = 4;
std::shared_ptr<RepeatOp> myRepeatOp;
rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
EXPECT_TRUE(rc.IsOk());
auto myTree = std::make_shared<ExecutionTree>();
rc = myTree->AssociateNode(so);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myLookupOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myMergeOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssignRoot(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
rc = myRepeatOp->AddChild(myMergeOp);
EXPECT_TRUE(rc.IsOk());
rc = myMergeOp->AddChild(myLookupOp);
EXPECT_TRUE(rc.IsOk());
rc = myMergeOp->AddChild(so);
EXPECT_TRUE(rc.IsOk());
rc = myTree->Prepare();
EXPECT_TRUE(rc.IsOk());
rc = myTree->Launch();
EXPECT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator dI(myTree);
TensorRow tensorList;
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
int rowCount = 0;
while (!tensorList.empty()) {
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
if (rc.IsError()) {
std::cout << rc << std::endl;
break;
}
rowCount++;
}
ASSERT_EQ(rowCount, 176);
std::cout << "Row count : " << rowCount << std::endl;
rc = myClient->DestroyCache();
EXPECT_TRUE(rc.IsOk());
}
//// Simple test with a repeated cache op over random data producer.
//// The difference in this one is that you do not add the sampler to the cache op directly.
//// Instead, the sampler is added as part of the leaf op construction. Then, the prepare
//// phase will pull this up from the leaf and into the cache.
//// It removes the sampler from the leaf op, which doesn't make sense there anyway for
//// the RandomDataOp which doesn't support sampling without a cache.
////
//// RepeatOp
//// |
//// CacheOp
//// |
//// RandomDataOp
////
TEST_F(MindDataTestCacheOp, TestCacheInheritSampler) {
Status rc;
int32_t rank = 0; // not used
MS_LOG(INFO) << "UT test TestCacheInheritSampler";
int64_t num_samples = 0;
int64_t start_index = 0;
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
// Start with an empty execution tree
auto myTree = std::make_shared<ExecutionTree>();
// Create a schema using the C api's
std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
// 2 columns. First column is an "image" 640,480,3
TensorShape c1Shape({640, 480, 3});
ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible,
rank, // not used
&c1Shape);
// Column 2 will just be a scalar label number
TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape);
testSchema->AddColumn(c1);
testSchema->AddColumn(c2);
// RandomDataOp
std::shared_ptr<RandomDataOp> myRandomDataOp;
rc = RandomDataOp::Builder()
.SetRowsPerBuffer(2)
.SetNumWorkers(4)
.SetDataSchema(std::move(testSchema))
.SetTotalRows(10)
.SetSampler(std::move(seq_sampler))
.Build(&myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
// CacheOp
std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 4, true);
std::shared_ptr<CacheOp> myCacheOp;
rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myCacheOp);
EXPECT_TRUE(rc.IsOk());
// RepeatOp
uint32_t numRepeats = 4;
std::shared_ptr<RepeatOp> myRepeatOp;
rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
// Assign tree relations and root
rc = myRepeatOp->AddChild(myCacheOp);
EXPECT_TRUE(rc.IsOk());
rc = myCacheOp->AddChild(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssignRoot(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
MS_LOG(INFO) << "Launching tree and begin iteration";
rc = myTree->Prepare();
EXPECT_TRUE(rc.IsOk());
std::cout << *myClient << std::endl;
rc = myTree->Launch();
EXPECT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator dI(myTree);
TensorRow tensorList;
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
int rowCount = 0;
while (!tensorList.empty()) {
// Don't display these rows, just count them
MS_LOG(INFO) << "Row fetched #: " << rowCount;
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
rowCount++;
}
ASSERT_EQ(rowCount, 40);
rc = myClient->DestroyCache();
EXPECT_TRUE(rc.IsOk());
}

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,157 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing cache operator with mappable datasets
"""
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as c_vision
from mindspore import log as logger
from util import save_and_check_md5
DATA_DIR = "../data/dataset/testImageNetData/train/"
GENERATE_GOLDEN = False
def test_cache_map_basic1():
"""
Test mappable leaf with cache op right over the leaf
Repeat
|
Map(decode)
|
Cache
|
ImageFolder
"""
logger.info("Test cache map basic 1")
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR, cache=some_cache)
decode_op = c_vision.Decode()
ds1 = ds1.map(input_columns=["image"], operations=decode_op)
ds1 = ds1.repeat(4)
filename = "cache_map_01_result.npz"
save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN)
logger.info("test_cache_map_basic1 Ended.\n")
def test_cache_map_basic2():
"""
Test mappable leaf with the cache op later in the tree above the map(decode)
Repeat
|
Cache
|
Map(decode)
|
ImageFolder
"""
logger.info("Test cache map basic 2")
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR)
decode_op = c_vision.Decode()
ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
ds1 = ds1.repeat(4)
filename = "cache_map_02_result.npz"
save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN)
logger.info("test_cache_map_basic2 Ended.\n")
def test_cache_map_basic3():
"""
Test a repeat under mappable cache
Cache
|
Map(decode)
|
Repeat
|
ImageFolder
"""
logger.info("Test cache basic 3")
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR)
decode_op = c_vision.Decode()
ds1 = ds1.repeat(4)
ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
num_iter = 0
for _ in ds1.create_dict_iterator():
num_iter += 1
logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 8
logger.info('test_cache_basic3 Ended.\n')
def test_cache_map_failure1():
"""
Test nested cache (failure)
Repeat
|
Cache
|
Map(decode)
|
Cache
|
ImageFolder
"""
logger.info("Test cache failure 1")
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR, cache=some_cache)
decode_op = c_vision.Decode()
ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
ds1 = ds1.repeat(4)
try:
num_iter = 0
for _ in ds1.create_dict_iterator():
num_iter += 1
except RuntimeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Nested cache operations is not supported!" in str(e)
assert num_iter == 0
logger.info('test_cache_failure1 Ended.\n')
if __name__ == '__main__':
test_cache_map_basic1()
test_cache_map_basic2()
test_cache_map_basic3()
test_cache_map_failure1()

View File

@ -0,0 +1,429 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing cache operator with non-mappable datasets
"""
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as c_vision
from mindspore import log as logger
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
GENERATE_GOLDEN = False
def test_cache_nomap_basic1():
"""
A random dataset (a non mappable dataset) with a cache over it just after the leaf
"""
logger.info("Test cache nomap basic 1")
schema = ds.Schema()
schema.add_column('image', de_type=mstype.uint8,
shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
schema.add_column('label', de_type=mstype.uint8, shape=[1])
# create a cache. arbitrary session_id for now
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
# User-created sampler here
ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=4, cache=some_cache)
ds1 = ds1.repeat(4)
num_iter = 0
for data in ds1.create_dict_iterator():
logger.info("printing the label: {}".format(data["label"]))
num_iter += 1
logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 40
logger.info("test_cache_nomap_basic1 Ended.\n")
def test_cache_nomap_basic2():
"""
A random dataset (a non mappable dataset) with a cache over it just after the leaf
"""
logger.info("Test cache nomap basic 2")
schema = ds.Schema()
schema.add_column('image', de_type=mstype.uint8,
shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
schema.add_column('label', de_type=mstype.uint8, shape=[1])
# create a cache. arbitrary session_id for now
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
# sampler arg not given directly, however any of these args will auto-generate an appropriate sampler:
# num_samples, shuffle, num_shards, shard_id
# In this case, the presence of num_samples chooses a sampler.
ds1 = ds.RandomDataset(schema=schema, total_rows=20, num_samples=20, num_parallel_workers=4, cache=some_cache)
ds1 = ds1.repeat(2)
num_iter = 0
for data in ds1.create_dict_iterator():
logger.info("printing the label: {}".format(data["label"]))
num_iter += 1
logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 40
logger.info("test_cache_nomap_basic2 Ended.\n")
def test_cache_nomap_basic3():
"""
A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
Repeat
|
Map(decode)
|
Cache
|
TFReader
"""
logger.info("Test cache nomap basic 3")
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
decode_op = c_vision.Decode()
ds1 = ds1.map(input_columns=["image"], operations=decode_op)
ds1 = ds1.repeat(4)
num_iter = 0
for _ in ds1.create_dict_iterator():
num_iter += 1
logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 12
logger.info("test_cache_nomap_basic3 Ended.\n")
def test_cache_nomap_basic4():
"""
A TF reader dataset (a non mappable dataset) with a map decode and cache after it
Since a global shuffle is used for the tf reader, it will inject a shuffle op over the tf.
But, if there's a cache later, that shuffle becomes invalid and should be removed.
Repeat
|
Cache
|
Map(decode)
|
TFReader
"""
logger.info("Test cache nomap basic 4")
# This dataset has 3 records in it only
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
# With shuffle not being set, TF defaults to a "global" shuffle when there is no cache
# in the picture. This causes a shuffle-injection over the TF. For clarify, this test will
# explicitly give the global option, even though it's the default in python.
# But, when caching is added in the ascendent tree above TF, we do global shuffling
# through the sampler over the cache, not by the shuffle op. In that case, tree prepare
# will remove the shuffle op that got injected by the initial tree creation.
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL)
decode_op = c_vision.Decode()
ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
ds1 = ds1.repeat(4)
num_iter = 0
for _ in ds1.create_dict_iterator():
num_iter += 1
logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 12
logger.info("test_cache_nomap_basic4 Ended.\n")
def test_cache_nomap_basic5():
"""
A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
Same as test 3, but this one does not have shuffle arg, causing tf to default to global
shuffle which attempts to inject a shuffle operator. However, since there is a cache
we do not need global shuffle, so the shuffle will not be built. It ends up being
identical to test basic 3, however we arrive at the same tree in different codepaths
(if there was no cache, then the shuffle IS built)
Repeat
|
Map(decode)
|
Cache
|
TFReader
"""
logger.info("Test cache nomap basic 5")
# This dataset has 3 records in it only
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], cache=some_cache)
decode_op = c_vision.Decode()
ds1 = ds1.map(input_columns=["image"], operations=decode_op)
ds1 = ds1.repeat(4)
num_iter = 0
for _ in ds1.create_dict_iterator():
num_iter += 1
logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 12
logger.info("test_cache_nomap_basic5 Ended.\n")
def test_cache_nomap_basic6():
"""
A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
In this one, the tf dataset will be given sharding configuration, however since a cache is
used, the tree prepare should undo the sharding configuration and instead, a distributed
sampler will be chosen with the same shard config.
Repeat
|
Map(decode)
|
Cache
|
TFReader
"""
logger.info("Test cache nomap basic 6")
# This dataset has 3 records in it only
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
# With only 3 records shard into 3, we expect only 1 record returned for this shard
# However, the sharding will be done by the sampler, not by the tf record leaf node
# In this case, it is a row-based sharding, not the file-based sharding that would happen if
# there was not any cache.
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_shards=3, shard_id=1, cache=some_cache)
decode_op = c_vision.Decode()
ds1 = ds1.map(input_columns=["image"], operations=decode_op)
ds1 = ds1.repeat(4)
num_iter = 0
for _ in ds1.create_dict_iterator():
num_iter += 1
logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 4
logger.info("test_cache_nomap_basic6 Ended.\n")
def test_cache_nomap_basic7():
"""
A TF reader dataset (a non mappable dataset) that uses global shuffle, and is cached followed by
map.
In this one, the tf dataset with global shuffle might want to inject a shuffle op over top of the
tf reader, but since a cache is given, it will choose not to.
Repeat
|
Map(decode)
|
cache
|
TFReader
"""
logger.info("Test cache nomap basic 7")
# This dataset has 3 records in it only
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL, cache=some_cache)
decode_op = c_vision.Decode()
ds1 = ds1.map(input_columns=["image"], operations=decode_op)
ds1 = ds1.repeat(4)
num_iter = 0
for _ in ds1.create_dict_iterator():
num_iter += 1
logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 12
logger.info("test_cache_nomap_basic7 Ended.\n")
def test_cache_nomap_allowed_share1():
"""
It is allowed to share the cache between the following two trees:
Repeat Shuffle
| |
Cache Cache
| |
TFReader TFReader
"""
logger.info("Test cache nomap allowed share 1")
ds.config.set_seed(1)
# This dataset has 3 records in it only
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
ds1 = ds1.repeat(4)
ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
ds2 = ds2.shuffle(buffer_size=2)
num_iter = 0
for _ in ds1.create_dict_iterator():
num_iter += 1
assert num_iter == 12
logger.info("Number of data in ds1: {} ".format(num_iter))
num_iter = 0
for _ in ds2.create_dict_iterator():
num_iter += 1
assert num_iter == 3
logger.info("test_cache_nomap_allowed_share1 Ended.\n")
def test_cache_nomap_allowed_share2():
"""
It is allowed to share the cache between the following two trees (with map decode):
Repeat Shuffle
| |
Cache Cache
| |
Map(decode) Map(decode)
| |
TFReader TFReader
"""
logger.info("Test cache nomap allowed share 2")
ds.config.set_seed(1)
# This dataset has 3 records in it only
some_cache = ds.DatasetCache(session_id=2, size=0, spilling=True)
decode_op = c_vision.Decode()
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
ds1 = ds1.repeat(4)
ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
ds2 = ds2.map(input_columns=["image"], operations=decode_op, cache=some_cache)
ds2 = ds2.shuffle(buffer_size=2)
num_iter = 0
for _ in ds1.create_dict_iterator():
num_iter += 1
logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 12
num_iter = 0
for _ in ds2.create_dict_iterator():
num_iter += 1
assert num_iter == 3
logger.info("test_cache_nomap_allowed_share2 Ended.\n")
def test_cache_nomap_allowed_share3():
"""
It is allowed to share the cache between the following two trees (different shard ids):
Repeat Repeat
| |
Cache Cache
| |
TFReader(shard_id = 0) TFReader(shard_id = 1)
"""
logger.info("Test cache nomap allowed share 3")
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data"]
ds1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=0, num_samples=3, shuffle=False, cache=some_cache)
ds1 = ds1.repeat(4)
ds2 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=1, num_samples=3, shuffle=False, cache=some_cache)
ds2 = ds2.repeat(4)
num_iter = 0
for _ in ds1.create_dict_iterator():
num_iter += 1
logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 12
num_iter = 0
for _ in ds2.create_dict_iterator():
num_iter += 1
assert num_iter == 12
logger.info("test_cache_nomap_allowed_share3 Ended.\n")
def test_cache_nomap_disallowed_share1():
"""
It is not allowed to share the cache between the following two trees:
Cache Cache
| |
Map(decode) Map(rescale)
| |
TFReader TFReader
"""
logger.info("Test cache nomap disallowed share1")
# This dataset has 3 records in it only
some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
decode_op = c_vision.Decode()
rescale_op = c_vision.Rescale(1.0 / 255.0, -1.0)
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
ds2 = ds2.map(input_columns=["image"], operations=rescale_op, cache=some_cache)
num_iter = 0
for _ in ds1.create_dict_iterator():
num_iter += 1
logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 3
try:
sum([1 for _ in ds2])
except RuntimeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Attempt to re-use a cache for a different tree!" in str(e)
logger.info("test_cache_nomap_disallowed_share1 Ended.\n")
if __name__ == '__main__':
test_cache_nomap_basic1()
test_cache_nomap_basic2()
test_cache_nomap_basic3()
test_cache_nomap_basic4()
test_cache_nomap_basic5()
test_cache_nomap_basic6()
test_cache_nomap_basic7()
test_cache_nomap_allowed_share1()
test_cache_nomap_allowed_share2()
test_cache_nomap_allowed_share3()
test_cache_nomap_disallowed_share1()

View File

@ -16,17 +16,16 @@ import mindspore.common.dtype as mstype
import mindspore.dataset as ds
from mindspore import log as logger
# just a basic test with parallel random data op
def test_randomdataset_basic1():
logger.info("Test randomdataset basic")
logger.info("Test randomdataset basic 1")
schema = ds.Schema()
schema.add_column('image', de_type=mstype.uint8, shape=[2])
schema.add_column('label', de_type=mstype.uint8, shape=[1])
# apply dataset operations
ds1 = ds.RandomDataset(schema=schema, num_samples=50, num_parallel_workers=4)
ds1 = ds.RandomDataset(schema=schema, total_rows=50, num_parallel_workers=4)
ds1 = ds1.repeat(4)
num_iter = 0
@ -36,8 +35,9 @@ def test_randomdataset_basic1():
logger.info("{} label: {}".format(num_iter, data["label"]))
num_iter += 1
logger.info("Number of data in ds1: ", num_iter)
logger.info("Number of data in ds1: {}".format(num_iter))
assert num_iter == 200
logger.info("Test randomdataset basic 1 complete")
# Another simple test
@ -49,10 +49,8 @@ def test_randomdataset_basic2():
shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
schema.add_column('label', de_type=mstype.uint8, shape=[1])
# Make up about 10 samples
ds1 = ds.RandomDataset(schema=schema, num_samples=10, num_parallel_workers=1)
# cache size allows for about 4 images since each image just a bit less than 1MB, after that we will have to spill
# Make up 10 rows
ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=1)
ds1 = ds1.repeat(4)
num_iter = 0
@ -62,11 +60,31 @@ def test_randomdataset_basic2():
logger.info("printing the label: {}".format(data["label"]))
num_iter += 1
logger.info("Number of data in ds1: ", num_iter)
logger.info("Number of data in ds1: {}".format(num_iter))
assert num_iter == 40
logger.info("Test randomdataset basic 2 complete")
# Another simple test
def test_randomdataset_basic3():
logger.info("Test randomdataset basic 3")
# Make up 10 samples, but here even the schema is randomly created
# The columns are named like this "c0", "c1", "c2" etc
# But, we will use a tuple iterator instead of dict iterator so the column names
# are not needed to iterate
ds1 = ds.RandomDataset(total_rows=10, num_parallel_workers=1)
ds1 = ds1.repeat(2)
num_iter = 0
for _ in ds1.create_tuple_iterator():
num_iter += 1
logger.info("Number of data in ds1: {}".format(num_iter))
assert num_iter == 20
logger.info("Test randomdataset basic 3 Complete")
if __name__ == '__main__':
test_randomdataset_basic1()
test_randomdataset_basic2()
logger.info('test_randomdataset_basic Ended.\n')
test_randomdataset_basic3()