!5085 Enable hostname parameter for DatasetCache

Merge pull request !5085 from lixiachen/cache_hostname
This commit is contained in:
mindspore-ci-bot 2020-09-02 04:56:06 +08:00 committed by Gitee
commit ecfe728963
9 changed files with 107 additions and 41 deletions

View File

@ -23,11 +23,13 @@ namespace dataset {
PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) { PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) {
(void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient") (void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient")
.def( .def(
py::init([](session_id_type id, uint64_t mem_sz, bool spill, int32_t port, int32_t prefetch_sz) { py::init([](session_id_type id, uint64_t mem_sz, bool spill, std::optional<std::string> hostname,
std::optional<int32_t> port, int32_t prefetch_sz) {
std::shared_ptr<CacheClient> cc; std::shared_ptr<CacheClient> cc;
CacheClient::Builder builder; CacheClient::Builder builder;
builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPort(port).SetPrefetchSize( builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPrefetchSize(prefetch_sz);
prefetch_sz); if (hostname) builder.SetHostname(hostname.value());
if (port) builder.SetPort(port.value());
THROW_IF_ERROR(builder.Build(&cc)); THROW_IF_ERROR(builder.Build(&cc));
return cc; return cc;
})) }))

View File

@ -19,10 +19,37 @@
#include <iostream> #include <iostream>
#include <string> #include <string>
#include "mindspore/core/utils/log_adapter.h"
#include "minddata/dataset/util/system_pool.h" #include "minddata/dataset/util/system_pool.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
ConfigManager::ConfigManager()
: rows_per_buffer_(kCfgRowsPerBuffer),
num_parallel_workers_(kCfgParallelWorkers),
worker_connector_size_(kCfgWorkerConnectorSize),
op_connector_size_(kCfgOpConnectorSize),
seed_(kCfgDefaultSeed),
monitor_sampling_interval_(kCfgMonitorSamplingInterval),
callback_timout_(kCfgCallbackTimeout),
cache_host_(kCfgDefaultCacheHost),
cache_port_(kCfgDefaultCachePort) {
auto env_cache_host = std::getenv("MS_CACHE_HOST");
auto env_cache_port = std::getenv("MS_CACHE_PORT");
if (env_cache_host) {
cache_host_ = env_cache_host;
}
if (env_cache_port) {
char *end = nullptr;
cache_port_ = strtol(env_cache_port, &end, 10);
if (*end != '\0') {
MS_LOG(WARNING) << "\nCache port from env variable MS_CACHE_PORT is invalid, back to use default "
<< kCfgDefaultCachePort << std::endl;
cache_port_ = kCfgDefaultCachePort;
}
}
}
// A print method typically used for debugging // A print method typically used for debugging
void ConfigManager::Print(std::ostream &out) const { void ConfigManager::Print(std::ostream &out) const {
// Don't show the test/internal ones. Only display the main ones here. // Don't show the test/internal ones. Only display the main ones here.
@ -42,6 +69,8 @@ Status ConfigManager::FromJson(const nlohmann::json &j) {
set_op_connector_size(j.value("opConnectorSize", op_connector_size_)); set_op_connector_size(j.value("opConnectorSize", op_connector_size_));
set_seed(j.value("seed", seed_)); set_seed(j.value("seed", seed_));
set_monitor_sampling_interval(j.value("monitorSamplingInterval", monitor_sampling_interval_)); set_monitor_sampling_interval(j.value("monitorSamplingInterval", monitor_sampling_interval_));
set_cache_host(j.value("cacheHost", cache_host_));
set_cache_port(j.value("cachePort", cache_port_));
return Status::OK(); return Status::OK();
} }
@ -91,5 +120,8 @@ void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_s
void ConfigManager::set_callback_timeout(uint32_t timeout) { callback_timout_ = timeout; } void ConfigManager::set_callback_timeout(uint32_t timeout) { callback_timout_ = timeout; }
void ConfigManager::set_cache_host(std::string cache_host) { cache_host_ = cache_host; }
void ConfigManager::set_cache_port(int32_t cache_port) { cache_port_ = cache_port; }
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -41,7 +41,7 @@ namespace dataset {
// those values. // those values.
class ConfigManager { class ConfigManager {
public: public:
ConfigManager() = default; ConfigManager();
// destructor // destructor
~ConfigManager() = default; ~ConfigManager() = default;
@ -89,6 +89,14 @@ class ConfigManager {
// @return The internal worker-to-master connector queue size // @return The internal worker-to-master connector queue size
int32_t worker_connector_size() const { return worker_connector_size_; } int32_t worker_connector_size() const { return worker_connector_size_; }
// getter function
// @return The hostname of cache server
std::string cache_host() const { return cache_host_; }
// getter function
// @return The port of cache server
int32_t cache_port() const { return cache_port_; }
// setter function // setter function
// @param rows_per_buffer - The setting to apply to the config // @param rows_per_buffer - The setting to apply to the config
void set_rows_per_buffer(int32_t rows_per_buffer); void set_rows_per_buffer(int32_t rows_per_buffer);
@ -105,6 +113,14 @@ class ConfigManager {
// @param connector_size - The setting to apply to the config // @param connector_size - The setting to apply to the config
void set_op_connector_size(int32_t connector_size); void set_op_connector_size(int32_t connector_size);
// setter function
// @param cache_host - The hostname of cache server
void set_cache_host(std::string cache_host);
// setter function
// @param cache_port - The port of cache server
void set_cache_port(int32_t cache_port);
uint32_t seed() const; uint32_t seed() const;
// setter function // setter function
@ -128,13 +144,15 @@ class ConfigManager {
int32_t callback_timeout() const { return callback_timout_; } int32_t callback_timeout() const { return callback_timout_; }
private: private:
int32_t rows_per_buffer_{kCfgRowsPerBuffer}; int32_t rows_per_buffer_;
int32_t num_parallel_workers_{kCfgParallelWorkers}; int32_t num_parallel_workers_;
int32_t worker_connector_size_{kCfgWorkerConnectorSize}; int32_t worker_connector_size_;
int32_t op_connector_size_{kCfgOpConnectorSize}; int32_t op_connector_size_;
uint32_t seed_{kCfgDefaultSeed}; uint32_t seed_;
uint32_t monitor_sampling_interval_{kCfgMonitorSamplingInterval}; uint32_t monitor_sampling_interval_;
uint32_t callback_timout_{kCfgCallbackTimeout}; uint32_t callback_timout_;
std::string cache_host_;
int32_t cache_port_;
// Private helper function that takes a nlohmann json format and populates the settings // Private helper function that takes a nlohmann json format and populates the settings
// @param j - The json nlohmann json info // @param j - The json nlohmann json info

View File

@ -69,6 +69,8 @@ constexpr uint32_t kCfgOpConnectorSize = 16;
constexpr uint32_t kCfgDefaultSeed = std::mt19937::default_seed; constexpr uint32_t kCfgDefaultSeed = std::mt19937::default_seed;
constexpr uint32_t kCfgMonitorSamplingInterval = 10; constexpr uint32_t kCfgMonitorSamplingInterval = 10;
constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds
constexpr int32_t kCfgDefaultCachePort = 50052;
constexpr char kCfgDefaultCacheHost[] = "127.0.0.1";
// Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h) // Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h)
constexpr uint8_t kCVInvalidType = 255; constexpr uint8_t kCVInvalidType = 255;

View File

@ -25,21 +25,21 @@
#include "minddata/dataset/engine/cache/cache_request.h" #include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/engine/cache/cache_client.h" #include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/util/path.h" #include "minddata/dataset/util/path.h"
#include "minddata/dataset/core/constants.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
const char CacheAdminArgHandler::kDefaultHost[] = "127.0.0.1";
const char CacheAdminArgHandler::kServerBinary[] = "cache_server"; const char CacheAdminArgHandler::kServerBinary[] = "cache_server";
const char CacheAdminArgHandler::kDefaultSpillDir[] = "/tmp"; const char CacheAdminArgHandler::kDefaultSpillDir[] = "/tmp";
CacheAdminArgHandler::CacheAdminArgHandler() CacheAdminArgHandler::CacheAdminArgHandler()
: port_(kDefaultPort), : port_(kCfgDefaultCachePort),
session_id_(0), session_id_(0),
num_workers_(kDefaultNumWorkers), num_workers_(kDefaultNumWorkers),
shm_mem_sz_(kDefaultSharedMemorySizeInGB), shm_mem_sz_(kDefaultSharedMemorySizeInGB),
log_level_(kDefaultLogLevel), log_level_(kDefaultLogLevel),
hostname_(kDefaultHost), hostname_(kCfgDefaultCacheHost),
spill_dir_(kDefaultSpillDir), spill_dir_(kDefaultSpillDir),
command_id_(CommandId::kCmdUnknown) { command_id_(CommandId::kCmdUnknown) {
// Initialize the command mappings // Initialize the command mappings
@ -376,6 +376,8 @@ Status CacheAdminArgHandler::StopServer() {
RETURN_IF_NOT_OK(comm.ServiceStart()); RETURN_IF_NOT_OK(comm.ServiceStart());
auto rq = std::make_shared<ShutdownRequest>(); auto rq = std::make_shared<ShutdownRequest>();
RETURN_IF_NOT_OK(comm.HandleRequest(rq)); RETURN_IF_NOT_OK(comm.HandleRequest(rq));
// We will ignore the rc because if the shutdown is successful, the server will not reply back.
(void)rq->Wait();
return Status::OK(); return Status::OK();
} }

View File

@ -29,11 +29,9 @@ namespace dataset {
class CacheAdminArgHandler { class CacheAdminArgHandler {
public: public:
static constexpr int32_t kDefaultPort = 50052;
static constexpr int32_t kDefaultNumWorkers = 32; static constexpr int32_t kDefaultNumWorkers = 32;
static constexpr int32_t kDefaultSharedMemorySizeInGB = 4; static constexpr int32_t kDefaultSharedMemorySizeInGB = 4;
static constexpr int32_t kDefaultLogLevel = 1; static constexpr int32_t kDefaultLogLevel = 1;
static const char kDefaultHost[];
static const char kServerBinary[]; static const char kServerBinary[];
static const char kDefaultSpillDir[]; static const char kDefaultSpillDir[];

View File

@ -23,6 +23,35 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
CacheClient::Builder::Builder()
: session_id_(0), cache_mem_sz_(0), spill_(false), hostname_(""), port_(0), num_workers_(0), prefetch_size_(0) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
hostname_ = cfg->cache_host();
port_ = cfg->cache_port();
num_workers_ = cfg->num_parallel_workers();
prefetch_size_ = 20; // rows_per_buf is too small (1 by default).
}
Status CacheClient::Builder::Build(std::shared_ptr<CacheClient> *out) {
RETURN_UNEXPECTED_IF_NULL(out);
RETURN_IF_NOT_OK(SanityCheck());
*out =
std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_, prefetch_size_);
return Status::OK();
}
Status CacheClient::Builder::SanityCheck() {
CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited");
CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty");
CHECK_FAIL_RETURN_UNEXPECTED(port_ > 0, "port must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(port_ <= 65535, "illegal port number");
CHECK_FAIL_RETURN_UNEXPECTED(hostname_ == "127.0.0.1",
"now cache client has to be on the same host with cache server");
return Status::OK();
}
// Constructor // Constructor
CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname,

View File

@ -44,13 +44,7 @@ class CacheClient {
/// \brief A builder to help creating a CacheClient object /// \brief A builder to help creating a CacheClient object
class Builder { class Builder {
public: public:
Builder() : session_id_(0), cache_mem_sz_(0), spill_(false), port_(0), num_workers_(0), prefetch_size_(0) { Builder();
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
hostname_ = "127.0.0.1";
port_ = 50052;
num_workers_ = cfg->num_parallel_workers();
prefetch_size_ = 20; // rows_per_buf is too small (1 by default).
}
~Builder() = default; ~Builder() = default;
@ -119,22 +113,9 @@ class CacheClient {
int32_t getNumWorkers() const { return num_workers_; } int32_t getNumWorkers() const { return num_workers_; }
int32_t getPrefetchSize() const { return prefetch_size_; } int32_t getPrefetchSize() const { return prefetch_size_; }
Status SanityCheck() { Status SanityCheck();
CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited");
CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty");
return Status::OK();
}
Status Build(std::shared_ptr<CacheClient> *out) { Status Build(std::shared_ptr<CacheClient> *out);
RETURN_UNEXPECTED_IF_NULL(out);
RETURN_IF_NOT_OK(SanityCheck());
*out = std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_,
prefetch_size_);
return Status::OK();
}
private: private:
session_id_type session_id_; session_id_type session_id_;

View File

@ -20,24 +20,25 @@ from mindspore._c_dataengine import CacheClient
from ..core.validator_helpers import type_check, check_uint32, check_uint64 from ..core.validator_helpers import type_check, check_uint32, check_uint64
class DatasetCache: class DatasetCache:
""" """
A client to interface with tensor caching service A client to interface with tensor caching service
""" """
def __init__(self, session_id=None, size=0, spilling=False, port=50052, prefetch_size=20): def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, prefetch_size=20):
check_uint32(session_id, "session_id") check_uint32(session_id, "session_id")
check_uint64(size, "size") check_uint64(size, "size")
type_check(spilling, (bool,), "spilling") type_check(spilling, (bool,), "spilling")
check_uint32(port, "port")
check_uint32(prefetch_size, "prefetch size") check_uint32(prefetch_size, "prefetch size")
self.session_id = session_id self.session_id = session_id
self.size = size self.size = size
self.spilling = spilling self.spilling = spilling
self.hostname = hostname
self.port = port self.port = port
self.prefetch_size = prefetch_size self.prefetch_size = prefetch_size
self.cache_client = CacheClient(session_id, size, spilling, port, prefetch_size) self.cache_client = CacheClient(session_id, size, spilling, hostname, port, prefetch_size)
def GetStat(self): def GetStat(self):
return self.cache_client.GetStat() return self.cache_client.GetStat()
@ -51,6 +52,7 @@ class DatasetCache:
new_cache.session_id = copy.deepcopy(self.session_id, memodict) new_cache.session_id = copy.deepcopy(self.session_id, memodict)
new_cache.spilling = copy.deepcopy(self.spilling, memodict) new_cache.spilling = copy.deepcopy(self.spilling, memodict)
new_cache.size = copy.deepcopy(self.size, memodict) new_cache.size = copy.deepcopy(self.size, memodict)
new_cache.hostname = copy.deepcopy(self.hostname, memodict)
new_cache.port = copy.deepcopy(self.port, memodict) new_cache.port = copy.deepcopy(self.port, memodict)
new_cache.prefetch_size = copy.deepcopy(self.prefetch_size, memodict) new_cache.prefetch_size = copy.deepcopy(self.prefetch_size, memodict)
new_cache.cache_client = self.cache_client new_cache.cache_client = self.cache_client