From 983827ec5cea151274585333d735b3f94afd3cbe Mon Sep 17 00:00:00 2001 From: Lixia Chen Date: Thu, 20 Aug 2020 12:27:39 -0400 Subject: [PATCH] Rebase up to 88ded11f5971f547adb7e1baecd3330c17fc5330 --- cmake/package.cmake | 10 +- .../bindings/dataset/engine/cache/bindings.cc | 25 +- .../minddata/dataset/core/config_manager.cc | 13 +- .../minddata/dataset/core/config_manager.h | 18 + .../ccsrc/minddata/dataset/core/constants.h | 2 + .../ccsrc/minddata/dataset/core/tensor_row.h | 8 + .../dataset/engine/cache/CMakeLists.txt | 4 +- .../dataset/engine/cache/cache_admin.cc | 11 +- .../dataset/engine/cache/cache_admin_arg.cc | 185 ++- .../dataset/engine/cache/cache_admin_arg.h | 16 +- .../dataset/engine/cache/cache_arena.cc | 45 +- .../dataset/engine/cache/cache_arena.h | 4 +- .../dataset/engine/cache/cache_client.cc | 108 +- .../dataset/engine/cache/cache_client.h | 82 +- .../dataset/engine/cache/cache_common.h | 28 +- .../dataset/engine/cache/cache_grpc_client.cc | 106 +- .../dataset/engine/cache/cache_grpc_client.h | 28 +- .../dataset/engine/cache/cache_grpc_server.cc | 50 +- .../dataset/engine/cache/cache_grpc_server.h | 4 +- .../dataset/engine/cache/cache_ipc.cc | 163 +++ .../minddata/dataset/engine/cache/cache_ipc.h | 207 ++++ .../dataset/engine/cache/cache_main.cc | 216 +++- .../dataset/engine/cache/cache_request.cc | 22 + .../dataset/engine/cache/cache_request.h | 81 +- .../dataset/engine/cache/cache_server.cc | 339 +++++- .../dataset/engine/cache/cache_server.h | 129 +- .../dataset/engine/cache/cache_service.cc | 191 +-- .../dataset/engine/cache/cache_service.h | 32 +- .../dataset/engine/cache/de_tensor.fbs | 10 + .../engine/datasetops/cache_base_op.cc | 225 ++-- .../dataset/engine/datasetops/cache_base_op.h | 11 +- .../engine/datasetops/cache_lookup_op.cc | 4 +- .../engine/datasetops/cache_merge_op.cc | 47 +- .../engine/datasetops/cache_merge_op.h | 1 + .../dataset/engine/datasetops/cache_op.cc | 5 +- .../dataset/engine/datasetops/concat_op.cc | 7 + .../dataset/engine/datasetops/concat_op.h | 6 + .../dataset/engine/datasetops/dataset_op.cc | 1 + .../dataset/engine/datasetops/filter_op.cc | 6 + .../dataset/engine/datasetops/filter_op.h | 6 + .../engine/datasetops/map_op/map_op.cc | 6 + .../dataset/engine/datasetops/map_op/map_op.h | 14 +- .../dataset/engine/datasetops/parallel_op.cc | 2 +- .../dataset/engine/datasetops/pipeline_op.cc | 4 +- .../dataset/engine/datasetops/zip_op.cc | 6 + .../dataset/engine/datasetops/zip_op.h | 14 +- .../minddata/dataset/engine/execution_tree.cc | 2 + .../dataset/engine/opt/CMakeLists.txt | 1 + .../ccsrc/minddata/dataset/engine/opt/pass.cc | 142 ++- .../ccsrc/minddata/dataset/engine/opt/pass.h | 106 +- .../dataset/engine/opt/post/repeat_pass.cc | 8 +- .../engine/opt/pre/cache_error_pass.cc | 79 ++ .../dataset/engine/opt/pre/cache_error_pass.h | 76 ++ .../engine/opt/pre/cache_transform_pass.cc | 45 +- .../engine/opt/pre/epoch_injection_pass.cc | 7 - .../engine/opt/pre/epoch_injection_pass.h | 7 - .../dataset/kernels/data/random_apply_op.cc | 1 + .../dataset/kernels/data/random_choice_op.cc | 1 + .../dataset/kernels/image/random_affine_op.cc | 1 + .../kernels/image/random_color_adjust_op.cc | 1 + .../dataset/kernels/image/random_color_op.cc | 4 +- .../image/random_crop_and_resize_op.cc | 1 + .../dataset/kernels/image/random_crop_op.cc | 1 + .../kernels/image/random_horizontal_flip_op.h | 1 + .../random_horizontal_flip_with_bbox_op.h | 1 + .../kernels/image/random_posterize_op.cc | 1 + .../dataset/kernels/image/random_resize_op.h | 1 + .../image/random_resize_with_bbox_op.h | 1 + .../kernels/image/random_rotation_op.cc | 1 + .../image/random_select_subpolicy_op.cc | 1 + .../kernels/image/random_sharpness_op.cc | 1 + .../kernels/image/random_solarize_op.h | 5 +- .../kernels/image/random_vertical_flip_op.h | 1 + .../image/random_vertical_flip_with_bbox_op.h | 1 + .../minddata/dataset/kernels/tensor_op.h | 7 + .../ccsrc/minddata/dataset/util/allocator.h | 12 +- .../ccsrc/minddata/dataset/util/arena.cc | 9 +- mindspore/ccsrc/minddata/dataset/util/arena.h | 3 +- mindspore/ccsrc/minddata/dataset/util/btree.h | 36 +- .../minddata/dataset/util/btree_impl.tpp | 38 +- .../minddata/dataset/util/btree_iterator.tpp | 22 +- .../ccsrc/minddata/dataset/util/buddy.cc | 2 +- .../ccsrc/minddata/dataset/util/cache_pool.cc | 94 +- .../ccsrc/minddata/dataset/util/cache_pool.h | 23 +- .../minddata/dataset/util/circular_pool.cc | 13 +- .../ccsrc/minddata/dataset/util/queue_map.h | 44 +- .../ccsrc/minddata/dataset/util/semaphore.cc | 5 +- .../ccsrc/minddata/dataset/util/semaphore.h | 2 +- .../ccsrc/minddata/dataset/util/status.cc | 6 + .../ccsrc/minddata/dataset/util/status.h | 4 + .../dataset/util/storage_container.cc | 6 +- .../minddata/dataset/util/storage_manager.cc | 8 +- mindspore/dataset/engine/cache_client.py | 15 +- mindspore/dataset/engine/validators.py | 7 +- setup.py | 19 + tests/ut/cpp/dataset/btree_test.cc | 55 +- tests/ut/cpp/dataset/cache_op_test.cc | 79 +- tests/ut/cpp/dataset/memory_pool_test.cc | 36 +- tests/ut/python/cachetests/cachetest.sh | 48 + tests/ut/python/cachetests/cachetest_args.sh | 207 ++++ tests/ut/python/cachetests/cachetest_cpp.sh | 72 ++ tests/ut/python/cachetests/cachetest_lib.sh | 336 +++++ tests/ut/python/cachetests/cachetest_py.sh | 378 ++++++ tests/ut/python/conftest.py | 12 + tests/ut/python/dataset/test_cache_map.py | 1082 ++++++++++++++++- tests/ut/python/dataset/test_cache_nomap.py | 1053 +++++++++++++++- 106 files changed, 5764 insertions(+), 969 deletions(-) create mode 100644 mindspore/ccsrc/minddata/dataset/engine/cache/cache_ipc.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/cache/cache_ipc.h create mode 100644 mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.h create mode 100755 tests/ut/python/cachetests/cachetest.sh create mode 100755 tests/ut/python/cachetests/cachetest_args.sh create mode 100755 tests/ut/python/cachetests/cachetest_cpp.sh create mode 100755 tests/ut/python/cachetests/cachetest_lib.sh create mode 100755 tests/ut/python/cachetests/cachetest_py.sh diff --git a/cmake/package.cmake b/cmake/package.cmake index 570ef377cd6..df166ad4387 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -36,6 +36,7 @@ include(CPack) set(INSTALL_LIB_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Installation directory for libraries") set(INSTALL_PY_DIR ".") set(INSTALL_BASE_DIR ".") +set(INSTALL_BIN_DIR "bin") if (CMAKE_SYSTEM_NAME MATCHES "Windows") set(INSTALL_LIB_DIR ".") @@ -78,7 +79,14 @@ if (ENABLE_MINDDATA) DESTINATION ${INSTALL_BASE_DIR} COMPONENT mindspore ) - + if (CMAKE_SYSTEM_NAME MATCHES "Linux") + install( + TARGETS cache_admin cache_server + OPTIONAL + DESTINATION ${INSTALL_BIN_DIR} + COMPONENT mindspore + ) + endif() file(GLOB_RECURSE OPENCV_LIB_LIST ${opencv_LIBPATH}/libopencv_core* ${opencv_LIBPATH}/libopencv_imgcodecs* diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/cache/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/cache/bindings.cc index 0506c9e47d0..ca29792a615 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/cache/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/cache/bindings.cc @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include "minddata/dataset/api/python/pybind_register.h" #include "minddata/dataset/engine/cache/cache_client.h" @@ -22,17 +23,19 @@ namespace dataset { PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) { (void)py::class_>(*m, "CacheClient") - .def( - py::init([](session_id_type id, uint64_t mem_sz, bool spill, std::optional hostname, - std::optional port, int32_t prefetch_sz) { - std::shared_ptr cc; - CacheClient::Builder builder; - builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPrefetchSize(prefetch_sz); - if (hostname) builder.SetHostname(hostname.value()); - if (port) builder.SetPort(port.value()); - THROW_IF_ERROR(builder.Build(&cc)); - return cc; - })) + .def(py::init([](session_id_type id, uint64_t mem_sz, bool spill, + std::optional hostname, std::optional port, + std::optional num_connections, std::optional prefetch_sz) { + std::shared_ptr cc; + CacheClient::Builder builder; + builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill); + if (hostname) builder.SetHostname(hostname.value()); + if (port) builder.SetPort(port.value()); + if (num_connections) builder.SetNumConnections(num_connections.value()); + if (prefetch_sz) builder.SetPrefetchSize(prefetch_sz.value()); + THROW_IF_ERROR(builder.Build(&cc)); + return cc; + })) .def("GetStat", [](CacheClient &cc) { CacheServiceStat stat{}; THROW_IF_ERROR(cc.GetStat(&stat)); diff --git a/mindspore/ccsrc/minddata/dataset/core/config_manager.cc b/mindspore/ccsrc/minddata/dataset/core/config_manager.cc index 67adfc22a1d..44901b81199 100644 --- a/mindspore/ccsrc/minddata/dataset/core/config_manager.cc +++ b/mindspore/ccsrc/minddata/dataset/core/config_manager.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include "mindspore/core/utils/log_adapter.h" #include "minddata/dataset/util/system_pool.h" @@ -33,7 +34,9 @@ ConfigManager::ConfigManager() monitor_sampling_interval_(kCfgMonitorSamplingInterval), callback_timout_(kCfgCallbackTimeout), cache_host_(kCfgDefaultCacheHost), - cache_port_(kCfgDefaultCachePort) { + cache_port_(kCfgDefaultCachePort), + num_connections_(kDftNumConnections), + prefetch_size_(kDftPrefetchSize) { auto env_cache_host = std::getenv("MS_CACHE_HOST"); auto env_cache_port = std::getenv("MS_CACHE_PORT"); if (env_cache_host != nullptr) { @@ -71,6 +74,8 @@ Status ConfigManager::FromJson(const nlohmann::json &j) { set_monitor_sampling_interval(j.value("monitorSamplingInterval", monitor_sampling_interval_)); set_cache_host(j.value("cacheHost", cache_host_)); set_cache_port(j.value("cachePort", cache_port_)); + set_num_connections(j.value("numConnections", num_connections_)); + set_prefetch_size(j.value("prefetchSize", prefetch_size_)); return Status::OK(); } @@ -120,8 +125,12 @@ void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_s void ConfigManager::set_callback_timeout(uint32_t timeout) { callback_timout_ = timeout; } -void ConfigManager::set_cache_host(std::string cache_host) { cache_host_ = cache_host; } +void ConfigManager::set_cache_host(std::string cache_host) { cache_host_ = std::move(cache_host); } void ConfigManager::set_cache_port(int32_t cache_port) { cache_port_ = cache_port; } + +void ConfigManager::set_num_connections(int32_t num_connections) { num_connections_ = num_connections; } + +void ConfigManager::set_prefetch_size(int32_t prefetch_size) { prefetch_size_ = prefetch_size; } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/config_manager.h b/mindspore/ccsrc/minddata/dataset/core/config_manager.h index ff35826708b..22cc38e68de 100644 --- a/mindspore/ccsrc/minddata/dataset/core/config_manager.h +++ b/mindspore/ccsrc/minddata/dataset/core/config_manager.h @@ -97,6 +97,14 @@ class ConfigManager { // @return The port of cache server int32_t cache_port() const { return cache_port_; } + /// getter function + /// \return Number of tcp/ip connection + int32_t num_connections() const { return num_connections_; } + + /// getter function + /// \return Prefetch size + int32_t prefetch_size() const { return prefetch_size_; } + // setter function // @param rows_per_buffer - The setting to apply to the config void set_rows_per_buffer(int32_t rows_per_buffer); @@ -121,6 +129,14 @@ class ConfigManager { // @param cache_port - The port of cache server void set_cache_port(int32_t cache_port); + /// setter function + /// \param num_connections + void set_num_connections(int32_t num_connections); + + /// setter function + /// \param prefetch_size + void set_prefetch_size(int32_t prefetch_size); + uint32_t seed() const; // setter function @@ -153,6 +169,8 @@ class ConfigManager { uint32_t callback_timout_; std::string cache_host_; int32_t cache_port_; + int32_t num_connections_; + int32_t prefetch_size_; // Private helper function that takes a nlohmann json format and populates the settings // @param j - The json nlohmann json info diff --git a/mindspore/ccsrc/minddata/dataset/core/constants.h b/mindspore/ccsrc/minddata/dataset/core/constants.h index dc966c3844b..87fb9b455ec 100644 --- a/mindspore/ccsrc/minddata/dataset/core/constants.h +++ b/mindspore/ccsrc/minddata/dataset/core/constants.h @@ -71,6 +71,8 @@ constexpr uint32_t kCfgMonitorSamplingInterval = 10; constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds constexpr int32_t kCfgDefaultCachePort = 50052; constexpr char kCfgDefaultCacheHost[] = "127.0.0.1"; +constexpr int32_t kDftPrefetchSize = 20; +constexpr int32_t kDftNumConnections = 12; // Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h) constexpr uint8_t kCVInvalidType = 255; diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor_row.h b/mindspore/ccsrc/minddata/dataset/core/tensor_row.h index 613c2560175..0b3b81183c4 100644 --- a/mindspore/ccsrc/minddata/dataset/core/tensor_row.h +++ b/mindspore/ccsrc/minddata/dataset/core/tensor_row.h @@ -79,6 +79,14 @@ class TensorRow { const vector_type &getRow() const { return row_; } + int64_t SizeInBytes() const { + size_t sz = 0; + for (auto &it : row_) { + sz += it->SizeInBytes(); + } + return sz; + } + // Wrapper functions to support vector operations void emplace_back(value_type t) { row_.emplace_back(t); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt index 7015d04716f..802fbf3779d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt @@ -12,7 +12,9 @@ add_library(engine-cache-client OBJECT if (ENABLE_CACHE) ms_grpc_generate(CACHE_GRPC_SRCS CACHE_GRPC_HDRS cache_grpc.proto) - target_sources(engine-cache-client PUBLIC ${CACHE_GRPC_SRCS} cache_grpc_client.cc) + target_sources(engine-cache-client PUBLIC ${CACHE_GRPC_SRCS} + cache_grpc_client.cc + cache_ipc.cc) add_library(engine-cache-server OBJECT ${CACHE_GRPC_SRCS} diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin.cc index 18f5acd5112..92c2380981f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin.cc @@ -37,12 +37,17 @@ int main(int argc, char **argv) { warningMsg += "WARNING:\n"; warningMsg += "cache_admin and the cache server that it controls are currently only used for experimental research"; warningMsg += " purposes at this time.\n"; - warningMsg += "This command is currently disabled. Quitting.\n"; + auto env_enable_cache = std::getenv("MS_ENABLE_CACHE"); + if (env_enable_cache == nullptr || strcmp(env_enable_cache, "TRUE") != 0) { + // temporary disable cache feature in the current release + warningMsg += "This command is currently disabled. Quitting.\n"; + std::cerr << warningMsg << std::endl; + return 0; + } + warningMsg += "It is not intended for general availability yet as it may not be stable. Use it at your own risk.\n"; // A warning message until the code is mature enough. std::cerr << warningMsg << std::endl; - // temporary disable cache feature in the current release - return 0; if (argc == 1) { args.Help(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc index 4075216a7e2..4fef953b2dc 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc @@ -19,9 +19,11 @@ #include #include #include +#include #include #include #include +#include #include "minddata/dataset/engine/cache/cache_request.h" #include "minddata/dataset/engine/cache/cache_client.h" #include "minddata/dataset/util/path.h" @@ -39,6 +41,7 @@ CacheAdminArgHandler::CacheAdminArgHandler() num_workers_(kDefaultNumWorkers), shm_mem_sz_(kDefaultSharedMemorySizeInGB), log_level_(kDefaultLogLevel), + memory_cap_ratio_(kMemoryCapRatio), hostname_(kCfgDefaultCacheHost), spill_dir_(kDefaultSpillDir), command_id_(CommandId::kCmdUnknown) { @@ -62,6 +65,9 @@ CacheAdminArgHandler::CacheAdminArgHandler() arg_map_["--shared_memory_size"] = ArgValue::kArgSharedMemorySize; arg_map_["-l"] = ArgValue::kArgLogLevel; arg_map_["--minloglevel"] = ArgValue::kArgLogLevel; + arg_map_["-r"] = ArgValue::kArgMemoryCapRatio; + arg_map_["--memory_cap_ratio"] = ArgValue::kArgMemoryCapRatio; + arg_map_["--list_sessions"] = ArgValue::kArgListSessions; // Initialize argument tracker with false values for (int16_t i = 0; i < static_cast(ArgValue::kArgNumArgs); ++i) { ArgValue currAV = static_cast(i); @@ -69,6 +75,8 @@ CacheAdminArgHandler::CacheAdminArgHandler() } } +CacheAdminArgHandler::~CacheAdminArgHandler() = default; + Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream, CommandId command_id) { // Detect if the user tried to provide this argument more than once @@ -102,7 +110,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std return Status(StatusCode::kSyntaxError, err_msg); } - // Now, attempt to convert the value into it's string format for output + // Now, attempt to convert the value into it's numeric format for output try { *out_arg = std::stoul(value_as_string); } catch (const std::exception &e) { @@ -140,7 +148,13 @@ Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg, // If there is no argument to get, such as the --start command, then out_arg will be a nullptr. if (out_arg != nullptr) { // Fetch the argument from the arg stream into a string - *arg_stream >> *out_arg; + if (arg_stream->rdbuf()->in_avail() != 0) { + *arg_stream >> *out_arg; + } else { + std::string err_msg = option + " option requires an argument field. Syntax: " + option + " "; + return Status(StatusCode::kSyntaxError, err_msg); + } + if (out_arg->empty()) { std::string err_msg = option + " option requires an argument field. Syntax: " + option + " "; return Status(StatusCode::kSyntaxError, err_msg); @@ -150,12 +164,62 @@ Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg, return Status::OK(); } +Status CacheAdminArgHandler::AssignArg(std::string option, float *out_arg, std::stringstream *arg_stream, + CommandId command_id) { + // Detect if the user tried to provide this argument more than once + ArgValue selected_arg = arg_map_[option]; + if (used_args_[selected_arg]) { + std::string err_msg = "The " + option + " argument was given more than once."; + return Status(StatusCode::kSyntaxError, err_msg); + } + + // Flag that this arg is used now + used_args_[selected_arg] = true; + + // Some options are just arguments, for example "--hostname "127.0.0.1" is not a command, it's just an argument. + // Other options are actual commands, for example "--start". + // If this option is also a command, make sure there has not been multiple commands given before assigning it. + if (command_id != CommandId::kCmdUnknown) { + if (command_id_ != CommandId::kCmdUnknown) { + std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option; + return Status(StatusCode::kSyntaxError, err_msg); + } else { + command_id_ = command_id; + } + } + + std::string value_as_string; + + // Fetch the argument from the arg stream into a string + *arg_stream >> value_as_string; + if (value_as_string.empty()) { + std::string err_msg = option + " option requires an argument field. Syntax: " + option + " "; + return Status(StatusCode::kSyntaxError, err_msg); + } + + // Now, attempt to convert the value into it's string format for output + try { + *out_arg = std::stof(value_as_string, nullptr); + } catch (const std::exception &e) { + std::string err_msg = "Invalid numeric value: " + value_as_string; + return Status(StatusCode::kSyntaxError, err_msg); + } + + return Status::OK(); +} + Status CacheAdminArgHandler::ParseArgStream(std::stringstream *arg_stream) { std::string tok; while (*arg_stream >> tok) { switch (arg_map_[tok]) { case ArgValue::kArgHost: { RETURN_IF_NOT_OK(AssignArg(tok, &hostname_, arg_stream)); + // Temporary sanity check. We only support localhost for now + if (hostname_ != std::string(kCfgDefaultCacheHost)) { + std::string err_msg = + "Invalid host interface: " + hostname_ + ". Current limitation, only 127.0.0.1 can be used."; + return Status(StatusCode::kSyntaxError, err_msg); + } break; } case ArgValue::kArgPort: { @@ -203,6 +267,14 @@ Status CacheAdminArgHandler::ParseArgStream(std::stringstream *arg_stream) { RETURN_IF_NOT_OK(AssignArg(tok, &log_level_, arg_stream)); break; } + case ArgValue::kArgMemoryCapRatio: { + RETURN_IF_NOT_OK(AssignArg(tok, &memory_cap_ratio_, arg_stream)); + break; + } + case ArgValue::kArgListSessions: { + RETURN_IF_NOT_OK(AssignArg(tok, static_cast(nullptr), arg_stream, CommandId::kCmdListSessions)); + break; + } default: { // Save space delimited trailing arguments trailing_args_ += (" " + tok); @@ -232,9 +304,12 @@ Status CacheAdminArgHandler::Validate() { } // Additional checks here - if (num_workers_ < 1) return Status(StatusCode::kSyntaxError, "Number of workers must be positive value."); + if (num_workers_ < 1 || num_workers_ > 100) + return Status(StatusCode::kSyntaxError, "Number of workers must be in range of 1 and 100."); if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kSyntaxError, "Log level must be in range (0..3)."); - // port range check? + if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1) + return Status(StatusCode::kSyntaxError, "Memory cap ratio should be positive and no greater than 1"); + if (port_ < 1025 || port_ > 65535) return Status(StatusCode::kSyntaxError, "Port must be in range (1025..65535)."); return Status::OK(); } @@ -245,12 +320,9 @@ Status CacheAdminArgHandler::RunCommand() { Help(); break; } - case CommandId::kCmdStart: { - RETURN_IF_NOT_OK(StartServer()); - break; - } + case CommandId::kCmdStart: case CommandId::kCmdStop: { - RETURN_IF_NOT_OK(StopServer()); + RETURN_IF_NOT_OK(StartStopServer(command_id_)); break; } case CommandId::kCmdGenerateSession: { @@ -259,7 +331,7 @@ Status CacheAdminArgHandler::RunCommand() { auto rq = std::make_shared(); RETURN_IF_NOT_OK(comm.HandleRequest(rq)); RETURN_IF_NOT_OK(rq->Wait()); - std::cout << rq->GetSessionId() << std::endl; + std::cout << "Session: " << rq->GetSessionId() << std::endl; break; } case CommandId::kCmdDestroySession: { @@ -273,6 +345,39 @@ Status CacheAdminArgHandler::RunCommand() { std::cout << "Drop session successful" << std::endl; break; } + case CommandId::kCmdListSessions: { + CacheClientGreeter comm(hostname_, port_, 1); + RETURN_IF_NOT_OK(comm.ServiceStart()); + auto rq = std::make_shared(); + RETURN_IF_NOT_OK(comm.HandleRequest(rq)); + RETURN_IF_NOT_OK(rq->Wait()); + std::vector session_info = rq->GetSessionCacheInfo(); + if (!session_info.empty()) { + std::cout << std::setw(12) << "Session" << std::setw(12) << "Cache Id" << std::setw(12) << "Mem cached" + << std::setw(12) << "Disk cached" << std::setw(16) << "Avg cache size" << std::endl; + for (auto curr_session : session_info) { + std::string cache_id; + std::string stat_mem_cached; + std::string stat_disk_cached; + std::string stat_avg_cached; + int32_t crc = (curr_session.connection_id & 0x00000000FFFFFFFF); + cache_id = (curr_session.connection_id == 0) ? "n/a" : std::to_string(crc); + stat_mem_cached = + (curr_session.stats.num_mem_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_mem_cached); + stat_disk_cached = + (curr_session.stats.num_disk_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_disk_cached); + stat_avg_cached = + (curr_session.stats.avg_cache_sz == 0) ? "n/a" : std::to_string(curr_session.stats.avg_cache_sz); + + std::cout << std::setw(12) << curr_session.session_id << std::setw(12) << cache_id << std::setw(12) + << stat_mem_cached << std::setw(12) << stat_disk_cached << std::setw(16) << stat_avg_cached + << std::endl; + } + } else { + std::cout << "No active sessions." << std::endl; + } + break; + } default: { RETURN_STATUS_UNEXPECTED("Invalid cache admin command id."); break; @@ -282,7 +387,7 @@ Status CacheAdminArgHandler::RunCommand() { return Status::OK(); } -Status CacheAdminArgHandler::StartServer() { +Status CacheAdminArgHandler::StartStopServer(CommandId command_id) { // There currently does not exist any "install path" or method to identify which path the installed binaries will // exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the // cache_admin binary (this process). @@ -324,7 +429,10 @@ Status CacheAdminArgHandler::StartServer() { close(fd[1]); dup2(fd[0], 0); close(fd[0]); - wait(nullptr); + int status; + if (waitpid(pid, &status, 0) == -1) { + RETURN_STATUS_UNEXPECTED("waitpid fails. errno = " + std::to_string(errno)); + } std::string msg; const int32_t buf_sz = 1024; msg.resize(buf_sz); @@ -335,6 +443,13 @@ Status CacheAdminArgHandler::StartServer() { } msg.resize(n); std::cout << msg << std::endl; + if (WIFEXITED(status)) { + auto exit_status = WEXITSTATUS(status); + if (exit_status) { + std::string errMsg = "Child exit status " + std::to_string(exit_status); + return Status(StatusCode::kUnexpectedError, errMsg); + } + } return Status::OK(); } else { // Child here ... @@ -350,19 +465,29 @@ Status CacheAdminArgHandler::StartServer() { std::string shared_memory_string = std::to_string(shm_mem_sz_); std::string minloglevel_string = std::to_string(log_level_); std::string daemonize_string = "true"; + std::string memory_cap_ratio_string = std::to_string(memory_cap_ratio_); - char *argv[8]; - argv[0] = cache_server_binary.data(); // First arg is usually the binary name - argv[1] = spill_dir_.data(); - argv[2] = workers_string.data(); - argv[3] = port_string.data(); - argv[4] = shared_memory_string.data(); - argv[5] = minloglevel_string.data(); - argv[6] = daemonize_string.data(); - argv[7] = nullptr; + char *argv[9]; + if (command_id == CommandId::kCmdStart) { + argv[0] = cache_server_binary.data(); + argv[1] = spill_dir_.data(); + argv[2] = workers_string.data(); + argv[3] = port_string.data(); + argv[4] = shared_memory_string.data(); + argv[5] = minloglevel_string.data(); + argv[6] = daemonize_string.data(); + argv[7] = memory_cap_ratio_string.data(); + argv[8] = nullptr; + } else { + // We are doing a --stop. Change the name to '-' and we also need the port number. + // The rest we don't need. + argv[0] = std::string("-").data(); + argv[1] = port_string.data(); + argv[2] = nullptr; + } // Now exec the binary - execv(argv[0], argv); + execv(cache_server_binary.data(), argv); // If the exec was successful, this line will never be reached due to process image being replaced. // ..unless exec failed. std::string err_msg = "Failed to exec cache server: " + cache_server_binary; @@ -371,16 +496,6 @@ Status CacheAdminArgHandler::StartServer() { } } -Status CacheAdminArgHandler::StopServer() { - CacheClientGreeter comm(hostname_, port_, 1); - RETURN_IF_NOT_OK(comm.ServiceStart()); - auto rq = std::make_shared(); - RETURN_IF_NOT_OK(comm.HandleRequest(rq)); - // We will ignore the rc because if the shutdown is successful, the server will not reply back. - (void)rq->Wait(); - return Status::OK(); -} - void CacheAdminArgHandler::Help() { std::cerr << "Syntax:\n"; std::cerr << " cache_admin [--start | --stop]\n"; @@ -390,8 +505,12 @@ void CacheAdminArgHandler::Help() { std::cerr << " [ [-d | --destroy_session] ]\n"; std::cerr << " [ [-w | --workers] ]\n"; std::cerr << " [ [-s | --spilldir] ]\n"; - std::cerr << " [ [-m | --shared_memory_size] ]\n"; std::cerr << " [ [-l | --minloglevel] ]\n"; + std::cerr << " [ --list_sessions ]\n"; + // Do not expose these option to the user via help or documentation, but the options do exist to aid with + // development and tuning. + // std::cerr << " [ [-m | --shared_memory_size] ]\n"; + // std::cerr << " [ [-r | --memory_cap_ratio] ]\n"; std::cerr << " [--help]" << std::endl; } } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h index 587ea15eb08..5a78ebf0c7e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h @@ -32,6 +32,7 @@ class CacheAdminArgHandler { static constexpr int32_t kDefaultNumWorkers = 32; static constexpr int32_t kDefaultSharedMemorySizeInGB = 4; static constexpr int32_t kDefaultLogLevel = 1; + static constexpr float kMemoryCapRatio = 0.8; static const char kServerBinary[]; static const char kDefaultSpillDir[]; @@ -42,12 +43,13 @@ class CacheAdminArgHandler { kCmdStop = 2, kCmdGenerateSession = 3, kCmdDestroySession = 4, + kCmdListSessions = 5, kCmdUnknown = 32767 }; CacheAdminArgHandler(); - ~CacheAdminArgHandler() = default; + virtual ~CacheAdminArgHandler(); Status ParseArgStream(std::stringstream *arg_stream); @@ -70,12 +72,12 @@ class CacheAdminArgHandler { kArgNumWorkers = 9, kArgSharedMemorySize = 10, kArgLogLevel = 11, - kArgNumArgs = 12 // Must be the last position to provide a count + kArgMemoryCapRatio = 12, + kArgListSessions = 13, + kArgNumArgs = 14 // Must be the last position to provide a count }; - Status StartServer(); - - Status StopServer(); + Status StartStopServer(CommandId); Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream, CommandId command_id = CommandId::kCmdUnknown); @@ -83,6 +85,9 @@ class CacheAdminArgHandler { Status AssignArg(std::string option, std::string *out_arg, std::stringstream *arg_stream, CommandId command_id = CommandId::kCmdUnknown); + Status AssignArg(std::string option, float *out_arg, std::stringstream *arg_stream, + CommandId command_id = CommandId::kCmdUnknown); + Status Validate(); CommandId command_id_; @@ -90,6 +95,7 @@ class CacheAdminArgHandler { int32_t num_workers_; int32_t shm_mem_sz_; int32_t log_level_; + float memory_cap_ratio_; session_id_type session_id_; std::string hostname_; std::string spill_dir_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc index 76bc0b4ecd8..a40362031bf 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc @@ -17,27 +17,19 @@ #include "minddata/dataset/util/path.h" namespace mindspore { namespace dataset { -CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB) - : ptr_(nullptr), val_in_GB_(val_in_GB), port_(port), shmid_(-1) {} +CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB) : val_in_GB_(val_in_GB), port_(port) { + // We create the shared memory and we will destroy it. All other client just detach only. + shm_.RemoveResourcesOnExit(); +} CachedSharedMemoryArena::~CachedSharedMemoryArena() { -#if CACHE_LOCAL_CLIENT - if (this->ptr_ != nullptr && this->ptr_ != reinterpret_cast(-1)) { - shmdt(this->ptr_); - } - this->ptr_ = nullptr; - if (shmid_ != -1) { - shmctl(shmid_, IPC_RMID, nullptr); - // Also remove the path we use to generate ftok. - Path p(PortToUnixSocketPath(port_)); - (void)p.Remove(); - } -#endif + // Also remove the path we use to generate ftok. + Path p(PortToUnixSocketPath(port_)); + (void)p.Remove(); } Status CachedSharedMemoryArena::CreateArena(std::unique_ptr *out, int32_t port, size_t val_in_GB) { RETURN_UNEXPECTED_IF_NULL(out); -#if CACHE_LOCAL_CLIENT auto ba = new (std::nothrow) CachedSharedMemoryArena(port, val_in_GB); if (ba == nullptr) { return Status(StatusCode::kOutOfMemory); @@ -46,26 +38,13 @@ Status CachedSharedMemoryArena::CreateArena(std::unique_ptrshm_.SetPublicKey(shm_key); // Value is in GB. Convert into bytes. int64_t sz = val_in_GB * 1073741824L; - ba->shmid_ = shmget(shm_key, sz, IPC_CREAT | IPC_EXCL | access_mode); - if (ba->shmid_) { - ba->ptr_ = shmat(ba->shmid_, nullptr, 0); - if (ba->ptr_ == reinterpret_cast(-1)) { - RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno)); - } - ba->impl_ = std::make_unique(ba->ptr_, sz); - } else { - RETURN_STATUS_UNEXPECTED("Shared memory creation failed. Errno " + std::to_string(errno)); - } -#endif + RETURN_IF_NOT_OK(ba->shm_.Create(sz)); + ba->impl_ = std::make_unique(ba->shm_.SharedMemoryBaseAddr(), sz); return Status::OK(); } } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.h index d0f41588da0..61e430aa1fc 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.h @@ -21,6 +21,7 @@ #include #include "minddata/dataset/util/arena.h" #include "minddata/dataset/engine/cache/cache_common.h" +#include "minddata/dataset/engine/cache/cache_ipc.h" namespace mindspore { namespace dataset { /// This is a derived class of Arena but resides in shared memory @@ -73,10 +74,9 @@ class CachedSharedMemoryArena : public MemoryPool { private: mutable std::mutex mux_; - void *ptr_; int32_t val_in_GB_; int32_t port_; - int shmid_; + SharedMemory shm_; std::unique_ptr impl_; /// Private constructor. Not to be called directly. CachedSharedMemoryArena(int32_t port, size_t val_in_GB); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc index b237a0f294f..077304c1c9c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc @@ -24,26 +24,26 @@ namespace mindspore { namespace dataset { CacheClient::Builder::Builder() - : session_id_(0), cache_mem_sz_(0), spill_(false), hostname_(""), port_(0), num_workers_(0), prefetch_size_(0) { + : session_id_(0), cache_mem_sz_(0), spill_(false), hostname_(""), port_(0), num_connections_(0), prefetch_size_(0) { std::shared_ptr cfg = GlobalContext::config_manager(); hostname_ = cfg->cache_host(); port_ = cfg->cache_port(); - num_workers_ = cfg->num_parallel_workers(); - prefetch_size_ = 20; // rows_per_buf is too small (1 by default). + num_connections_ = cfg->num_connections(); // number of async tcp/ip connections + prefetch_size_ = cfg->prefetch_size(); // prefetch size } Status CacheClient::Builder::Build(std::shared_ptr *out) { RETURN_UNEXPECTED_IF_NULL(out); RETURN_IF_NOT_OK(SanityCheck()); - *out = - std::make_shared(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_, prefetch_size_); + *out = std::make_shared(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_connections_, + prefetch_size_); return Status::OK(); } Status CacheClient::Builder::SanityCheck() { CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive"); CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited"); - CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive"); + CHECK_FAIL_RETURN_UNEXPECTED(num_connections_ > 0, "rpc connections must be positive"); CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive"); CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty"); CHECK_FAIL_RETURN_UNEXPECTED(port_ > 0, "port must be positive"); @@ -55,26 +55,32 @@ Status CacheClient::Builder::SanityCheck() { // Constructor CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, - int32_t port, int32_t num_workers, int32_t prefetch_size) + int32_t port, int32_t num_connections, int32_t prefetch_size) : server_connection_id_(0), cache_mem_sz_(cache_mem_sz), spill_(spill), local_bypass_(false), hostname_(std::move(hostname)), port_(port), - num_workers_(num_workers), - prefetch_size_(prefetch_size) { + num_connections_(num_connections), + prefetch_size_(prefetch_size), + fetch_all_keys_(true) { cinfo_.set_session_id(session_id); - comm_ = std::make_shared(hostname_, port_, num_workers_); + comm_ = std::make_shared(hostname_, port_, num_connections_); +} + +CacheClient::~CacheClient() { + cache_miss_keys_wp_.Set(); + (void)comm_->ServiceStop(); } // print method for display cache details void CacheClient::Print(std::ostream &out) const { out << " Session id: " << session_id() << "\n Cache crc: " << cinfo_.crc() - << "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << getCacheMemSz() - << "\n Spilling: " << std::boolalpha << isSpill() << "\n Hostname: " << getHostname() - << "\n Port: " << getPort() << "\n Number of rpc workers: " << getNumWorkers() - << "\n Prefetch size: " << getPrefetchSize() << "\n Local client support: " << std::boolalpha + << "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << GetCacheMemSz() + << "\n Spilling: " << std::boolalpha << isSpill() << "\n Hostname: " << GetHostname() + << "\n Port: " << GetPort() << "\n Number of rpc workers: " << GetNumConnections() + << "\n Prefetch size: " << GetPrefetchSize() << "\n Local client support: " << std::boolalpha << SupportLocalClient(); } @@ -199,14 +205,6 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { return Status::OK(); } -Status CacheClient::PurgeCache() { - UniqueLock lck(&mux_); - auto rq = std::make_shared(server_connection_id_); - RETURN_IF_NOT_OK(PushRequest(rq)); - RETURN_IF_NOT_OK(rq->Wait()); - return Status::OK(); -} - Status CacheClient::DestroyCache() { UniqueLock lck(&mux_); auto rq = std::make_shared(server_connection_id_); @@ -253,5 +251,71 @@ Status CacheClient::BuildPhaseDone() const { } Status CacheClient::PushRequest(std::shared_ptr rq) const { return comm_->HandleRequest(std::move(rq)); } + +void CacheClient::ServerRunningOutOfResources() { + bool expected = true; + if (fetch_all_keys_.compare_exchange_strong(expected, false)) { + Status rc; + // Server runs out of memory or disk space to cache any more rows. + // First of all, we will turn off the locking. + auto toggle_write_mode_rq = std::make_shared(server_connection_id_, false); + rc = PushRequest(toggle_write_mode_rq); + if (rc.IsError()) { + return; + } + // Wait until we can toggle the state of the server to non-locking + rc = toggle_write_mode_rq->Wait(); + if (rc.IsError()) { + return; + } + // Now we get a list of all the keys not cached at the server so + // we can filter out at the prefetch level. + auto cache_miss_rq = std::make_shared(server_connection_id_); + rc = PushRequest(cache_miss_rq); + if (rc.IsError()) { + return; + } + rc = cache_miss_rq->Wait(); + if (rc.IsError()) { + return; + } + // We will get back a vector of row id between [min,max] that are absent in the server. + auto &row_id_buf = cache_miss_rq->reply_.result(); + auto p = flatbuffers::GetRoot(row_id_buf.data()); + std::vector row_ids; + auto sz = p->row_id()->size(); + row_ids.reserve(sz); + for (auto i = 0; i < sz; ++i) { + row_ids.push_back(p->row_id()->Get(i)); + } + cache_miss_keys_ = std::make_unique(row_ids); + // We are all set. + cache_miss_keys_wp_.Set(); + } +} + +CacheClient::CacheMissKeys::CacheMissKeys(const std::vector &v) { + auto it = v.begin(); + min_ = *it; + ++it; + max_ = *it; + ++it; + while (it != v.end()) { + gap_.insert(*it); + ++it; + } + MS_LOG(WARNING) << "# of cache miss keys between min(" << min_ << ") and max(" << max_ << ") is " << gap_.size(); +} + +bool CacheClient::CacheMissKeys::KeyIsCacheMiss(row_id_type key) { + if (key > max_ || key < min_) { + return true; + } else if (key == min_ || key == max_) { + return false; + } else { + auto it = gap_.find(key); + return it != gap_.end(); + } +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h index 7fc7a47816d..eec0ed7dfd8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h @@ -16,8 +16,13 @@ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_ +#include #include +#include #include +#include +#include +#include #include #include #include @@ -31,6 +36,8 @@ #endif #include "minddata/dataset/engine/data_buffer.h" #include "minddata/dataset/util/lock.h" +#include "minddata/dataset/util/cond_var.h" +#include "minddata/dataset/util/queue_map.h" namespace mindspore { namespace dataset { @@ -89,10 +96,10 @@ class CacheClient { } /// Setter function to set number of async rpc workers - /// \param num_workers + /// \param num_connections /// \return Builder object itself - Builder &SetNumWorkers(int32_t num_workers) { - num_workers_ = num_workers; + Builder &SetNumConnections(int32_t num_connections) { + num_connections_ = num_connections; return *this; } @@ -105,13 +112,13 @@ class CacheClient { } /// Getter functions - session_id_type getSessionId() const { return session_id_; } - uint64_t getCacheMemSz() const { return cache_mem_sz_; } + session_id_type GetSessionId() const { return session_id_; } + uint64_t GetCacheMemSz() const { return cache_mem_sz_; } bool isSpill() const { return spill_; } const std::string &getHostname() const { return hostname_; } - int32_t getPort() const { return port_; } - int32_t getNumWorkers() const { return num_workers_; } - int32_t getPrefetchSize() const { return prefetch_size_; } + int32_t GetPort() const { return port_; } + int32_t GetNumConnections() const { return num_connections_; } + int32_t GetPrefetchSize() const { return prefetch_size_; } Status SanityCheck(); @@ -123,7 +130,7 @@ class CacheClient { bool spill_; std::string hostname_; int32_t port_; - int32_t num_workers_; + int32_t num_connections_; int32_t prefetch_size_; }; @@ -132,10 +139,10 @@ class CacheClient { /// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited /// \param spill Spill to disk if out of memory CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, int32_t port, - int32_t num_workers, int32_t prefetch_size); + int32_t num_connections, int32_t prefetch_size); /// \brief Destructor - ~CacheClient() { (void)comm_->ServiceStop(); } + ~CacheClient(); /// \brief Send a TensorRow to the cache server /// \param[in] row @@ -161,10 +168,6 @@ class CacheClient { /// \return Status object Status CreateCache(uint32_t tree_crc, bool generate_id); - /// \brief Purge a cache. Cache can be reused after reset. - /// \return Status object - Status PurgeCache(); - /// \brief Destroy a cache. Like Purge but the cache is deleted and can't be reused. /// \return Status object Status DestroyCache(); @@ -218,12 +221,31 @@ class CacheClient { /// Getter functions session_id_type session_id() const { return cinfo_.session_id(); } - uint64_t getCacheMemSz() const { return cache_mem_sz_; } + uint64_t GetCacheMemSz() const { return cache_mem_sz_; } bool isSpill() const { return spill_; } - const std::string &getHostname() const { return hostname_; } - int32_t getPort() const { return port_; } - int32_t getNumWorkers() const { return num_workers_; } - int32_t getPrefetchSize() const { return prefetch_size_; } + const std::string &GetHostname() const { return hostname_; } + int32_t GetPort() const { return port_; } + int32_t GetNumConnections() const { return num_connections_; } + int32_t GetPrefetchSize() const { return prefetch_size_; } + + /// MergeOp will notify us when the server can't cache any more rows. + /// We will stop any attempt to fetch any rows that are most likely + /// not present at the server. + void ServerRunningOutOfResources(); + + /// \brief Check if a row is 100% cache miss at the server by checking the local information + /// \param key row id to be test + /// \return true if not at the server + bool KeyIsCacheMiss(row_id_type key) { + if (cache_miss_keys_) { + // Make sure it is fully built even though the pointer is not null + Status rc = cache_miss_keys_wp_.Wait(); + if (rc.IsOk()) { + return cache_miss_keys_->KeyIsCacheMiss(key); + } + } + return false; + } private: mutable RWLock mux_; @@ -240,9 +262,27 @@ class CacheClient { bool local_bypass_; std::string hostname_; int32_t port_; - int32_t num_workers_; + int32_t num_connections_; int32_t prefetch_size_; mutable std::shared_ptr comm_; + std::atomic fetch_all_keys_; + WaitPost cache_miss_keys_wp_; + /// A structure shared by all the prefetchers to know what keys are missing at the server. + class CacheMissKeys { + public: + explicit CacheMissKeys(const std::vector &v); + ~CacheMissKeys() = default; + /// This checks if a key is missing. + /// \param key + /// \return true if definitely a key miss + bool KeyIsCacheMiss(row_id_type key); + + private: + row_id_type min_; + row_id_type max_; + std::set gap_; + }; + std::unique_ptr cache_miss_keys_; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h index 947cd83fdad..22761d099f7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h @@ -25,13 +25,6 @@ #define CACHE_LOCAL_CLIENT 1 #endif -#ifdef CACHE_LOCAL_CLIENT -#include -#include -#include -#else -typedef int key_t; -#endif #ifdef ENABLE_CACHE #include #endif @@ -54,6 +47,8 @@ constexpr static uint32_t kLocalClientSupport = 1; /// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is /// inline in the protobuf. This also implies kLocalClientSupport is also true. constexpr static uint32_t kDataIsInSharedMemory = 2; +/// \brief Size of each message used in message queue. +constexpr static int32_t kSharedMessageSize = 2048; /// \brief Convert a Status object into a protobuf /// \param rc[in] Status object @@ -62,29 +57,10 @@ inline void Status2CacheReply(const Status &rc, CacheReply *reply) { reply->set_rc(static_cast(rc.get_code())); reply->set_msg(rc.ToString()); } - /// \brief Generate the unix socket file we use on both client/server side given a tcp/ip port number /// \param port /// \return unix socket url inline std::string PortToUnixSocketPath(int port) { return "/tmp/cache_server_p" + std::to_string(port); } - -/// \brief Generate a shared memory key using the tcp/ip port. -/// \note It must be called after the cache server generates the unix socket or ftok will fail. -/// \note Caller must check the return value. -1 means ftok failed. -/// \param[in] port -/// \param[out] err. If not null and ftok fails, this will contain the value of errno -/// \return key -inline key_t PortToFtok(int port, int *err) { - key_t shmkey = -1; -#ifdef CACHE_LOCAL_CLIENT - const std::string unix_path = PortToUnixSocketPath(port); - shmkey = ftok(unix_path.data(), 'a'); - if (err != nullptr && shmkey == (key_t)-1) { - *err = errno; - } -#endif - return shmkey; -} } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.cc index 33151201eac..2f5cbe6d632 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.cc @@ -17,34 +17,10 @@ #include namespace mindspore { namespace dataset { -Status CacheClientRequestTag::MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq, - std::unique_ptr &&tag) { - // If there is anything extra we need to do before we send. - RETURN_IF_NOT_OK(tag->base_rq_->Prepare()); - // One minute timeout - auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); - tag->ctx_.set_deadline(deadline); - tag->rpc_ = stub->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, cq); - tag->rpc_->StartCall(); - // Last step is we release the ownership and transfer it to the completion queue. - // The memory will be released by WorkerEntry or by the destructor when we drain the queue - auto ccReqTag = tag.release(); - ccReqTag->rpc_->Finish(&ccReqTag->base_rq_->reply_, &ccReqTag->rc_, - ccReqTag); // inject this object into the completion queue - return Status::OK(); -} +CacheClientGreeter::~CacheClientGreeter() { (void)ServiceStop(); } -CacheClientGreeter::~CacheClientGreeter() { - (void)ServiceStop(); - // Detach from shared memory if any - if (shmat_addr_ != nullptr) { - shmdt(shmat_addr_); - shmat_addr_ = nullptr; - } -} - -CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers) - : num_workers_(num_workers), shm_key_(-1), shm_id_(-1), shmat_addr_(nullptr) { +CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_connections) + : num_connections_(num_connections), request_cnt_(0) { grpc::ChannelArguments args; // We need to bump up the message size to unlimited. The default receiving // message limit is 4MB which is not big enough. @@ -68,21 +44,11 @@ CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port Status CacheClientGreeter::AttachToSharedMemory(int32_t port, bool *local_bypass) { *local_bypass = false; #if CACHE_LOCAL_CLIENT - int err; - shm_key_ = PortToFtok(port, &err); - if (shm_key_ == (key_t)-1) { - std::string errMsg = "Ftok failed with errno " + std::to_string(err); - RETURN_STATUS_UNEXPECTED(errMsg); - } + SharedMemory::shm_key_t shm_key; + RETURN_IF_NOT_OK(PortToFtok(port, &shm_key)); // Attach to the shared memory - shm_id_ = shmget(shm_key_, 0, 0); - if (shm_id_ == -1) { - RETURN_STATUS_UNEXPECTED("Shmget failed. Errno " + std::to_string(errno)); - } - shmat_addr_ = shmat(shm_id_, nullptr, 0); - if (shmat_addr_ == reinterpret_cast(-1)) { - RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno)); - } + mem_.SetPublicKey(shm_key); + RETURN_IF_NOT_OK(mem_.Attach()); *local_bypass = true; #endif return Status::OK(); @@ -90,7 +56,7 @@ Status CacheClientGreeter::AttachToSharedMemory(int32_t port, bool *local_bypass Status CacheClientGreeter::DoServiceStart() { RETURN_IF_NOT_OK(vg_.ServiceStart()); - RETURN_IF_NOT_OK(DispatchWorkers(num_workers_)); + RETURN_IF_NOT_OK(DispatchWorkers(num_connections_)); return Status::OK(); } @@ -100,19 +66,40 @@ Status CacheClientGreeter::DoServiceStop() { // Shutdown the TaskGroup. vg_.interrupt_all(); vg_.join_all(Task::WaitFlag::kNonBlocking); - // Drain the queue - bool success; - void *tag; - while (cq_.Next(&tag, &success)) { - auto r = reinterpret_cast(tag); - delete r; + // Drain the queue. We know how many requests we send out + while (!req_.empty()) { + bool success; + void *tag; + while (cq_.Next(&tag, &success)) { + auto r = reinterpret_cast(tag); + req_.erase(r->seqNo_); + } } return Status::OK(); } Status CacheClientGreeter::HandleRequest(std::shared_ptr rq) { - auto tag = std::make_unique(std::move(rq)); - return tag->MakeCall(stub_.get(), &cq_, std::move(tag)); + // If there is anything extra we need to do before we send. + RETURN_IF_NOT_OK(rq->Prepare()); + auto seqNo = request_cnt_.fetch_add(1); + auto tag = std::make_unique(std::move(rq), seqNo); + // One minute timeout + auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); + tag->ctx_.set_deadline(deadline); + tag->rpc_ = stub_->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, &cq_); + tag->rpc_->StartCall(); + auto ccReqTag = tag.get(); + // Insert it into the map. + { + std::unique_lock lck(mux_); + auto r = req_.emplace(seqNo, std::move(tag)); + if (!r.second) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__); + } + } + // Last step is to tag the request. + ccReqTag->rpc_->Finish(&ccReqTag->base_rq_->reply_, &ccReqTag->rc_, ccReqTag); + return Status::OK(); } Status CacheClientGreeter::WorkerEntry() { @@ -129,15 +116,26 @@ Status CacheClientGreeter::WorkerEntry() { auto &rc = rq->rc_; if (!rc.ok()) { auto error_code = rq->rc_.error_code(); - std::string errMsg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code); - Status remote_rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + std::string err_msg; + if (error_code == grpc::StatusCode::UNAVAILABLE) { + err_msg = + "Cache server is unreachable. Make sure the server is running. GRPC Code" + std::to_string(error_code); + } else { + err_msg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code); + } + Status remote_rc = Status(StatusCode::kNetWorkError, __LINE__, __FILE__, err_msg); Status2CacheReply(remote_rc, &rq->base_rq_->reply_); } // Notify the waiting thread. rq->Notify(); } - // We can now free the memory - delete rq; + { + // We can now free the memory + std::unique_lock lck(mux_); + auto seqNo = rq->seqNo_; + auto n = req_.erase(seqNo); + CHECK_FAIL_RETURN_UNEXPECTED(n == 1, "Sequence " + std::to_string(seqNo) + " not found"); + } } else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) { // If we are interrupted, exit. Otherwise wait again. RETURN_IF_INTERRUPTED(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.h index 8fbd265bc30..369369eda56 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.h @@ -16,10 +16,14 @@ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_ +#include +#include #include +#include #include #include #include "minddata/dataset/engine/cache/cache_common.h" +#include "minddata/dataset/engine/cache/cache_ipc.h" #include "minddata/dataset/util/service.h" #include "minddata/dataset/util/task_manager.h" namespace mindspore { @@ -34,16 +38,10 @@ namespace dataset { class CacheClientRequestTag { public: friend class CacheClientGreeter; - explicit CacheClientRequestTag(std::shared_ptr rq) : base_rq_(std::move(rq)) {} + explicit CacheClientRequestTag(std::shared_ptr rq, int64_t seqNo) + : base_rq_(std::move(rq)), seqNo_(seqNo) {} ~CacheClientRequestTag() = default; - /// \brief Make a RPC call - /// \param stub from CacheClientGreeter - /// \param cq from CacheClientGreeter - /// \return Status object - static Status MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq, - std::unique_ptr &&tag); - /// \brief Notify the client that a result has come back from the server void Notify() { base_rq_->wp_.Set(); } @@ -52,6 +50,7 @@ class CacheClientRequestTag { grpc::Status rc_; grpc::ClientContext ctx_; std::unique_ptr> rpc_; + int64_t seqNo_; }; /// \brief A GRPC layer to convert BaseRequest into protobuf and send to the cache server using gRPC @@ -60,7 +59,7 @@ class CacheClientGreeter : public Service { friend class CacheClient; public: - explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers); + explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_connections); ~CacheClientGreeter(); /// Override base Service class @@ -85,17 +84,18 @@ class CacheClientGreeter : public Service { /// \brief This returns where we attach to the shared memory. /// \return Base address of the shared memory. - const void *SharedMemoryBaseAddr() const { return shmat_addr_; } + const void *SharedMemoryBaseAddr() const { return mem_.SharedMemoryBaseAddr(); } private: std::shared_ptr channel_; std::unique_ptr stub_; grpc::CompletionQueue cq_; TaskGroup vg_; - int32_t num_workers_; - key_t shm_key_; - int32_t shm_id_; - void *shmat_addr_; + int32_t num_connections_; + std::atomic request_cnt_; + mutable std::mutex mux_; + std::map> req_; + SharedMemory mem_; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc index 43c0bcfe5a7..485d1a38060 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc @@ -47,53 +47,10 @@ void CacheServerGreeterImpl::Shutdown() { CacheServerGreeterImpl::~CacheServerGreeterImpl() { Shutdown(); } -Status CacheServerGreeterImpl::IpcResourceCleanup() { -#if CACHE_LOCAL_CLIENT - int err; - auto shm_key = PortToFtok(port_, &err); - // We are expecting the unix path doesn't exist. - if (shm_key == (key_t)-1) { - return Status::OK(); - } - // Attach to the shared memory - auto shm_id = shmget(shm_key, 0, 0); - if (shm_id == -1) { - return Status::OK(); - } - struct shmid_ds ds {}; - auto inx = shmctl(shm_id, IPC_STAT, &ds); - if (inx == -1) { - std::string errMsg = "Unable to query shared memory with id " + std::to_string(shm_id); - errMsg += "\nPlesae remove it manually using ipcrm -m command"; - RETURN_STATUS_UNEXPECTED(errMsg); - } - if (ds.shm_nattch == 0) { - // Stale shared memory from last time. - // Remove both the memory and the socket path - inx = shmctl(shm_id, IPC_RMID, nullptr); - if (inx == -1) { - std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id); - errMsg += ". Errno :" + std::to_string(errno); - errMsg += "\nPlesae remove it manually using ipcrm -m command"; - RETURN_STATUS_UNEXPECTED(errMsg); - } - Path p(unix_socket_); - (void)p.Remove(); - } else { - // Server is already up. - MS_LOG(ERROR) << "Cache server is already up and running"; - // We return a duplicate error. The main() will intercept - // and output a proper message - return Status(StatusCode::kDuplicateKey); - } -#endif - return Status::OK(); -} - Status CacheServerGreeterImpl::Run() { // To listen on all interfaces, use 0.0.0.0 - // Use 127.0.0.1 if just locally on the same machine. - std::string host("0.0.0.0"); // listen on all interfaces. + // Future, allow the user to choose listening interface. For now, default to localhost + std::string host("127.0.0.1"); std::string server_address = host + ":" + std::to_string(port_); grpc::ServerBuilder builder; // Default message size for gRPC is 4MB. Increase it to 2g-1 @@ -101,9 +58,6 @@ Status CacheServerGreeterImpl::Run() { int port_tcpip = 0; #if CACHE_LOCAL_CLIENT int port_local = 0; - // Check if we need to do clean up on the shared memory if the server - // came down unexpectedly like SEGV - RETURN_IF_NOT_OK(IpcResourceCleanup()); // We also optimize on local clients on the same machine using unix socket builder.AddListeningPort("unix://" + unix_socket_, grpc::InsecureServerCredentials(), &port_local); #endif diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h index ac3e648bf36..b3d1cc5f702 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h @@ -41,7 +41,7 @@ class CacheServerRequest : public BaseRequest { st_(STATE::CREATE), responder_(&ctx_) {} - ~CacheServerRequest() = default; + ~CacheServerRequest() override = default; /// \brief Functor. Used mainly by CacheServerGreeterImpl class to tag each incoming request and this /// functor will translate each protobuf into some form understood by by CacheService class. @@ -87,8 +87,6 @@ class CacheServerGreeterImpl final { void Shutdown(); - Status IpcResourceCleanup(); - private: int32_t port_; size_t shm_pool_sz_in_gb_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_ipc.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_ipc.cc new file mode 100644 index 00000000000..adb5f6450b2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_ipc.cc @@ -0,0 +1,163 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "minddata/dataset/engine/cache/cache_ipc.h" +#include + +namespace mindspore { +namespace dataset { +Status PortToFtok(int port, SharedMemory::shm_key_t *out) { + RETURN_UNEXPECTED_IF_NULL(out); + key_t shmkey = -1; + const std::string unix_path = PortToUnixSocketPath(port); + shmkey = ftok(unix_path.data(), 'a'); + if (shmkey == (key_t)-1) { + std::string errMsg = "Unable to create a ftok token. Errno = " + std::to_string(errno); + return Status(errno == ENOENT ? StatusCode::kFileNotExist : StatusCode::kUnexpectedError, errMsg); + } + *out = shmkey; + return Status::OK(); +} + +SharedMessage::~SharedMessage() { + // Only remove the queue if we are asked to. + if (remove_ipc_on_exit_ && msg_qid_ != -1) { + // Remove the message que and never mind about the return code. + (void)msgctl(msg_qid_, IPC_RMID, nullptr); + msg_qid_ = -1; + } +} + +Status SharedMessage::Create() { + CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ == -1, "Message queue already created"); + auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP; + msg_qid_ = msgget(IPC_PRIVATE, IPC_CREAT | IPC_EXCL | access_mode); + if (msg_qid_ == -1) { + std::string errMsg = "Unable to create a message queue. Errno = " + std::to_string(errno); + RETURN_STATUS_UNEXPECTED(errMsg); + } + return Status::OK(); +} + +Status SharedMessage::SendStatus(const Status &rc) { + CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id"); + StatusMsgBuf msg{ + 1, + }; + msg.body.status.err_code = static_cast(rc.get_code()); + auto err = memcpy_s(msg.body.status.err_msg, kSharedMessageSize, rc.ToString().data(), rc.ToString().size()); + CHECK_FAIL_RETURN_UNEXPECTED(err == EOK, "memcpy_s failed. err = " + std::to_string(err)); + msg.body.status.err_msg[rc.ToString().size()] = '\0'; + err = msgsnd(msg_qid_, reinterpret_cast(&msg), sizeof(msg.body.status), IPC_NOWAIT); + if (err == -1) { + std::string errMsg = "Failed to call msgsnd. Errno = " + std::to_string(errno); + RETURN_STATUS_UNEXPECTED(errMsg); + } + return Status::OK(); +} + +Status SharedMessage::ReceiveStatus(Status *rc) { + RETURN_UNEXPECTED_IF_NULL(rc); + CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id"); + struct StatusMsgBuf msg {}; + auto err = msgrcv(msg_qid_, reinterpret_cast(&msg), sizeof(msg.body.status), 0, MSG_NOERROR); + if (err == -1) { + std::string errMsg = "Failed to call msgrcv. Errno = " + std::to_string(errno); + RETURN_STATUS_UNEXPECTED(errMsg); + } + + Status rc_recv(static_cast(msg.body.status.err_code), msg.body.status.err_msg); + *rc = std::move(rc_recv); + return Status::OK(); +} + +SharedMemory::~SharedMemory() { + if (shmat_addr_) { + (void)Detach(); + } + if (remove_ipc_on_exit_ && shm_id_ != -1) { + // Remove the shared memory and never mind about the return code. + Status rc = Destroy(); + if (rc.IsError()) { + MS_LOG(ERROR) << rc.ToString(); + } + } + shm_id_ = -1; + shmat_addr_ = nullptr; +} + +Status SharedMemory::Create(int64_t sz) { + auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP; + shm_id_ = shmget(shm_key_, sz, IPC_CREAT | IPC_EXCL | access_mode); + if (shm_id_ == -1) { + RETURN_STATUS_UNEXPECTED("Shared memory creation failed. Errno " + std::to_string(errno)); + } else { + shmat_addr_ = shmat(shm_id_, nullptr, 0); + if (shmat_addr_ == reinterpret_cast(-1)) { + RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno)); + } + } + return Status::OK(); +} + +Status SharedMemory::Attach() { + shm_id_ = shmget(shm_key_, 0, 0); + if (shm_id_ == -1) { + RETURN_STATUS_UNEXPECTED("Shmget failed. Errno " + std::to_string(errno)); + } + shmat_addr_ = shmat(shm_id_, nullptr, 0); + if (shmat_addr_ == reinterpret_cast(-1)) { + RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno)); + } + return Status::OK(); +} + +Status SharedMemory::Detach() { + if (shmat_addr_) { + auto err = shmdt(shmat_addr_); + if (err == -1) { + RETURN_STATUS_UNEXPECTED("Shared memory detach failed. Errno " + std::to_string(errno)); + } + } + shmat_addr_ = nullptr; + return Status::OK(); +} + +Status SharedMemory::Destroy() { + // Remove the shared memory and never mind about the return code. + auto err = shmctl(shm_id_, IPC_RMID, nullptr); + if (err == -1) { + std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id_); + errMsg += ". Errno :" + std::to_string(errno); + errMsg += "\nPlesae remove it manually using ipcrm -m command"; + RETURN_STATUS_UNEXPECTED(errMsg); + } + return Status::OK(); +} + +Status SharedMemory::GetNumAttached(int32_t *num) { + RETURN_UNEXPECTED_IF_NULL(num); + struct shmid_ds ds {}; + auto err = shmctl(shm_id_, IPC_STAT, &ds); + if (err == -1) { + std::string errMsg = "Unable to query shared memory with id " + std::to_string(shm_id_); + errMsg += "\nPlease remove it manually using ipcrm -m command"; + RETURN_STATUS_UNEXPECTED(errMsg); + } + *num = ds.shm_nattch; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_ipc.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_ipc.h new file mode 100644 index 00000000000..ce174ff808b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_ipc.h @@ -0,0 +1,207 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_IPC_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_IPC_H_ + +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/cache/cache_common.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +/// A message queue structure between the parent and the child process +struct StatusMsgBuf { + int64_t mtype; + union { + char mtext[1]; + struct { + int32_t err_code; + char err_msg[kSharedMessageSize]; + } status; + } body; +}; + +class BaseIPC { + public: + BaseIPC() : remove_ipc_on_exit_(false) {} + virtual ~BaseIPC() {} + /// Indicate if we should remove the ipc resource on exit. Usually this is done by parent process. + void RemoveResourcesOnExit() { remove_ipc_on_exit_ = true; } + /// Copy constructors + BaseIPC(const BaseIPC &rhs) : remove_ipc_on_exit_(false) {} + BaseIPC &operator=(const BaseIPC &rhs) { + if (&rhs != this) { + remove_ipc_on_exit_ = false; + } + return *this; + } + /// Move constructors + BaseIPC(BaseIPC &&rhs) noexcept : remove_ipc_on_exit_(rhs.remove_ipc_on_exit_) { rhs.remove_ipc_on_exit_ = false; } + BaseIPC &operator=(BaseIPC &&rhs) noexcept { + if (&rhs != this) { + remove_ipc_on_exit_ = rhs.remove_ipc_on_exit_; + rhs.remove_ipc_on_exit_ = false; + } + return *this; + } + + protected: + bool remove_ipc_on_exit_; +}; + +/// \brief This wraps a shared message for the communication between processes. It is used primarily +/// for starting and stopping a server. +class SharedMessage : public BaseIPC { + public: + using queue_id_t = int; + SharedMessage() : msg_qid_(-1) {} + explicit SharedMessage(queue_id_t qid) : msg_qid_(qid) {} + ~SharedMessage() override; + + /// Copy constructors + SharedMessage(const SharedMessage &rhs) : BaseIPC(rhs), msg_qid_(rhs.msg_qid_) {} + SharedMessage &operator=(const SharedMessage &rhs) { + if (&rhs != this) { + msg_qid_ = rhs.msg_qid_; + BaseIPC::operator=(rhs); + } + return *this; + } + /// Move constructors + SharedMessage(SharedMessage &&rhs) noexcept : BaseIPC(std::move(rhs)) { + msg_qid_ = rhs.msg_qid_; + rhs.msg_qid_ = -1; + } + SharedMessage &operator=(SharedMessage &&rhs) noexcept { + if (&rhs != this) { + msg_qid_ = rhs.msg_qid_; + rhs.msg_qid_ = -1; + BaseIPC::operator=(std::move(rhs)); + } + return *this; + } + + /// Return the private id + queue_id_t GetMsgQueueId() const { return msg_qid_; } + + /// \brief Create a private message queue + Status Create(); + + /// Send a Status object + Status SendStatus(const Status &rc); + + /// Retrieve a Status object + Status ReceiveStatus(Status *rc); + + private: + queue_id_t msg_qid_; +}; + +/// \brief This wraps a shared memory for the communication between processes. It is used primarily +/// for transporting large tensor rows. +class SharedMemory : public BaseIPC { + public: + using shm_key_t = int; + using shm_id_t = int; + SharedMemory() : shm_id_(-1), shm_key_(-1), shmat_addr_(nullptr) {} + explicit SharedMemory(shm_key_t public_key) : shm_id_(-1), shm_key_(public_key), shmat_addr_(nullptr) {} + ~SharedMemory() override; + /// Copy constructors + SharedMemory(const SharedMemory &rhs) + : BaseIPC(rhs), shm_id_(rhs.shm_id_), shm_key_(rhs.shm_key_), shmat_addr_(rhs.shmat_addr_) {} + SharedMemory &operator=(const SharedMemory &rhs) { + if (&rhs != this) { + shm_id_ = rhs.shm_id_; + shm_key_ = rhs.shm_key_; + shmat_addr_ = rhs.shmat_addr_; + BaseIPC::operator=(rhs); + } + return *this; + } + /// Move constructors + SharedMemory(SharedMemory &&rhs) noexcept : BaseIPC(std::move(rhs)) { + shm_id_ = rhs.shm_id_; + shm_key_ = rhs.shm_key_; + shmat_addr_ = rhs.shmat_addr_; + rhs.shm_id_ = -1; + rhs.shm_key_ = -1; + rhs.shmat_addr_ = nullptr; + } + SharedMemory &operator=(SharedMemory &&rhs) noexcept { + if (&rhs != this) { + shm_id_ = rhs.shm_id_; + shm_key_ = rhs.shm_key_; + shmat_addr_ = rhs.shmat_addr_; + rhs.shm_id_ = -1; + rhs.shm_key_ = -1; + rhs.shmat_addr_ = nullptr; + BaseIPC::operator=(std::move(rhs)); + } + return *this; + } + /// \brief Set the public key + void SetPublicKey(key_t public_key) { shm_key_ = public_key; } + + /// \brief This returns where we attach to the shared memory. + /// \return Base address of the shared memory. + const void *SharedMemoryBaseAddr() const { return shmat_addr_; } + void *SharedMemoryBaseAddr() { return shmat_addr_; } + + /// \brief Attach to shared memory + /// \return Status object + Status Attach(); + + /// Detach from shared memory + /// \return Status object + Status Detach(); + + /// Create shared memory + /// \return Status object + Status Create(int64_t sz); + + /// Destroy shared memory + /// \return Status object + Status Destroy(); + + /// \brief Return the shared memory id + shm_id_t GetSharedMemoryId() const { return shm_id_; } + + /// \brief Get number of processes attached to the shared memory + /// \return Status object + Status GetNumAttached(int32_t *num); + + private: + shm_id_t shm_id_; + shm_key_t shm_key_; + void *shmat_addr_; +}; + +/// \brief Generate a shared memory key using the tcp/ip port. +/// \note It must be called after the cache server generates the unix socket or ftok will fail. +/// \note Caller must check the return value. -1 means ftok failed. +/// \param[in] port +/// \param[out] err. If not null and ftok fails, this will contain the value of errno +/// \return key +Status PortToFtok(int port, SharedMemory::shm_key_t *); + +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_IPC_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc index 3de7b671105..956db70246f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc @@ -21,24 +21,136 @@ #include #endif #include - +#include +#include +#include "minddata/dataset/engine/cache/cache_common.h" +#include "minddata/dataset/engine/cache/cache_ipc.h" namespace ds = mindspore::dataset; -int main(int argc, char **argv) { +/// Send a synchronous command to the local server using tcp/ip. +/// We aren't using any client code because this binary is not necessarily linked with the client library. +/// So just using grpc call directly. +/// \param port tcp/ip port to use +/// \param type Type of command. +/// \param out grpc result +/// \return Status object +ds::Status SendSyncCommand(int32_t port, ds::BaseRequest::RequestType type, ds::CacheRequest *rq, ds::CacheReply *reply, + grpc::Status *out) { + if (rq == nullptr) { + return ds::Status(ds::StatusCode::kUnexpectedError, "pointer rq is null"); + } + if (reply == nullptr) { + return ds::Status(ds::StatusCode::kUnexpectedError, "pointer reply is null"); + } + if (out == nullptr) { + return ds::Status(ds::StatusCode::kUnexpectedError, "pointer out is null"); + } + const std::string hostname = "127.0.0.1"; + auto unix_socket = ds::PortToUnixSocketPath(port); +#if CACHE_LOCAL_CLIENT + const std::string target = "unix://" + unix_socket; +#else + const std::string target = hostname + ":" + std::to_string(port); +#endif + try { + rq->set_type(static_cast(type)); + grpc::ChannelArguments args; + grpc::ClientContext ctx; + grpc::CompletionQueue cq; + // Standard async rpc call + std::shared_ptr channel = + grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args); + std::unique_ptr stub = ds::CacheServerGreeter::NewStub(channel); + std::unique_ptr> rpc = + stub->PrepareAsyncCacheServerRequest(&ctx, *rq, &cq); + rpc->StartCall(); + // We need to pass a tag. But since this is the only request in the completion queue and so we + // just pass a dummy + int64_t dummy; + void *tag; + bool success; + rpc->Finish(reply, out, &dummy); + // Now we wait on the completion queue synchronously. + auto r = cq.Next(&tag, &success); + if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) { + if (!success || tag != &dummy) { + std::string errMsg = "Unexpected programming error "; + return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } + if (out->ok()) { + return ds::Status(static_cast(reply->rc()), reply->msg()); + } else { + auto error_code = out->error_code(); + std::string errMsg = out->error_message() + ". GRPC Code " + std::to_string(error_code); + return ds::Status(ds::StatusCode::kNetWorkError, errMsg); + } + } else { + std::string errMsg = "Unexpected queue rc = " + std::to_string(r); + return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } + } catch (const std::exception &e) { + return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, e.what()); + } +} + +/// Stop the server +/// \param argv +/// \return Status object +ds::Status StopServer(int argc, char **argv) { ds::Status rc; ds::CacheServer::Builder builder; + std::string errMsg; + if (argc != 2) { + return ds::Status(ds::StatusCode::kSyntaxError); + } + int32_t port = strtol(argv[1], nullptr, 10); + // We will go through the builder to do some snaity check. We only need the port number + // to shut down the server. Null the root directory as we don't trigger the sanity code to write out anything + // to the spill directory. + builder.SetPort(port).SetRootDirectory(""); + // Part of the sanity check is check the shared memory. If the server is up and running, we expect + // the return code is kDuplicate. + rc = builder.SanityCheck(); + if (rc.IsOk()) { + errMsg = "Server is not up or has been shutdown already."; + return ds::Status(ds::StatusCode::kUnexpectedError, errMsg); + } else if (rc.get_code() != ds::StatusCode::kDuplicateKey) { + // Not OK, and no duplicate, just return the rc whatever it is. + return rc; + } else { + // Now we get some work to do. We will send a tcp/ip request to the given port. + // This binary is not linked with client side of code, so we will just call grpc directly. + ds::CacheRequest rq; + ds::CacheReply reply; + grpc::Status grpc_rc; + rc = SendSyncCommand(port, ds::BaseRequest::RequestType::kStopService, &rq, &reply, &grpc_rc); + // The request is like a self destruct message, the server will not send anything back and + // shutdown all incoming request. So we should expect some unexpected network error if + // all goes well and we expect to GRPC code 14. + auto err_code = grpc_rc.error_code(); + if (rc.get_code() != ds::StatusCode::kNetWorkError || err_code != grpc::StatusCode::UNAVAILABLE) { + return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__); + } + } + return ds::Status::OK(); +} - // This executable is not to be called directly, and should be invoked by cache_admin executable. - if (argc != 7) { - rc = ds::Status(ds::StatusCode::kSyntaxError); - std::cerr << rc.ToString() << std::endl; - return static_cast(rc.get_code()); +/// Start the server +/// \param argv +/// \return Status object +ds::Status StartServer(int argc, char **argv) { + ds::Status rc; + ds::CacheServer::Builder builder; + if (argc != 8) { + return ds::Status(ds::StatusCode::kSyntaxError); } + int32_t port = strtol(argv[3], nullptr, 10); builder.SetRootDirectory(argv[1]) .SetNumWorkers(strtol(argv[2], nullptr, 10)) - .SetPort(strtol(argv[3], nullptr, 10)) - .SetSharedMemorySizeInGB(strtol(argv[4], nullptr, 10)); + .SetPort(port) + .SetSharedMemorySizeInGB(strtol(argv[4], nullptr, 10)) + .SetMemoryCapRatio(strtof(argv[7], nullptr)); #ifdef USE_GLOG FLAGS_minloglevel = strtol(argv[5], nullptr, 10); @@ -52,36 +164,42 @@ int main(int argc, char **argv) { // is called. This is a standard procedure for daemonize a process on unix. if (chdir("/") == -1) { std::string errMsg = "Unable to change directory to /. Errno = " + std::to_string(errno); - std::cerr << errMsg << std::endl; - return -1; + return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); } - // Simple check of the parameters before we move on. - rc = builder.SanityCheck(); - if (rc.IsError()) { - std::cerr << rc.ToString() << std::endl; - return static_cast(rc.get_code()); - } - -#ifdef USE_GLOG - FLAGS_log_dir = "/tmp"; - google::InitGoogleLogging(argv[0]); -#endif - + // A message queue for communication between parent and child (if we fork). + ds::SharedMessage msg; if (daemonize) { - // fork the child process to become the daemon +#ifdef USE_GLOG + FLAGS_log_dir = "/tmp"; + google::InitGoogleLogging(argv[0]); +#endif + rc = msg.Create(); + if (rc.IsError()) { + return rc; + } pid_t pid = fork(); // failed to fork if (pid < 0) { - std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno); - std::cerr << err_msg << std::endl; - return errno; + std::string errMsg = "Failed to fork process for cache server. Errno = " + std::to_string(errno); + return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); } else if (pid > 0) { - // Parent + // Parent and will be responsible for remove the queue on exit. + msg.RemoveResourcesOnExit(); + // Sleep one second and we attach to the msg que + std::this_thread::sleep_for(std::chrono::seconds(1)); + ds::Status child_rc; + rc = msg.ReceiveStatus(&child_rc); + if (rc.IsError()) { + return rc; + } + if (child_rc.IsError()) { + return child_rc; + } std::cerr << "cache server daemon process has been created as process id: " << pid << "\nCheck log file for any start up error" << std::endl; signal(SIGCHLD, SIG_IGN); // ignore sig child signal. - return 0; + return ds::Status::OK(); } else { // Child process will continue from here if daemonize and parent has already exited. // If we are running in the foreground, none of the code in block below will be run. @@ -89,8 +207,8 @@ int main(int argc, char **argv) { umask(0); sid = setsid(); if (sid < 0) { - MS_LOG(ERROR) << "Failed to setsid(). Errno = " << std::to_string(errno); - return errno; + std::string errMsg = "Failed to setsid(). Errno = " + std::to_string(errno); + return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); } close(0); close(1); @@ -100,22 +218,36 @@ int main(int argc, char **argv) { // Dump the summary MS_LOG(INFO) << builder << std::endl; + // Create the instance with some sanity checks built in rc = builder.Build(); if (rc.IsOk()) { + // If all goes well, kick off the threads. Loop forever and never return unless error. ds::CacheServer &cs = ds::CacheServer::GetInstance(); - // Kick off the threads. Loop forever and never return unless error. - rc = cs.Run(); - if (rc.get_code() == ds::StatusCode::kDuplicateKey) { - std::string errMsg = "Server is already started"; - MS_LOG(ERROR) << errMsg; - std::cerr << errMsg << std::endl; - return 0; - } + rc = cs.Run(msg.GetMsgQueueId()); + } else if (daemonize) { + // If we didn't pass the sanity check to at least create the instance, use + // the message queue to return the error message if this is the child daemon. + return msg.SendStatus(rc); } + return rc; +} + +int main(int argc, char **argv) { + ds::Status rc; + ds::CacheServer::Builder builder; + + // This executable is not to be called directly, and should be invoked by cache_admin executable. + if (strcmp(argv[0], "-") == 0) { + rc = StopServer(argc, argv); + } else { + rc = StartServer(argc, argv); + } + // Check result if (rc.IsError()) { - MS_LOG(ERROR) << rc.ToString(); - std::cerr << rc.ToString() << std::endl; - return static_cast(rc.get_code()); + auto errCode = rc.get_code(); + auto errMsg = rc.ToString(); + std::cerr << errMsg << std::endl; + return static_cast(errCode); } return 0; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc index fc69a7eeabb..fe2641b10fb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc @@ -250,5 +250,27 @@ Status GetStatRequest::PostReply() { stat_.cache_service_state = msg->state(); return Status::OK(); } + +Status ListSessionsRequest::PostReply() { + auto *msg = flatbuffers::GetRoot(reply_.result().data()); + auto session_vector = msg->sessions(); + for (auto i = 0; i < session_vector->size(); ++i) { + SessionCacheInfo current_info; + CacheServiceStat stats; + auto current_session_info = session_vector->Get(i); + current_info.session_id = current_session_info->session_id(); + current_info.connection_id = current_session_info->connection_id(); + stats.num_mem_cached = current_session_info->stats()->num_mem_cached(); + stats.num_disk_cached = current_session_info->stats()->num_disk_cached(); + stats.avg_cache_sz = current_session_info->stats()->avg_cache_sz(); + stats.min_row_id = current_session_info->stats()->min_row_id(); + stats.max_row_id = current_session_info->stats()->max_row_id(); + stats.cache_service_state = current_session_info->stats()->state(); + current_info.stats = stats; // fixed length struct. = operator is safe + session_info_list_.push_back(current_info); + } + + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h index 1177691e7b7..84117f95684 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h @@ -46,6 +46,13 @@ struct CacheServiceStat { int8_t cache_service_state; }; +/// \brief Info structure ListSessionsRequest +struct SessionCacheInfo { + session_id_type session_id; + connection_id_type connection_id; + CacheServiceStat stats; +}; + /// \brief CacheClient communicates with CacheServer using Requests. class BaseRequest { public: @@ -54,7 +61,7 @@ class BaseRequest { kCacheRow = 0, kBatchFetchRows = 1, kCreateCache = 2, - kPurgeCache = 3, + kGetCacheMissKeys = 3, kDestroyCache = 4, kGetStat = 5, kCacheSchema = 6, @@ -65,6 +72,9 @@ class BaseRequest { kAllocateSharedBlock = 11, kFreeSharedBlock = 12, kStopService = 13, + kHeartBeat = 14, + kToggleWriteMode = 15, + kListSessions = 16, // Add new request before it. kRequestUnknown = 32767 }; @@ -73,6 +83,7 @@ class BaseRequest { friend class CacheServerRequest; friend class CacheClientGreeter; friend class CacheClientRequestTag; + friend class CacheClient; /// \brief Base class of a cache server request /// \param type Type of the request @@ -119,7 +130,7 @@ class FreeSharedBlockRequest : public BaseRequest { rq_.set_connection_id(connection_id); rq_.add_buf_data(std::to_string(addr)); } - ~FreeSharedBlockRequest() = default; + ~FreeSharedBlockRequest() override = default; }; /// \brief Request to cache a single TensorRow @@ -136,7 +147,7 @@ class CacheRowRequest : public BaseRequest { rq_.set_connection_id(connection_id); rq_.add_buf_data(cookie); } - ~CacheRowRequest() = default; + ~CacheRowRequest() override = default; /// \brief Serialize a TensorRow for streaming to the cache server /// \param row TensorRow @@ -183,7 +194,7 @@ class BatchFetchRequest : public BaseRequest { friend class CacheServer; friend class CacheService; BatchFetchRequest(connection_id_type connection_id, const std::vector &row_id, bool local_bypass); - ~BatchFetchRequest() = default; + ~BatchFetchRequest() override = default; Status RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr); private: @@ -203,7 +214,7 @@ class CreateCacheRequest : public BaseRequest { /// \param flag Attributes of the cache. explicit CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz, CreateCacheFlag flag = CreateCacheFlag::kNone); - ~CreateCacheRequest() = default; + ~CreateCacheRequest() override = default; void ParseResult(connection_id_type *id, std::string *out) { auto p = flatbuffers::GetRoot(reply_.result().data()); *id = p->connection_id(); @@ -218,14 +229,15 @@ class CreateCacheRequest : public BaseRequest { CreateCacheFlag flag_; }; -/// \brief Request to purge a cache. -class PurgeCacheRequest : public BaseRequest { +/// \brief Request to get all the keys not present at the server. +/// \note Only applicable to mappable case +class GetCacheMissKeysRequest : public BaseRequest { public: friend class CacheServer; - explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kPurgeCache) { + explicit GetCacheMissKeysRequest(connection_id_type connection_id) : BaseRequest(RequestType::kGetCacheMissKeys) { rq_.set_connection_id(connection_id); } - ~PurgeCacheRequest() = default; + ~GetCacheMissKeysRequest() override = default; }; /// \brief Request to destroy a cache @@ -235,7 +247,7 @@ class DestroyCacheRequest : public BaseRequest { explicit DestroyCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kDestroyCache) { rq_.set_connection_id(connection_id); } - ~DestroyCacheRequest() = default; + ~DestroyCacheRequest() override = default; }; /// \brief Obtain the statistics of the current connection @@ -247,7 +259,7 @@ class GetStatRequest : public BaseRequest { rq_.set_connection_id(connection_id); } - ~GetStatRequest() = default; + ~GetStatRequest() override = default; /// \brief Override base function to process the result. Status PostReply() override; @@ -269,7 +281,7 @@ class CacheSchemaRequest : public BaseRequest { explicit CacheSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kCacheSchema) { rq_.set_connection_id(connection_id); } - ~CacheSchemaRequest() = default; + ~CacheSchemaRequest() override = default; Status SerializeCacheSchemaRequest(const std::unordered_map &map); }; @@ -281,7 +293,7 @@ class FetchSchemaRequest : public BaseRequest { explicit FetchSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kFetchSchema) { rq_.set_connection_id(connection_id); } - ~FetchSchemaRequest() = default; + ~FetchSchemaRequest() override = default; Status PostReply() override; @@ -300,7 +312,7 @@ class BuildPhaseDoneRequest : public BaseRequest { rq_.set_connection_id(connection_id); rq_.add_buf_data(cookie_); } - ~BuildPhaseDoneRequest() = default; + ~BuildPhaseDoneRequest() override = default; private: std::string cookie_; @@ -313,7 +325,7 @@ class DropSessionRequest : public BaseRequest { explicit DropSessionRequest(const CacheClientInfo &cinfo) : BaseRequest(RequestType::kDropSession) { rq_.mutable_connection_info()->operator=(cinfo); } - ~DropSessionRequest() = default; + ~DropSessionRequest() override = default; }; class GenerateSessionIdRequest : public BaseRequest { @@ -325,11 +337,36 @@ class GenerateSessionIdRequest : public BaseRequest { rq_.set_connection_id(0); } - ~GenerateSessionIdRequest() = default; + ~GenerateSessionIdRequest() override = default; session_id_type GetSessionId() { return atoi(reply_.result().data()); } }; +class ListSessionsRequest : public BaseRequest { + public: + friend class CacheServer; + ListSessionsRequest() : BaseRequest(RequestType::kListSessions) { + // This request is not specific to any cache or session + rq_.set_connection_id(0); + } + + ~ListSessionsRequest() override = default; + + /// \brief Override base function to process the result. + Status PostReply() override; + + void GetSessionCacheInfo(std::vector *info) { + if (info != nullptr) { + (*info) = session_info_list_; + } + } + + std::vector GetSessionCacheInfo() { return session_info_list_; } + + private: + std::vector session_info_list_; +}; + class AllocateSharedBlockRequest : public BaseRequest { public: friend class CacheServer; @@ -338,7 +375,7 @@ class AllocateSharedBlockRequest : public BaseRequest { rq_.set_connection_id(connection_id); rq_.add_buf_data(std::to_string(requestedSz)); } - ~AllocateSharedBlockRequest() = default; + ~AllocateSharedBlockRequest() override = default; /// \brief On return from the server, we get the (relative) address where /// the free block is located. @@ -349,11 +386,15 @@ class AllocateSharedBlockRequest : public BaseRequest { } }; -class ShutdownRequest : public BaseRequest { +class ToggleWriteModeRequest : public BaseRequest { public: friend class CacheServer; - ShutdownRequest() : BaseRequest(RequestType::kStopService) {} - ~ShutdownRequest() = default; + explicit ToggleWriteModeRequest(connection_id_type connection_id, bool on_off) + : BaseRequest(RequestType::kToggleWriteMode) { + rq_.set_connection_id(connection_id); + rq_.add_buf_data(on_off ? "on" : "off"); + } + ~ToggleWriteModeRequest() override = default; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc index c181376c762..c7df81cf89e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc @@ -18,6 +18,7 @@ #include #include #include "minddata/dataset/core/constants.h" +#include "minddata/dataset/engine/cache/cache_ipc.h" #include "minddata/dataset/engine/cache/cache_service.h" #include "minddata/dataset/engine/cache/cache_request.h" #include "minddata/dataset/util/bit.h" @@ -107,6 +108,8 @@ Status CacheServer::DoServiceStop() { // First stop all the threads. RETURN_IF_NOT_OK(vg_.ServiceStop()); // Clean up all the caches if any. + // Dump a message how much memory we have consumed in total. + MS_LOG(INFO) << "Memory usage for the current server: " << GetMemoryUsage() << " bytes."; UniqueLock lck(&rwLock_); auto it = all_caches_.begin(); while (it != all_caches_.end()) { @@ -121,7 +124,6 @@ Status CacheServer::DoServiceStop() { } CacheService *CacheServer::GetService(connection_id_type id) const { - SharedLock lck(&rwLock_); auto it = all_caches_.find(id); if (it != all_caches_.end()) { return it->second.get(); @@ -134,6 +136,16 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { std::string cookie; auto session_id = rq->connection_info().session_id(); auto crc = rq->connection_info().crc(); + + // Before allowing the creation, make sure the session had already been created by the user + // Our intention is to add this cache to the active sessions list so leave the list locked during + // this entire function. + UniqueLock lock(&sessions_lock_); + auto session_it = active_sessions_.find(session_id); + if (session_it == active_sessions_.end()) { + RETURN_STATUS_UNEXPECTED("A cache creation has been requested but the session was not found!"); + } + // We concat both numbers to form the internal connection id. auto connection_id = GetConnectionID(session_id, crc); CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing info to create cache"); @@ -172,10 +184,15 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { } catch (const std::bad_alloc &e) { return Status(StatusCode::kOutOfMemory); } + // Add the cache into the active session tracking. + // We have already validated that the session exists and that this is a new cache created. + session_it->second.insert(connection_id); + } else { duplicate = true; MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service"; } + off_cookie = fbb.CreateString(cookie); CreateCacheReplyMsgBuilder bld(fbb); bld.add_connection_id(connection_id); @@ -183,19 +200,18 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { auto off = bld.Finish(); fbb.Finish(off); reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); - // Track the history of all the sessions that we have created so far. - history_sessions_.insert(session_id); // We can return OK but we will return a duplicate key so user can act accordingly to either ignore it // treat it as OK. return duplicate ? Status(StatusCode::kDuplicateKey) : Status::OK(); } -Status CacheServer::DestroyCache(CacheService *cs, CacheRequest *rq) { +Status CacheServer::DestroyCache(CacheRequest *rq) { // We need a strong lock to protect the map. UniqueLock lck(&rwLock_); + auto id = rq->connection_id(); + CacheService *cs = GetService(id); // it is already destroyed. Ignore it. if (cs != nullptr) { - auto id = rq->connection_id(); MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id); // std::map will invoke the destructor of CacheService. So we don't need to do anything here. auto n = all_caches_.erase(id); @@ -204,11 +220,34 @@ Status CacheServer::DestroyCache(CacheService *cs, CacheRequest *rq) { MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service"; } } + + // Now that this cache is removed, we need to also remove it's connection id from active session tracking + auto session_id = GetSessionID(id); + UniqueLock sess_lck(&sessions_lock_); + + auto it = active_sessions_.find(session_id); + if (it == active_sessions_.end()) { + // The session was not found in the active sessions + RETURN_STATUS_UNEXPECTED("A destroy cache request has been completed but it had a stale session id!"); + } + + auto connection_it = it->second.find(id); + if (connection_it == it->second.end()) { + RETURN_STATUS_UNEXPECTED("A destroy cache request could not find the connection in the activate sessions!"); + } + + // remove that connection id from the set + it->second.erase(connection_it); + MS_LOG(INFO) << "Destroyed cache " << id << " and removed from active session " << session_id; + return Status::OK(); } -inline Status CacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) { +Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) { auto connection_id = rq->connection_id(); + // Hold the shared lock to prevent the cache from being dropped. + SharedLock lck(&rwLock_); + CacheService *cs = GetService(connection_id); if (cs == nullptr) { std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found"; return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); @@ -236,8 +275,11 @@ inline Status CacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) { return Status::OK(); } -Status CacheServer::FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) { +Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) { auto connection_id = rq->connection_id(); + // Hold the shared lock to prevent the cache from being dropped. + SharedLock lck(&rwLock_); + CacheService *cs = GetService(connection_id); auto shared_pool = comm_layer_->GetSharedMemoryPool(); auto *base = shared_pool->SharedMemoryBaseAddr(); // Ensure we got 3 pieces of data coming in @@ -270,8 +312,11 @@ Status CacheServer::FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply return rc; } -Status CacheServer::BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheReply *reply) { +Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { auto connection_id = rq->connection_id(); + // Hold the shared lock to prevent the cache from being dropped. + SharedLock lck(&rwLock_); + CacheService *cs = GetService(connection_id); if (cs == nullptr) { std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found"; return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); @@ -325,8 +370,11 @@ Status CacheServer::BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheRepl return Status::OK(); } -inline Status GetStat(CacheService *cs, CacheRequest *rq, CacheReply *reply) { +Status CacheServer::GetStat(CacheRequest *rq, CacheReply *reply) { auto connection_id = rq->connection_id(); + // Hold the shared lock to prevent the cache from being dropped. + SharedLock lck(&rwLock_); + CacheService *cs = GetService(connection_id); if (cs == nullptr) { std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); @@ -338,8 +386,8 @@ inline Status GetStat(CacheService *cs, CacheRequest *rq, CacheReply *reply) { bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached); bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached); bld.add_avg_cache_sz(svc_stat.stat_.average_cache_sz); - bld.add_max_row_id(svc_stat.max_); - bld.add_min_row_id(svc_stat.min_); + bld.add_max_row_id(svc_stat.stat_.max_key); + bld.add_min_row_id(svc_stat.stat_.min_key); bld.add_state(svc_stat.state_); auto offset = bld.Finish(); fbb.Finish(offset); @@ -348,8 +396,11 @@ inline Status GetStat(CacheService *cs, CacheRequest *rq, CacheReply *reply) { return Status::OK(); } -inline Status CacheSchema(CacheService *cs, CacheRequest *rq) { +Status CacheServer::CacheSchema(CacheRequest *rq) { auto connection_id = rq->connection_id(); + // Hold the shared lock to prevent the cache from being dropped. + SharedLock lck(&rwLock_); + CacheService *cs = GetService(connection_id); if (cs == nullptr) { std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); @@ -361,8 +412,11 @@ inline Status CacheSchema(CacheService *cs, CacheRequest *rq) { return Status::OK(); } -inline Status FetchSchema(CacheService *cs, CacheRequest *rq, CacheReply *reply) { +Status CacheServer::FetchSchema(CacheRequest *rq, CacheReply *reply) { auto connection_id = rq->connection_id(); + // Hold the shared lock to prevent the cache from being dropped. + SharedLock lck(&rwLock_); + CacheService *cs = GetService(connection_id); if (cs == nullptr) { std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); @@ -377,8 +431,11 @@ inline Status FetchSchema(CacheService *cs, CacheRequest *rq, CacheReply *reply) return Status::OK(); } -inline Status BuildPhaseDone(CacheService *cs, CacheRequest *rq) { +Status CacheServer::BuildPhaseDone(CacheRequest *rq) { auto connection_id = rq->connection_id(); + // Hold the shared lock to prevent the cache from being dropped. + SharedLock lck(&rwLock_); + CacheService *cs = GetService(connection_id); if (cs == nullptr) { std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); @@ -396,15 +453,24 @@ inline Status BuildPhaseDone(CacheService *cs, CacheRequest *rq) { return Status::OK(); } -Status CacheServer::PurgeCache(CacheService *cs) { +Status CacheServer::GetCacheMissKeys(CacheRequest *rq, CacheReply *reply) { + auto connection_id = rq->connection_id(); + // Hold the shared lock to prevent the cache from being dropped. SharedLock lck(&rwLock_); - // If shutdown in progress, ignore the command. - if (global_shutdown_) { - return Status::OK(); - } - // it is already purged. Ignore it. - if (cs != nullptr) { - RETURN_IF_NOT_OK(cs->Purge()); + CacheService *cs = GetService(connection_id); + if (cs == nullptr) { + std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + std::vector gap; + RETURN_IF_NOT_OK(cs->FindKeysMiss(&gap)); + flatbuffers::FlatBufferBuilder fbb; + auto off_t = fbb.CreateVector(gap); + TensorRowIdsBuilder bld(fbb); + bld.add_row_id(off_t); + auto off = bld.Finish(); + fbb.Finish(off); + reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); } return Status::OK(); } @@ -414,6 +480,72 @@ inline Status GenerateClientSessionID(session_id_type session_id, CacheReply *re return Status::OK(); } +Status CacheServer::ToggleWriteMode(CacheRequest *rq) { + auto connection_id = rq->connection_id(); + // Hold the shared lock to prevent the cache from being dropped. + SharedLock lck(&rwLock_); + CacheService *cs = GetService(connection_id); + if (cs == nullptr) { + std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + // First piece of data is the on/off flag + CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing action flag"); + const auto &action = rq->buf_data(0); + bool on_off = false; + if (strcmp(action.data(), "on") == 0) { + on_off = true; + } else if (strcmp(action.data(), "off") == 0) { + on_off = false; + } else { + RETURN_STATUS_UNEXPECTED("Unknown request: " + action); + } + RETURN_IF_NOT_OK(cs->ToggleWriteMode(on_off)); + } + return Status::OK(); +} + +Status CacheServer::ListSessions(CacheReply *reply) { + SharedLock lck(&sessions_lock_); + + flatbuffers::FlatBufferBuilder fbb; + std::vector> session_msgs_vector; + for (auto it = active_sessions_.begin(); it != active_sessions_.end(); it++) { + session_id_type current_session_id = it->first; + // Loop over each cache inside this session + if (!it->second.empty()) { + for (auto current_conn_id : it->second) { + CacheService *cs = GetService(current_conn_id); + if (cs == nullptr) { + std::string errMsg = "Connection " + std::to_string(current_conn_id) + " not found during list sessions."; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + CacheService::ServiceStat svc_stat; + RETURN_IF_NOT_OK(cs->GetStat(&svc_stat)); + auto current_stats = CreateServiceStatMsg(fbb, svc_stat.stat_.num_mem_cached, svc_stat.stat_.num_disk_cached, + svc_stat.stat_.average_cache_sz, svc_stat.stat_.min_key, + svc_stat.stat_.max_key, svc_stat.state_); + auto current_session_info = CreateListSessionMsg(fbb, current_session_id, current_conn_id, current_stats); + session_msgs_vector.push_back(current_session_info); + } + } + } else { + // If there is no cache created yet, assign a connection id of 0 along with empty stats + auto current_stats = CreateServiceStatMsg(fbb, 0, 0, 0, 0, 0, 0); + auto current_session_info = CreateListSessionMsg(fbb, current_session_id, 0, current_stats); + session_msgs_vector.push_back(current_session_info); + } + } + auto session_msgs = fbb.CreateVector(session_msgs_vector); + ListSessionsMsgBuilder s_builder(fbb); + s_builder.add_sessions(session_msgs); + auto offset = s_builder.Finish(); + fbb.Finish(offset); + reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); + + return Status::OK(); +} + /// \brief This is the main loop the cache server thread(s) are running. /// Each thread will pop a request and send the result back to the client using grpc /// \return @@ -426,12 +558,6 @@ Status CacheServer::ServerRequest(int32_t worker_id) { RETURN_IF_NOT_OK(my_que->PopFront(&cache_req)); auto &rq = cache_req->rq_; auto &reply = cache_req->reply_; - CacheService *cs = nullptr; - // Request comes in roughly two sets. One set is at the cache level with a connection id. - // The other set is working at a high level and without a connection id - if (!rq.has_connection_info()) { - cs = GetService(rq.connection_id()); - } // Except for creating a new session, we expect cs is not null. switch (cache_req->type_) { case BaseRequest::RequestType::kCacheRow: { @@ -439,42 +565,42 @@ Status CacheServer::ServerRequest(int32_t worker_id) { // call the appropriate method. auto flag = rq.flag(); if (BitTest(flag, kDataIsInSharedMemory)) { - cache_req->rc_ = FastCacheRow(cs, &rq, &reply); + cache_req->rc_ = FastCacheRow(&rq, &reply); } else { - cache_req->rc_ = CacheRow(cs, &rq, &reply); + cache_req->rc_ = CacheRow(&rq, &reply); } break; } case BaseRequest::RequestType::kBatchFetchRows: { - cache_req->rc_ = BatchFetchRows(cs, &rq, &reply); + cache_req->rc_ = BatchFetchRows(&rq, &reply); break; } case BaseRequest::RequestType::kCreateCache: { cache_req->rc_ = CreateService(&rq, &reply); break; } - case BaseRequest::RequestType::kPurgeCache: { - cache_req->rc_ = PurgeCache(cs); + case BaseRequest::RequestType::kGetCacheMissKeys: { + cache_req->rc_ = GetCacheMissKeys(&rq, &reply); break; } case BaseRequest::RequestType::kDestroyCache: { - cache_req->rc_ = DestroyCache(cs, &rq); + cache_req->rc_ = DestroyCache(&rq); break; } case BaseRequest::RequestType::kGetStat: { - cache_req->rc_ = GetStat(cs, &rq, &reply); + cache_req->rc_ = GetStat(&rq, &reply); break; } case BaseRequest::RequestType::kCacheSchema: { - cache_req->rc_ = CacheSchema(cs, &rq); + cache_req->rc_ = CacheSchema(&rq); break; } case BaseRequest::RequestType::kFetchSchema: { - cache_req->rc_ = FetchSchema(cs, &rq, &reply); + cache_req->rc_ = FetchSchema(&rq, &reply); break; } case BaseRequest::RequestType::kBuildPhaseDone: { - cache_req->rc_ = BuildPhaseDone(cs, &rq); + cache_req->rc_ = BuildPhaseDone(&rq); break; } case BaseRequest::RequestType::kDropSession: { @@ -498,6 +624,18 @@ Status CacheServer::ServerRequest(int32_t worker_id) { cache_req->rc_ = GlobalShutdown(); break; } + case BaseRequest::RequestType::kHeartBeat: { + cache_req->rc_ = Status::OK(); + break; + } + case BaseRequest::RequestType::kToggleWriteMode: { + cache_req->rc_ = ToggleWriteMode(&rq); + break; + } + case BaseRequest::RequestType::kListSessions: { + cache_req->rc_ = ListSessions(&reply); + break; + } default: std::string errMsg("Unknown request type : "); errMsg += std::to_string(static_cast(cache_req->type_)); @@ -526,18 +664,32 @@ session_id_type CacheServer::GetSessionID(connection_id_type connection_id) cons } CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, - int32_t shared_meory_sz_in_gb) + int32_t shared_meory_sz_in_gb, float memory_cap_ratio) : top_(spill_path), num_workers_(num_workers), port_(port), shared_memory_sz_in_gb_(shared_meory_sz_in_gb), - global_shutdown_(false) {} + global_shutdown_(false), + memory_cap_ratio_(memory_cap_ratio), + cur_mem_usage_(0) { + memory_cap_ = CacheServer::GetTotalSystemMemory() * memory_cap_ratio_; +} -Status CacheServer::Run() { - RETURN_IF_NOT_OK(ServiceStart()); +Status CacheServer::Run(int msg_qid) { + Status rc = ServiceStart(); + // If there is a message que, return the status now before we call join_all which will never return + if (msg_qid != -1) { + SharedMessage msg(msg_qid); + RETURN_IF_NOT_OK(msg.SendStatus(rc)); + } + if (rc.IsError()) { + return rc; + } // This is called by the main function and we shouldn't exit. Otherwise the main thread // will just shutdown. So we will call some function that never return unless error. // One good case will be simply to wait for all threads to return. + // note that after we have sent the initial status using the msg_qid, parent process will exit and + // remove it. So we can't use it again. RETURN_IF_NOT_OK(vg_.join_all(Task::WaitFlag::kBlocking)); return Status::OK(); } @@ -567,32 +719,51 @@ Status CacheServer::ReturnRequestTag(CacheServerRequest *p) { Status CacheServer::DestroySession(CacheRequest *rq) { CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing session id"); auto drop_session_id = rq->connection_info().session_id(); - UniqueLock lck(&rwLock_); - for (auto &cs : all_caches_) { - auto connection_id = cs.first; - auto session_id = GetSessionID(connection_id); - // We can just call DestroyCache() but we are holding a lock already. Doing so will cause deadlock. - // So we will just manually do it. - if (session_id == drop_session_id) { - // std::map will invoke the destructor of CacheService. So we don't need to do anything here. - auto n = all_caches_.erase(connection_id); - MS_LOG(INFO) << "Destroy " << n << " copies of cache with id " << connection_id; + + UniqueLock lck(&sessions_lock_); + + // First validate that this session exists + auto it = active_sessions_.find(drop_session_id); + if (it == active_sessions_.end()) { + RETURN_STATUS_UNEXPECTED("A destroy session command has been requested but the session was not found!"); + } + + // Iterate over the set of connection id's for this session that we're dropping and erase each one. + { + UniqueLock rwlck(&rwLock_); + for (auto drop_connection_id : it->second) { + auto cache_drop_it = all_caches_.find(drop_connection_id); + if (cache_drop_it == all_caches_.end()) { + RETURN_STATUS_UNEXPECTED("active session tracking had stale or incorrect cache entry."); + } + all_caches_.erase(cache_drop_it); + MS_LOG(INFO) << "Session destroy: Destroy cache with id " << drop_connection_id; + // **Do not bother to remove the cache connection id from the active session because we will soon remove the + // entire session. } } + + // Finally remove the session itself + active_sessions_.erase(it); + MS_LOG(INFO) << "Session destroyed with id " << drop_session_id; + return Status::OK(); } -session_id_type CacheServer::GenerateSessionID() const { - SharedLock lock(&rwLock_); +session_id_type CacheServer::GenerateSessionID() { + UniqueLock lock(&sessions_lock_); auto mt = GetRandomDevice(); std::uniform_int_distribution distribution(0, std::numeric_limits::max()); session_id_type session_id; bool duplicate = false; do { session_id = distribution(mt); - auto it = history_sessions_.find(session_id); - duplicate = (it != history_sessions_.end()); + auto it = active_sessions_.find(session_id); + duplicate = (it != active_sessions_.end()); } while (duplicate); + + // Add this session to our tracking of active sessions with initialized empty set of connections. + active_sessions_[session_id] = std::set(); return session_id; } @@ -637,19 +808,59 @@ Status CacheServer::GlobalShutdown() { vg_.interrupt_all(); // The next thing to do drop all the caches. UniqueLock lck(&rwLock_); - for (auto &it : all_caches_) { - auto id = it.first; + for (auto it = all_caches_.begin(); it != all_caches_.end();) { + auto id = it->first; MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id); // Wait for all outstanding work to be finished. - auto &cs = it.second; + auto &cs = it->second; UniqueLock cs_lock(&cs->rw_lock_); - // std::map will invoke the destructor of CacheService. So we don't need to do anything here. - (void)all_caches_.erase(id); + it = all_caches_.erase(it); } } return Status::OK(); } +int64_t CacheServer::GetTotalSystemMemory() { + auto pages = sysconf(_SC_PHYS_PAGES); + auto page_size = sysconf(_SC_PAGE_SIZE); + auto total = static_cast(pages) * static_cast(page_size); + MS_LOG(INFO) << "Total physical RAM in bytes: " << total; + return total; +} + +Status CacheServer::Builder::IpcResourceCleanup() { + Status rc; + SharedMemory::shm_key_t shm_key; + auto unix_socket = PortToUnixSocketPath(port_); + rc = PortToFtok(port_, &shm_key); + // We are expecting the unix path doesn't exist. + if (rc.IsError()) { + return Status::OK(); + } + // Attach to the shared memory which we expect don't exist + SharedMemory mem(shm_key); + rc = mem.Attach(); + if (rc.IsError()) { + return Status::OK(); + } + int32_t num_attached; + RETURN_IF_NOT_OK(mem.GetNumAttached(&num_attached)); + if (num_attached == 0) { + // Stale shared memory from last time. + // Remove both the memory and the socket path + RETURN_IF_NOT_OK(mem.Destroy()); + Path p(unix_socket); + (void)p.Remove(); + } else { + // Server is already up. + std::string errMsg = "Cache server is already up and running"; + // We return a duplicate error. The main() will intercept + // and output a proper message + return Status(StatusCode::kDuplicateKey, errMsg); + } + return Status::OK(); +} + Status CacheServer::Builder::SanityCheck() { if (shared_memory_sz_in_gb_ <= 0) { RETURN_STATUS_UNEXPECTED("Shared memory size (in GB unit) must be positive"); @@ -673,6 +884,12 @@ Status CacheServer::Builder::SanityCheck() { RETURN_STATUS_UNEXPECTED("Spilling directory is not writable\n" + rc.ToString()); } } + if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1) { + RETURN_STATUS_UNEXPECTED("Memory cap ratio should be positive and no greater than 1"); + } + + // Check if the shared memory. + RETURN_IF_NOT_OK(IpcResourceCleanup()); return Status::OK(); } } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h index 1b6a16f1c77..631a4abb088 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h @@ -17,6 +17,8 @@ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_ +#include +#include #include #include #include @@ -47,15 +49,16 @@ class CacheServer : public Service { using cache_index = std::map>; class Builder { public: - Builder() : top_("/tmp"), num_workers_(32), port_(50052), shared_memory_sz_in_gb_(4) {} + Builder() : top_("/tmp"), num_workers_(32), port_(50052), shared_memory_sz_in_gb_(4), memory_cap_ratio_(0.8) {} ~Builder() = default; /// \brief Getter functions - const std::string &getTop() const { return top_; } - int32_t getNumWorkers() const { return num_workers_; } - int32_t getPort() const { return port_; } - int32_t getSharedMemorySzInGb() const { return shared_memory_sz_in_gb_; } + const std::string &GetTop() const { return top_; } + int32_t GetNumWorkers() const { return num_workers_; } + int32_t GetPort() const { return port_; } + int32_t GetSharedMemorySzInGb() const { return shared_memory_sz_in_gb_; } + float GetMemoryCapRatio() const { return memory_cap_ratio_; } Builder &SetRootDirectory(std::string root) { top_ = std::move(root); @@ -73,15 +76,20 @@ class CacheServer : public Service { shared_memory_sz_in_gb_ = sz; return *this; } + Builder &SetMemoryCapRatio(float ratio) { + memory_cap_ratio_ = ratio; + return *this; + } Status SanityCheck(); void Print(std::ostream &out) const { out << "Summary of the cache server configuration\n" - << "Spill directory: " << getTop() << "\n" - << "Number of parallel workers: " << getNumWorkers() << "\n" - << "Tcp/ip port: " << getPort() << "\n" - << "Shared memory size (in GB): " << getSharedMemorySzInGb(); + << "Spill directory: " << GetTop() << "\n" + << "Number of parallel workers: " << GetNumWorkers() << "\n" + << "Tcp/ip port: " << GetPort() << "\n" + << "Shared memory size (in GB): " << GetSharedMemorySzInGb() << "\n" + << "Memory cap ratio: " << GetMemoryCapRatio(); } friend std::ostream &operator<<(std::ostream &out, const Builder &bld) { @@ -93,7 +101,8 @@ class CacheServer : public Service { RETURN_IF_NOT_OK(SanityCheck()); // We need to bring up the Task Manager by bringing up the Services singleton. RETURN_IF_NOT_OK(Services::CreateInstance()); - RETURN_IF_NOT_OK(CacheServer::CreateInstance(top_, num_workers_, port_, shared_memory_sz_in_gb_)); + RETURN_IF_NOT_OK( + CacheServer::CreateInstance(top_, num_workers_, port_, shared_memory_sz_in_gb_, memory_cap_ratio_)); return Status::OK(); } @@ -102,20 +111,27 @@ class CacheServer : public Service { int32_t num_workers_; int32_t port_; int32_t shared_memory_sz_in_gb_; + float memory_cap_ratio_; + + /// \brief Sanity checks on the shared memory. + /// \return Status object + Status IpcResourceCleanup(); }; + CacheServer(const CacheServer &) = delete; CacheServer &operator=(const CacheServer &) = delete; CacheServer(CacheServer &&) = delete; CacheServer &operator=(CacheServer &) = delete; Status DoServiceStart() override; Status DoServiceStop() override; - ~CacheServer() { (void)ServiceStop(); } + ~CacheServer() override { (void)ServiceStop(); } static Status CreateInstance(const std::string &spill_path, int32_t num_workers, int32_t port, - int32_t shared_memory_sz) { + int32_t shared_memory_sz, float memory_cap_ratio) { std::call_once(init_instance_flag_, [&]() -> Status { - auto &svcManager = Services::GetInstance(); - RETURN_IF_NOT_OK(svcManager.AddHook(&instance_, spill_path, num_workers, port, shared_memory_sz)); + auto &SvcManager = Services::GetInstance(); + RETURN_IF_NOT_OK( + SvcManager.AddHook(&instance_, spill_path, num_workers, port, shared_memory_sz, memory_cap_ratio)); return Status::OK(); }); return Status::OK(); @@ -133,7 +149,7 @@ class CacheServer : public Service { } /// \\brief Kick off server threads. Never return unless error out. - Status Run(); + Status Run(SharedMessage::queue_id_t msg_qid); /// \brief Get a free tag /// \param q[in] pointer to a pointer to a CacheServerRequest @@ -145,13 +161,35 @@ class CacheServer : public Service { /// \return Status object static Status ReturnRequestTag(CacheServerRequest *p); + /// \brief This returns the size (in bytes) of the physical RAM on the machine. + /// \return the size (in bytes) of the physical RAM on the machine. + static int64_t GetTotalSystemMemory(); + + /// \brief Internally this is how much we will try to use without exceeding the limit + /// \return Internal cap maximum + int64_t GetAvailableSystemMemory() { return memory_cap_; } + + /// \brief Find out the current memory usage + int64_t GetMemoryUsage() { return cur_mem_usage_; } + + /// \brief This updates our current memory usage. + enum MemUsageOp : int8_t { kAllocate = 1, kFree = 2 }; + void UpdateMemoryUsage(int64_t sz, MemUsageOp op) { + if (op == MemUsageOp::kAllocate) { + cur_mem_usage_ += sz; + } else { + cur_mem_usage_ -= sz; + } + } + private: static std::once_flag init_instance_flag_; static CacheServer *instance_; mutable RWLock rwLock_; + mutable RWLock sessions_lock_; std::string top_; cache_index all_caches_; - std::set history_sessions_; + std::map> active_sessions_; std::shared_ptr> cache_q_; std::shared_ptr> free_list_; std::vector>>> tag_; @@ -162,11 +200,15 @@ class CacheServer : public Service { int32_t port_; int32_t shared_memory_sz_in_gb_; std::atomic global_shutdown_; + float memory_cap_ratio_; + int64_t memory_cap_; + std::atomic cur_mem_usage_; /// \brief Constructor /// \param spill_path Top directory for spilling buffers to. /// \param num_workers Number of threads for handling requests. - explicit CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, int32_t share_memory_sz_in_gb); + explicit CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, int32_t share_memory_sz_in_gb, + float memory_cap_ratio); /// \brief Locate a cache service from connection id. /// \return Pointer to cache service. Null if not found @@ -179,11 +221,9 @@ class CacheServer : public Service { Status CreateService(CacheRequest *rq, CacheReply *reply); /// \brief Destroy a cache service - /// \param cs /// \param rq - /// \return - Status DestroyCache(CacheService *cs, CacheRequest *rq); - Status PurgeCache(CacheService *cs); + /// \return Status object + Status DestroyCache(CacheRequest *rq); /// \brief Entry point for all internal server threads. Status ServerRequest(int32_t worker_id); @@ -207,7 +247,7 @@ class CacheServer : public Service { /// \brief Generate a session ID for the client /// \return Session ID - session_id_type GenerateSessionID() const; + session_id_type GenerateSessionID(); /// \brief Handle kAllocateSharedBlock request /// \param rq CacheRequest @@ -220,20 +260,55 @@ class CacheServer : public Service { /// \return Status object Status FreeSharedMemory(CacheRequest *rq); - /// \brief Handle kFastCacheRow request + /// \brief Handle CacheRow request + /// \note There are two different implementation depends if shared memory is used for transportation. /// \return Status object - Status FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply); + Status FastCacheRow(CacheRequest *rq, CacheReply *reply); + Status CacheRow(CacheRequest *rq, CacheReply *reply); /// \brief Internal function to do row batch fetch - /// \param cs CacheService /// \param rq Request /// \param reply Reply - /// \return - Status BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheReply *reply); + /// \return Status object + Status BatchFetchRows(CacheRequest *rq, CacheReply *reply); + + /// \brief Internal function to get statistics + /// \param rq + /// \param reply + /// \return Status object + Status GetStat(CacheRequest *rq, CacheReply *reply); + + /// \brief Cache a schema request + /// \param rq + /// \return Status object + Status CacheSchema(CacheRequest *rq); + + /// \brief Fetch a schema request + /// \param rq + /// \param reply + /// \return Status object + Status FetchSchema(CacheRequest *rq, CacheReply *reply); + + /// \brief Mark Build phase done (for non-mappable case) + /// \param rq + /// \return Status object + Status BuildPhaseDone(CacheRequest *rq); /// \brief A proper shutdown of the server /// \return Status object Status GlobalShutdown(); + + /// \brief Find keys that will be cache miss + /// \return Status object + Status GetCacheMissKeys(CacheRequest *rq, CacheReply *reply); + + /// \brief Toggle write mode for a service + Status ToggleWriteMode(CacheRequest *rq); + + /// \brief List the sessions and their caches + /// \param reply + /// \return Status object + Status ListSessions(CacheReply *reply); }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc index ee6e835dc67..727f9e736f6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc @@ -14,6 +14,7 @@ * limitations under the License. */ #include "minddata/dataset/engine/cache/cache_service.h" +#include "minddata/dataset/engine/cache/cache_server.h" #include "minddata/dataset/util/slice.h" namespace mindspore { @@ -22,42 +23,62 @@ CacheService::CacheService(uint64_t mem_sz, const std::string &root, bool genera : root_(root), cache_mem_sz_(mem_sz), cp_(nullptr), - map_(nullptr), next_id_(0), generate_id_(generate_id), - schema_key_(-1), - st_(generate_id ? State::kBuildPhase : State::kNone) {} + st_(generate_id ? State::kBuildPhase : State::kNone), + cur_mem_usage_(0), + cur_disk_usage_(0) {} + CacheService::~CacheService() { (void)ServiceStop(); } + bool CacheService::UseArena() { // If fixed size, use Arena instead of the pool from global context. return (cache_mem_sz_ > 0); } + Status CacheService::DoServiceStart() { std::shared_ptr mp_; + CacheServer &cs = CacheServer::GetInstance(); if (UseArena()) { + auto avail_mem = cs.GetAvailableSystemMemory() / 1048576L; + if (cache_mem_sz_ > avail_mem) { + // Output a warning that we use more than recommended. If we fail to allocate, we will fail anyway. + MS_LOG(WARNING) << "Requesting cache size " << cache_mem_sz_ << " MB while available system memory " << avail_mem + << " MB"; + } // Create a fixed size arena based on the parameter. std::shared_ptr arena; RETURN_IF_NOT_OK(Arena::CreateArena(&arena, cache_mem_sz_)); mp_ = std::move(arena); + // update the global usage only. + cs.UpdateMemoryUsage(cache_mem_sz_ * 1048576L, CacheServer::MemUsageOp::kAllocate); } else { // Unlimited size. Simply use a system pool. Another choice is CircularPool. mp_ = std::make_shared(); } // Put together a CachePool for backing up the Tensor - cp_ = std::make_shared(CachePool::value_allocator(mp_), root_); + cp_ = std::make_shared(CachePool::value_allocator(mp_), UseArena(), root_); RETURN_IF_NOT_OK(cp_->ServiceStart()); - // Set up the B+ tree as well. But use the system pool instead. - map_ = std::make_shared(); // Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name. cookie_ = cp_->MyName(); return Status::OK(); } + Status CacheService::DoServiceStop() { if (cp_ != nullptr) { RETURN_IF_NOT_OK(cp_->ServiceStop()); } + CacheServer &cs = CacheServer::GetInstance(); + if (UseArena()) { + cs.UpdateMemoryUsage(cache_mem_sz_ * 1048576L, CacheServer::MemUsageOp::kFree); + } else { + MS_LOG(INFO) << "Memory/disk usage for the current service: " << GetMemoryUsage() << " bytes and " << GetDiskUsage() + << " bytes."; + cs.UpdateMemoryUsage(GetMemoryUsage(), CacheServer::MemUsageOp::kFree); + } return Status::OK(); } + Status CacheService::CacheRow(const std::vector &buf, row_id_type *row_id_generated) { SharedLock rw(&rw_lock_); RETURN_UNEXPECTED_IF_NULL(row_id_generated); @@ -66,6 +87,11 @@ Status CacheService::CacheRow(const std::vector &buf, row_id_type // allow other to cache more rows. RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); } + if (st_ == State::kNoLocking) { + // We ignore write this request once we turn off locking on the B+ tree. So we will just + // return out of memory from now on. + return Status(StatusCode::kOutOfMemory); + } try { // The first buffer is a flatbuffer which describes the rest of the buffers follow auto fb = buf.front(); @@ -86,6 +112,7 @@ Status CacheService::CacheRow(const std::vector &buf, row_id_type *row_id_generated = msg->row_id(); } auto size_of_this = msg->size_of_this(); + size_t total_sz = size_of_this; auto column_hdr = msg->column(); // Number of tensor buffer should match the number of columns plus one. if (buf.size() != column_hdr->size() + 1) { @@ -99,16 +126,28 @@ Status CacheService::CacheRow(const std::vector &buf, row_id_type all_data.emplace_back(fb, size_of_this); for (auto i = 0; i < column_hdr->size(); ++i) { all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i)); + total_sz += msg->data_sz()->Get(i); } - // Now we cache the flat buffer. - CachePool::key_type key; - RETURN_IF_NOT_OK(cp_->Insert(all_data, &key)); - Status rc = map_->DoInsert(*row_id_generated, key); + // Now we cache the buffer. If we are using Arena which has a fixed cap, then just do it. + // Otherwise, we check how much (globally) how much we use and may simply spill to disk + // directly. + CacheServer &cs = CacheServer::GetInstance(); + bool write_to_disk_directly = UseArena() ? false : (total_sz + cs.GetMemoryUsage()) > cs.GetAvailableSystemMemory(); + Status rc = cp_->Insert(*row_id_generated, all_data, write_to_disk_directly); if (rc == Status(StatusCode::kDuplicateKey)) { MS_LOG(DEBUG) << "Ignoring duplicate key."; } else { RETURN_IF_NOT_OK(rc); } + // All good, then update the memory usage local and global (if not using arena) + if (write_to_disk_directly) { + cur_disk_usage_ += total_sz; + } else { + cur_mem_usage_ += total_sz; + if (!UseArena()) { + cs.UpdateMemoryUsage(total_sz, CacheServer::MemUsageOp::kAllocate); + } + } return Status::OK(); } catch (const std::exception &e) { RETURN_STATUS_UNEXPECTED(e.what()); @@ -123,6 +162,11 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_ // allow other to cache more rows. RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); } + if (st_ == State::kNoLocking) { + // We ignore write this request once we turn off locking on the B+ tree. So we will just + // return out of memory from now on. + return Status(StatusCode::kOutOfMemory); + } try { // If we don't need to generate id, we need to find it from the buffer. if (generate_id_) { @@ -139,20 +183,33 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_ } *row_id_generated = msg->row_id(); } - // Now we cache the flat buffer. - CachePool::key_type key; - RETURN_IF_NOT_OK(cp_->Insert({src}, &key)); - Status rc = map_->DoInsert(*row_id_generated, key); + // Now we cache the buffer. If we are using Arena which has a fixed cap, then just do it. + // Otherwise, we check how much (globally) how much we use and may simply spill to disk + // directly. + auto total_sz = src.GetSize(); + CacheServer &cs = CacheServer::GetInstance(); + bool write_to_disk_directly = UseArena() ? false : (total_sz + cs.GetMemoryUsage()) > cs.GetAvailableSystemMemory(); + Status rc = cp_->Insert(*row_id_generated, {src}, write_to_disk_directly); if (rc == Status(StatusCode::kDuplicateKey)) { MS_LOG(DEBUG) << "Ignoring duplicate key."; } else { RETURN_IF_NOT_OK(rc); } + // All good, then update the memory usage local and global (if not using arena) + if (write_to_disk_directly) { + cur_disk_usage_ += total_sz; + } else { + cur_mem_usage_ += total_sz; + if (!UseArena()) { + cs.UpdateMemoryUsage(total_sz, CacheServer::MemUsageOp::kAllocate); + } + } return Status::OK(); } catch (const std::exception &e) { RETURN_STATUS_UNEXPECTED(e.what()); } } + std::ostream &operator<<(std::ostream &out, const CacheService &cs) { // Then show any custom derived-internal stuff out << "\nCache memory size: " << cs.cache_mem_sz_; @@ -164,34 +221,29 @@ std::ostream &operator<<(std::ostream &out, const CacheService &cs) { } return out; } + Path CacheService::GetSpillPath() const { return cp_->GetSpillPath(); } -Status CacheService::Purge() { - // First we must lock exclusively. No one else can cache/restore anything. - UniqueLock rw(&rw_lock_); - RETURN_IF_NOT_OK(cp_->ServiceStop()); - auto new_map = std::make_shared(); - map_.reset(); - map_ = std::move(new_map); - next_id_ = 0; - RETURN_IF_NOT_OK(cp_->ServiceStart()); + +Status CacheService::FindKeysMiss(std::vector *out) { + RETURN_UNEXPECTED_IF_NULL(out); + std::unique_lock lock(get_key_miss_mux_); + if (key_miss_results_ == nullptr) { + // Just do it once. + key_miss_results_ = std::make_shared>(); + auto stat = cp_->GetStat(true); + key_miss_results_->push_back(stat.min_key); + key_miss_results_->push_back(stat.max_key); + key_miss_results_->insert(key_miss_results_->end(), stat.gap.begin(), stat.gap.end()); + } + out->insert(out->end(), key_miss_results_->begin(), key_miss_results_->end()); return Status::OK(); } + Status CacheService::GetStat(CacheService::ServiceStat *out) { SharedLock rw(&rw_lock_); RETURN_UNEXPECTED_IF_NULL(out); - if (st_ == State::kNone || st_ == State::kFetchPhase) { - out->stat_ = cp_->GetStat(); - out->state_ = static_cast(st_); - auto it = map_->begin(); - if (it != map_->end()) { - out->min_ = it.key(); - auto end_it = map_->end(); - --end_it; - out->max_ = end_it.key(); - } - } else { - out->state_ = static_cast(st_); - } + out->stat_ = cp_->GetStat(); + out->state_ = static_cast(st_); return Status::OK(); } @@ -204,19 +256,12 @@ Status CacheService::PreBatchFetch(const std::vector &v, std::vecto *mem_sz = (num_elements + 1) * sizeof(int64_t); (*out).reserve(num_elements); for (auto row_id : v) { - auto r = map_->Search(row_id); - if (r.second) { - auto &it = r.first; - CachePool::key_type key = it.value(); - auto sz = cp_->GetSize(key); - if (sz == 0) { - std::string errMsg = "Key not found: "; - errMsg += std::to_string(key); - RETURN_STATUS_UNEXPECTED(errMsg); - } - (*out).emplace_back(key, sz); + auto sz = cp_->GetSize(row_id); + if (sz > 0) { + (*out).emplace_back(row_id, sz); (*mem_sz) += sz; } else { + // key not found (*out).emplace_back(-1, 0); } } @@ -252,27 +297,19 @@ Status CacheService::BatchFetch(const std::vector &v, const std::ve } return Status::OK(); } + Status CacheService::CacheSchema(const void *buf, int64_t len) { - SharedLock rw(&rw_lock_); - if (st_ == State::kFetchPhase) { - // For this kind of cache service, once we are done with the build phase into fetch phase, we can't - // allow other to cache more rows. - RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); - } - // This is a special request and we need to remember where we store it. + UniqueLock rw(&rw_lock_); // In case we are calling the same function from multiple threads, only // the first one is considered. Rest is ignored. - CachePool::key_type cur_key = schema_key_; - CachePool::key_type key; - if (cur_key < 0) { - RETURN_IF_NOT_OK(cp_->Insert({ReadableSlice(buf, len)}, &key)); - auto result = std::atomic_compare_exchange_strong(&schema_key_, &cur_key, key); - MS_LOG(DEBUG) << "Caching Schema. Result = " << result; + if (schema_.empty()) { + schema_.assign(static_cast(buf), len); } else { MS_LOG(DEBUG) << "Caching Schema already done"; } return Status::OK(); } + Status CacheService::FetchSchema(std::string *out) const { SharedLock rw(&rw_lock_); if (st_ == State::kBuildPhase) { @@ -283,32 +320,44 @@ Status CacheService::FetchSchema(std::string *out) const { // We are going to use std::string to allocate and hold the result which will be eventually // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose // to minimize memory copy. - std::string mem; - if (schema_key_ >= 0) { - auto len = cp_->GetSize(schema_key_); - try { - mem.resize(len); - CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= len, "Programming error"); - } catch (const std::bad_alloc &e) { - return Status(StatusCode::kOutOfMemory); - } - auto slice = WritableSlice(mem.data(), len); - RETURN_IF_NOT_OK(cp_->Read(schema_key_, &slice)); + std::string mem(schema_); + if (!mem.empty()) { *out = std::move(mem); } else { return Status(StatusCode::kFileNotExist, __LINE__, __FILE__, "No schema has been cached"); } return Status::OK(); } + Status CacheService::BuildPhaseDone() { if (HasBuildPhase()) { // Exclusive lock to switch phase UniqueLock rw(&rw_lock_); st_ = State::kFetchPhase; + cp_->SetLocking(false); return Status::OK(); } else { RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase"); } } + +Status CacheService::ToggleWriteMode(bool on_off) { + UniqueLock rw(&rw_lock_); + if (HasBuildPhase()) { + RETURN_STATUS_UNEXPECTED("Not applicable to non-mappable dataset"); + } else { + // If we stop accepting write request, we turn off locking for the + // underlying B+ tree. All future write request we will return kOutOfMemory. + if (st_ == State::kNone && !on_off) { + st_ = State::kNoLocking; + cp_->SetLocking(on_off); + MS_LOG(WARNING) << "Locking mode is switched off."; + } else if (st_ == State::kNoLocking && on_off) { + st_ = State::kNone; + cp_->SetLocking(on_off); + } + } + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h index 824d24975f7..ab8a50775b3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -44,9 +45,8 @@ using key_size_pair = std::pair; class CacheService : public Service { public: friend class CacheServer; - using row_map = BPlusTree; - enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase }; + enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase, kNoLocking }; /// \brief Constructor /// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited @@ -97,11 +97,9 @@ class CacheService : public Service { class ServiceStat { public: using state_type = std::underlying_type::type; - ServiceStat() : min_(0), max_(0), state_(0) {} + ServiceStat() : state_(0) {} ~ServiceStat() = default; CachePool::CacheStat stat_{}; - row_id_type min_; - row_id_type max_; state_type state_; }; /// \brief Statistics for the current service @@ -117,9 +115,9 @@ class CacheService : public Service { /// \param out A contiguous memory that contains the serialized form of schema. /// \return Status object Status FetchSchema(std::string *out) const; - /// \brief Purge the content of a cache + /// \brief Return a set of keys that are definitely cache miss /// \return Status object - Status Purge(); + Status FindKeysMiss(std::vector *out); /// \brief Overload the << operator to print a cache service /// \param out std::ostream /// \param cs A cache service @@ -136,19 +134,33 @@ class CacheService : public Service { /// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call. /// \return Status object Status BuildPhaseDone(); + /// \brief Find out the current memory usage + int64_t GetMemoryUsage() { return cur_mem_usage_; } + /// \brief Find out the current disk usage + int64_t GetDiskUsage() { return cur_disk_usage_; } + /// \brief For kToggleWriteMode request + Status ToggleWriteMode(bool on_off); private: mutable RWLock rw_lock_; std::string root_; uint64_t cache_mem_sz_; std::shared_ptr cp_; - std::shared_ptr map_; std::atomic next_id_; bool generate_id_; - std::atomic schema_key_; std::string cookie_; State st_; - + std::string schema_; + // If we use an Arena, cur_disk_usage is always 0 as we don't know how CachePool manages it. + // Otherwise we track how much is in memory and how much is on disk (if root_ is not empty). + // We use them to control when we should stop caching in memory in the case when there is no + // Arena. + std::atomic cur_mem_usage_; + std::atomic cur_disk_usage_; + // We also cache the result from calling FindKeysMiss because it is expensive. Besides user make + // this request after we hit memory full or disk full. So the result is unlikely to change. + std::mutex get_key_miss_mux_; + std::shared_ptr> key_miss_results_; /// \brief Private function to generate a row id /// \return Row id assigned. row_id_type GetNextRowId() { return next_id_.fetch_add(1); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs b/mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs index 5d24995ed1c..5986f379f77 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs @@ -92,3 +92,13 @@ table CreateCacheReplyMsg { connection_id:int64; cookie:string; } + +table ListSessionMsg { + session_id:uint32; + connection_id:uint64; + stats:ServiceStatMsg; +} + +table ListSessionsMsg { + sessions:[ListSessionMsg]; +} diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc index 554eb2b19b3..ceca08558a6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc @@ -53,22 +53,35 @@ CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t row num_cache_miss_(0), cache_client_(std::move(cache_client)), rows_per_buffer_(rows_per_buf), - // We can cause deadlock if this internal Connector size is too small. - keys_miss_(num_workers_, 1, connector_capacity_), - prefetch_size_(cache_client_->getPrefetchSize()) { + prefetch_size_(rows_per_buffer_), + num_prefetchers_(num_workers_) { + // Adjust the prefetch size based on the number of workers. + auto prefetch_sz_per_thread = cache_client_->GetPrefetchSize() / num_prefetchers_; + if (prefetch_size_ < prefetch_sz_per_thread) { + prefetch_size_ = prefetch_sz_per_thread; + MS_LOG(DEBUG) << "Per worker prefetch size : " << prefetch_size_; + } io_block_queues_.Init(num_workers, op_connector_size); - prefetch_queues_.Init(num_workers, op_connector_size); - sampler_queue_ = std::make_unique>>(op_connector_size); + prefetch_queues_.Init(num_prefetchers_, op_connector_size); + // We can cause deadlock if this internal Connector size is too small. + keys_miss_ = std::make_unique>>(num_prefetchers_, 1, connector_capacity_); } // Common function to fetch samples from the sampler and send them using the io_block_queues to // the parallel workers Status CacheBase::FetchSamplesToWorkers() { int64_t buf_cnt = 0; int64_t wait_cnt = 0; + int64_t prefetch_cnt = 0; // Kick off several threads which will prefetch prefetch_size_ rows in advance. The rows_per_buffers_ // is too small (1 by default) and won't help performance. - RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Dispatcher", std::bind(&CacheBase::Dispatcher, this))); - RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1))); + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_prefetchers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1))); + auto send_to_que = [](QueueList> &qList, int32_t worker_id, + std::vector &keys) -> Status { + auto blk = std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)); + RETURN_IF_NOT_OK(qList[worker_id]->Add(std::move(blk))); + return Status::OK(); + }; // Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them // to the WorkerEntry. do { @@ -82,33 +95,54 @@ Status CacheBase::FetchSamplesToWorkers() { ++wait_cnt; std::vector keys; keys.reserve(rows_per_buffer_); + std::vector prefetch_keys; + prefetch_keys.reserve(prefetch_size_); std::unique_ptr sampler_buffer; RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); while (!sampler_buffer->eoe()) { TensorRow sample_row; RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); std::shared_ptr sample_ids = sample_row[0]; - // Send the sampler tensor to other thread for prefetching. We are using shared pointer so it - // won't go out scope until it is really not in use. - RETURN_IF_NOT_OK(sampler_queue_->Add(sample_ids)); for (auto itr = sample_ids->begin(); itr != sample_ids->end(); itr++) { - keys.push_back(*itr); ++row_cnt_; - if (row_cnt_ % rows_per_buffer_ == 0) { - auto blk = std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)); - RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); - keys.clear(); + prefetch_keys.push_back(*itr); + // Batch enough rows for performance reason. + if (row_cnt_ % prefetch_size_ == 0) { + RETURN_IF_NOT_OK(send_to_que(prefetch_queues_, prefetch_cnt++ % num_prefetchers_, prefetch_keys)); + // Now we tell the WorkerEntry to wait for them to come back. If prefetch_size_ is a multiple + // of rows_per_buffer_, the keys vector will always be empty. But it can be partially filled. + // The only requirement we set up is rows_per_buffer_ is less than or equal to prefetch_size_. + for (auto row_id : prefetch_keys) { + keys.push_back(row_id); + if (keys.size() == rows_per_buffer_) { + RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys)); + keys.clear(); + } + } + prefetch_keys.clear(); } } RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } + // Deal with any partial keys left. + if (!prefetch_keys.empty()) { + RETURN_IF_NOT_OK(send_to_que(prefetch_queues_, prefetch_cnt++ % num_prefetchers_, prefetch_keys)); + for (auto row_id : prefetch_keys) { + keys.push_back(row_id); + if (keys.size() == rows_per_buffer_) { + RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys)); + keys.clear(); + } + } + } if (!keys.empty()) { - auto blk = std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)); - RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); + RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys)); } // send the eoe RETURN_IF_NOT_OK( io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK(prefetch_queues_[(prefetch_cnt++) % num_prefetchers_]->Add( + std::make_unique(IOBlock::kDeIoBlockFlagEoe))); // If repeat but the not last repeat, wait for reset. if (!IsLastIteration()) { MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << wait_cnt << " Buffer sent " << buf_cnt; @@ -123,8 +157,6 @@ Status CacheBase::FetchSamplesToWorkers() { RETURN_IF_NOT_OK( io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); // Shutdown threads - std::shared_ptr empty; - RETURN_IF_NOT_OK(sampler_queue_->Add(std::move(empty))); for (int32_t i = 0; i < num_workers_; i++) { RETURN_IF_NOT_OK( io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); @@ -145,13 +177,6 @@ Status CacheBase::FetchFromCache(int32_t worker_id) { if (blk->eof()) { RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); } else if (blk->eoe()) { - if (AllowCacheMiss()) { - // This code path is for CacheLookupOp acting as a sampler. If we get a eoe from - // a sampler, send a eoe to physical leaf op as well. - std::vector eoe; - eoe.push_back(eoe_row_id); - RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, eoe)); - } RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); } else { std::vector keys; @@ -162,22 +187,21 @@ Status CacheBase::FetchFromCache(int32_t worker_id) { } std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); std::unique_ptr que = std::make_unique(); - std::vector cache_miss; - cache_miss.reserve(keys.size()); for (auto row_id : keys) { TensorRow row; // Block until the row shows up in the pool. - RETURN_IF_NOT_OK(prefetch_.PopFront(row_id, &row)); + RETURN_IF_NOT_OK(GetPrefetchRow(row_id, &row)); if (row.empty()) { - cache_miss.push_back(row_id); + if (AllowCacheMiss()) { + ++num_cache_miss_; + } else { + std::string errMsg = "Row id " + std::to_string(row_id) + " not found."; + RETURN_STATUS_UNEXPECTED(errMsg); + } } que->push_back(std::move(row)); } db->set_tensor_table(std::move(que)); - if (AllowCacheMiss()) { - // Because of the way connector works, we push unconditionally even cache_miss can be empty. - RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, cache_miss)); - } RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); buffer_id += num_workers_; } @@ -189,7 +213,6 @@ Status CacheBase::RegisterResources() { RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(sampler_queue_->Register(tree_->AllTasks())); return Status::OK(); } @@ -208,73 +231,97 @@ Status CacheBase::UpdateColumnMapFromCache() { return rc; } -Status CacheBase::Dispatcher() { - TaskManager::FindMe()->Post(); - int64_t buf_cnt = 0; - int64_t num_row = 0; - std::vector keys; - keys.reserve(prefetch_size_); - do { - keys.clear(); - std::shared_ptr sample_ids; - RETURN_IF_NOT_OK(sampler_queue_->PopFront(&sample_ids)); - if (sample_ids == nullptr) { - // A null shared pointer signal times to quit. - // Also signal all prefetchers to quit. - for (int32_t i = 0; i < num_workers_; i++) { - RETURN_IF_NOT_OK( - prefetch_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); - } - break; - } - // Now we distribute the sampler ids to each prefetcher according to the prefetch size. - for (auto itr = sample_ids->begin(); itr != sample_ids->end(); itr++) { - keys.push_back(*itr); - ++num_row; - if (num_row % prefetch_size_ == 0) { - auto blk = std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)); - RETURN_IF_NOT_OK(prefetch_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); - keys.clear(); - } - } - // Send the remaining sample id - if (!keys.empty()) { - auto blk = std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)); - RETURN_IF_NOT_OK(prefetch_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); - } - } while (true); +Status CacheBase::GetPrefetchRow(row_id_type row_id, TensorRow *out) { + RETURN_UNEXPECTED_IF_NULL(out); + CHECK_FAIL_RETURN_UNEXPECTED(row_id >= 0, "Expect positive row id"); + RETURN_IF_NOT_OK(prefetch_.PopFront(row_id, out)); return Status::OK(); } +Status CacheBase::PrefetchRows(const std::vector &keys, std::vector *cache_miss) { + RETURN_UNEXPECTED_IF_NULL(cache_miss); + std::vector prefetch_keys; + prefetch_keys.reserve(keys.size()); + + // Filter out all those keys that unlikely we will find at the server + for (auto row_id : keys) { + if (cache_client_->KeyIsCacheMiss(row_id)) { + // Just put an empty row in the cache. + TensorRow row; + row.setId(row_id); + RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); + cache_miss->push_back(row_id); + } else { + prefetch_keys.push_back(row_id); + } + } + // Early exit if nothing to fetch + if (prefetch_keys.empty()) { + return Status::OK(); + } + // Get the rows from the server + TensorTable ttbl; + Status rc = cache_client_->GetRows(prefetch_keys, &ttbl); + if (rc.IsOk()) { + auto row_it = ttbl.begin(); + for (auto row_id : prefetch_keys) { + auto &row = *row_it; + if (row.empty()) { + cache_miss->push_back(row_id); + } + // Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row + RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); + ++row_it; + } + } else { + // In case any thread is waiting for the rows to come back and blocked on a semaphore, + // we will put an empty row in the local cache. + for (auto row_id : prefetch_keys) { + TensorRow row; + row.setId(row_id); + RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); + cache_miss->push_back(row_id); + } + } + return rc; +} + Status CacheBase::Prefetcher(int32_t worker_id) { TaskManager::FindMe()->Post(); std::vector prefetch_keys; prefetch_keys.reserve(prefetch_size_); + std::vector cache_miss; + cache_miss.reserve(prefetch_size_); do { prefetch_keys.clear(); + cache_miss.clear(); std::unique_ptr blk; RETURN_IF_NOT_OK(prefetch_queues_[worker_id]->PopFront(&blk)); - RETURN_IF_NOT_OK(blk->GetKeys(&prefetch_keys)); - if (prefetch_keys.empty()) { - // Empty keys mean time to quit. - break; - } - TensorTable ttbl; - RETURN_IF_NOT_OK(cache_client_->GetRows(prefetch_keys, &ttbl)); - auto row_it = ttbl.begin(); - for (auto row_id : prefetch_keys) { - auto &row = *row_it; - if (row.empty()) { - if (AllowCacheMiss()) { - ++num_cache_miss_; - } else { - std::string errMsg = "Row id " + std::to_string(row_id) + " not found."; - RETURN_STATUS_UNEXPECTED(errMsg); + CHECK_FAIL_RETURN_UNEXPECTED(!blk->eof(), "Expect eoe or a regular io block"); + if (!blk->eoe()) { + RETURN_IF_NOT_OK(blk->GetKeys(&prefetch_keys)); + Status rc; + const int32_t max_retries = 5; + int32_t retry_count = 0; + do { + rc = PrefetchRows(prefetch_keys, &cache_miss); + if (rc.IsNetWorkError() && retry_count < max_retries) { + // If we get some network error, we will attempt some retries + retry_count++; + } else if (rc.IsError()) { + return rc; } + } while (rc.IsNetWorkError()); + } else { + if (AllowCacheMiss()) { + // This code path is for CacheLookupOp acting as a sampler. If we get a eoe from + // a sampler, send a eoe to physical leaf op as well. + cache_miss.push_back(eoe_row_id); } - // Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row - RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); - ++row_it; + } + if (AllowCacheMiss()) { + // Because of the way connector works, we push unconditionally even cache_miss can be empty. + RETURN_IF_NOT_OK(keys_miss_->Push(worker_id, cache_miss)); } } while (true); return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h index 2225d4f3350..4c0c8016293 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h @@ -22,6 +22,7 @@ #include #include #include +#include "minddata/dataset/engine/connector.h" #include "minddata/dataset/engine/cache/cache_client.h" #include "minddata/dataset/engine/cache/cache_service.h" #include "minddata/dataset/engine/datasetops/parallel_op.h" @@ -90,8 +91,7 @@ class CacheBase : public ParallelOp { std::shared_ptr cache_client_; WaitPost epoch_sync_; int32_t rows_per_buffer_; - Connector> keys_miss_; - QueueMap prefetch_; + std::unique_ptr>> keys_miss_; /// \brief Common function to register resources for interrupt /// \note Derived should override this function for extra resources to be registered @@ -111,13 +111,16 @@ class CacheBase : public ParallelOp { constexpr static int32_t connector_capacity_ = 1024; int32_t prefetch_size_; QueueList> io_block_queues_; + int32_t num_prefetchers_; QueueList> prefetch_queues_; - std::unique_ptr>> sampler_queue_; + QueueMap prefetch_; - Status Dispatcher(); /// \brief Prefetcher. It prefetch the rows from cache server /// \return Status object. Status Prefetcher(int32_t worker_id); + /// \brief Functions used by prefetcher and WorkerEntry + Status PrefetchRows(const std::vector &keys, std::vector *cache_miss); + Status GetPrefetchRow(row_id_type row_id, TensorRow *out); }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc index 5fd882dce7b..0e38ad63333 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc @@ -87,10 +87,10 @@ Status CacheLookupOp::InitSampler() { return Sampler::InitSampler(); } void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); } Status CacheLookupOp::GetNextSample(std::unique_ptr *out_buffer) { std::vector cache_miss; - RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss)); + RETURN_IF_NOT_OK(keys_miss_->Pop(0, &cache_miss)); // Ignore the case we have no cache miss, we can't return empty samples. while (cache_miss.empty()) { - RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss)); + RETURN_IF_NOT_OK(keys_miss_->Pop(0, &cache_miss)); } // Special code for eoe if (cache_miss.at(0) == eoe_row_id) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc index 2de9f30b5aa..0de3767f45f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc @@ -25,6 +25,7 @@ #include "minddata/dataset/core/global_context.h" #include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/util/system_pool.h" #include "minddata/dataset/util/task_manager.h" namespace mindspore { @@ -48,7 +49,8 @@ CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t std::shared_ptr cache_client, const std::shared_ptr &sampler) : ParallelOp(numWorkers, opConnectorSize, sampler), num_cleaners_(numCleaners), - cache_client_(std::move(cache_client)) {} + cache_client_(std::move(cache_client)), + cache_missing_rows_(true) {} Status CacheMergeOp::operator()() { // A queue of row id to let cleaner send cache miss rows to the cache server @@ -129,17 +131,19 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { std::string errMsg = "Expect positive row id: " + std::to_string(row_id); RETURN_STATUS_UNEXPECTED(errMsg); } - // Technically number of this row shows up in the cache miss stream is equal to the number - // of P() call. However the cleaner wants it too. So we need an extra copy. - TensorRowCacheRequest *rq; - RETURN_IF_NOT_OK(GetRq(row_id, &rq)); - if (rq->GetState() == TensorRowCacheRequest::State::kEmpty) { - // We will send the request async. But any error we most - // likely ignore and continue. - Status rc; - rc = rq->AsyncSendCacheRequest(cache_client_, row); - if (rc.IsOk()) { - RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id)); + if (cache_missing_rows_) { + // Technically number of this row shows up in the cache miss stream is equal to the number + // of P() call. However the cleaner wants it too. So we need an extra copy. + TensorRowCacheRequest *rq; + RETURN_IF_NOT_OK(GetRq(row_id, &rq)); + if (rq->GetState() == TensorRowCacheRequest::State::kEmpty) { + // We will send the request async. But any error we most + // likely ignore and continue. + Status rc; + rc = rq->AsyncSendCacheRequest(cache_client_, row); + if (rc.IsOk()) { + RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id)); + } } } RETURN_IF_NOT_OK(cache_miss_.Add(row_id, std::move(row))); @@ -168,13 +172,18 @@ Status CacheMergeOp::Cleaner() { Status rc = rq->CheckCacheResult(); if (rc.IsError()) { // If interrupt, time to quit. - if (rc.get_code() == StatusCode::kInterrupted) { + if (rc.IsInterrupted()) { return Status::OK(); + } else if (rc.IsOutofMemory() || rc.IsNoSpace()) { + // The server is hitting some limit and we will turn off caching from now on. + cache_missing_rows_ = false; + cache_client_->ServerRunningOutOfResources(); + } else { + MS_LOG(INFO) << "Cache row not successful: " << rc.ToString(); + // Bad rc should not bring down the pipeline. We will simply continue and + // change the state back to empty. We don't need a CAS from CLEAN back to EMPTY. + rq->SetState(TensorRowCacheRequest::State::kEmpty); } - MS_LOG(INFO) << "Cache row not successful: " << rc.ToString(); - // Bad rc should not bring down the pipeline. We will simply continue and - // change the state back to empty. We don't need a CAS from CLEAN back to EMPTY. - rq->SetState(TensorRowCacheRequest::State::kEmpty); } } return Status::OK(); @@ -253,7 +262,7 @@ Status CacheMergeOp::Accept(NodePass *p, bool *modified) { Status CacheMergeOp::EoeReceived(int32_t worker_id) { // If we are in a repeat path, send the eoe up. // Otherwise ignore it. - if (op_total_repeats_ > 1) { + if (op_total_repeats_ != 1) { return DatasetOp::EoeReceived(worker_id); } return Status::OK(); @@ -281,7 +290,7 @@ Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowCacheReque *out = it->second.GetMutablePointer(); } else { // We will create a new one. - auto alloc = Services::GetAllocator(); + auto alloc = SystemPool::GetAllocator(); auto r = io_request_.emplace(row_id, MemGuard>(alloc)); if (r.second) { auto &mem = r.first->second; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h index 4c62af1d5c5..db702c03dbc 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h @@ -202,6 +202,7 @@ class CacheMergeOp : public ParallelOp { std::unique_ptr> io_que_; std::shared_ptr cache_client_; int32_t num_cleaners_; + std::atomic cache_missing_rows_; /// \brief Locate the cache request from the io_request_ map /// \param row_id diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc index f5aaa545d2a..652d78bfbc6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc @@ -16,6 +16,7 @@ #include "minddata/dataset/engine/datasetops/cache_op.h" #include +#include #include #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/core/constants.h" @@ -64,7 +65,7 @@ Status CacheOp::Builder::Build(std::shared_ptr *ptr) { // Constructor of CacheOp CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, std::shared_ptr cache_client, std::shared_ptr sampler) - : CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), + : CacheBase(num_workers, op_connector_size, rows_per_buf, std::move(cache_client), std::move(sampler)), num_guys_in_(0), phase_(Phase::kBuildPhase) {} @@ -174,7 +175,7 @@ Status CacheOp::WorkerEntry(int32_t worker_id) { Status CacheOp::RegisterResources() { RETURN_IF_NOT_OK(CacheBase::RegisterResources()); RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(keys_miss_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(keys_miss_->Register(tree_->AllTasks())); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc index 9f2294bf16b..5233af631a0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc @@ -20,6 +20,7 @@ #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/engine/data_buffer.h" #include "minddata/dataset/engine/datasetops/concat_op.h" +#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/engine/db_connector.h" #include "minddata/dataset/engine/execution_tree.h" @@ -188,5 +189,11 @@ Status ConcatOp::ComputeColMap() { } return Status::OK(); } + +// Visitor pre-accept method for NodePass +Status ConcatOp::PreAccept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->PreRunOnNode(shared_from_base(), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h index e2fc6735778..7df5e0ae6f5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h @@ -105,6 +105,12 @@ class ConcatOp : public PipelineOp { // @return - Status Status ComputeColMap() override; + /// \brief Base-class override for NodePass pre-visit acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status PreAccept(NodePass *p, bool *modified) override; + private: Status Verify(int32_t id, const std::unique_ptr &buf); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc index 1b877779274..723a5bcc388 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc @@ -243,6 +243,7 @@ void DatasetOp::Print(std::ostream &out, bool show_all) const { out << "\nConnector queue size : " << oc_queue_size_ << "\nTotal repeats : " << op_total_repeats_ << "\nNumber repeats per epoch : " << op_num_repeats_per_epoch_; if (sampler_) { + out << "\nSampler:\n"; sampler_->Print(out, show_all); } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc index dd672468b11..454989a3f2a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc @@ -268,5 +268,11 @@ Status FilterOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor return p->RunOnNode(shared_from_base(), modified); } + +// Visitor pre-accept method for NodePass +Status FilterOp::PreAccept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->PreRunOnNode(shared_from_base(), modified); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h index 8cc0cd55ff4..e41011bab93 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h @@ -121,6 +121,12 @@ class FilterOp : public ParallelOp { // @param show_all A bool to control if you want to show all info or just a summary. void Print(std::ostream &out, bool show_all) const override; + /// \brief Base-class override for NodePass pre-visit acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status PreAccept(NodePass *p, bool *modified) override; + // Base-class override for NodePass visitor acceptor. // @param p - Pointer to the NodePass to be accepted. // @param modified - Whether this node visit modified the pipeline. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc index 3effedafbc7..f8c628647ff 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc @@ -458,6 +458,12 @@ Status MapOp::Accept(NodePass *p, bool *modified) { return p->RunOnNode(shared_from_base(), modified); } +// Visitor pre-accept method for NodePass +Status MapOp::PreAccept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->PreRunOnNode(shared_from_base(), modified); +} + Status MapOp::WaitForWorkers() { // reset num_paused workers to 0 num_workers_paused_ = 0; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h index de99a2587ea..59b1811b68b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h @@ -177,10 +177,16 @@ class MapOp : public ParallelOp { // @return the number of threads consuming data from previous op's output Connector. int32_t num_consumers() const override; - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. + /// \brief Base-class override for NodePass pre-visit acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status PreAccept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for NodePass visitor acceptor. + /// \param[in] p Pointer to the NodePass to be accepted. + /// \param[out] modified Whether this node visit modified the pipeline. + /// \return - Status of the node visit. Status Accept(NodePass *p, bool *modified) override; // Op name getter diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc index abb827aea85..5e02516e01a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc @@ -52,9 +52,9 @@ Status ParallelOp::CreateWorkerConnector(int32_t worker_connector_size) { void ParallelOp::Print(std::ostream &out, bool show_all) const { // Summary 1-liner print if (!show_all) { - out << " [workers: " << num_workers_ << "]"; // Call super class printer DatasetOp::Print(out, show_all); + out << " [workers: " << num_workers_ << "]"; } else { // Detailed print DatasetOp::Print(out, show_all); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.cc index fff5ba19e7c..6e4f533eb72 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.cc @@ -27,14 +27,14 @@ PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr sampl void PipelineOp::Print(std::ostream &out, bool show_all) const { // Summary 1-liner print if (!show_all) { + // Call super class printer + DatasetOp::Print(out, show_all); out << " [workers: "; if (this->inlined()) { out << "0 (inlined)]"; } else { out << "1]"; // Pipeline ops only have 1 worker } - // Call super class printer - DatasetOp::Print(out, show_all); } else { // Detailed print DatasetOp::Print(out, show_all); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc index cee51bbe1ae..3b1fec02e0b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc @@ -235,6 +235,12 @@ Status ZipOp::EoeReceived(int32_t) { return Status::OK(); } +// Visitor pre-accept method for NodePass +Status ZipOp::PreAccept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->PreRunOnNode(shared_from_base(), modified); +} + // Visitor accept method for NodePass Status ZipOp::Accept(NodePass *p, bool *modified) { // Downcast shared pointer then call visitor diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h index 2995b49c23c..f2cc2823997 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h @@ -104,10 +104,16 @@ class ZipOp : public PipelineOp { // @return Status - The error code return Status operator()() override; - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. + /// \brief Base-class override for NodePass pre-visit acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status PreAccept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for NodePass visitor acceptor. + /// \param[in] p Pointer to the NodePass to be accepted. + /// \param[out] modified Whether this node visit modified the pipeline. + /// \return - Status of the node visit. Status Accept(NodePass *p, bool *modified) override; // Op name getter diff --git a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc index d2eba5ff64c..8a435749976 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc @@ -26,6 +26,7 @@ #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" #include "minddata/dataset/engine/opt/post/repeat_pass.h" #endif +#include "minddata/dataset/engine/opt/pre/cache_error_pass.h" #include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h" #include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" #include "minddata/dataset/engine/perf/profiling.h" @@ -235,6 +236,7 @@ Status ExecutionTree::PrepareTreePreAction() { std::vector> pre_actions; // Construct pre actions MS_LOG(INFO) << "Running pre pass loops."; + pre_actions.push_back(std::make_unique()); pre_actions.push_back(std::make_unique()); pre_actions.push_back(std::make_unique()); #ifndef ENABLE_ANDROID diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt index 50346ffad81..a9b439cd392 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt @@ -3,6 +3,7 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE add_library(engine-opt OBJECT pass.cc post/repeat_pass.cc + pre/cache_error_pass.cc pre/cache_transform_pass.cc pre/epoch_injection_pass.cc pre/removal_pass.cc diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc index 847be38d5da..1ee5d2b68c4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc @@ -23,6 +23,7 @@ #include "minddata/dataset/engine/datasetops/cache_merge_op.h" #include "minddata/dataset/engine/datasetops/cache_lookup_op.h" #endif +#include "minddata/dataset/engine/datasetops/concat_op.h" #include "minddata/dataset/engine/datasetops/dataset_op.h" #include "minddata/dataset/engine/datasetops/device_queue_op.h" #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" @@ -143,40 +144,6 @@ Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { return RunOnNode(std::static_pointer_cast(node), modified); } -#ifndef ENABLE_ANDROID -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} -#endif - -#ifdef ENABLE_PYTHON -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} -#endif - Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { // Fallback to base class visitor by default return RunOnNode(std::static_pointer_cast(node), modified); @@ -207,13 +174,6 @@ Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { return RunOnNode(std::static_pointer_cast(node), modified); } -#ifndef ENABLE_ANDROID -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} -#endif - Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { // Fallback to base class visitor by default return RunOnNode(std::static_pointer_cast(node), modified); @@ -239,18 +199,6 @@ Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { return RunOnNode(std::static_pointer_cast(node), modified); } -#ifndef ENABLE_ANDROID -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} -#endif - Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { // Fallback to base class visitor by default return RunOnNode(std::static_pointer_cast(node), modified); @@ -261,18 +209,6 @@ Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { return PreRunOnNode(std::static_pointer_cast(node), modified); } -#ifndef ENABLE_ANDROID -Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return PreRunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return PreRunOnNode(std::static_pointer_cast(node), modified); -} -#endif - Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { // Fallback to base class visitor by default return PreRunOnNode(std::static_pointer_cast(node), modified); @@ -283,12 +219,88 @@ Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified return PreRunOnNode(std::static_pointer_cast(node), modified); } +Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return PreRunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return PreRunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return PreRunOnNode(std::static_pointer_cast(node), modified); +} + #ifndef ENABLE_ANDROID +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return PreRunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return PreRunOnNode(std::static_pointer_cast(node), modified); +} + Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { // Fallback to base class visitor by default return PreRunOnNode(std::static_pointer_cast(node), modified); } #endif +#ifdef ENABLE_PYTHON +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return PreRunOnNode(std::static_pointer_cast(node), modified); +} +#endif } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h index 626dd7687e2..33433f6e490 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h @@ -37,18 +37,6 @@ class SkipOp; class ShuffleOp; -#ifndef ENABLE_ANDROID -class MindRecordOp; - -class TFReaderOp; -#endif - -#ifdef ENABLE_PYTHON -class FilterOp; - -class GeneratorOp; -#endif - class AlbumOp; class RandomDataOp; @@ -63,10 +51,6 @@ class DeviceQueueOp; class ImageFolderOp; -#ifndef ENABLE_ANDROID -class CacheOp; -#endif - class MnistOp; class ManifestOp; @@ -79,20 +63,32 @@ class CocoOp; class CelebAOp; -#ifndef ENABLE_ANDROID -class CacheMergeOp; - -class CacheLookupOp; -#endif - class EpochCtrlOp; class BuildVocabOp; +class ConcatOp; + #ifndef ENABLE_ANDROID +class MindRecordOp; + +class TFReaderOp; + +class CacheOp; + +class CacheMergeOp; + +class CacheLookupOp; + class BuildSentencePieceVocabOp; #endif +#ifdef ENABLE_PYTHON +class FilterOp; + +class GeneratorOp; +#endif + // The base class Pass is the basic unit of tree transformation. // The actual implementation of the passes will be derived from here. class Pass : public std::enable_shared_from_this { @@ -168,22 +164,6 @@ class NodePass : public Pass { virtual Status RunOnNode(std::shared_ptr node, bool *modified); -#ifndef ENABLE_ANDROID - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); -#endif - -#ifdef ENABLE_PYTHON - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); -#endif - virtual Status RunOnNode(std::shared_ptr node, bool *modified); virtual Status RunOnNode(std::shared_ptr node, bool *modified); @@ -194,10 +174,6 @@ class NodePass : public Pass { virtual Status RunOnNode(std::shared_ptr node, bool *modified); -#ifndef ENABLE_ANDROID - virtual Status RunOnNode(std::shared_ptr node, bool *modified); -#endif - virtual Status RunOnNode(std::shared_ptr node, bool *modified); virtual Status RunOnNode(std::shared_ptr node, bool *modified); @@ -210,32 +186,50 @@ class NodePass : public Pass { virtual Status RunOnNode(std::shared_ptr node, bool *modified); -#ifndef ENABLE_ANDROID - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); -#endif - virtual Status RunOnNode(std::shared_ptr node, bool *modified); -#ifndef ENABLE_ANDROID - virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); -#endif - virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); -#ifndef ENABLE_ANDROID - virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); -#endif - virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); + + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); + + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); + #ifndef ENABLE_ANDROID + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); + + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); #endif +#ifdef ENABLE_PYTHON + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); +#endif + private: // Helper function to perform DFS visit Status DFSNodeVisit(std::shared_ptr node, bool *modified); diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc index d737a0fa1bc..d4139b5b834 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc @@ -225,13 +225,17 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { // Turns off the tracking for operations under merge op Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + // If there was not any repeat in the merge cache miss leg, then the cache_lookup + // would not have been consumed yet. In that case, we need to set its total repeats for it. + if (cache_lookup_) { + cache_lookup_->set_total_repeats(num_repeats_); + cache_lookup_->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); + } // Setting the flag is needed since we didn't call the base class DatasetOp version if (is_repeated_) { // If there was not any repeat in the merge cache miss leg, then the cache_lookup // would not have been consumed yet. In that case, we need to assign it to the upper repeat eoe stack if (cache_lookup_) { - cache_lookup_->set_total_repeats(num_repeats_); - node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); AddToEOEOpStack(std::move(cache_lookup_)); } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.cc new file mode 100644 index 00000000000..b3bb275dcfe --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "minddata/dataset/engine/datasetops/cache_op.h" +#include "minddata/dataset/engine/datasetops/zip_op.h" +#include "minddata/dataset/engine/datasetops/map_op/map_op.h" +#include "minddata/dataset/engine/opt/pre/cache_error_pass.h" + +namespace mindspore { +namespace dataset { + +// Constructor +CacheErrorPass::CacheErrorPass() : is_cached_(false) {} + +// Identifies the subtree below this node as being cached +Status CacheErrorPass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Turn on the flag that we're under a merge op + is_cached_ = true; + return Status::OK(); +} + +// Returns an error if ZipOp exists under a cache +Status CacheErrorPass::PreRunOnNode(std::shared_ptr node, bool *modified) { + if (is_cached_) { + RETURN_STATUS_UNEXPECTED("ZipOp is currently not supported as a descendant operator under a cache."); + } + + return Status::OK(); +} + +// Returns an error if MapOp with non-deterministic TensorOps exists under a cache +Status CacheErrorPass::PreRunOnNode(std::shared_ptr node, bool *modified) { + if (is_cached_) { + auto tfuncs = node->TFuncs(); + for (size_t i = 0; i < tfuncs.size(); i++) { + if (!tfuncs[i]->Deterministic()) { + RETURN_STATUS_UNEXPECTED( + "MapOp with non-deterministic TensorOps is currently not supported as a descendant of cache."); + } + } + } + return Status::OK(); +} + +// Returns an error if ConcatOp exists under a cache +Status CacheErrorPass::PreRunOnNode(std::shared_ptr node, bool *modified) { + if (is_cached_) { + RETURN_STATUS_UNEXPECTED("ConcatOp is currently not supported as a descendant operator under a cache."); + } + + return Status::OK(); +} + +#ifdef ENABLE_PYTHON +// Returns an error if FilterOp exists under a cache +Status CacheErrorPass::PreRunOnNode(std::shared_ptr node, bool *modified) { + if (is_cached_) { + RETURN_STATUS_UNEXPECTED("FilterOp is currently not supported as a descendant operator under a cache."); + } + + return Status::OK(); +} +#endif +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.h new file mode 100644 index 00000000000..5a20d2f35eb --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.h @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_ERROR_PASS_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_ERROR_PASS_ + +#include +#include +#include +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +/// \class CacheErrorPass cache_error_pass.h +/// \brief This is a NodePass who's job is to catch invalid tree configurations related to cache and generate failures. +class CacheErrorPass : public NodePass { + public: + /// \brief Constructor + CacheErrorPass(); + + /// \brief Destructor + ~CacheErrorPass() = default; + + /// \brief Identifies the subtree below this node as being cached + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Returns an error if ZipOp exists under a cache + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Returns an error if MapOp with non-deterministic TensorOps exists under a cache + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Returns an error if ConcatOp exists under a cache + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + +#ifdef ENABLE_PYTHON + /// \brief Returns an error if FilterOp exists under a cache + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; +#endif + + private: + bool is_cached_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_POST_CACHE_ERROR_PASS_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc index 12e47a3c514..953aeab5376 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc @@ -155,50 +155,77 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr n // Perform leaf node cache transform identification Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); + if (is_caching_) { + RETURN_STATUS_UNEXPECTED("There is currently no support for AlbumOp under cache."); + } + return Status::OK(); } // Perform leaf node cache transform identification Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); + if (is_caching_) { + RETURN_STATUS_UNEXPECTED("There is currently no support for MnistOp under cache."); + } + return Status::OK(); } // Perform leaf node cache transform identification Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); + if (is_caching_) { + RETURN_STATUS_UNEXPECTED("There is currently no support for CifarOp under cache."); + } + return Status::OK(); } // Perform leaf node cache transform identification Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); + if (is_caching_) { + RETURN_STATUS_UNEXPECTED("There is currently no support for CocoOp under cache."); + } + return Status::OK(); } // Perform leaf node cache transform identification Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); + if (is_caching_) { + RETURN_STATUS_UNEXPECTED("There is currently no support for CelebAOp under cache."); + } + return Status::OK(); } #ifndef ENABLE_ANDROID // Perform leaf node cache transform identification Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); + if (is_caching_) { + RETURN_STATUS_UNEXPECTED("There is currently no support for MindRecordOp under cache."); + } + return Status::OK(); } #endif #ifdef ENABLE_PYTHON // Perform leaf node cache transform identification Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); + if (is_caching_) { + RETURN_STATUS_UNEXPECTED("There is currently no support for GeneratorOp under cache."); + } + return Status::OK(); } // Perform leaf node cache transform identification Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); + if (is_caching_) { + RETURN_STATUS_UNEXPECTED("There is currently no support for ManifestOp under cache."); + } + return Status::OK(); } // Perform leaf node cache transform identification Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); + if (is_caching_) { + RETURN_STATUS_UNEXPECTED("There is currently no support for VOCOp under cache."); + } + return Status::OK(); } #endif diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.cc index d202b9a1747..34e225de626 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.cc @@ -40,13 +40,6 @@ Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr node, bool *modified) { - injection_point_ = nullptr; - return Status::OK(); -} #endif Status EpochInjectionPass::InjectionFinder::RunOnNode(std::shared_ptr node, bool *modified) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.h index a80c0d6650b..ae1beead919 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.h @@ -54,13 +54,6 @@ class EpochInjectionPass : public TreePass { /// \param[inout] modified Indicator if the node was changed at all /// \return Status The error code return Status PreRunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Temporary code to prevent the injection of epoch control when cache op is present. - /// Remove this code in cache op phase 2 - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status PreRunOnNode(std::shared_ptr node, bool *modified) override; #endif /// \brief Register the DeviceQueueOp for further action. diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/random_apply_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/random_apply_op.cc index 9fe1d875e21..9eb4a1d2e65 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/random_apply_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/random_apply_op.cc @@ -62,6 +62,7 @@ Status RandomApplyOp::Compute(const TensorRow &input, TensorRow *output) { RandomApplyOp::RandomApplyOp(double prob, const std::vector> &ops) : prob_(prob), gen_(GetSeed()), rand_double_(0, 1) { compose_ = std::make_unique(ops); + is_deterministic_ = false; } } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/random_choice_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/random_choice_op.cc index ee505b0dc28..b9444a3298c 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/random_choice_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/random_choice_op.cc @@ -92,6 +92,7 @@ RandomChoiceOp::RandomChoiceOp(const std::vector> &ops } else if (ops_.size() == 1) { MS_LOG(WARNING) << "op_list has only 1 op, this op would be picked every time."; } + is_deterministic_ = false; } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_affine_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_affine_op.cc index 55ba93895b6..c57065e304b 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_affine_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_affine_op.cc @@ -44,6 +44,7 @@ RandomAffineOp::RandomAffineOp(std::vector degrees, std::vector &input, std::shared_ptr *output) { diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.cc index 6dbf30c33e3..05f3d4ea056 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.cc @@ -36,6 +36,7 @@ RandomColorAdjustOp::RandomColorAdjustOp(float s_bright_factor, float e_bright_f hue_factor_start_(s_hue_factor), hue_factor_end_(e_hue_factor) { rnd_.seed(GetSeed()); + is_deterministic_ = false; } Status RandomColorAdjustOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_op.cc index 7400ab1fa14..acfb46ab905 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_op.cc @@ -19,7 +19,9 @@ namespace mindspore { namespace dataset { -RandomColorOp::RandomColorOp(float t_lb, float t_ub) : rnd_(GetSeed()), dist_(t_lb, t_ub), t_lb_(t_lb), t_ub_(t_ub) {} +RandomColorOp::RandomColorOp(float t_lb, float t_ub) : rnd_(GetSeed()), dist_(t_lb, t_ub), t_lb_(t_lb), t_ub_(t_ub) { + is_deterministic_ = false; +} Status RandomColorOp::Compute(const std::shared_ptr &in, std::shared_ptr *out) { IO_CHECK(in, out); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_op.cc index 8a7364d6667..0991db197b9 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_op.cc @@ -41,6 +41,7 @@ RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t targ aspect_ub_(aspect_ub), max_iter_(max_iter) { rnd_.seed(GetSeed()); + is_deterministic_ = false; } Status RandomCropAndResizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.cc index c53e0c06d22..975dbf56fd6 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.cc @@ -46,6 +46,7 @@ RandomCropOp::RandomCropOp(int32_t crop_height, int32_t crop_width, int32_t pad_ fill_g_(fill_g), fill_b_(fill_b) { rnd_.seed(GetSeed()); + is_deterministic_ = false; } Status RandomCropOp::ImagePadding(const std::shared_ptr &input, std::shared_ptr *pad_image, diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_op.h index 53c11df1a6e..eb08424ad12 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_op.h @@ -33,6 +33,7 @@ class RandomHorizontalFlipOp : public TensorOp { static const float kDefProbability; explicit RandomHorizontalFlipOp(float probability = kDefProbability) : distribution_(probability) { + is_deterministic_ = false; rnd_.seed(GetSeed()); } diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h index f8e1e847f66..27e3d1e1cef 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h @@ -35,6 +35,7 @@ class RandomHorizontalFlipWithBBoxOp : public TensorOp { explicit RandomHorizontalFlipWithBBoxOp(float probability = kDefProbability) : distribution_(probability) { rnd_.seed(GetSeed()); + is_deterministic_ = false; } ~RandomHorizontalFlipWithBBoxOp() override = default; diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_posterize_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_posterize_op.cc index 605b5942a02..5d2304f5ca2 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_posterize_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_posterize_op.cc @@ -29,6 +29,7 @@ const std::vector RandomPosterizeOp::kBitRange = {4, 8}; RandomPosterizeOp::RandomPosterizeOp(const std::vector &bit_range) : PosterizeOp(bit_range[0]), bit_range_(bit_range) { rnd_.seed(GetSeed()); + is_deterministic_ = false; } Status RandomPosterizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_op.h index 77dee5b4d9c..4451eb820cd 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_op.h @@ -35,6 +35,7 @@ class RandomResizeOp : public ResizeOp { explicit RandomResizeOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeOp(size_1, size_2) { random_generator_.seed(GetSeed()); + is_deterministic_ = false; } ~RandomResizeOp() = default; diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_with_bbox_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_with_bbox_op.h index dbca032520e..61a0d1b868a 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_with_bbox_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_with_bbox_op.h @@ -36,6 +36,7 @@ class RandomResizeWithBBoxOp : public ResizeWithBBoxOp { static const int32_t kDefTargetWidth; explicit RandomResizeWithBBoxOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeWithBBoxOp(size_1, size_2) { random_generator_.seed(GetSeed()); + is_deterministic_ = false; } ~RandomResizeWithBBoxOp() = default; diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.cc index b2cb4facae7..872e0d5f0b2 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.cc @@ -46,6 +46,7 @@ RandomRotationOp::RandomRotationOp(float start_degree, float end_degree, float c fill_g_(fill_g), fill_b_(fill_b) { rnd_.seed(GetSeed()); + is_deterministic_ = false; } // main function call for random rotation : Generate the random degrees diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_select_subpolicy_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_select_subpolicy_op.cc index d01231f1f84..43eb839cebb 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_select_subpolicy_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_select_subpolicy_op.cc @@ -90,6 +90,7 @@ RandomSelectSubpolicyOp::RandomSelectSubpolicyOp(const std::vector &p if (policy_.empty()) { MS_LOG(ERROR) << "policy in RandomSelectSubpolicyOp is empty."; } + is_deterministic_ = false; } } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_sharpness_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_sharpness_op.cc index bf7b5a7ab84..d8711448f28 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_sharpness_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_sharpness_op.cc @@ -31,6 +31,7 @@ const float RandomSharpnessOp::kDefEndDegree = 1.9; RandomSharpnessOp::RandomSharpnessOp(float start_degree, float end_degree) : start_degree_(start_degree), end_degree_(end_degree) { rnd_.seed(GetSeed()); + is_deterministic_ = false; } /// main function call for random sharpness : Generate the random degrees diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_solarize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_solarize_op.h index d85fd1889ac..49c2f5596f0 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_solarize_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_solarize_op.h @@ -32,7 +32,10 @@ namespace dataset { class RandomSolarizeOp : public SolarizeOp { public: // Pick a random threshold value to solarize the image with - explicit RandomSolarizeOp(std::vector threshold = {0, 255}) : threshold_(threshold) { rnd_.seed(GetSeed()); } + explicit RandomSolarizeOp(std::vector threshold = {0, 255}) : threshold_(threshold) { + rnd_.seed(GetSeed()); + is_deterministic_ = false; + } ~RandomSolarizeOp() = default; diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_op.h index 1724d7a57d7..6dba1742446 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_op.h @@ -34,6 +34,7 @@ class RandomVerticalFlipOp : public TensorOp { explicit RandomVerticalFlipOp(float probability = kDefProbability) : distribution_(probability) { rnd_.seed(GetSeed()); + is_deterministic_ = false; } ~RandomVerticalFlipOp() override = default; diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h index f46101cc488..01840692547 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h @@ -34,6 +34,7 @@ class RandomVerticalFlipWithBBoxOp : public TensorOp { // @param probability: Probablity of Image flipping, 0.5 by default explicit RandomVerticalFlipWithBBoxOp(float probability = kDefProbability) : distribution_(probability) { rnd_.seed(GetSeed()); + is_deterministic_ = false; } ~RandomVerticalFlipWithBBoxOp() override = default; diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h index 1013bd9397d..2294e72bb81 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h @@ -168,6 +168,10 @@ class TensorOp { // @return true/false bool OneToOne() { return NumInput() == 1 && NumOutput() == 1; } + // Returns true oif the TensorOp produces deterministic result. + // @return true/false + bool Deterministic() { return is_deterministic_; } + // Function to determine the number of inputs the TensorOp can take. 0: means undefined. // @return uint32_t virtual uint32_t NumInput() { return 1; } @@ -191,6 +195,9 @@ class TensorOp { virtual Status OutputType(const std::vector &inputs, std::vector &outputs); virtual std::string Name() const = 0; + + protected: + bool is_deterministic_{true}; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/allocator.h b/mindspore/ccsrc/minddata/dataset/util/allocator.h index 76b98a45b03..220d0a6d1b7 100644 --- a/mindspore/ccsrc/minddata/dataset/util/allocator.h +++ b/mindspore/ccsrc/minddata/dataset/util/allocator.h @@ -88,21 +88,21 @@ class Allocator { std::shared_ptr pool_; }; /// \brief It is a wrapper of unique_ptr with a custom Allocator class defined above -template -Status MakeUnique(std::unique_ptr> *out, Allocator alloc, size_t n, Args &&... args) { +template , typename... Args> +Status MakeUnique(std::unique_ptr> *out, C alloc, size_t n, Args &&... args) { RETURN_UNEXPECTED_IF_NULL(out); CHECK_FAIL_RETURN_UNEXPECTED(n > 0, "size must be positive"); try { T *data = alloc.allocate(n); if (!std::is_arithmetic::value) { for (auto i = 0; i < n; i++) { - std::allocator_traits>::construct(alloc, &(data[i]), std::forward(args)...); + std::allocator_traits::construct(alloc, &(data[i]), std::forward(args)...); } } - auto deleter = [](T *p, Allocator f_alloc, size_t f_n) { + auto deleter = [](T *p, C f_alloc, size_t f_n) { if (!std::is_arithmetic::value && std::is_destructible::value) { for (auto i = 0; i < f_n; ++i) { - std::allocator_traits>::destroy(f_alloc, &p[i]); + std::allocator_traits::destroy(f_alloc, &p[i]); } } f_alloc.deallocate(p, f_n); @@ -129,7 +129,7 @@ class MemGuard { MemGuard(const MemGuard &) = delete; MemGuard &operator=(const MemGuard &) = delete; // On the other hand, We can support move constructor - MemGuard(MemGuard &&lhs) noexcept : alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)), n_(lhs.n_) {} + MemGuard(MemGuard &&lhs) noexcept : n_(lhs.n_), alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)) {} MemGuard &operator=(MemGuard &&lhs) noexcept { if (this != &lhs) { this->deallocate(); diff --git a/mindspore/ccsrc/minddata/dataset/util/arena.cc b/mindspore/ccsrc/minddata/dataset/util/arena.cc index 85ce35e6610..b4864e26873 100644 --- a/mindspore/ccsrc/minddata/dataset/util/arena.cc +++ b/mindspore/ccsrc/minddata/dataset/util/arena.cc @@ -37,7 +37,8 @@ struct MemHdr { ArenaImpl::ArenaImpl(void *ptr, size_t sz) : size_in_bytes_(sz), ptr_(ptr) { // Divide the memory into blocks. Ignore the last partial block. uint64_t num_blks = size_in_bytes_ / ARENA_BLK_SZ; - MS_LOG(DEBUG) << "Size of memory pool is " << num_blks << ", number of blocks of size is " << ARENA_BLK_SZ << "."; + MS_LOG(DEBUG) << "Arena memory pool is created. Number of blocks : " << num_blks << ". Block size : " << ARENA_BLK_SZ + << "."; tr_.Insert(0, num_blks); } @@ -233,9 +234,9 @@ std::ostream &operator<<(std::ostream &os, const ArenaImpl &s) { Status Arena::Init() { try { - auto sz = size_in_MB_ * 1048576L; - mem_ = std::make_unique(sz); - impl_ = std::make_unique(mem_.get(), sz); + int64_t sz = size_in_MB_ * 1048576L; + RETURN_IF_NOT_OK(mem_.allocate(sz)); + impl_ = std::make_unique(mem_.GetMutablePointer(), sz); } catch (std::bad_alloc &e) { return Status(StatusCode::kOutOfMemory); } diff --git a/mindspore/ccsrc/minddata/dataset/util/arena.h b/mindspore/ccsrc/minddata/dataset/util/arena.h index 132ff0e7eb2..1409b91d823 100644 --- a/mindspore/ccsrc/minddata/dataset/util/arena.h +++ b/mindspore/ccsrc/minddata/dataset/util/arena.h @@ -19,6 +19,7 @@ #include #include #include +#include "minddata/dataset/util/allocator.h" #include "minddata/dataset/util/memory_pool.h" #include "minddata/dataset/util/treap.h" @@ -140,7 +141,7 @@ class Arena : public MemoryPool { protected: mutable std::mutex mux_; std::unique_ptr impl_; - std::unique_ptr mem_; + MemGuard mem_; size_t size_in_MB_; explicit Arena(size_t val_in_MB = 4096); diff --git a/mindspore/ccsrc/minddata/dataset/util/btree.h b/mindspore/ccsrc/minddata/dataset/util/btree.h index e1ac8520dce..69723ac2f6a 100644 --- a/mindspore/ccsrc/minddata/dataset/util/btree.h +++ b/mindspore/ccsrc/minddata/dataset/util/btree.h @@ -131,6 +131,23 @@ class BPlusTree { tree_stats() : size_(0), leaves_(0), inner_nodes_(0), level_(0) {} }; + /// \brief Statistics functions + /// \return Return the height of the tree + auto GetHeight() const { return empty() ? 0 : stats_.level_ + 1; } + /// \return Order of the B+ tree + auto GetOrder() const { return traits::kLeafSlots; } + /// \return Number of leaves nodes + auto GetNumLeaves() const { return stats_.leaves_; } + /// \return Number of inner nodes + auto GetNumInnerNodes() const { return stats_.inner_nodes_; } + + /// \brief Toggle locking + /// \note Once locking is off. It is user's responsibility to ensure concurrency + void SetLocking(bool on_off) { + UniqueLock lck(&rw_lock_); + acquire_lock_ = on_off; + } + private: // Abstract class of a node (leaf or inner) class BaseNode { @@ -288,6 +305,17 @@ class BPlusTree { key_compare key_less_; // Stat tree_stats stats_; + // lock mode + bool acquire_lock_; + + void Init() { + typename LeafNode::alloc_type alloc(alloc_); + auto *p = alloc.allocate(1); + root_ = new (p) LeafNode(alloc_); + all_.Prepend(p); + leaf_nodes_.Append(p); + stats_.leaves_++; + } bool LessThan(const key_type &a, const key_type &b) const { return key_less_(a, b); } @@ -350,11 +378,11 @@ class BPlusTree { ~Iterator(); - explicit Iterator(const Iterator &); + Iterator(const Iterator &); Iterator &operator=(const Iterator &lhs); - explicit Iterator(Iterator &&); + Iterator(Iterator &&) noexcept; Iterator &operator=(Iterator &&lhs); @@ -399,11 +427,11 @@ class BPlusTree { ConstIterator(const LeafNode *leaf, slot_type slot, bool locked = false) : cur_(leaf), slot_(slot), locked_(locked) {} - explicit ConstIterator(const ConstIterator &); + ConstIterator(const ConstIterator &); ConstIterator &operator=(const ConstIterator &lhs); - explicit ConstIterator(ConstIterator &&); + ConstIterator(ConstIterator &&) noexcept; ConstIterator &operator=(ConstIterator &&lhs); diff --git a/mindspore/ccsrc/minddata/dataset/util/btree_impl.tpp b/mindspore/ccsrc/minddata/dataset/util/btree_impl.tpp index fc3b05e3a1e..7f4f2a2090d 100644 --- a/mindspore/ccsrc/minddata/dataset/util/btree_impl.tpp +++ b/mindspore/ccsrc/minddata/dataset/util/btree_impl.tpp @@ -413,11 +413,16 @@ typename BPlusTree::IndexRc BPlusTree::Locate(RWLo } template -BPlusTree::BPlusTree() : leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr) {} +BPlusTree::BPlusTree() + : leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr), acquire_lock_(true) { + Init(); +} template BPlusTree::BPlusTree(const Allocator &alloc) - : alloc_(alloc), leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr) {} + : alloc_(alloc), leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr), acquire_lock_(true) { + Init(); +} template BPlusTree::~BPlusTree() noexcept { @@ -446,20 +451,6 @@ BPlusTree::~BPlusTree() noexcept { template Status BPlusTree::DoInsert(const key_type &key, std::unique_ptr &&value) { IndexRc rc; - if (root_ == nullptr) { - UniqueLock lck(&rw_lock_); - // Check again after we get the lock. Other thread may have created the root node already. - if (root_ == nullptr) { - LeafNode *leaf = nullptr; - rc = AllocateLeaf(&leaf); - if (rc != IndexRc::kOk) { - return IndexRc2Status(rc); - } - leaf_nodes_.Append(leaf); - root_ = leaf; - } - // lock will be unlocked when it goes out of scope. - } bool retry = false; do { // Track all the paths to the target and lock each internal node in S. @@ -468,7 +459,7 @@ Status BPlusTree::DoInsert(const key_type &key, std::unique_ptr BPlusTree::DoUpdate(const key_type &key, std:: if (root_ != nullptr) { LeafNode *leaf = nullptr; slot_type slot; - RWLock *myLock = &this->rw_lock_; - // Lock the tree in S, pass the lock to Locate which will unlock it for us underneath. - myLock->LockShared(); + RWLock *myLock = nullptr; + if (acquire_lock_) { + myLock = &this->rw_lock_; + // Lock the tree in S, pass the lock to Locate which will unlock it for us underneath. + myLock->LockShared(); + } IndexRc rc = Locate(myLock, true, root_, key, &leaf, &slot); if (rc == IndexRc::kOk) { // All locks from the tree to the parent of leaf are all gone. We still have a X lock @@ -521,7 +515,9 @@ std::unique_ptr BPlusTree::DoUpdate(const key_type &key, std:: // Swap out the old value and replace it with new value. std::unique_ptr old = std::move(leaf->data_[leaf->slot_dir_[slot]]); leaf->data_[leaf->slot_dir_[slot]] = std::move(new_value); - leaf->rw_lock_.Unlock(); + if (acquire_lock_) { + leaf->rw_lock_.Unlock(); + } return old; } else { MS_LOG(DEBUG) << "Key not found. rc = " << static_cast(rc) << "."; diff --git a/mindspore/ccsrc/minddata/dataset/util/btree_iterator.tpp b/mindspore/ccsrc/minddata/dataset/util/btree_iterator.tpp index cffef3fa7a6..3e72013911b 100644 --- a/mindspore/ccsrc/minddata/dataset/util/btree_iterator.tpp +++ b/mindspore/ccsrc/minddata/dataset/util/btree_iterator.tpp @@ -109,7 +109,7 @@ BPlusTree::Iterator::Iterator(const BPlusTree::Ite } template -BPlusTree::Iterator::Iterator(BPlusTree::Iterator &&lhs) { +BPlusTree::Iterator::Iterator(BPlusTree::Iterator &&lhs) noexcept { this->cur_ = lhs.cur_; this->slot_ = lhs.slot_; this->locked_ = lhs.locked_; @@ -241,7 +241,7 @@ BPlusTree::ConstIterator::ConstIterator(const BPlusTree -BPlusTree::ConstIterator::ConstIterator(BPlusTree::ConstIterator &&lhs) { +BPlusTree::ConstIterator::ConstIterator(BPlusTree::ConstIterator &&lhs) noexcept { this->cur_ = lhs.cur_; this->slot_ = lhs.slot_; this->locked_ = lhs.locked_; @@ -290,9 +290,12 @@ std::pair::ConstIterator, bool> BPlusTreerw_lock_; - // Lock the tree in S, pass the lock to Locate which will unlock it for us underneath. - myLock->LockShared(); + RWLock *myLock = nullptr; + if (acquire_lock_) { + myLock = &this->rw_lock_; + // Lock the tree in S, pass the lock to Locate which will unlock it for us underneath. + myLock->LockShared(); + } IndexRc rc = Locate(myLock, false, root_, key, &leaf, &slot); bool find = (rc == IndexRc::kOk); return std::make_pair(ConstIterator(leaf, slot, find), find); @@ -306,9 +309,12 @@ std::pair::Iterator, bool> BPlusTreerw_lock_; - // Lock the tree in S, pass the lock to Locate which will unlock it for us underneath. - myLock->LockShared(); + RWLock *myLock = nullptr; + if (acquire_lock_) { + myLock = &this->rw_lock_; + // Lock the tree in S, pass the lock to Locate which will unlock it for us underneath. + myLock->LockShared(); + } IndexRc rc = Locate(myLock, false, root_, key, &leaf, &slot); bool find = (rc == IndexRc::kOk); return std::make_pair(Iterator(leaf, slot, find), find); diff --git a/mindspore/ccsrc/minddata/dataset/util/buddy.cc b/mindspore/ccsrc/minddata/dataset/util/buddy.cc index bbca344bb67..0cc2baed5f2 100644 --- a/mindspore/ccsrc/minddata/dataset/util/buddy.cc +++ b/mindspore/ccsrc/minddata/dataset/util/buddy.cc @@ -69,7 +69,7 @@ Status BuddySpace::Alloc(const uint64_t sz, BSpaceDescriptor *desc, addr_t *p) n *p = addr; return Status::OK(); } else { - return Status(StatusCode::kNoSpace, "BuddySpace full. Not an error. Please ignore."); + return Status(StatusCode::kBuddySpaceFull, "BuddySpace full. Not an error. Please ignore."); } } diff --git a/mindspore/ccsrc/minddata/dataset/util/cache_pool.cc b/mindspore/ccsrc/minddata/dataset/util/cache_pool.cc index 18016d6cea6..5972ada0220 100644 --- a/mindspore/ccsrc/minddata/dataset/util/cache_pool.cc +++ b/mindspore/ccsrc/minddata/dataset/util/cache_pool.cc @@ -20,8 +20,13 @@ namespace mindspore { namespace dataset { -CachePool::CachePool(const value_allocator &alloc, const std::string &root) - : alloc_(alloc), root_(root), subfolder_(Services::GetUniqueID()), sm_(nullptr), tree_(nullptr) {} +CachePool::CachePool(const value_allocator &alloc, bool ourOwnArena, const std::string &root) + : alloc_(alloc), + root_(root), + subfolder_(Services::GetUniqueID()), + sm_(nullptr), + tree_(nullptr), + custom_arena_(ourOwnArena) {} Status CachePool::DoServiceStart() { tree_ = std::make_shared(); @@ -45,9 +50,12 @@ Status CachePool::DoServiceStop() { } } sm_.reset(); - for (auto &bl : *tree_) { - if (bl.ptr != nullptr) { - alloc_.deallocate(bl.ptr, bl.sz); + // If it is our own arena, skip freeing individual pieces. + if (!custom_arena_) { + for (auto &bl : *tree_) { + if (bl.ptr != nullptr) { + alloc_.deallocate(bl.ptr, bl.sz); + } } } tree_.reset(); @@ -68,7 +76,7 @@ Status CachePool::DoServiceStop() { return rc2; } CachePool::~CachePool() noexcept { (void)ServiceStop(); } -Status CachePool::Insert(const std::vector &buf, CachePool::key_type *key) { +Status CachePool::Insert(CachePool::key_type key, const std::vector &buf, bool writeToDiskDirectly) { DataLocator bl; Status rc; size_t sz = 0; @@ -78,22 +86,31 @@ Status CachePool::Insert(const std::vector &buf, CachePool::key_t } bl.sz = sz; try { - bl.ptr = alloc_.allocate(sz); - // We will do a piecewise copy. - WritableSlice dest(bl.ptr, bl.sz); - size_t pos = 0; - for (auto &v : buf) { - WritableSlice out(dest, pos); - rc = WritableSlice::Copy(&out, v); - if (rc.IsError()) { - break; + if (!writeToDiskDirectly) { + bl.ptr = alloc_.allocate(sz); + // We will do a piecewise copy. + WritableSlice dest(bl.ptr, bl.sz); + size_t pos = 0; + for (auto &v : buf) { + WritableSlice out(dest, pos); + rc = WritableSlice::Copy(&out, v); + if (rc.IsError()) { + break; + } + pos += v.GetSize(); } - pos += v.GetSize(); - } - if (rc.IsError()) { - alloc_.deallocate(bl.ptr, sz); - bl.ptr = nullptr; - return rc; + if (rc.IsError()) { + alloc_.deallocate(bl.ptr, sz); + bl.ptr = nullptr; + return rc; + } + } else if (sm_ != nullptr) { + MS_LOG(DEBUG) << "Spill to disk directly ... " << bl.sz << " bytes."; + RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf)); + } else { + // If asked to spill to disk instead but there is no storage set up, simply return no memory + // instead. + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); } } catch (std::bad_alloc &e) { if (sm_ != nullptr) { @@ -102,7 +119,13 @@ Status CachePool::Insert(const std::vector &buf, CachePool::key_t return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); } } - rc = tree_->insert(bl, key); + // Insert into the B+ tree. We may still get out of memory error. So need to catch it. + try { + rc = tree_->DoInsert(key, bl); + } catch (const std::bad_alloc &e) { + rc = Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } + // Duplicate key is treated as error and we will also free the memory. if (rc.IsError() && bl.ptr != nullptr) { alloc_.deallocate(bl.ptr, sz); } @@ -138,15 +161,26 @@ Path CachePool::GetSpillPath() const { auto spill = Path(root_) / subfolder_; return spill; } -CachePool::CacheStat CachePool::GetStat() const { - CacheStat cs{0}; +CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const { + CacheStat cs{-1, -1, 0, 0, 0}; int64_t total_sz = 0; - for (auto &it : *tree_) { - total_sz += it.sz; - if (it.ptr != nullptr) { - ++cs.num_mem_cached; - } else { - ++cs.num_disk_cached; + if (tree_->begin() != tree_->end()) { + cs.min_key = tree_->begin().key(); + cs.max_key = cs.min_key; // will adjust later. + for (auto it = tree_->begin(); it != tree_->end(); ++it) { + total_sz += it.value().sz; + if (it.value().ptr != nullptr) { + ++cs.num_mem_cached; + } else { + ++cs.num_disk_cached; + } + auto cur_key = it.key(); + if (GetMissingKeys) { + for (auto i = cs.max_key + 1; i < cur_key; ++i) { + cs.gap.push_back((i)); + } + } + cs.max_key = cur_key; } } if (total_sz > 0) { diff --git a/mindspore/ccsrc/minddata/dataset/util/cache_pool.h b/mindspore/ccsrc/minddata/dataset/util/cache_pool.h index 3989941a337..77c1c06f24f 100644 --- a/mindspore/ccsrc/minddata/dataset/util/cache_pool.h +++ b/mindspore/ccsrc/minddata/dataset/util/cache_pool.h @@ -25,13 +25,13 @@ #include "minddata/dataset/util/slice.h" #include "minddata/dataset/util/storage_manager.h" #include "minddata/dataset/util/auto_index.h" +#include "minddata/dataset/util/btree.h" namespace mindspore { namespace dataset { /// \brief A CachePool provides service for backup/restore a buffer. A buffer can be represented in a form of vector of /// ReadableSlice where all memory blocks will be copied to one contiguous block which can be in memory or spilled to -/// disk (if a disk directory is provided). Every buffer insert will return a generated key which can be used to -/// restore the buffer. +/// disk (if a disk directory is provided). User must provide a key to insert the buffer. /// \see ReadableSlice class CachePool : public Service { public: @@ -73,22 +73,25 @@ class CachePool : public Service { StorageManager::key_type storage_key; }; - using data_index = AutoIndexObj; + using data_index = BPlusTree; using key_type = data_index::key_type; using bl_alloc_type = typename value_allocator::template rebind::other; /// \brief Simple statistics returned from CachePool like how many elements are cached in memory and /// how many elements are spilled to disk. struct CacheStat { + key_type min_key; + key_type max_key; int64_t num_mem_cached; int64_t num_disk_cached; int64_t average_cache_sz; + std::vector gap; }; /// \brief Constructor /// \param alloc Allocator to allocate memory from /// \param root Optional disk folder to spill - explicit CachePool(const value_allocator &alloc, const std::string &root = ""); + explicit CachePool(const value_allocator &alloc, bool customArena, const std::string &root = ""); CachePool(const CachePool &) = delete; CachePool(CachePool &&) = delete; @@ -103,10 +106,11 @@ class CachePool : public Service { /// \brief Insert a sequence of ReadableSlice objects into the pool. /// All memory blocks will be consolidated into one contiguous block and be cached in either memory or on disk. + /// \param[in] key User supplied key /// \param[in] buf A sequence of ReadableSlice objects. - /// \param[out] key Generated key + /// \param[in] writeToDiskDirectly If true, no spill to disk if spill is enabled, or return no memory /// \return Error code - Status Insert(const std::vector &buf, key_type *key); + Status Insert(key_type key, const std::vector &buf, bool writeToDiskDirectly); /// \brief Restore a cached buffer (from memory or disk) /// \param[in] key A previous key returned from Insert /// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice @@ -122,18 +126,23 @@ class CachePool : public Service { /// \brief Get statistics. /// \return CacheStat object - CacheStat GetStat() const; + CacheStat GetStat(bool GetMissingKeys = false) const; const value_allocator &get_allocator() const; std::string MyName() const { return subfolder_; } + /// \brief Toggle locking + /// \note Once locking is off. It is user's responsibility to ensure concurrency + void SetLocking(bool on_off) { tree_->SetLocking(on_off); } + private: value_allocator alloc_; Path root_; const std::string subfolder_; std::shared_ptr sm_; std::shared_ptr tree_; + bool custom_arena_; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/circular_pool.cc b/mindspore/ccsrc/minddata/dataset/util/circular_pool.cc index f99e6de2f1f..5f57d596eb0 100644 --- a/mindspore/ccsrc/minddata/dataset/util/circular_pool.cc +++ b/mindspore/ccsrc/minddata/dataset/util/circular_pool.cc @@ -133,12 +133,13 @@ void CircularPool::Deallocate(void *p) { // Lock in the chain in shared mode and find out which // segment it comes from SharedLock lock(&rw_lock_); - auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [p](std::shared_ptr &b) -> bool { + auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [this, p](std::shared_ptr &b) -> bool { char *q = reinterpret_cast(p); - char *base = const_cast(reinterpret_cast(b->get_base_addr())); - return (q > base && q < base + b->get_max_size()); + auto *base = reinterpret_cast(b->get_base_addr()); + return (q > base && q < base + arena_size_ * 1048576L); }); lock.Unlock(); + MS_ASSERT(it != mem_segments_.end()); it->get()->Deallocate(p); } @@ -150,10 +151,10 @@ Status CircularPool::Reallocate(void **pp, size_t old_sz, size_t new_sz) { } void *p = *pp; SharedLock lock(&rw_lock_); - auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [p](std::shared_ptr &b) -> bool { + auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [this, p](std::shared_ptr &b) -> bool { char *q = reinterpret_cast(p); - char *base = const_cast(reinterpret_cast(b->get_base_addr())); - return (q > base && q < base + b->get_max_size()); + auto *base = reinterpret_cast(b->get_base_addr()); + return (q > base && q < base + arena_size_ * 1048576L); }); lock.Unlock(); MS_ASSERT(it != mem_segments_.end()); diff --git a/mindspore/ccsrc/minddata/dataset/util/queue_map.h b/mindspore/ccsrc/minddata/dataset/util/queue_map.h index 3951ec14ce2..f04f7996044 100644 --- a/mindspore/ccsrc/minddata/dataset/util/queue_map.h +++ b/mindspore/ccsrc/minddata/dataset/util/queue_map.h @@ -16,11 +16,14 @@ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_ +#include #include +#include #include #include #include #include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/system_pool.h" #include "minddata/dataset/util/semaphore.h" #include "minddata/dataset/util/services.h" namespace mindspore { @@ -37,7 +40,7 @@ class QueueMap { using key_type = K; using value_type = T; - QueueMap() = default; + QueueMap() : num_rows_(0) {} virtual ~QueueMap() = default; /// Add an element to the map and wake up any consumer that is waiting @@ -48,6 +51,7 @@ class QueueMap { RequestQueue *rq = nullptr; RETURN_IF_NOT_OK(GetRq(key, &rq)); RETURN_IF_NOT_OK(rq->WakeUpAny(std::move(payload))); + ++num_rows_; return Status::OK(); } @@ -56,9 +60,35 @@ class QueueMap { RequestQueue *rq = nullptr; RETURN_IF_NOT_OK(GetRq(key, &rq)); RETURN_IF_NOT_OK(rq->Wait(out)); + --num_rows_; return Status::OK(); } + /// Get the number of elements in the container + /// \return The number of elements in the container + int64_t size() const { return num_rows_; } + + /// \return if the container is empty + bool empty() const { return num_rows_ == 0; } + + /// Print out some useful information about the container + friend std::ostream &operator<<(std::ostream &out, const QueueMap &qm) { + std::unique_lock lck(qm.mux_); + out << "Number of elements: " << qm.num_rows_ << "\n"; + out << "Dumping internal info:\n"; + int64_t k = 0; + for (auto &it : qm.all_) { + auto key = it.first; + const RequestQueue *rq = it.second.GetPointer(); + out << "(k:" << key << "," << *rq << ") "; + ++k; + if (k % 6 == 0) { + out << "\n"; + } + } + return out; + } + protected: /// This is a handshake structure between producer and consumer class RequestQueue { @@ -86,8 +116,13 @@ class QueueMap { return Status::OK(); } + friend std::ostream &operator<<(std::ostream &out, const RequestQueue &rq) { + out << "sz:" << rq.row_.size() << ",uc:" << rq.use_count_.Peek(); + return out; + } + private: - std::mutex dq_mux_; + mutable std::mutex dq_mux_; Semaphore use_count_; std::deque row_; }; @@ -104,7 +139,7 @@ class QueueMap { *out = it->second.GetMutablePointer(); } else { // We will create a new one. - auto alloc = Services::GetAllocator(); + auto alloc = SystemPool::GetAllocator(); auto r = all_.emplace(key, MemGuard>(alloc)); if (r.second) { auto &mem = r.first->second; @@ -118,8 +153,9 @@ class QueueMap { } private: - std::mutex mux_; + mutable std::mutex mux_; std::map>> all_; + std::atomic num_rows_; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/semaphore.cc b/mindspore/ccsrc/minddata/dataset/util/semaphore.cc index 5dadd98f3ca..15e6a04ef1f 100644 --- a/mindspore/ccsrc/minddata/dataset/util/semaphore.cc +++ b/mindspore/ccsrc/minddata/dataset/util/semaphore.cc @@ -29,10 +29,7 @@ void Semaphore::V() { ++value_; wait_cond_.NotifyOne(); } -int Semaphore::Peek() { - std::unique_lock lck(mutex_); - return value_; -} +int Semaphore::Peek() const { return value_; } Status Semaphore::Register(TaskGroup *vg) { return wait_cond_.Register(vg->GetIntrpService()); } Status Semaphore::Deregister() { return (wait_cond_.Deregister()); } void Semaphore::ResetIntrpState() { wait_cond_.ResetIntrpState(); } diff --git a/mindspore/ccsrc/minddata/dataset/util/semaphore.h b/mindspore/ccsrc/minddata/dataset/util/semaphore.h index 6d604d00438..35fd09ed5a3 100644 --- a/mindspore/ccsrc/minddata/dataset/util/semaphore.h +++ b/mindspore/ccsrc/minddata/dataset/util/semaphore.h @@ -38,7 +38,7 @@ class Semaphore { void V(); /// \brief Peek the internal value /// \return The internal value - int Peek(); + int Peek() const; Status Register(TaskGroup *vg); Status Deregister(); void ResetIntrpState(); diff --git a/mindspore/ccsrc/minddata/dataset/util/status.cc b/mindspore/ccsrc/minddata/dataset/util/status.cc index 0ad45f71244..b2ab3362e95 100644 --- a/mindspore/ccsrc/minddata/dataset/util/status.cc +++ b/mindspore/ccsrc/minddata/dataset/util/status.cc @@ -51,6 +51,12 @@ std::string CodeAsString(const StatusCode c) { case StatusCode::kSyntaxError: s = "Syntax error"; break; + case StatusCode::kBuddySpaceFull: + s = "BuddySpace full"; + break; + case StatusCode::kNetWorkError: + s = "Network error"; + break; case StatusCode::kUnexpectedError: default: s = "Unexpected error"; diff --git a/mindspore/ccsrc/minddata/dataset/util/status.h b/mindspore/ccsrc/minddata/dataset/util/status.h index 0caa6e6895c..5bb9b5815e8 100644 --- a/mindspore/ccsrc/minddata/dataset/util/status.h +++ b/mindspore/ccsrc/minddata/dataset/util/status.h @@ -82,6 +82,8 @@ enum class StatusCode : char { kBoundingBoxInvalidShape = 12, kSyntaxError = 13, kTimeOut = 14, + kBuddySpaceFull = 14, + kNetWorkError = 15, // Make this error code the last one. Add new error code above it. kUnexpectedError = 127 }; @@ -137,6 +139,8 @@ class Status { bool IsNoSpace() const { return (get_code() == StatusCode::kNoSpace); } + bool IsNetWorkError() const { return (get_code() == StatusCode::kNetWorkError); } + private: StatusCode code_; std::string err_msg_; diff --git a/mindspore/ccsrc/minddata/dataset/util/storage_container.cc b/mindspore/ccsrc/minddata/dataset/util/storage_container.cc index b64926e5969..b0855512cdc 100644 --- a/mindspore/ccsrc/minddata/dataset/util/storage_container.cc +++ b/mindspore/ccsrc/minddata/dataset/util/storage_container.cc @@ -99,7 +99,11 @@ Status StorageContainer::Write(const ReadableSlice &dest, off64_t offset) const #endif if (r_sz != sz) { errno_t err = (r_sz == 0) ? EOF : errno; - RETURN_STATUS_UNEXPECTED(strerror(err)); + if (errno == ENOSPC) { + return Status(StatusCode::kNoSpace, __LINE__, __FILE__); + } else { + RETURN_STATUS_UNEXPECTED(strerror(err)); + } } return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/util/storage_manager.cc b/mindspore/ccsrc/minddata/dataset/util/storage_manager.cc index 82f39fd7ae7..8c99e5a91e9 100644 --- a/mindspore/ccsrc/minddata/dataset/util/storage_manager.cc +++ b/mindspore/ccsrc/minddata/dataset/util/storage_manager.cc @@ -71,10 +71,11 @@ Status StorageManager::Write(key_type *key, const std::vector &bu key_type out_key; value_type out_value; bool create_new_container = false; + size_t last_num_container = -1; do { SharedLock lock_s(&rw_lock_); size_t num_containers = containers_.size(); - if (create_new_container) { + if (create_new_container && (num_containers == last_num_container)) { // Upgrade to exclusvie lock. lock_s.Upgrade(); create_new_container = false; @@ -95,8 +96,11 @@ Status StorageManager::Write(key_type *key, const std::vector &bu cont = containers_.at(num_containers - 1); off64_t offset; Status rc = cont->Insert(buf, &offset); - if (rc.IsNoSpace()) { + if (rc.get_code() == StatusCode::kBuddySpaceFull) { create_new_container = true; + // Remember how many containers we saw. In the next iteration we will do a comparision to see + // if someone has already created it. + last_num_container = num_containers; } else if (rc.IsOk()) { out_value = std::make_pair(num_containers - 1, std::make_pair(offset, sz)); RETURN_IF_NOT_OK(index_.insert(out_value, &out_key)); diff --git a/mindspore/dataset/engine/cache_client.py b/mindspore/dataset/engine/cache_client.py index 3128432b06c..2abe7aa4b35 100644 --- a/mindspore/dataset/engine/cache_client.py +++ b/mindspore/dataset/engine/cache_client.py @@ -15,6 +15,7 @@ """Cache client """ +import os import copy from ..core.validator_helpers import type_check, check_uint32, check_uint64 @@ -25,11 +26,11 @@ class DatasetCache: A client to interface with tensor caching service """ - def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, prefetch_size=20): + def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, num_connections=None, + prefetch_size=None): check_uint32(session_id, "session_id") check_uint64(size, "size") type_check(spilling, (bool,), "spilling") - check_uint32(prefetch_size, "prefetch size") self.session_id = session_id self.size = size @@ -37,8 +38,13 @@ class DatasetCache: self.hostname = hostname self.port = port self.prefetch_size = prefetch_size - # temporary disable cache feature in the current release - self.cache_client = None + self.num_connections = num_connections + if os.getenv('MS_ENABLE_CACHE') != 'TRUE': + # temporary disable cache feature in the current release + self.cache_client = None + else: + from mindspore._c_dataengine import CacheClient + self.cache_client = CacheClient(session_id, size, spilling, hostname, port, num_connections, prefetch_size) def GetStat(self): return self.cache_client.GetStat() @@ -55,5 +61,6 @@ class DatasetCache: new_cache.hostname = copy.deepcopy(self.hostname, memodict) new_cache.port = copy.deepcopy(self.port, memodict) new_cache.prefetch_size = copy.deepcopy(self.prefetch_size, memodict) + new_cache.num_connections = copy.deepcopy(self.num_connections, memodict) new_cache.cache_client = self.cache_client return new_cache diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 2082f5be2ca..c17857e3f6f 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -1234,5 +1234,8 @@ def check_paddeddataset(method): def check_cache_option(cache): """Sanity check for cache parameter""" if cache is not None: - # temporary disable cache feature in the current release - raise ValueError("Caching is disabled in the current release") + if os.getenv('MS_ENABLE_CACHE') != 'TRUE': + # temporary disable cache feature in the current release + raise ValueError("Caching is disabled in the current release") + from . import cache_client + type_check(cache, (cache_client.DatasetCache,), "cache") diff --git a/setup.py b/setup.py index fed9ffa0e83..8e548bf3c52 100644 --- a/setup.py +++ b/setup.py @@ -156,6 +156,24 @@ def update_permissions(path): if filename == "ms_serving": os.chmod(file_fullpath, stat.S_IREAD | stat.S_IEXEC) +def bin_files(): + """ + Gets the binary files to be installed. + """ + data_files = [] + binary_files = [] + + cache_server_bin = os.path.join('mindspore', 'bin', 'cache_server') + if not os.path.exists(cache_server_bin): + return data_files + binary_files.append(cache_server_bin) + cache_admin_bin = os.path.join('mindspore', 'bin', 'cache_admin') + if not os.path.exists(cache_admin_bin): + return data_files + binary_files.append(cache_admin_bin) + data_files.append(('bin', binary_files)) + return data_files + class EggInfo(egg_info): """Egg info.""" @@ -192,6 +210,7 @@ setup( 'framework that could be used for mobile, edge and cloud scenarios.', long_description="\n\n".join([readme, release]), long_description_content_type="text/markdown", + data_files=bin_files(), packages=find_packages(), package_data=package_data, include_package_data=True, diff --git a/tests/ut/cpp/dataset/btree_test.cc b/tests/ut/cpp/dataset/btree_test.cc index 9fa4fce8126..5e309354cf9 100644 --- a/tests/ut/cpp/dataset/btree_test.cc +++ b/tests/ut/cpp/dataset/btree_test.cc @@ -24,27 +24,26 @@ #include "utils/log_adapter.h" using namespace mindspore::dataset; -using mindspore::MsLogLevel::INFO; -using mindspore::ExceptionType::NoExceptionType; using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::INFO; // For testing purposes, we will make the branching factor very low. struct mytraits { - using slot_type = uint16_t; - static const slot_type kLeafSlots = 6; - static const slot_type kInnerSlots = 3; + using slot_type = uint16_t; + static const slot_type kLeafSlots = 6; + static const slot_type kInnerSlots = 3; }; - class MindDataTestBPlusTree : public UT::Common { public: - MindDataTestBPlusTree() = default; + MindDataTestBPlusTree() = default; }; // Test serial insert. TEST_F(MindDataTestBPlusTree, Test1) { Allocator alloc(std::make_shared()); - BPlusTree, std::less, mytraits> btree(alloc); + BPlusTree, std::less<>, mytraits> btree(alloc); Status rc; for (int i = 0; i < 100; i++) { uint64_t key = 2 * i; @@ -109,23 +108,24 @@ TEST_F(MindDataTestBPlusTree, Test1) { // Test concurrent insert. TEST_F(MindDataTestBPlusTree, Test2) { Allocator alloc(std::make_shared()); - BPlusTree, std::less, mytraits> btree(alloc); + BPlusTree, std::less<>, mytraits> btree(alloc); TaskGroup vg; auto f = [&](int k) -> Status { TaskManager::FindMe()->Post(); - for (int i = 0; i < 100; i++) { - uint64_t key = k * 100 + i; - std::ostringstream oss; - oss << "Hello World. I am " << key; - Status rc = btree.DoInsert(key, oss.str()); - EXPECT_TRUE(rc.IsOk()); - } - return Status::OK(); + for (int i = 0; i < 100; i++) { + uint64_t key = k * 100 + i; + std::ostringstream oss; + oss << "Hello World. I am " << key; + Status rc = btree.DoInsert(key, oss.str()); + EXPECT_TRUE(rc.IsOk()); + } + return Status::OK(); }; auto g = [&](int k) -> Status { TaskManager::FindMe()->Post(); for (int i = 0; i < 1000; i++) { - uint64_t key = rand() % 10000;; + uint64_t key = rand() % 10000; + ; auto it = btree.Search(key); } return Status::OK(); @@ -226,3 +226,22 @@ TEST_F(MindDataTestBPlusTree, Test4) { EXPECT_EQ(cnt, 1000); } } + +TEST_F(MindDataTestBPlusTree, TestPerfNoLocking) { + AutoIndexObj btree; + // No locking test + btree.SetLocking(false); + // Insert a million entries using the default traits. + for (auto i = 0; i < 1000000; ++i) { + ASSERT_TRUE(btree.insert(i)); + } + std::cout << "Tree height : " << btree.GetHeight() << std::endl; + std::cout << "Tree Order : " << btree.GetOrder() << std::endl; + std::cout << "Number of leaves : " << btree.GetNumLeaves() << std::endl; + std::cout << "Number of inner nodes : " << btree.GetNumInnerNodes() << std::endl; + + auto r = btree.Search(3); + EXPECT_TRUE(r.second); + r = btree.Search(999999); + EXPECT_TRUE(r.second); +} diff --git a/tests/ut/cpp/dataset/cache_op_test.cc b/tests/ut/cpp/dataset/cache_op_test.cc index d408ae3e72a..fcfc42a1be3 100644 --- a/tests/ut/cpp/dataset/cache_op_test.cc +++ b/tests/ut/cpp/dataset/cache_op_test.cc @@ -35,6 +35,23 @@ using mindspore::dataset::TaskGroup; using mindspore::ExceptionType::NoExceptionType; using mindspore::MsLogLevel::INFO; +// Helper function to get the session id from SESSION_ID env variable +Status GetSessionFromEnv(session_id_type *session_id) { + RETURN_UNEXPECTED_IF_NULL(session_id); + if (const char *session_env = std::getenv("SESSION_ID")) { + std::string session_id_str(session_env); + try { + *session_id = std::stoul(session_id_str); + } catch (const std::exception &e) { + std::string err_msg = "Invalid numeric value for session id in env var: " + session_id_str; + return Status(StatusCode::kSyntaxError, err_msg); + } + } else { + RETURN_STATUS_UNEXPECTED("Test case requires a session id to be provided via SESSION_ID environment variable."); + } + return Status::OK(); +} + class MindDataTestCacheOp : public UT::DatasetOpTesting { public: void SetUp() override { @@ -46,8 +63,12 @@ class MindDataTestCacheOp : public UT::DatasetOpTesting { TEST_F(MindDataTestCacheOp, DISABLED_TestCacheServer) { Status rc; CacheClient::Builder builder; + session_id_type env_session; + rc = GetSessionFromEnv(&env_session); + ASSERT_TRUE(rc.IsOk()); + // use arbitrary session of 1, size of 0, spilling// is true - builder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true); + builder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true); std::shared_ptr myClient; rc = builder.Build(&myClient); ASSERT_TRUE(rc.IsOk()); @@ -118,9 +139,6 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheServer) { cmp = (map_out == map); ASSERT_TRUE(cmp); - // Test Purge and Destroy - rc = myClient->PurgeCache(); - ASSERT_TRUE(rc.IsOk()); rc = myClient->DestroyCache(); ASSERT_TRUE(rc.IsOk()); } @@ -130,10 +148,15 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestConcurrencyRequest) { (void)TaskManager::GetMasterThreadRc(); TaskGroup vg; Status rc; + + session_id_type env_session; + rc = GetSessionFromEnv(&env_session); + ASSERT_TRUE(rc.IsOk()); + // use arbitrary session of 1, size 1, spilling is true CacheClient::Builder builder; // use arbitrary session of 1, size of 0, spilling// is true - builder.SetSessionId(1).SetCacheMemSz(1).SetSpill(true); + builder.SetSessionId(env_session).SetCacheMemSz(1).SetSpill(true); std::shared_ptr myClient; rc = builder.Build(&myClient); ASSERT_TRUE(rc.IsOk()); @@ -199,8 +222,15 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestConcurrencyRequest) { // RandomDataOp // TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) { + // Clear the rc of the master thread if any + (void)TaskManager::GetMasterThreadRc(); Status rc; int32_t rank = 0; // not used + + session_id_type env_session; + rc = GetSessionFromEnv(&env_session); + ASSERT_TRUE(rc.IsOk()); + MS_LOG(INFO) << "UT test TestRandomDataCache1"; // Start with an empty execution tree auto myTree = std::make_shared(); @@ -236,8 +266,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) { // CacheOp // size of 0, spilling is true CacheClient::Builder builder; - // use arbitrary session of 1, size of 0, spilling// is true - builder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true); + builder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true); std::shared_ptr myClient; rc = builder.Build(&myClient); ASSERT_TRUE(rc.IsOk()); @@ -273,7 +302,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) { ASSERT_TRUE(rc.IsOk()); MS_LOG(INFO) << "Launching tree and begin iteration"; - rc = myTree->Prepare(); + rc = myTree->Prepare(1); ASSERT_TRUE(rc.IsOk()); // quick check to see what tree looks like @@ -314,9 +343,16 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) { //// RandomDataOp //// TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) { + // Clear the rc of the master thread if any + (void)TaskManager::GetMasterThreadRc(); Status rc; int32_t rank = 0; // not used MS_LOG(INFO) << "UT test TestRandomDataCacheSpill"; + + session_id_type env_session; + rc = GetSessionFromEnv(&env_session); + ASSERT_TRUE(rc.IsOk()); + // Start with an empty execution tree auto myTree = std::make_shared(); @@ -353,8 +389,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) { int64_t start_index = 0; auto seq_sampler = std::make_shared(num_samples, start_index); CacheClient::Builder builder; - // use arbitrary session of 1, size of 0, spilling// is true - builder.SetSessionId(1).SetCacheMemSz(4).SetSpill(true); + builder.SetSessionId(env_session).SetCacheMemSz(4).SetSpill(true); std::shared_ptr myClient; rc = builder.Build(&myClient); ASSERT_TRUE(rc.IsOk()); @@ -386,7 +421,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) { ASSERT_TRUE(rc.IsOk()); MS_LOG(INFO) << "Launching tree and begin iteration"; - rc = myTree->Prepare(); + rc = myTree->Prepare(1); ASSERT_TRUE(rc.IsOk()); std::cout << *myClient << std::endl; @@ -413,14 +448,20 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) { } TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { + // Clear the rc of the master thread if any + (void)TaskManager::GetMasterThreadRc(); Status rc; int64_t num_samples = 0; int64_t start_index = 0; + + session_id_type env_session; + rc = GetSessionFromEnv(&env_session); + ASSERT_TRUE(rc.IsOk()); + auto seq_sampler = std::make_shared(num_samples, start_index); CacheClient::Builder ccbuilder; - // use arbitrary session of 1, size of 0, spilling// is true - ccbuilder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true); + ccbuilder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true); std::shared_ptr myClient; rc = ccbuilder.Build(&myClient); ASSERT_TRUE(rc.IsOk()); @@ -468,7 +509,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { rc = myCacheOp->AddChild(so); ASSERT_TRUE(rc.IsOk()); - rc = myTree->Prepare(); + rc = myTree->Prepare(1); ASSERT_TRUE(rc.IsOk()); rc = myTree->Launch(); ASSERT_TRUE(rc.IsOk()); @@ -507,10 +548,16 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { //// RandomDataOp //// TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) { + // Clear the rc of the master thread if any + (void)TaskManager::GetMasterThreadRc(); Status rc; int32_t rank = 0; // not used MS_LOG(INFO) << "UT test TestCacheInheritSampler"; + session_id_type env_session; + rc = GetSessionFromEnv(&env_session); + ASSERT_TRUE(rc.IsOk()); + int64_t num_samples = 0; int64_t start_index = 0; auto seq_sampler = std::make_shared(num_samples, start_index); @@ -550,7 +597,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) { // CacheOp CacheClient::Builder ccbuilder; // use arbitrary session of 1, size of 0, spilling// is true - ccbuilder.SetSessionId(1).SetCacheMemSz(4).SetSpill(true); + ccbuilder.SetSessionId(env_session).SetCacheMemSz(4).SetSpill(true); std::shared_ptr myClient; rc = ccbuilder.Build(&myClient); ASSERT_TRUE(rc.IsOk()); @@ -577,7 +624,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) { ASSERT_TRUE(rc.IsOk()); MS_LOG(INFO) << "Launching tree and begin iteration"; - rc = myTree->Prepare(); + rc = myTree->Prepare(1); ASSERT_TRUE(rc.IsOk()); std::cout << *myClient << std::endl; diff --git a/tests/ut/cpp/dataset/memory_pool_test.cc b/tests/ut/cpp/dataset/memory_pool_test.cc index b5907655dc7..2981a63708b 100644 --- a/tests/ut/cpp/dataset/memory_pool_test.cc +++ b/tests/ut/cpp/dataset/memory_pool_test.cc @@ -25,13 +25,13 @@ using namespace mindspore::dataset; class MindDataTestMemoryPool : public UT::Common { public: - std::shared_ptr mp_; - MindDataTestMemoryPool() {} + std::shared_ptr mp_; + MindDataTestMemoryPool() {} - void SetUp() { - Status rc = CircularPool::CreateCircularPool(&mp_, 1, 1, true); - ASSERT_TRUE(rc.IsOk()); - } + void SetUp() { + Status rc = CircularPool::CreateCircularPool(&mp_, 1, 1, true); + ASSERT_TRUE(rc.IsOk()); + } }; TEST_F(MindDataTestMemoryPool, DumpPoolInfo) { @@ -40,7 +40,7 @@ TEST_F(MindDataTestMemoryPool, DumpPoolInfo) { TEST_F(MindDataTestMemoryPool, TestOperator1) { Status rc; - int *p = new(&rc, mp_) int; + int *p = new (&rc, mp_) int; ASSERT_TRUE(rc.IsOk()); *p = 2048; ::operator delete(p, mp_); @@ -61,12 +61,11 @@ TEST_F(MindDataTestMemoryPool, TestOperator3) { TEST_F(MindDataTestMemoryPool, TestAllocator) { class A { public: - explicit A (int x) : a(x) {} - int val_a() const { - return a; - } + explicit A(int x) : a(x) {} + int val_a() const { return a; } + private: - int a; + int a; }; Allocator alloc(mp_); std::shared_ptr obj_a = std::allocate_shared(alloc, 3); @@ -74,3 +73,16 @@ TEST_F(MindDataTestMemoryPool, TestAllocator) { ASSERT_EQ(v, 3); MS_LOG(DEBUG) << *(std::dynamic_pointer_cast(mp_)) << std::endl; } + +TEST_F(MindDataTestMemoryPool, TestMemGuard) { + MemGuard mem; + // Try some large value. + int64_t sz = 5LL * 1024LL * 1024LL * 1024LL; + Status rc = mem.allocate(sz); + ASSERT_TRUE(rc.IsOk() || rc.IsOutofMemory()); + if (rc.IsOk()) { + // Try write a character half way. + auto *p = mem.GetMutablePointer(); + p[sz / 2] = 'a'; + } +} diff --git a/tests/ut/python/cachetests/cachetest.sh b/tests/ut/python/cachetests/cachetest.sh new file mode 100755 index 00000000000..8db7d12c400 --- /dev/null +++ b/tests/ut/python/cachetests/cachetest.sh @@ -0,0 +1,48 @@ +#!/bin/bash +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# This script is the driver of the individual test scenarios +CURRPATH=$(cd $(dirname $0); pwd) + +echo "----------------------------------------------" +echo "Invalid syntax and cache_admin failure testing" +echo "----------------------------------------------" +echo +${CURRPATH}/cachetest_args.sh +num_failures=$? +echo +echo "Invalid syntax and cache_admin failure testing complete. Number of failures: $num_failures" +echo + +echo "----------------------------------------------" +echo "Test pipelines with cache (python)" +echo "----------------------------------------------" +echo +${CURRPATH}/cachetest_py.sh +num_failures=$? +echo +echo "Test pipelines with cache complete. Number of failures: $num_failures" +echo + +echo "----------------------------------------------" +echo "Cache cpp tests" +echo "----------------------------------------------" +echo +${CURRPATH}/cachetest_cpp.sh +num_failures=$? +echo +echo "Cache cpp tests complete. Number of failures: $num_failures" +echo diff --git a/tests/ut/python/cachetests/cachetest_args.sh b/tests/ut/python/cachetests/cachetest_args.sh new file mode 100755 index 00000000000..7f6f3082f83 --- /dev/null +++ b/tests/ut/python/cachetests/cachetest_args.sh @@ -0,0 +1,207 @@ +#!/bin/bash +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# source the globals and functions for use with cache testing +SKIP_ADMIN_COUNTER=false +. cachetest_lib.sh +echo + +################################################################################ +# Cache testing: cache_admin argument testing # +# Summary: Various tests that expect to get failure messages returned # +################################################################################ + +# Double-command test +cmd="${CACHE_ADMIN} --start --stop" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 + +# missing command test +cmd="${CACHE_ADMIN} --port 50082" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 + +# bad arg test +cmd="${CACHE_ADMIN} -p abc --start" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 + +# missing arg test +cmd="${CACHE_ADMIN} -p --start" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 + +# invalid command +cmd="${CACHE_ADMIN} -p 50082 --start --not_exist_cmd" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 + +# spill directory does not exist +cmd="${CACHE_ADMIN} --start --spilldir /path_that_does_not_exist" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 + +# start cache server twice +StartServer +HandleRcExit $? 1 1 +# start the cache server again, however, this time we expect an error +cmd="${CACHE_ADMIN} --start" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 1 +StopServer +HandleRcExit $? 1 1 + +# start cache server twice with different ports +# this one starts with the default port 50052 +StartServer +HandleRcExit $? 1 1 +# this one starts with port 50053 +cmd="${CACHE_ADMIN} --start -p 50053" +CacheAdminCmd "${cmd}" 0 +HandleRcExit $? 1 1 +# stop the cache server with default port +StopServer +HandleRcExit $? 1 1 +# stop the cache server with port 50053 +cmd="${CACHE_ADMIN} --stop -p 50053" +CacheAdminCmd "${cmd}" 0 +HandleRcExit $? 1 1 + +# stop the cache server without bringing it up +cmd="${CACHE_ADMIN} --stop" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 1 + +# start the cache server with illegal hostname +cmd="${CACHE_ADMIN} --start -h 0.0.0.0" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 1 +cmd="${CACHE_ADMIN} --start -h illegal" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 1 +cmd="${CACHE_ADMIN} --start -h" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 1 +cmd="${CACHE_ADMIN} --start -h --hostname" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 1 +cmd="${CACHE_ADMIN} --start -h --hostname 127.0.0.1" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 1 + +# start the cache server with illegal port +cmd="${CACHE_ADMIN} --start -p 0" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 1 +cmd="${CACHE_ADMIN} --start -p -1" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 1 +cmd="${CACHE_ADMIN} --start -p 65536" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 1 +cmd="${CACHE_ADMIN} --start -p illegal" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 1 +cmd="${CACHE_ADMIN} --start -p" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 1 + +# find a port that is occupied using netstat +if [ -x "$(command -v netstat)" ]; then + port=$(netstat -ntp | grep -v '::' | awk '{print $4}' | grep -E '^[[:digit:]]+' | awk -F: '{print $2}' | sort -n | tail -n 1) + if [ ${port} -gt 1025 ]; then + # start cache server with occupied port + cmd="${CACHE_ADMIN} --start -p ${port}" + CacheAdminCmd "${cmd}" 1 + HandleRcExit $? 0 1 + fi +fi + +# generate session before starting the cache server +cmd="${CACHE_ADMIN} -g" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 + +# illegal generate session command +StartServer +HandleRcExit $? 1 1 +cmd="${CACHE_ADMIN} -g 1" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 + +# illegal destroy session command +cmd="${CACHE_ADMIN} -d -2" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 +cmd="${CACHE_ADMIN} -d illegal" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 +cmd="${CACHE_ADMIN} -d" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 +# destroy a non-existing session +cmd="${CACHE_ADMIN} -d 99999" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 + +# stop cache server at this point +StopServer +HandleRcExit $? 1 1 + +# illegal number of workers +cmd="${CACHE_ADMIN} --start -w 0" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 +cmd="${CACHE_ADMIN} --start -w -1" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 +cmd="${CACHE_ADMIN} --start -w illegal" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 +cmd="${CACHE_ADMIN} --start -w 101" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 +cmd="${CACHE_ADMIN} --start -w 9999999" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 +cmd="${CACHE_ADMIN} --start -w" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 + +# illegal spill path +cmd="${CACHE_ADMIN} --start -s" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 + +# spill path without writing perm +if [ "$EUID" -ne 0 ]; then + cmd="${CACHE_ADMIN} --start -s /" + CacheAdminCmd "${cmd}" 1 + HandleRcExit $? 0 0 +fi + +# illegal log level +cmd="${CACHE_ADMIN} --start -l 4" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 +cmd="${CACHE_ADMIN} --start -l -1" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 +cmd="${CACHE_ADMIN} --start -l" +CacheAdminCmd "${cmd}" 1 +HandleRcExit $? 0 0 + +exit ${failed_tests} diff --git a/tests/ut/python/cachetests/cachetest_cpp.sh b/tests/ut/python/cachetests/cachetest_cpp.sh new file mode 100755 index 00000000000..875ebbd5b55 --- /dev/null +++ b/tests/ut/python/cachetests/cachetest_cpp.sh @@ -0,0 +1,72 @@ +#!/bin/bash +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# source the globals and functions for use with cache testing +SKIP_ADMIN_COUNTER=true +. cachetest_lib.sh +echo + +################################################################################ +# Cache testing: cache cpp test driver # +# Summary: A launcher for invoking cpp cache tests # +################################################################################ + +UT_TEST_DIR="${BUILD_PATH}/mindspore/tests/ut/cpp" +DateStamp=$(date +%Y%m%d_%H%M%S); +CPP_TEST_LOG_OUTPUT="/tmp/ut_tests_cache_${DateStamp}.log" + +# Start a basic cache server to be used for all tests +StartServer +HandleRcExit $? 1 1 + +# Set the environment variable to enable these pytests +export RUN_CACHE_TEST=TRUE +GTEST_FILTER_OLD=$GTEST_FILTER +export GTEST_FILTER="MindDataTestCacheOp.*" +export GTEST_ALSO_RUN_DISABLED_TESTS=1 + +# All of the cpp tests run under the same session +GetSession +HandleRcExit $? 1 1 +export SESSION_ID=$session_id + +test_count=$(($test_count+1)) +cd ${UT_TEST_DIR} +cmd="${UT_TEST_DIR}/ut_tests" +echo "Test ${test_count}: ${cmd}" +MsgEnter "Run test ${test_count}" +${cmd} > ${CPP_TEST_LOG_OUTPUT} 2>&1 +rc=$? +if [ ${rc} -ne 0 ]; then + MsgFail "FAILED" + MsgError "Invoking cpp tests failed!" "${rc}" "See log: ${CPP_TEST_LOG_OUTPUT}" +else + MsgOk "OK" +fi +echo +HandleRcExit $rc 1 0 + +cd ${CURRPATH} + +StopServer +HandleRcExit $? 1 0 + +# restore old env var +export GTEST_FILTER=$GTEST_FILTER_OLD +unset RUN_CACHE_TEST +unset GTEST_ALSO_RUN_DISABLED_TESTS + +exit ${failed_tests} diff --git a/tests/ut/python/cachetests/cachetest_lib.sh b/tests/ut/python/cachetests/cachetest_lib.sh new file mode 100755 index 00000000000..e9bfcfc7756 --- /dev/null +++ b/tests/ut/python/cachetests/cachetest_lib.sh @@ -0,0 +1,336 @@ +#!/bin/bash +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# This file is a collection of functions and globals that are common to the +# test scenarios for cache op testing. + +# Set any path variables here +CURRPATH=$(cd $(dirname $0); pwd) +TESTPATH=$(cd ${CURRPATH}/../dataset; pwd) +PROJECT_PATH=$(cd ${CURRPATH}/../../../..; pwd) + +if [ "x${BUILD_PATH}" == "x" ]; then + BUILD_PATH=${PROJECT_PATH}/build +fi +echo "Using test path: ${TESTPATH}" +echo "Using build path: ${BUILD_PATH}" + +# Point to the cache_admin from the build path. The user may also have installed the wheel file but we don't know that. +CACHE_ADMIN="${BUILD_PATH}/package/mindspore/bin/cache_admin" +PYTHON_PYTEST="python -m pytest ${TESTPATH}/" + +# These are globals that all testcases use and may get updated during testcase running +failed_tests=0 +test_count=0 +session_id=0 + +# sanity check on the cache_admin + +if [ ! -f ${CACHE_ADMIN} ]; then + echo "Could not find cache_admin binary. ${CACHE_ADMIN}" + exit 1 +fi + +################################################################################# +# Function: MsgEnter # +# Description: Display the leading text before entering a block of logic. # +################################################################################# +MsgEnter() +{ + printf "%-60s : " "${1}" +} + +################################################################################# +# Function: MsgOK # +# Description: Display input msg with a green format for success # +################################################################################# +MsgOk() +{ + echo -e '\E[32m'"\033[1m$1\033[0m" +} + +################################################################################# +# Function: MsgFail # +# Description: Display intput msg with a red format for a failure # +################################################################################# +MsgFail() +{ + echo -e '\E[31m'"\033[1m$1\033[0m" +} + +################################################################################# +# Function: MsgError # +# Description: If something is not successful, display some info about it # +# # +# Arguments are optional with defaults. You should pass empty string for any # +# args not being used so that it chooses the defaults. # +# # +# Optional arguments: arg 1: An error message. # +# arg 2: The return code. # +# arg 3: The error details # +# # +################################################################################# +MsgError() +{ + msg=${1:-none} + err_rc=${2:-none} + err_detail=${3:-none} + + if [ "${msg}" != "none" ] ; then + echo "${msg}" + fi + + if [ "${err_rc}" != "none" ] ; then + echo "Return code: ${err_rc}" + fi + + if [ "${err_detail}" != "none" ] ; then + echo "Error detail:" + echo "{$err_detail}" + fi + echo +} + +################################################################################# +# Function: ServerCleanup # +# Description: This is a non-code method to clean up a running cache server. # +# The intended use is for cases when some command has failed, we # +# want to check for any stale process or resources and forcefully # +# remove those resources. # +################################################################################# +ServerCleanup() +{ + echo "ServerCleanup is running" + server_pid=$(ps -elf | grep ${USER} | grep cache_server | grep -v grep | awk '{print $4}') + if [ "x${server_pid}" != "x" ]; then + echo "Found a running cache server pid ${server_pid}. Killing this process" + kill -9 ${server_pid} + # small sleep to allow some cleanup time + sleep 2 + fi + + for i in `ipcs -m | grep ${USER} | awk '{print $2}'` + do + ipcrm -m ${i} + done + + echo "ServerCleanup complete." +} + +################################################################################# +# Function: CacheAdminCmd # +# Description: Wrapper function for executing cache_admin commands # +# Caller must use HandleRcExit to process the return code. # +# # +# Arguments: arg 1: The command to run # +# arg 2: value of 0 means that we check rc for success. a value # +# of 1 means that we expect a failure from the command. # +################################################################################# +CacheAdminCmd() +{ + if [ $# -ne 2 ]; then + echo "Test script invalid. Bad CacheAdminCmd function args." + exit 1 + fi + cmd=$1 + expect_fail=$2 + if [ "${SKIP_ADMIN_COUNTER}" != "true" ]; then + test_count=$(($test_count+1)) + echo "Test ${test_count}: ${cmd}" + MsgEnter "Run test ${test_count}" + fi + result=$(${cmd} 2>&1) + rc=$? + if [ ${expect_fail} -eq 0 -a ${rc} -ne 0 ]; then + MsgFail "FAILED" + MsgError "cache_admin command failure!" "${rc}" "${result}" + return 1 + elif [ ${expect_fail} -eq 1 -a ${rc} -eq 0 ]; then + MsgFail "FAILED" + MsgError "Expected failure but got success!" "${rc}" "${result}" + return 1 + else + if [ "${SKIP_ADMIN_COUNTER}" != "true" ]; then + MsgOk "OK" + fi + fi + echo + return 0 +} + +################################################################################# +# Function: PytestCmd # +# Description: Wrapper function for executing pytest # +# Caller must use HandleRcExit to process the return code. # +# # +# Arguments: arg 1: The python script name # +# arg 2: The python function name # +################################################################################# +PytestCmd() +{ + test_count=$(($test_count+1)) + py_script=$1 + py_func=$2 + pattern=${3:-0} + # python scripts require special relative paths + cd .. + if [ ${pattern} -eq 0 ]; then + cmd="${PYTHON_PYTEST}${py_script}::${py_func}" + elif [ ${pattern} -eq 1 ]; then + cmd="${PYTHON_PYTEST}${py_script} -k ${py_func}" + else + echo "Invalid Pytest command test script error" + exit 1 + fi + echo "Test ${test_count}: ${cmd}" + MsgEnter "Run test ${test_count}" + result=$(${cmd} 2>&1) + rc=$? + if [ ${rc} -ne 0 ]; then + MsgFail "FAILED" + MsgError "pytest call had failure!" "${rc}" "${result}" + cd ${CURRPATH} + return 1 + else + MsgOk "OK" + fi + echo + cd ${CURRPATH} + return 0 +} + +################################################################################# +# Function: StartServer # +# Description: Helper function to call cache_admin to start a default server # +# Caller must use HandleRcExit to process the return code. # +################################################################################# +StartServer() +{ + cmd="${CACHE_ADMIN} --start" + CacheAdminCmd "${cmd}" 0 + sleep 1 + return $? +} + +################################################################################# +# Function: StopServer # +# Description: Helper function to call cache_admin to stop cache server # +# Caller must use HandleRcExit to process the return code. # +################################################################################# +StopServer() +{ + cmd="${CACHE_ADMIN} --stop" + CacheAdminCmd "${cmd}" 0 + return $? +} + +################################################################################# +# Function: GetSession # +# Description: Helper function to call cache_admin to generate a session # +# Caller must use HandleRcExit to process the return code. # +################################################################################# +GetSession() +{ + # Cannot use CacheAdminCmd for this one because we have special action to set + # the global variable for session id. + cmd="${CACHE_ADMIN} --generate_session" + if [ "${SKIP_ADMIN_COUNTER}" != "true" ]; then + test_count=$(($test_count+1)) + echo "Test ${test_count}: ${cmd}" + MsgEnter "Run test ${test_count}" + fi + result=$(${cmd} 2>&1) + rc=$? + if [ ${rc} -ne 0 ]; then + MsgFail "FAILED" + MsgError "cache_admin command failure!" "${rc}" "${result}" + return 1 + else + session_id=$(echo $result | awk '{print $NF}') + if [ "${SKIP_ADMIN_COUNTER}" != "true" ]; then + MsgOk "OK" + echo "Generated session id: ${session_id}" + echo + fi + fi + return 0 +} + +################################################################################# +# Function: DestroySession # +# Description: Helper function to call cache_admin to destroy a session # +# Caller must use HandleRcExit to process the return code. # +################################################################################# +DestroySession() +{ + cmd="${CACHE_ADMIN} --destroy_session ${session_id}" + CacheAdminCmd "${cmd}" 0 + return $? +} + +################################################################################# +# Function: HandlerRcExit # +# Description: handles a return code if you used one of the above helper funcs # +# It updates the global test counters and chooses to quit or not # +# depending on the setting of exit_on_fail argument # +# # +# Arguments: arg 1: The rc to handle # +# arg 2: Set to 1 to cause error exit. 0 for no exit # +# arg 3: Set to 1 to invoke server cleanup on error case # +################################################################################# +HandleRcExit() +{ + if [ $# -ne 3 ]; then + echo "Test script invalid. Bad CacheAdminCmd function args." + exit 1 + fi + + err_rc=$1 + exit_on_fail=$2 + clean_on_fail=$3 + + if [ ${err_rc} -ne 0 ]; then + failed_tests=$(($failed_tests+1)) + + if [ ${clean_on_fail} -eq 1 ]; then + ServerCleanup + fi + + if [ ${exit_on_fail} -eq 1 ]; then + exit $failed_tests + else + return 1 + fi + fi + + return 0 +} + +################################################################################# +# Function: ExitHandler # +# Description: Invokes final display message of the script before quitting # +################################################################################# +ExitHandler() +{ + success_count=$(($test_count-$failed_tests)) + echo "------------------------------------" + echo "${test_count} tests run in total." + echo "${success_count} tests ran successfully." + echo "${failed_tests} failed tests." + exit ${failed_tests} +} + +trap ExitHandler EXIT SIGINT diff --git a/tests/ut/python/cachetests/cachetest_py.sh b/tests/ut/python/cachetests/cachetest_py.sh new file mode 100755 index 00000000000..1094e104866 --- /dev/null +++ b/tests/ut/python/cachetests/cachetest_py.sh @@ -0,0 +1,378 @@ +#!/bin/bash +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# source the globals and functions for use with cache testing +SKIP_ADMIN_COUNTER=true +. cachetest_lib.sh +echo + +################################################################################ +# Cache testing: cache python test driver # +# Summary: Various tests for running the python testcases for caching # +################################################################################ + +StartServer +HandleRcExit $? 1 1 + +# Set the environment variable to enable these pytests +export RUN_CACHE_TEST=TRUE + +# Each of these tests will create session, use it, then destroy it after the test +for i in $(seq 1 6) +do + test_name="test_cache_map_basic${i}" + GetSession + HandleRcExit $? 1 1 + export SESSION_ID=$session_id + + PytestCmd "test_cache_map.py" "${test_name}" + HandleRcExit $? 0 0 + + DestroySession $session_id + HandleRcExit $? 1 1 +done + +GetSession +HandleRcExit $? 1 1 +export SESSION_ID=$session_id + +# use pytest pattern match to run all the tests that match the name test_cache_map_failure. +# All of these tests will interact with the same cache session and may result in multiple +# caches under the common session handle (although these are failure tests so probably not) +PytestCmd "test_cache_map.py" "test_cache_map_failure" 1 +HandleRcExit $? 0 0 + +# DatasetCache parameter check +PytestCmd "test_cache_map.py" "test_cache_map_parameter_check" +HandleRcExit $? 0 0 + +# Executing the same pipeline for twice under the same session +# Executing the same pipeline for twice (from python) +PytestCmd "test_cache_map.py" "test_cache_map_running_twice1" +HandleRcExit $? 0 0 +# Executing the same pipeline for twice (from shell) +PytestCmd "test_cache_map.py" "test_cache_map_running_twice2" +HandleRcExit $? 0 0 +PytestCmd "test_cache_map.py" "test_cache_map_running_twice2" +HandleRcExit $? 0 0 + +# Executing the same pipeline for twice under the different session +# Executing the same pipeline for twice (from shell) +PytestCmd "test_cache_map.py" "test_cache_map_running_twice2" +HandleRcExit $? 0 0 +DestroySession $session_id +HandleRcExit $? 1 1 +GetSession +HandleRcExit $? 1 1 +export SESSION_ID=$session_id +PytestCmd "test_cache_map.py" "test_cache_map_running_twice2" +HandleRcExit $? 0 0 + +# Set size parameter of DatasetCache to a extra small value +PytestCmd "test_cache_map.py" "test_cache_map_extra_small_size" 1 +HandleRcExit $? 0 0 + +PytestCmd "test_cache_map.py" "test_cache_map_no_image" +HandleRcExit $? 0 0 + +PytestCmd "test_cache_map.py" "test_cache_map_parallel_workers" +HandleRcExit $? 0 0 + +PytestCmd "test_cache_map.py" "test_cache_map_num_connections" 1 +HandleRcExit $? 0 0 + +PytestCmd "test_cache_map.py" "test_cache_map_prefetch_size" 1 +HandleRcExit $? 0 0 + +PytestCmd "test_cache_map.py" "test_cache_map_to_device" +HandleRcExit $? 0 0 + +PytestCmd "test_cache_map.py" "test_cache_map_epoch_ctrl" 1 +HandleRcExit $? 0 0 + +# Run two parallel pipelines (sharing cache) +for i in $(seq 1 2) +do + test_name="test_cache_map_parallel_pipeline${i}" + GetSession + HandleRcExit $? 1 1 + export SESSION_ID=$session_id + + PytestCmd "test_cache_map.py" "${test_name} --shard 0" & + pids+=("$!") + PytestCmd "test_cache_map.py" "${test_name} --shard 1" & + pids+=("$!") + + for pid in "${pids[@]}"; do + wait ${pid} + HandleRcExit $? 0 0 + done + + # Running those PytestCmd in the background will not get our test_count updated. So we need to manually update it here. + test_count=$(($test_count+1)) + DestroySession $session_id + HandleRcExit $? 1 1 +done + +StopServer +HandleRcExit $? 1 1 +sleep 1 + +# test cache server with --workers 1 +cmd="${CACHE_ADMIN} --start --workers 1" +CacheAdminCmd "${cmd}" 0 +sleep 1 +HandleRcExit $? 0 0 + +GetSession +HandleRcExit $? 1 1 +export SESSION_ID=$session_id + +PytestCmd "test_cache_map.py" "test_cache_map_server_workers_1" +HandleRcExit $? 0 0 +StopServer +HandleRcExit $? 0 1 + +# test cache server with --workers 100 +cmd="${CACHE_ADMIN} --start --workers 100" +CacheAdminCmd "${cmd}" 0 +sleep 1 +HandleRcExit $? 0 0 + +GetSession +HandleRcExit $? 1 1 +export SESSION_ID=$session_id + +PytestCmd "test_cache_map.py" "test_cache_map_server_workers_100" +HandleRcExit $? 0 0 +StopServer +HandleRcExit $? 0 1 + +# The next set of testing is for the non-mappable cases. +StartServer +HandleRcExit $? 1 1 + +# This runs all of the basic tests. These will all share the same and we do not destroy +# the session in between each. +GetSession +HandleRcExit $? 1 1 +export SESSION_ID=$session_id + +PytestCmd "test_cache_nomap.py" "test_cache_nomap_basic" 1 +HandleRcExit $? 0 0 + +DestroySession $session_id +HandleRcExit $? 1 1 + +# run the small shared cache tests +for i in $(seq 1 4) +do + test_name="test_cache_nomap_allowed_share${i}" + GetSession + HandleRcExit $? 1 1 + export SESSION_ID=$session_id + + PytestCmd "test_cache_nomap.py" "${test_name}" + HandleRcExit $? 0 0 + + DestroySession $session_id + HandleRcExit $? 1 1 +done + +GetSession +HandleRcExit $? 1 1 +export SESSION_ID=$session_id + +PytestCmd "test_cache_nomap.py" "test_cache_nomap_disallowed_share" 1 +HandleRcExit $? 0 0 + +DestroySession $session_id +HandleRcExit $? 1 1 + +GetSession +HandleRcExit $? 1 1 +export SESSION_ID=$session_id + +# Executing the same pipeline for twice under the same session +# Executing the same pipeline for twice (from python) +PytestCmd "test_cache_nomap.py" "test_cache_nomap_running_twice1" +HandleRcExit $? 0 0 +# Executing the same pipeline for twice (from shell) +PytestCmd "test_cache_nomap.py" "test_cache_nomap_running_twice2" +HandleRcExit $? 0 0 +PytestCmd "test_cache_nomap.py" "test_cache_nomap_running_twice2" +HandleRcExit $? 0 0 + +# Executing the same pipeline for twice under the different session +# Executing the same pipeline for twice (from shell) +PytestCmd "test_cache_nomap.py" "test_cache_nomap_running_twice2" +HandleRcExit $? 0 0 +DestroySession $session_id +HandleRcExit $? 1 1 +GetSession +HandleRcExit $? 1 1 +export SESSION_ID=$session_id +PytestCmd "test_cache_nomap.py" "test_cache_nomap_running_twice2" +HandleRcExit $? 0 0 + +# Set size parameter of DatasetCache to a extra small value +GetSession +HandleRcExit $? 1 1 +export SESSION_ID=$session_id +PytestCmd "test_cache_nomap.py" "test_cache_nomap_extra_small_size" 1 +HandleRcExit $? 0 0 +DestroySession $session_id +HandleRcExit $? 1 1 + +# Run two parallel pipelines (sharing cache) +for i in $(seq 1 2) +do + test_name="test_cache_nomap_parallel_pipeline${i}" + GetSession + HandleRcExit $? 1 1 + export SESSION_ID=$session_id + + PytestCmd "test_cache_nomap.py" "${test_name} --shard 0" & + pids+=("$!") + PytestCmd "test_cache_nomap.py" "${test_name} --shard 1" & + pids+=("$!") + PytestCmd "test_cache_nomap.py" "${test_name} --shard 2" & + pids+=("$!") + + for pid in "${pids[@]}"; do + wait ${pid} + HandleRcExit $? 0 0 + done + + # Running those PytestCmd in the background will not get our test_count updated. So we need to manually update it here. + test_count=$(($test_count+1)) + DestroySession $session_id + HandleRcExit $? 1 1 +done + +GetSession +HandleRcExit $? 1 1 +export SESSION_ID=$session_id + +PytestCmd "test_cache_nomap.py" "test_cache_nomap_parallel_workers" +HandleRcExit $? 0 0 + +PytestCmd "test_cache_nomap.py" "test_cache_nomap_num_connections" 1 +HandleRcExit $? 0 0 + +PytestCmd "test_cache_nomap.py" "test_cache_nomap_prefetch_size" 1 +HandleRcExit $? 0 0 + +PytestCmd "test_cache_nomap.py" "test_cache_nomap_to_device" +HandleRcExit $? 0 0 + +PytestCmd "test_cache_nomap.py" "test_cache_nomap_epoch_ctrl" 1 +HandleRcExit $? 0 0 + +for i in $(seq 1 3) +do + test_name="test_cache_nomap_multiple_cache${i}" + GetSession + HandleRcExit $? 1 1 + export SESSION_ID=$session_id + + PytestCmd "test_cache_nomap.py" "${test_name}" + HandleRcExit $? 0 0 + + DestroySession $session_id + HandleRcExit $? 1 1 +done + +# Create session, run train and eval pipeline concurrently with different cache +GetSession +HandleRcExit $? 1 1 +export SESSION_ID=$session_id +PytestCmd "test_cache_nomap.py" "test_cache_nomap_multiple_cache_train" & +pids+=("$!") +PytestCmd "test_cache_nomap.py" "test_cache_nomap_multiple_cache_eval" & +pids+=("$!") + +for pid in "${pids[@]}"; do + wait ${pid} + HandleRcExit $? 0 0 +done + +# Running those PytestCmd in the background will not get our test_count updated. So we need to manually update it here. +test_count=$(($test_count+1)) +DestroySession $session_id +HandleRcExit $? 1 1 + +# Create session, use it to run a pipeline, and destroy the session while pipeline is running +GetSession +HandleRcExit $? 1 1 +export SESSION_ID=$session_id + +PytestCmd "test_cache_nomap.py" "test_cache_nomap_session_destroy" & +pid=("$!") + +sleep 10 +DestroySession $session_id +HandleRcExit $? 1 1 +wait ${pid} +# Running those PytestCmd in the background will not get our test_count updated. So we need to manually update it here. +test_count=$(($test_count+1)) + +# Stop cache server while pipeline is running +GetSession +HandleRcExit $? 1 1 +export SESSION_ID=$session_id + +PytestCmd "test_cache_nomap.py" "test_cache_nomap_server_stop" & +pid=("$!") + +sleep 10 +StopServer +HandleRcExit $? 1 1 +sleep 1 +wait ${pid} +# Running those PytestCmd in the background will not get our test_count updated. So we need to manually update it here. +test_count=$(($test_count+1)) + +# test cache server with --workers 1 +cmd="${CACHE_ADMIN} --start --workers 1" +CacheAdminCmd "${cmd}" 0 +sleep 1 +HandleRcExit $? 0 0 +GetSession +HandleRcExit $? 1 1 +export SESSION_ID=$session_id +PytestCmd "test_cache_nomap.py" "test_cache_nomap_server_workers_1" +HandleRcExit $? 0 0 +StopServer +HandleRcExit $? 0 1 + +# test cache server with --workers 100 +cmd="${CACHE_ADMIN} --start --workers 100" +CacheAdminCmd "${cmd}" 0 +sleep 1 +HandleRcExit $? 0 0 +GetSession +HandleRcExit $? 1 1 +export SESSION_ID=$session_id +PytestCmd "test_cache_nomap.py" "test_cache_nomap_server_workers_100" +HandleRcExit $? 0 0 +StopServer +HandleRcExit $? 0 1 + +unset RUN_CACHE_TEST +unset SESSION_ID + +exit ${failed_tests} diff --git a/tests/ut/python/conftest.py b/tests/ut/python/conftest.py index 6850f956a5b..87e5d5692ac 100644 --- a/tests/ut/python/conftest.py +++ b/tests/ut/python/conftest.py @@ -29,6 +29,10 @@ def pytest_addoption(parser): "--runmode", action="store", default="nosimu", help="simu:simulator backend & nosimu for no backend" ) + parser.addoption( + "--shard", action="store", default="0", + help="shard id for parallel pipeline" + ) @pytest.fixture @@ -39,6 +43,14 @@ def test_with_simu(request): return request.config.getoption("--runmode") == "simu" +@pytest.fixture +def shard(request): + """ + specify shard id for parallel pipeline testcases + """ + return request.config.getoption("--shard") + + # https://stackoverflow.com/questions/14121657/how-to-get-test-name-and-test-result-during-run-time-in-pytest def pytest_runtest_protocol(item, nextitem): reports = runtestprotocol(item, nextitem=nextitem) diff --git a/tests/ut/python/dataset/test_cache_map.py b/tests/ut/python/dataset/test_cache_map.py index 7640fd8bf04..e703b7d3099 100644 --- a/tests/ut/python/dataset/test_cache_map.py +++ b/tests/ut/python/dataset/test_cache_map.py @@ -23,6 +23,9 @@ from mindspore import log as logger from util import save_and_check_md5 DATA_DIR = "../data/dataset/testImageNetData/train/" +COCO_DATA_DIR = "../data/dataset/testCOCO/train/" +COCO_ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json" +NO_IMAGE_DIR = "../data/dataset/testRandomData/" GENERATE_GOLDEN = False @@ -42,8 +45,12 @@ def test_cache_map_basic1(): """ logger.info("Test cache map basic 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + session_id = 1 - some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) @@ -56,6 +63,7 @@ def test_cache_map_basic1(): logger.info("test_cache_map_basic1 Ended.\n") + @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_map_basic2(): """ @@ -71,8 +79,12 @@ def test_cache_map_basic2(): """ logger.info("Test cache map basic 2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) @@ -85,6 +97,7 @@ def test_cache_map_basic2(): logger.info("test_cache_map_basic2 Ended.\n") + @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_map_basic3(): """ @@ -100,8 +113,12 @@ def test_cache_map_basic3(): """ logger.info("Test cache basic 3") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) @@ -119,13 +136,18 @@ def test_cache_map_basic3(): assert num_iter == 8 logger.info('test_cache_basic3 Ended.\n') + @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_map_basic4(): """ Test different rows result in core dump """ logger.info("Test cache basic 4") - some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) @@ -142,7 +164,76 @@ def test_cache_map_basic4(): logger.info("Number of data in ds1: {} ".format(num_iter)) assert num_iter == 8 - logger.info('test_cache_basic3 Ended.\n') + logger.info('test_cache_basic4 Ended.\n') + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_basic5(): + """ + Test Map with non-deterministic TensorOps above cache + + repeat + | + Map(decode, randomCrop) + | + Cache + | + ImageFolder + + """ + logger.info("Test cache failure 5") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) + random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200]) + decode_op = c_vision.Decode() + + data = data.map(input_columns=["image"], operations=decode_op) + data = data.map(input_columns=["image"], operations=random_crop_op) + data = data.repeat(4) + + num_iter = 0 + for _ in data.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 8 + logger.info('test_cache_failure5 Ended.\n') + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_basic6(): + """ + Test cache as root node + + cache + | + ImageFolder + """ + logger.info("Test cache basic 6") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) + num_iter = 0 + for _ in ds1.create_dict_iterator(num_epochs=1): + logger.info("get data from dataset") + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 2 + logger.info('test_cache_basic6 Ended.\n') + @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_map_failure1(): @@ -161,8 +252,12 @@ def test_cache_map_failure1(): """ logger.info("Test cache failure 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) @@ -170,26 +265,983 @@ def test_cache_map_failure1(): ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) ds1 = ds1.repeat(4) - try: + with pytest.raises(RuntimeError) as e: num_iter = 0 for _ in ds1.create_dict_iterator(num_epochs=1): num_iter += 1 - except RuntimeError as e: - logger.info("Got an exception in DE: {}".format(str(e))) - assert "Nested cache operations is not supported!" in str(e) + assert "Nested cache operations is not supported!" in str(e.value) assert num_iter == 0 logger.info('test_cache_failure1 Ended.\n') +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_failure2(): + """ + Test zip under cache (failure) + + repeat + | + Cache + | + Map(decode) + | + Zip + | | + ImageFolder ImageFolder + + """ + logger.info("Test cache failure 2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) + ds2 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) + dsz = ds.zip((ds1, ds2)) + decode_op = c_vision.Decode() + dsz = dsz.map(input_columns=["image"], operations=decode_op, cache=some_cache) + dsz = dsz.repeat(4) + + with pytest.raises(RuntimeError) as e: + num_iter = 0 + for _ in dsz.create_dict_iterator(): + num_iter += 1 + assert "ZipOp is currently not supported as a descendant operator under a cache" in str(e.value) + + assert num_iter == 0 + logger.info('test_cache_failure2 Ended.\n') + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_failure3(): + """ + Test batch under cache (failure) + + repeat + | + Cache + | + Map(resize) + | + Batch + | + ImageFolder + """ + logger.info("Test cache failure 3") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) + ds1 = ds1.batch(2) + resize_op = c_vision.Resize((224, 224)) + ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) + ds1 = ds1.repeat(4) + + with pytest.raises(RuntimeError) as e: + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + assert "Unexpected error. Expect positive row id: -1" in str(e.value) + + assert num_iter == 0 + logger.info('test_cache_failure3 Ended.\n') + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_failure4(): + """ + Test filter under cache (failure) + + repeat + | + Cache + | + Map(decode) + | + Filter + | + ImageFolder + + """ + logger.info("Test cache failure 4") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) + ds1 = ds1.filter(predicate=lambda data: data < 11, input_columns=["label"]) + + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + with pytest.raises(RuntimeError) as e: + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + assert "FilterOp is currently not supported as a descendant operator under a cache" in str(e.value) + + assert num_iter == 0 + logger.info('test_cache_failure4 Ended.\n') + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_failure5(): + """ + Test Map with non-deterministic TensorOps under cache (failure) + + repeat + | + Cache + | + Map(decode, randomCrop) + | + ImageFolder + + """ + logger.info("Test cache failure 5") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + data = ds.ImageFolderDataset(dataset_dir=DATA_DIR) + random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200]) + decode_op = c_vision.Decode() + + data = data.map(input_columns=["image"], operations=decode_op) + data = data.map(input_columns=["image"], operations=random_crop_op, cache=some_cache) + data = data.repeat(4) + + with pytest.raises(RuntimeError) as e: + num_iter = 0 + for _ in data.create_dict_iterator(): + num_iter += 1 + assert "MapOp with non-deterministic TensorOps is currently not supported as a descendant of cache" in str(e.value) + + assert num_iter == 0 + logger.info('test_cache_failure5 Ended.\n') + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_failure6(): + """ + Test no-cache-supporting leaf ops with Map under cache (failure) + + repeat + | + Cache + | + Map(resize) + | + Coco + + """ + logger.info("Test cache failure 6") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + data = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True) + resize_op = c_vision.Resize((224, 224)) + + data = data.map(input_columns=["image"], operations=resize_op, cache=some_cache) + data = data.repeat(4) + + with pytest.raises(RuntimeError) as e: + num_iter = 0 + for _ in data.create_dict_iterator(): + num_iter += 1 + assert "There is currently no support for CocoOp under cache" in str(e.value) + + assert num_iter == 0 + logger.info('test_cache_failure6 Ended.\n') + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_parameter_check(): + """ + Test illegal parameters for DatasetCache + """ + + logger.info("Test cache map parameter check") + + with pytest.raises(ValueError) as info: + ds.DatasetCache(session_id=-1, size=0, spilling=True) + assert "Input is not within the required interval" in str(info.value) + + with pytest.raises(TypeError) as info: + ds.DatasetCache(session_id="1", size=0, spilling=True) + assert "Argument session_id with value 1 is not of type (,)" in str(info.value) + + with pytest.raises(TypeError) as info: + ds.DatasetCache(session_id=None, size=0, spilling=True) + assert "Argument session_id with value None is not of type (,)" in str(info.value) + + with pytest.raises(ValueError) as info: + ds.DatasetCache(session_id=1, size=-1, spilling=True) + assert "Input is not within the required interval" in str(info.value) + + with pytest.raises(TypeError) as info: + ds.DatasetCache(session_id=1, size="1", spilling=True) + assert "Argument size with value 1 is not of type (,)" in str(info.value) + + with pytest.raises(TypeError) as info: + ds.DatasetCache(session_id=1, size=None, spilling=True) + assert "Argument size with value None is not of type (,)" in str(info.value) + + with pytest.raises(TypeError) as info: + ds.DatasetCache(session_id=1, size=0, spilling="illegal") + assert "Argument spilling with value illegal is not of type (,)" in str(info.value) + + with pytest.raises(RuntimeError) as err: + ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="illegal") + assert "Unexpected error. now cache client has to be on the same host with cache server" in str(err.value) + + with pytest.raises(RuntimeError) as err: + ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="127.0.0.2") + assert "Unexpected error. now cache client has to be on the same host with cache server" in str(err.value) + + with pytest.raises(TypeError) as info: + ds.DatasetCache(session_id=1, size=0, spilling=True, port="illegal") + assert "incompatible constructor arguments" in str(info.value) + + with pytest.raises(TypeError) as info: + ds.DatasetCache(session_id=1, size=0, spilling=True, port="50052") + assert "incompatible constructor arguments" in str(info.value) + + with pytest.raises(RuntimeError) as err: + ds.DatasetCache(session_id=1, size=0, spilling=True, port=0) + assert "Unexpected error. port must be positive" in str(err.value) + + with pytest.raises(RuntimeError) as err: + ds.DatasetCache(session_id=1, size=0, spilling=True, port=65536) + assert "Unexpected error. illegal port number" in str(err.value) + + with pytest.raises(TypeError) as err: + ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=True) + assert "Argument cache with value True is not of type" in str(err.value) + + logger.info("test_cache_map_parameter_check Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_running_twice1(): + """ + Executing the same pipeline for twice (from python), with cache injected after map + + Repeat + | + Cache + | + Map(decode) + | + ImageFolder + """ + + logger.info("Test cache map running twice 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 8 + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 8 + + logger.info("test_cache_map_running_twice1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_running_twice2(): + """ + Executing the same pipeline for twice (from shell), with cache injected after leaf + + Repeat + | + Map(decode) + | + Cache + | + ImageFolder + """ + + logger.info("Test cache map running twice 2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 8 + logger.info("test_cache_map_running_twice2 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_extra_small_size1(): + """ + Test running pipeline with cache of extra small size and spilling true + + Repeat + | + Map(decode) + | + Cache + | + ImageFolder + """ + + logger.info("Test cache map extra small size 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 8 + logger.info("test_cache_map_extra_small_size1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_extra_small_size2(): + """ + Test running pipeline with cache of extra small size and spilling false + + Repeat + | + Cache + | + Map(decode) + | + ImageFolder + """ + + logger.info("Test cache map extra small size 2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 8 + logger.info("test_cache_map_extra_small_size2 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_no_image(): + """ + Test cache with no dataset existing in the path + + Repeat + | + Map(decode) + | + Cache + | + ImageFolder + """ + + logger.info("Test cache map no image") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=NO_IMAGE_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + with pytest.raises(RuntimeError): + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + assert num_iter == 0 + logger.info("test_cache_map_no_image Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_parallel_pipeline1(shard): + """ + Test running two parallel pipelines (sharing cache) with cache injected after leaf op + + Repeat + | + Map(decode) + | + Cache + | + ImageFolder + """ + + logger.info("Test cache map parallel pipeline 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard), cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 4 + logger.info("test_cache_map_parallel_pipeline1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_parallel_pipeline2(shard): + """ + Test running two parallel pipelines (sharing cache) with cache injected after map op + + Repeat + | + Cache + | + Map(decode) + | + ImageFolder + """ + + logger.info("Test cache map parallel pipeline 2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard)) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 4 + logger.info("test_cache_map_parallel_pipeline2 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_parallel_workers(): + """ + Test cache with num_parallel_workers > 1 set for map op and leaf op + + Repeat + | + cache + | + Map(decode) + | + ImageFolder + """ + + logger.info("Test cache map parallel workers") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_parallel_workers=4) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, num_parallel_workers=4, cache=some_cache) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 8 + logger.info("test_cache_map_parallel_workers Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_server_workers_1(): + """ + start cache server with --workers 1 and then test cache function + + Repeat + | + cache + | + Map(decode) + | + ImageFolder + """ + + logger.info("Test cache map server workers 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 8 + logger.info("test_cache_map_server_workers_1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_server_workers_100(): + """ + start cache server with --workers 100 and then test cache function + + Repeat + | + Map(decode) + | + cache + | + ImageFolder + """ + + logger.info("Test cache map server workers 100") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 8 + logger.info("test_cache_map_server_workers_100 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_num_connections_1(): + """ + Test setting num_connections=1 in DatasetCache + + Repeat + | + cache + | + Map(decode) + | + ImageFolder + """ + + logger.info("Test cache map num_connections 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=1) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 8 + logger.info("test_cache_map_num_connections_1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_num_connections_100(): + """ + Test setting num_connections=100 in DatasetCache + + Repeat + | + Map(decode) + | + cache + | + ImageFolder + """ + + logger.info("Test cache map num_connections 100") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=100) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 8 + logger.info("test_cache_map_num_connections_100 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_prefetch_size_1(): + """ + Test setting prefetch_size=1 in DatasetCache + + Repeat + | + cache + | + Map(decode) + | + ImageFolder + """ + + logger.info("Test cache map prefetch_size 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=1) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 8 + logger.info("test_cache_map_prefetch_size_1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_prefetch_size_100(): + """ + Test setting prefetch_size=100 in DatasetCache + + Repeat + | + Map(decode) + | + cache + | + ImageFolder + """ + + logger.info("Test cache map prefetch_size 100") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=100) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 8 + logger.info("test_cache_map_prefetch_size_100 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_to_device(): + """ + Test cache with to_device + + DeviceQueue + | + EpochCtrl + | + Repeat + | + Map(decode) + | + cache + | + ImageFolder + """ + + logger.info("Test cache map to_device") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + ds1 = ds1.to_device() + ds1.send() + + logger.info("test_cache_map_to_device Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_epoch_ctrl1(): + """ + Test using two-loops method to run several epochs + + Map(decode) + | + cache + | + ImageFolder + """ + + logger.info("Test cache map epoch ctrl1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + + num_epoch = 5 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) + + epoch_count = 0 + for _ in range(num_epoch): + row_count = 0 + for _ in iter1: + row_count += 1 + logger.info("Number of data in ds1: {} ".format(row_count)) + assert row_count == 2 + epoch_count += 1 + assert epoch_count == num_epoch + logger.info("test_cache_map_epoch_ctrl1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_epoch_ctrl2(): + """ + Test using two-loops method with infinite epochs + + cache + | + Map(decode) + | + ImageFolder + """ + + logger.info("Test cache map epoch ctrl2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + + num_epoch = 5 + # iter1 will always assume there is a next epoch and never shutdown + iter1 = ds1.create_dict_iterator() + + epoch_count = 0 + for _ in range(num_epoch): + row_count = 0 + for _ in iter1: + row_count += 1 + logger.info("Number of data in ds1: {} ".format(row_count)) + assert row_count == 2 + epoch_count += 1 + assert epoch_count == num_epoch + + # manually stop the iterator + iter1.stop() + logger.info("test_cache_map_epoch_ctrl2 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_epoch_ctrl3(): + """ + Test using two-loops method with infinite epochs over repeat + + repeat + | + Map(decode) + | + cache + | + ImageFolder + """ + + logger.info("Test cache map epoch ctrl3") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(2) + + num_epoch = 5 + # iter1 will always assume there is a next epoch and never shutdown + iter1 = ds1.create_dict_iterator() + + epoch_count = 0 + for _ in range(num_epoch): + row_count = 0 + for _ in iter1: + row_count += 1 + logger.info("Number of data in ds1: {} ".format(row_count)) + assert row_count == 4 + epoch_count += 1 + assert epoch_count == num_epoch + + # reply on garbage collector to destroy iter1 + + logger.info("test_cache_map_epoch_ctrl3 Ended.\n") + + if __name__ == '__main__': test_cache_map_basic1() - logger.info("test_cache_map_basic1 success.") test_cache_map_basic2() - logger.info("test_cache_map_basic2 success.") test_cache_map_basic3() - logger.info("test_cache_map_basic3 success.") test_cache_map_basic4() - logger.info("test_cache_map_basic3 success.") test_cache_map_failure1() - logger.info("test_cache_map_failure1 success.") + test_cache_map_failure2() + test_cache_map_failure3() + test_cache_map_failure4() diff --git a/tests/ut/python/dataset/test_cache_nomap.py b/tests/ut/python/dataset/test_cache_nomap.py index df268594441..52b3424d651 100644 --- a/tests/ut/python/dataset/test_cache_nomap.py +++ b/tests/ut/python/dataset/test_cache_nomap.py @@ -16,6 +16,7 @@ Testing cache operator with non-mappable datasets """ import os +import itertools import pytest import mindspore.common.dtype as mstype import mindspore.dataset as ds @@ -25,8 +26,20 @@ from mindspore import log as logger DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" +DATA_DIR2 = ["../data/dataset/testTextTFRecord/text.tfrecord"] +SCHEMA_DIR2 = "../data/dataset/testTextTFRecord/datasetSchema.json" + +DATA_DIR3 = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", + "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data", + "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data", + "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"] +SCHEMA_DIR3 = "../data/dataset/test_tf_file_3_images2/datasetSchema.json" + +DATA_DIR4 = "../data/dataset/testImageNetData/train/" + GENERATE_GOLDEN = False + @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_nomap_basic1(): """ @@ -34,6 +47,10 @@ def test_cache_nomap_basic1(): """ logger.info("Test cache nomap basic 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") schema = ds.Schema() schema.add_column('image', de_type=mstype.uint8, @@ -41,7 +58,7 @@ def test_cache_nomap_basic1(): schema.add_column('label', de_type=mstype.uint8, shape=[1]) # create a cache. arbitrary session_id for now - some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) # User-created sampler here ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=4, cache=some_cache) @@ -64,6 +81,10 @@ def test_cache_nomap_basic2(): """ logger.info("Test cache nomap basic 2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") schema = ds.Schema() schema.add_column('image', de_type=mstype.uint8, @@ -71,7 +92,7 @@ def test_cache_nomap_basic2(): schema.add_column('label', de_type=mstype.uint8, shape=[1]) # create a cache. arbitrary session_id for now - some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) # sampler arg not given directly, however any of these args will auto-generate an appropriate sampler: # num_samples, shuffle, num_shards, shard_id @@ -104,8 +125,12 @@ def test_cache_nomap_basic3(): """ logger.info("Test cache nomap basic 3") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) decode_op = c_vision.Decode() ds1 = ds1.map(operations=decode_op, input_columns=["image"]) @@ -148,9 +173,13 @@ def test_cache_nomap_basic4(): """ logger.info("Test cache nomap basic 4") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") # This dataset has 3 records in it only - some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) # With shuffle not being set, TF defaults to a "global" shuffle when there is no cache # in the picture. This causes a shuffle-injection over the TF. For clarify, this test will # explicitly give the global option, even though it's the default in python. @@ -192,9 +221,13 @@ def test_cache_nomap_basic5(): """ logger.info("Test cache nomap basic 5") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") # This dataset has 3 records in it only - some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], cache=some_cache) decode_op = c_vision.Decode() ds1 = ds1.map(operations=decode_op, input_columns=["image"]) @@ -227,9 +260,13 @@ def test_cache_nomap_basic6(): """ logger.info("Test cache nomap basic 6") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") # This dataset has 3 records in it only - some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) # With only 3 records shard into 3, we expect only 1 record returned for this shard # However, the sharding will be done by the sampler, not by the tf record leaf node @@ -267,10 +304,14 @@ def test_cache_nomap_basic7(): """ logger.info("Test cache nomap basic 7") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) # This dataset has 3 records in it only - some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) - ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL, cache=some_cache) decode_op = c_vision.Decode() ds1 = ds1.map(operations=decode_op, input_columns=["image"]) @@ -285,6 +326,34 @@ def test_cache_nomap_basic7(): logger.info("test_cache_nomap_basic7 Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_basic8(): + """ + Test cache as root node + + cache + | + TFReader + """ + logger.info("Test cache basic 4") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) + num_iter = 0 + for _ in ds1.create_dict_iterator(num_epochs=1): + logger.info("get data from dataset") + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 3 + logger.info('test_cache_basic3 Ended.\n') + + @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_nomap_allowed_share1(): """ @@ -298,10 +367,14 @@ def test_cache_nomap_allowed_share1(): """ logger.info("Test cache nomap allowed share 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") ds.config.set_seed(1) # This dataset has 3 records in it only - some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True, prefetch_size=32) + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=32) ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) ds1 = ds1.repeat(4) @@ -336,10 +409,14 @@ def test_cache_nomap_allowed_share2(): """ logger.info("Test cache nomap allowed share 2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") ds.config.set_seed(1) # This dataset has 3 records in it only - some_cache = ds.DatasetCache(session_id=2, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) decode_op = c_vision.Decode() ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) @@ -376,8 +453,12 @@ def test_cache_nomap_allowed_share3(): """ logger.info("Test cache nomap allowed share 3") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data"] ds1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=0, num_samples=3, shuffle=False, cache=some_cache) @@ -412,9 +493,13 @@ def test_cache_nomap_allowed_share4(): """ logger.info("Test cache nomap allowed share 4") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") # This dataset has 3 records in it only - some_cache = ds.DatasetCache(session_id=2, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) decode_op = c_vision.Decode() ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) @@ -451,9 +536,13 @@ def test_cache_nomap_disallowed_share1(): """ logger.info("Test cache nomap disallowed share1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") # This dataset has 3 records in it only - some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) decode_op = c_vision.Decode() rescale_op = c_vision.Rescale(1.0 / 255.0, -1.0) @@ -469,15 +558,945 @@ def test_cache_nomap_disallowed_share1(): logger.info("Number of data in ds1: {} ".format(num_iter)) assert num_iter == 3 - try: + with pytest.raises(RuntimeError) as e: sum([1 for _ in ds2]) - except RuntimeError as e: - logger.info("Got an exception in DE: {}".format(str(e))) - assert "Attempt to re-use a cache for a different tree!" in str(e) + assert "Attempt to re-use a cache for a different tree!" in str(e.value) logger.info("test_cache_nomap_disallowed_share1 Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_running_twice1(): + """ + Executing the same pipeline for twice (from python), with cache injected after map + + Repeat + | + Cache + | + Map(decode) + | + TFRecord + """ + + logger.info("Test cache nomap running twice 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + + logger.info("test_cache_nomap_running_twice1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_running_twice2(): + """ + Executing the same pipeline for twice (from shell), with cache injected after leaf + + Repeat + | + Map(decode) + | + Cache + | + TFRecord + """ + + logger.info("Test cache nomap running twice 2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + logger.info("test_cache_nomap_running_twice2 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_extra_small_size1(): + """ + Test running pipeline with cache of extra small size and spilling true + + Repeat + | + Map(decode) + | + Cache + | + TFRecord + """ + + logger.info("Test cache nomap extra small size 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=True) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + logger.info("test_cache_nomap_extra_small_size1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_extra_small_size2(): + """ + Test running pipeline with cache of extra small size and spilling false (failure) + + Repeat + | + Cache + | + Map(decode) + | + TFRecord + """ + + logger.info("Test cache nomap extra small size 2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + with pytest.raises(RuntimeError) as e: + sum([1 for _ in ds1]) + assert "Out of memory" in str(e.value) + logger.info("test_cache_nomap_extra_small_size2 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_parallel_pipeline1(shard): + """ + Test running two parallel pipelines (sharing cache) with cache injected after leaf op + + Repeat + | + Map(decode) + | + cache + | + TFReader + """ + + logger.info("Test cache nomap parallel pipeline 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard), cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(num_epochs=1): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 4 + logger.info("test_cache_nomap_parallel_pipeline1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_parallel_pipeline2(shard): + """ + Test running two parallel pipelines (sharing cache) with cache injected after map op + + Repeat + | + cache + | + Map(decode) + | + TFReader + """ + + logger.info("Test cache nomap parallel pipeline 2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard)) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(num_epochs=1): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 4 + logger.info("test_cache_nomap_parallel_pipeline2 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_parallel_workers(): + """ + Test cache with num_parallel_workers > 1 set for map op and leaf op + + Repeat + | + Map(decode) + | + cache + | + TFReader + """ + + logger.info("Test cache nomap parallel workers") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_parallel_workers=4) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, num_parallel_workers=4, cache=some_cache) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(num_epochs=1): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + logger.info("test_cache_nomap_parallel_workers Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_server_workers_1(): + """ + start cache server with --workers 1 and then test cache function + + Repeat + | + cache + | + Map(decode) + | + TFRecord + """ + + logger.info("Test cache nomap server workers 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + logger.info("test_cache_nomap_server_workers_1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_server_workers_100(): + """ + start cache server with --workers 100 and then test cache function + + Repeat + | + Map(decode) + | + cache + | + TFRecord + """ + + logger.info("Test cache nomap server workers 100") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + logger.info("test_cache_nomap_server_workers_100 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_num_connections_1(): + """ + Test setting num_connections=1 in DatasetCache + + Repeat + | + cache + | + Map(decode) + | + TFRecord + """ + + logger.info("Test cache nomap num_connections 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=1) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + logger.info("test_cache_nomap_num_connections_1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_num_connections_100(): + """ + Test setting num_connections=100 in DatasetCache + + Repeat + | + Map(decode) + | + cache + | + TFRecord + """ + + logger.info("Test cache nomap num_connections 100") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=100) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + logger.info("test_cache_nomap_num_connections_100 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_prefetch_size_1(): + """ + Test setting prefetch_size=1 in DatasetCache + + Repeat + | + cache + | + Map(decode) + | + TFRecord + """ + + logger.info("Test cache nomap prefetch_size 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=1) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + logger.info("test_cache_nomap_prefetch_size_1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_prefetch_size_100(): + """ + Test setting prefetch_size=100 in DatasetCache + + Repeat + | + Map(decode) + | + cache + | + TFRecord + """ + + logger.info("Test cache nomap prefetch_size 100") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=100) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + logger.info("test_cache_nomap_prefetch_size_100 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_to_device(): + """ + Test cache with to_device + + DeviceQueue + | + EpochCtrl + | + Repeat + | + Map(decode) + | + cache + | + TFReader + """ + + logger.info("Test cache nomap to_device") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + ds1 = ds1.to_device() + ds1.send() + + logger.info("test_cache_nomap_to_device Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_session_destroy(): + """ + Test executing cache_admin -d while the pipeline is running + + Repeat + | + Cache + | + RandomDataset + """ + + logger.info("Test cache nomap session destroy") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + schema = ds.Schema() + schema.add_column('image', de_type=mstype.uint8, + shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) + schema.add_column('label', de_type=mstype.uint8, shape=[1]) + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # User-created sampler here + ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache) + ds1 = ds1.repeat() + + with pytest.raises(RuntimeError) as e: + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + assert "Unexpected error" in str(e.value) + + logger.info("test_cache_nomap_session_destroy Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_server_stop(): + """ + Test executing cache_admin --stop while the pipeline is running + + Repeat + | + Cache + | + RandomDataset + """ + + logger.info("Test cache nomap server stop") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + schema = ds.Schema() + schema.add_column('image', de_type=mstype.uint8, + shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) + schema.add_column('label', de_type=mstype.uint8, shape=[1]) + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # User-created sampler here + ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache) + ds1 = ds1.repeat() + + with pytest.raises(RuntimeError) as e: + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + assert "Network error. Cache server is unreachable. Make sure the server is running." in str(e.value) + + logger.info("test_cache_nomap_server_stop Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_epoch_ctrl1(): + """ + Test using two-loops method to run several epochs + + Map(decode) + | + cache + | + TFRecord + """ + + logger.info("Test cache nomap epoch ctrl1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + + num_epoch = 5 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) + + epoch_count = 0 + for _ in range(num_epoch): + row_count = 0 + for _ in iter1: + row_count += 1 + logger.info("Number of data in ds1: {} ".format(row_count)) + assert row_count == 3 + epoch_count += 1 + assert epoch_count == num_epoch + logger.info("test_cache_nomap_epoch_ctrl1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_epoch_ctrl2(): + """ + Test using two-loops method with infinite epochs + + cache + | + Map(decode) + | + TFRecord + """ + + logger.info("Test cache nomap epoch ctrl2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + + num_epoch = 5 + # iter1 will always assume there is a next epoch and never shutdown + iter1 = ds1.create_dict_iterator() + + epoch_count = 0 + for _ in range(num_epoch): + row_count = 0 + for _ in iter1: + row_count += 1 + logger.info("Number of data in ds1: {} ".format(row_count)) + assert row_count == 3 + epoch_count += 1 + assert epoch_count == num_epoch + + # manually stop the iterator + iter1.stop() + logger.info("test_cache_nomap_epoch_ctrl2 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_epoch_ctrl3(): + """ + Test using two-loops method with infinite epochs over repeat + + repeat + | + Map(decode) + | + cache + | + TFRecord + """ + + logger.info("Test cache nomap epoch ctrl3") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(2) + + num_epoch = 5 + # iter1 will always assume there is a next epoch and never shutdown + iter1 = ds1.create_dict_iterator() + + epoch_count = 0 + for _ in range(num_epoch): + row_count = 0 + for _ in iter1: + row_count += 1 + logger.info("Number of data in ds1: {} ".format(row_count)) + assert row_count == 6 + epoch_count += 1 + assert epoch_count == num_epoch + + # reply on garbage collector to destroy iter1 + + logger.info("test_cache_nomap_epoch_ctrl3 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_multiple_cache1(): + """ + Test multiple cache in the same python script + + cache cache + | | + Map(decode) Map(decode) + | | + TFRecord(train) TFRecord(eval) + """ + + logger.info("Test cache nomap multiple cache 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 12 records in it + train_dataset = ds.TFRecordDataset(DATA_DIR3, SCHEMA_DIR3) + decode_op = c_vision.Decode() + train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache) + + # This dataset has 3 records in it only + eval_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) + eval_dataset = eval_dataset.map(input_columns=["image"], operations=decode_op, cache=eval_cache) + + num_epoch = 5 + train_iter = train_dataset.create_dict_iterator(num_epochs=num_epoch) + eval_iter = eval_dataset.create_dict_iterator(num_epochs=num_epoch) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in train_iter]) == 12 + assert sum([1 for _ in eval_iter]) == 3 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_nomap_multiple_cache1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_multiple_cache2(): + """ + Test multiple cache in the same python script + + cache + | + Map(decode) cache + | | + TFRecord(image) TFRecord(text) + """ + + logger.info("Test cache nomap multiple cache 2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + image_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + text_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 3 records in it only + image_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) + decode_op = c_vision.Decode() + image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache) + + # This dataset has 3 records in it only + text_dataset = ds.TFRecordDataset(DATA_DIR2, SCHEMA_DIR2, cache=text_cache) + + num_epoch = 5 + image_iter = image_dataset.create_dict_iterator(num_epochs=num_epoch) + text_iter = text_dataset.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) + + epoch_count = 0 + for _ in range(num_epoch): + row_count = 0 + for _, _ in itertools.zip_longest(image_iter, text_iter): + row_count += 1 + assert row_count == 3 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_nomap_multiple_cache2 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_multiple_cache3(): + """ + Test multiple cache in the same python script + + cache cache + | | + Map(decode) Map(decode) + | | + TFRecord ImageFolder + """ + + logger.info("Test cache nomap multiple cache 3") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + tf_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + image_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 3 records in it only + tf_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) + decode_op = c_vision.Decode() + tf_dataset = tf_dataset.map(input_columns=["image"], operations=decode_op, cache=tf_cache) + + # This DATA_DIR only has 2 images in it + image_dataset = ds.ImageFolderDataset(dataset_dir=DATA_DIR4) + image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache) + + num_epoch = 5 + tf_iter = tf_dataset.create_dict_iterator(num_epochs=num_epoch) + image_iter = image_dataset.create_dict_iterator(num_epochs=num_epoch) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in tf_iter]) == 3 + assert sum([1 for _ in image_iter]) == 2 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_nomap_multiple_cache3 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_multiple_cache_train(): + """ + Test multiple cache in different python scripts. This test case is going to run concurrently with + test_cache_nomap_multiple_cache_eval. + + cache + | + Map(decode) + | + TFRecord(train) + """ + + logger.info("Test cache nomap multiple cache train") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 12 records in it + train_dataset = ds.TFRecordDataset(DATA_DIR3, SCHEMA_DIR3) + decode_op = c_vision.Decode() + train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache) + + num_epoch = 5 + train_iter = train_dataset.create_dict_iterator(num_epochs=num_epoch) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in train_iter]) == 12 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_nomap_multiple_cache_train Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_multiple_cache_eval(): + """ + Test multiple cache in different python scripts. This test case is going to run concurrently with + test_cache_nomap_multiple_cache_train. + + cache + | + Map(decode) + | + TFRecord(eval) + """ + + logger.info("Test cache nomap multiple cache eval") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset only has 3 records in it + eval_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) + decode_op = c_vision.Decode() + eval_dataset = eval_dataset.map(input_columns=["image"], operations=decode_op, cache=eval_cache) + + num_epoch = 5 + eval_iter = eval_dataset.create_dict_iterator(num_epochs=num_epoch) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in eval_iter]) == 3 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_nomap_multiple_cache_eval Ended.\n") + + if __name__ == '__main__': test_cache_nomap_basic1() test_cache_nomap_basic2()