diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/cache/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/cache/bindings.cc index 018611332af..0506c9e47d0 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/cache/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/cache/bindings.cc @@ -23,11 +23,13 @@ namespace dataset { PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) { (void)py::class_>(*m, "CacheClient") .def( - py::init([](session_id_type id, uint64_t mem_sz, bool spill, int32_t port, int32_t prefetch_sz) { + py::init([](session_id_type id, uint64_t mem_sz, bool spill, std::optional hostname, + std::optional port, int32_t prefetch_sz) { std::shared_ptr cc; CacheClient::Builder builder; - builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPort(port).SetPrefetchSize( - prefetch_sz); + builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPrefetchSize(prefetch_sz); + if (hostname) builder.SetHostname(hostname.value()); + if (port) builder.SetPort(port.value()); THROW_IF_ERROR(builder.Build(&cc)); return cc; })) diff --git a/mindspore/ccsrc/minddata/dataset/core/config_manager.cc b/mindspore/ccsrc/minddata/dataset/core/config_manager.cc index f505a6b187e..24f2261ad1e 100644 --- a/mindspore/ccsrc/minddata/dataset/core/config_manager.cc +++ b/mindspore/ccsrc/minddata/dataset/core/config_manager.cc @@ -19,10 +19,37 @@ #include #include +#include "mindspore/core/utils/log_adapter.h" #include "minddata/dataset/util/system_pool.h" namespace mindspore { namespace dataset { +ConfigManager::ConfigManager() + : rows_per_buffer_(kCfgRowsPerBuffer), + num_parallel_workers_(kCfgParallelWorkers), + worker_connector_size_(kCfgWorkerConnectorSize), + op_connector_size_(kCfgOpConnectorSize), + seed_(kCfgDefaultSeed), + monitor_sampling_interval_(kCfgMonitorSamplingInterval), + callback_timout_(kCfgCallbackTimeout), + cache_host_(kCfgDefaultCacheHost), + cache_port_(kCfgDefaultCachePort) { + auto env_cache_host = std::getenv("MS_CACHE_HOST"); + auto env_cache_port = std::getenv("MS_CACHE_PORT"); + if (env_cache_host) { + cache_host_ = env_cache_host; + } + if (env_cache_port) { + char *end = nullptr; + cache_port_ = strtol(env_cache_port, &end, 10); + if (*end != '\0') { + MS_LOG(WARNING) << "\nCache port from env variable MS_CACHE_PORT is invalid, back to use default " + << kCfgDefaultCachePort << std::endl; + cache_port_ = kCfgDefaultCachePort; + } + } +} + // A print method typically used for debugging void ConfigManager::Print(std::ostream &out) const { // Don't show the test/internal ones. Only display the main ones here. @@ -42,6 +69,8 @@ Status ConfigManager::FromJson(const nlohmann::json &j) { set_op_connector_size(j.value("opConnectorSize", op_connector_size_)); set_seed(j.value("seed", seed_)); set_monitor_sampling_interval(j.value("monitorSamplingInterval", monitor_sampling_interval_)); + set_cache_host(j.value("cacheHost", cache_host_)); + set_cache_port(j.value("cachePort", cache_port_)); return Status::OK(); } @@ -91,5 +120,8 @@ void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_s void ConfigManager::set_callback_timeout(uint32_t timeout) { callback_timout_ = timeout; } +void ConfigManager::set_cache_host(std::string cache_host) { cache_host_ = cache_host; } + +void ConfigManager::set_cache_port(int32_t cache_port) { cache_port_ = cache_port; } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/config_manager.h b/mindspore/ccsrc/minddata/dataset/core/config_manager.h index eb154b8f440..ff35826708b 100644 --- a/mindspore/ccsrc/minddata/dataset/core/config_manager.h +++ b/mindspore/ccsrc/minddata/dataset/core/config_manager.h @@ -41,7 +41,7 @@ namespace dataset { // those values. class ConfigManager { public: - ConfigManager() = default; + ConfigManager(); // destructor ~ConfigManager() = default; @@ -89,6 +89,14 @@ class ConfigManager { // @return The internal worker-to-master connector queue size int32_t worker_connector_size() const { return worker_connector_size_; } + // getter function + // @return The hostname of cache server + std::string cache_host() const { return cache_host_; } + + // getter function + // @return The port of cache server + int32_t cache_port() const { return cache_port_; } + // setter function // @param rows_per_buffer - The setting to apply to the config void set_rows_per_buffer(int32_t rows_per_buffer); @@ -105,6 +113,14 @@ class ConfigManager { // @param connector_size - The setting to apply to the config void set_op_connector_size(int32_t connector_size); + // setter function + // @param cache_host - The hostname of cache server + void set_cache_host(std::string cache_host); + + // setter function + // @param cache_port - The port of cache server + void set_cache_port(int32_t cache_port); + uint32_t seed() const; // setter function @@ -128,13 +144,15 @@ class ConfigManager { int32_t callback_timeout() const { return callback_timout_; } private: - int32_t rows_per_buffer_{kCfgRowsPerBuffer}; - int32_t num_parallel_workers_{kCfgParallelWorkers}; - int32_t worker_connector_size_{kCfgWorkerConnectorSize}; - int32_t op_connector_size_{kCfgOpConnectorSize}; - uint32_t seed_{kCfgDefaultSeed}; - uint32_t monitor_sampling_interval_{kCfgMonitorSamplingInterval}; - uint32_t callback_timout_{kCfgCallbackTimeout}; + int32_t rows_per_buffer_; + int32_t num_parallel_workers_; + int32_t worker_connector_size_; + int32_t op_connector_size_; + uint32_t seed_; + uint32_t monitor_sampling_interval_; + uint32_t callback_timout_; + std::string cache_host_; + int32_t cache_port_; // Private helper function that takes a nlohmann json format and populates the settings // @param j - The json nlohmann json info diff --git a/mindspore/ccsrc/minddata/dataset/core/constants.h b/mindspore/ccsrc/minddata/dataset/core/constants.h index 8b7911e5f6e..dc966c3844b 100644 --- a/mindspore/ccsrc/minddata/dataset/core/constants.h +++ b/mindspore/ccsrc/minddata/dataset/core/constants.h @@ -69,6 +69,8 @@ constexpr uint32_t kCfgOpConnectorSize = 16; constexpr uint32_t kCfgDefaultSeed = std::mt19937::default_seed; constexpr uint32_t kCfgMonitorSamplingInterval = 10; constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds +constexpr int32_t kCfgDefaultCachePort = 50052; +constexpr char kCfgDefaultCacheHost[] = "127.0.0.1"; // Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h) constexpr uint8_t kCVInvalidType = 255; diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc index 892f7842ef3..4075216a7e2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc @@ -25,21 +25,21 @@ #include "minddata/dataset/engine/cache/cache_request.h" #include "minddata/dataset/engine/cache/cache_client.h" #include "minddata/dataset/util/path.h" +#include "minddata/dataset/core/constants.h" namespace mindspore { namespace dataset { -const char CacheAdminArgHandler::kDefaultHost[] = "127.0.0.1"; const char CacheAdminArgHandler::kServerBinary[] = "cache_server"; const char CacheAdminArgHandler::kDefaultSpillDir[] = "/tmp"; CacheAdminArgHandler::CacheAdminArgHandler() - : port_(kDefaultPort), + : port_(kCfgDefaultCachePort), session_id_(0), num_workers_(kDefaultNumWorkers), shm_mem_sz_(kDefaultSharedMemorySizeInGB), log_level_(kDefaultLogLevel), - hostname_(kDefaultHost), + hostname_(kCfgDefaultCacheHost), spill_dir_(kDefaultSpillDir), command_id_(CommandId::kCmdUnknown) { // Initialize the command mappings @@ -376,6 +376,8 @@ Status CacheAdminArgHandler::StopServer() { RETURN_IF_NOT_OK(comm.ServiceStart()); auto rq = std::make_shared(); RETURN_IF_NOT_OK(comm.HandleRequest(rq)); + // We will ignore the rc because if the shutdown is successful, the server will not reply back. + (void)rq->Wait(); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h index e48916c482a..587ea15eb08 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h @@ -29,11 +29,9 @@ namespace dataset { class CacheAdminArgHandler { public: - static constexpr int32_t kDefaultPort = 50052; static constexpr int32_t kDefaultNumWorkers = 32; static constexpr int32_t kDefaultSharedMemorySizeInGB = 4; static constexpr int32_t kDefaultLogLevel = 1; - static const char kDefaultHost[]; static const char kServerBinary[]; static const char kDefaultSpillDir[]; diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc index 428e0f4eebd..b237a0f294f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc @@ -23,6 +23,35 @@ namespace mindspore { namespace dataset { +CacheClient::Builder::Builder() + : session_id_(0), cache_mem_sz_(0), spill_(false), hostname_(""), port_(0), num_workers_(0), prefetch_size_(0) { + std::shared_ptr cfg = GlobalContext::config_manager(); + hostname_ = cfg->cache_host(); + port_ = cfg->cache_port(); + num_workers_ = cfg->num_parallel_workers(); + prefetch_size_ = 20; // rows_per_buf is too small (1 by default). +} + +Status CacheClient::Builder::Build(std::shared_ptr *out) { + RETURN_UNEXPECTED_IF_NULL(out); + RETURN_IF_NOT_OK(SanityCheck()); + *out = + std::make_shared(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_, prefetch_size_); + return Status::OK(); +} + +Status CacheClient::Builder::SanityCheck() { + CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive"); + CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited"); + CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive"); + CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive"); + CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty"); + CHECK_FAIL_RETURN_UNEXPECTED(port_ > 0, "port must be positive"); + CHECK_FAIL_RETURN_UNEXPECTED(port_ <= 65535, "illegal port number"); + CHECK_FAIL_RETURN_UNEXPECTED(hostname_ == "127.0.0.1", + "now cache client has to be on the same host with cache server"); + return Status::OK(); +} // Constructor CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h index 9461dcbd533..7fc7a47816d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h @@ -44,13 +44,7 @@ class CacheClient { /// \brief A builder to help creating a CacheClient object class Builder { public: - Builder() : session_id_(0), cache_mem_sz_(0), spill_(false), port_(0), num_workers_(0), prefetch_size_(0) { - std::shared_ptr cfg = GlobalContext::config_manager(); - hostname_ = "127.0.0.1"; - port_ = 50052; - num_workers_ = cfg->num_parallel_workers(); - prefetch_size_ = 20; // rows_per_buf is too small (1 by default). - } + Builder(); ~Builder() = default; @@ -119,22 +113,9 @@ class CacheClient { int32_t getNumWorkers() const { return num_workers_; } int32_t getPrefetchSize() const { return prefetch_size_; } - Status SanityCheck() { - CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive"); - CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited"); - CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive"); - CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive"); - CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty"); - return Status::OK(); - } + Status SanityCheck(); - Status Build(std::shared_ptr *out) { - RETURN_UNEXPECTED_IF_NULL(out); - RETURN_IF_NOT_OK(SanityCheck()); - *out = std::make_shared(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_, - prefetch_size_); - return Status::OK(); - } + Status Build(std::shared_ptr *out); private: session_id_type session_id_; diff --git a/mindspore/dataset/engine/cache_client.py b/mindspore/dataset/engine/cache_client.py index 32a9829349a..6139f9e083c 100644 --- a/mindspore/dataset/engine/cache_client.py +++ b/mindspore/dataset/engine/cache_client.py @@ -20,24 +20,25 @@ from mindspore._c_dataengine import CacheClient from ..core.validator_helpers import type_check, check_uint32, check_uint64 + class DatasetCache: """ A client to interface with tensor caching service """ - def __init__(self, session_id=None, size=0, spilling=False, port=50052, prefetch_size=20): + def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, prefetch_size=20): check_uint32(session_id, "session_id") check_uint64(size, "size") type_check(spilling, (bool,), "spilling") - check_uint32(port, "port") check_uint32(prefetch_size, "prefetch size") self.session_id = session_id self.size = size self.spilling = spilling + self.hostname = hostname self.port = port self.prefetch_size = prefetch_size - self.cache_client = CacheClient(session_id, size, spilling, port, prefetch_size) + self.cache_client = CacheClient(session_id, size, spilling, hostname, port, prefetch_size) def GetStat(self): return self.cache_client.GetStat() @@ -51,6 +52,7 @@ class DatasetCache: new_cache.session_id = copy.deepcopy(self.session_id, memodict) new_cache.spilling = copy.deepcopy(self.spilling, memodict) new_cache.size = copy.deepcopy(self.size, memodict) + new_cache.hostname = copy.deepcopy(self.hostname, memodict) new_cache.port = copy.deepcopy(self.port, memodict) new_cache.prefetch_size = copy.deepcopy(self.prefetch_size, memodict) new_cache.cache_client = self.cache_client