forked from mindspore-Ecosystem/mindspore
!5085 Enable hostname parameter for DatasetCache
Merge pull request !5085 from lixiachen/cache_hostname
This commit is contained in:
commit
ecfe728963
|
@ -23,11 +23,13 @@ namespace dataset {
|
|||
PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) {
|
||||
(void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient")
|
||||
.def(
|
||||
py::init([](session_id_type id, uint64_t mem_sz, bool spill, int32_t port, int32_t prefetch_sz) {
|
||||
py::init([](session_id_type id, uint64_t mem_sz, bool spill, std::optional<std::string> hostname,
|
||||
std::optional<int32_t> port, int32_t prefetch_sz) {
|
||||
std::shared_ptr<CacheClient> cc;
|
||||
CacheClient::Builder builder;
|
||||
builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPort(port).SetPrefetchSize(
|
||||
prefetch_sz);
|
||||
builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPrefetchSize(prefetch_sz);
|
||||
if (hostname) builder.SetHostname(hostname.value());
|
||||
if (port) builder.SetPort(port.value());
|
||||
THROW_IF_ERROR(builder.Build(&cc));
|
||||
return cc;
|
||||
}))
|
||||
|
|
|
@ -19,10 +19,37 @@
|
|||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "mindspore/core/utils/log_adapter.h"
|
||||
#include "minddata/dataset/util/system_pool.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
ConfigManager::ConfigManager()
|
||||
: rows_per_buffer_(kCfgRowsPerBuffer),
|
||||
num_parallel_workers_(kCfgParallelWorkers),
|
||||
worker_connector_size_(kCfgWorkerConnectorSize),
|
||||
op_connector_size_(kCfgOpConnectorSize),
|
||||
seed_(kCfgDefaultSeed),
|
||||
monitor_sampling_interval_(kCfgMonitorSamplingInterval),
|
||||
callback_timout_(kCfgCallbackTimeout),
|
||||
cache_host_(kCfgDefaultCacheHost),
|
||||
cache_port_(kCfgDefaultCachePort) {
|
||||
auto env_cache_host = std::getenv("MS_CACHE_HOST");
|
||||
auto env_cache_port = std::getenv("MS_CACHE_PORT");
|
||||
if (env_cache_host) {
|
||||
cache_host_ = env_cache_host;
|
||||
}
|
||||
if (env_cache_port) {
|
||||
char *end = nullptr;
|
||||
cache_port_ = strtol(env_cache_port, &end, 10);
|
||||
if (*end != '\0') {
|
||||
MS_LOG(WARNING) << "\nCache port from env variable MS_CACHE_PORT is invalid, back to use default "
|
||||
<< kCfgDefaultCachePort << std::endl;
|
||||
cache_port_ = kCfgDefaultCachePort;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A print method typically used for debugging
|
||||
void ConfigManager::Print(std::ostream &out) const {
|
||||
// Don't show the test/internal ones. Only display the main ones here.
|
||||
|
@ -42,6 +69,8 @@ Status ConfigManager::FromJson(const nlohmann::json &j) {
|
|||
set_op_connector_size(j.value("opConnectorSize", op_connector_size_));
|
||||
set_seed(j.value("seed", seed_));
|
||||
set_monitor_sampling_interval(j.value("monitorSamplingInterval", monitor_sampling_interval_));
|
||||
set_cache_host(j.value("cacheHost", cache_host_));
|
||||
set_cache_port(j.value("cachePort", cache_port_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -91,5 +120,8 @@ void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_s
|
|||
|
||||
void ConfigManager::set_callback_timeout(uint32_t timeout) { callback_timout_ = timeout; }
|
||||
|
||||
void ConfigManager::set_cache_host(std::string cache_host) { cache_host_ = cache_host; }
|
||||
|
||||
void ConfigManager::set_cache_port(int32_t cache_port) { cache_port_ = cache_port; }
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -41,7 +41,7 @@ namespace dataset {
|
|||
// those values.
|
||||
class ConfigManager {
|
||||
public:
|
||||
ConfigManager() = default;
|
||||
ConfigManager();
|
||||
|
||||
// destructor
|
||||
~ConfigManager() = default;
|
||||
|
@ -89,6 +89,14 @@ class ConfigManager {
|
|||
// @return The internal worker-to-master connector queue size
|
||||
int32_t worker_connector_size() const { return worker_connector_size_; }
|
||||
|
||||
// getter function
|
||||
// @return The hostname of cache server
|
||||
std::string cache_host() const { return cache_host_; }
|
||||
|
||||
// getter function
|
||||
// @return The port of cache server
|
||||
int32_t cache_port() const { return cache_port_; }
|
||||
|
||||
// setter function
|
||||
// @param rows_per_buffer - The setting to apply to the config
|
||||
void set_rows_per_buffer(int32_t rows_per_buffer);
|
||||
|
@ -105,6 +113,14 @@ class ConfigManager {
|
|||
// @param connector_size - The setting to apply to the config
|
||||
void set_op_connector_size(int32_t connector_size);
|
||||
|
||||
// setter function
|
||||
// @param cache_host - The hostname of cache server
|
||||
void set_cache_host(std::string cache_host);
|
||||
|
||||
// setter function
|
||||
// @param cache_port - The port of cache server
|
||||
void set_cache_port(int32_t cache_port);
|
||||
|
||||
uint32_t seed() const;
|
||||
|
||||
// setter function
|
||||
|
@ -128,13 +144,15 @@ class ConfigManager {
|
|||
int32_t callback_timeout() const { return callback_timout_; }
|
||||
|
||||
private:
|
||||
int32_t rows_per_buffer_{kCfgRowsPerBuffer};
|
||||
int32_t num_parallel_workers_{kCfgParallelWorkers};
|
||||
int32_t worker_connector_size_{kCfgWorkerConnectorSize};
|
||||
int32_t op_connector_size_{kCfgOpConnectorSize};
|
||||
uint32_t seed_{kCfgDefaultSeed};
|
||||
uint32_t monitor_sampling_interval_{kCfgMonitorSamplingInterval};
|
||||
uint32_t callback_timout_{kCfgCallbackTimeout};
|
||||
int32_t rows_per_buffer_;
|
||||
int32_t num_parallel_workers_;
|
||||
int32_t worker_connector_size_;
|
||||
int32_t op_connector_size_;
|
||||
uint32_t seed_;
|
||||
uint32_t monitor_sampling_interval_;
|
||||
uint32_t callback_timout_;
|
||||
std::string cache_host_;
|
||||
int32_t cache_port_;
|
||||
|
||||
// Private helper function that takes a nlohmann json format and populates the settings
|
||||
// @param j - The json nlohmann json info
|
||||
|
|
|
@ -69,6 +69,8 @@ constexpr uint32_t kCfgOpConnectorSize = 16;
|
|||
constexpr uint32_t kCfgDefaultSeed = std::mt19937::default_seed;
|
||||
constexpr uint32_t kCfgMonitorSamplingInterval = 10;
|
||||
constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds
|
||||
constexpr int32_t kCfgDefaultCachePort = 50052;
|
||||
constexpr char kCfgDefaultCacheHost[] = "127.0.0.1";
|
||||
|
||||
// Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h)
|
||||
constexpr uint8_t kCVInvalidType = 255;
|
||||
|
|
|
@ -25,21 +25,21 @@
|
|||
#include "minddata/dataset/engine/cache/cache_request.h"
|
||||
#include "minddata/dataset/engine/cache/cache_client.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
const char CacheAdminArgHandler::kDefaultHost[] = "127.0.0.1";
|
||||
const char CacheAdminArgHandler::kServerBinary[] = "cache_server";
|
||||
const char CacheAdminArgHandler::kDefaultSpillDir[] = "/tmp";
|
||||
|
||||
CacheAdminArgHandler::CacheAdminArgHandler()
|
||||
: port_(kDefaultPort),
|
||||
: port_(kCfgDefaultCachePort),
|
||||
session_id_(0),
|
||||
num_workers_(kDefaultNumWorkers),
|
||||
shm_mem_sz_(kDefaultSharedMemorySizeInGB),
|
||||
log_level_(kDefaultLogLevel),
|
||||
hostname_(kDefaultHost),
|
||||
hostname_(kCfgDefaultCacheHost),
|
||||
spill_dir_(kDefaultSpillDir),
|
||||
command_id_(CommandId::kCmdUnknown) {
|
||||
// Initialize the command mappings
|
||||
|
@ -376,6 +376,8 @@ Status CacheAdminArgHandler::StopServer() {
|
|||
RETURN_IF_NOT_OK(comm.ServiceStart());
|
||||
auto rq = std::make_shared<ShutdownRequest>();
|
||||
RETURN_IF_NOT_OK(comm.HandleRequest(rq));
|
||||
// We will ignore the rc because if the shutdown is successful, the server will not reply back.
|
||||
(void)rq->Wait();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -29,11 +29,9 @@ namespace dataset {
|
|||
|
||||
class CacheAdminArgHandler {
|
||||
public:
|
||||
static constexpr int32_t kDefaultPort = 50052;
|
||||
static constexpr int32_t kDefaultNumWorkers = 32;
|
||||
static constexpr int32_t kDefaultSharedMemorySizeInGB = 4;
|
||||
static constexpr int32_t kDefaultLogLevel = 1;
|
||||
static const char kDefaultHost[];
|
||||
static const char kServerBinary[];
|
||||
static const char kDefaultSpillDir[];
|
||||
|
||||
|
|
|
@ -23,6 +23,35 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
CacheClient::Builder::Builder()
|
||||
: session_id_(0), cache_mem_sz_(0), spill_(false), hostname_(""), port_(0), num_workers_(0), prefetch_size_(0) {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
hostname_ = cfg->cache_host();
|
||||
port_ = cfg->cache_port();
|
||||
num_workers_ = cfg->num_parallel_workers();
|
||||
prefetch_size_ = 20; // rows_per_buf is too small (1 by default).
|
||||
}
|
||||
|
||||
Status CacheClient::Builder::Build(std::shared_ptr<CacheClient> *out) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*out =
|
||||
std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_, prefetch_size_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClient::Builder::SanityCheck() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(port_ > 0, "port must be positive");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(port_ <= 65535, "illegal port number");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(hostname_ == "127.0.0.1",
|
||||
"now cache client has to be on the same host with cache server");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Constructor
|
||||
CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname,
|
||||
|
|
|
@ -44,13 +44,7 @@ class CacheClient {
|
|||
/// \brief A builder to help creating a CacheClient object
|
||||
class Builder {
|
||||
public:
|
||||
Builder() : session_id_(0), cache_mem_sz_(0), spill_(false), port_(0), num_workers_(0), prefetch_size_(0) {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
hostname_ = "127.0.0.1";
|
||||
port_ = 50052;
|
||||
num_workers_ = cfg->num_parallel_workers();
|
||||
prefetch_size_ = 20; // rows_per_buf is too small (1 by default).
|
||||
}
|
||||
Builder();
|
||||
|
||||
~Builder() = default;
|
||||
|
||||
|
@ -119,22 +113,9 @@ class CacheClient {
|
|||
int32_t getNumWorkers() const { return num_workers_; }
|
||||
int32_t getPrefetchSize() const { return prefetch_size_; }
|
||||
|
||||
Status SanityCheck() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty");
|
||||
return Status::OK();
|
||||
}
|
||||
Status SanityCheck();
|
||||
|
||||
Status Build(std::shared_ptr<CacheClient> *out) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*out = std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_,
|
||||
prefetch_size_);
|
||||
return Status::OK();
|
||||
}
|
||||
Status Build(std::shared_ptr<CacheClient> *out);
|
||||
|
||||
private:
|
||||
session_id_type session_id_;
|
||||
|
|
|
@ -20,24 +20,25 @@ from mindspore._c_dataengine import CacheClient
|
|||
|
||||
from ..core.validator_helpers import type_check, check_uint32, check_uint64
|
||||
|
||||
|
||||
class DatasetCache:
|
||||
"""
|
||||
A client to interface with tensor caching service
|
||||
"""
|
||||
|
||||
def __init__(self, session_id=None, size=0, spilling=False, port=50052, prefetch_size=20):
|
||||
def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, prefetch_size=20):
|
||||
check_uint32(session_id, "session_id")
|
||||
check_uint64(size, "size")
|
||||
type_check(spilling, (bool,), "spilling")
|
||||
check_uint32(port, "port")
|
||||
check_uint32(prefetch_size, "prefetch size")
|
||||
|
||||
self.session_id = session_id
|
||||
self.size = size
|
||||
self.spilling = spilling
|
||||
self.hostname = hostname
|
||||
self.port = port
|
||||
self.prefetch_size = prefetch_size
|
||||
self.cache_client = CacheClient(session_id, size, spilling, port, prefetch_size)
|
||||
self.cache_client = CacheClient(session_id, size, spilling, hostname, port, prefetch_size)
|
||||
|
||||
def GetStat(self):
|
||||
return self.cache_client.GetStat()
|
||||
|
@ -51,6 +52,7 @@ class DatasetCache:
|
|||
new_cache.session_id = copy.deepcopy(self.session_id, memodict)
|
||||
new_cache.spilling = copy.deepcopy(self.spilling, memodict)
|
||||
new_cache.size = copy.deepcopy(self.size, memodict)
|
||||
new_cache.hostname = copy.deepcopy(self.hostname, memodict)
|
||||
new_cache.port = copy.deepcopy(self.port, memodict)
|
||||
new_cache.prefetch_size = copy.deepcopy(self.prefetch_size, memodict)
|
||||
new_cache.cache_client = self.cache_client
|
||||
|
|
Loading…
Reference in New Issue