forked from mindspore-Ecosystem/mindspore
Rebase up to 88ded11f59
This commit is contained in:
parent
88ded11f59
commit
983827ec5c
|
@ -36,6 +36,7 @@ include(CPack)
|
|||
set(INSTALL_LIB_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Installation directory for libraries")
|
||||
set(INSTALL_PY_DIR ".")
|
||||
set(INSTALL_BASE_DIR ".")
|
||||
set(INSTALL_BIN_DIR "bin")
|
||||
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
set(INSTALL_LIB_DIR ".")
|
||||
|
@ -78,7 +79,14 @@ if (ENABLE_MINDDATA)
|
|||
DESTINATION ${INSTALL_BASE_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
install(
|
||||
TARGETS cache_admin cache_server
|
||||
OPTIONAL
|
||||
DESTINATION ${INSTALL_BIN_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
endif()
|
||||
file(GLOB_RECURSE OPENCV_LIB_LIST
|
||||
${opencv_LIBPATH}/libopencv_core*
|
||||
${opencv_LIBPATH}/libopencv_imgcodecs*
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <optional>
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/engine/cache/cache_client.h"
|
||||
|
||||
|
@ -22,14 +23,16 @@ namespace dataset {
|
|||
|
||||
PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) {
|
||||
(void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient")
|
||||
.def(
|
||||
py::init([](session_id_type id, uint64_t mem_sz, bool spill, std::optional<std::string> hostname,
|
||||
std::optional<int32_t> port, int32_t prefetch_sz) {
|
||||
.def(py::init([](session_id_type id, uint64_t mem_sz, bool spill,
|
||||
std::optional<std::string> hostname, std::optional<int32_t> port,
|
||||
std::optional<int32_t> num_connections, std::optional<int32_t> prefetch_sz) {
|
||||
std::shared_ptr<CacheClient> cc;
|
||||
CacheClient::Builder builder;
|
||||
builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPrefetchSize(prefetch_sz);
|
||||
builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill);
|
||||
if (hostname) builder.SetHostname(hostname.value());
|
||||
if (port) builder.SetPort(port.value());
|
||||
if (num_connections) builder.SetNumConnections(num_connections.value());
|
||||
if (prefetch_sz) builder.SetPrefetchSize(prefetch_sz.value());
|
||||
THROW_IF_ERROR(builder.Build(&cc));
|
||||
return cc;
|
||||
}))
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "mindspore/core/utils/log_adapter.h"
|
||||
#include "minddata/dataset/util/system_pool.h"
|
||||
|
@ -33,7 +34,9 @@ ConfigManager::ConfigManager()
|
|||
monitor_sampling_interval_(kCfgMonitorSamplingInterval),
|
||||
callback_timout_(kCfgCallbackTimeout),
|
||||
cache_host_(kCfgDefaultCacheHost),
|
||||
cache_port_(kCfgDefaultCachePort) {
|
||||
cache_port_(kCfgDefaultCachePort),
|
||||
num_connections_(kDftNumConnections),
|
||||
prefetch_size_(kDftPrefetchSize) {
|
||||
auto env_cache_host = std::getenv("MS_CACHE_HOST");
|
||||
auto env_cache_port = std::getenv("MS_CACHE_PORT");
|
||||
if (env_cache_host != nullptr) {
|
||||
|
@ -71,6 +74,8 @@ Status ConfigManager::FromJson(const nlohmann::json &j) {
|
|||
set_monitor_sampling_interval(j.value("monitorSamplingInterval", monitor_sampling_interval_));
|
||||
set_cache_host(j.value("cacheHost", cache_host_));
|
||||
set_cache_port(j.value("cachePort", cache_port_));
|
||||
set_num_connections(j.value("numConnections", num_connections_));
|
||||
set_prefetch_size(j.value("prefetchSize", prefetch_size_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -120,8 +125,12 @@ void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_s
|
|||
|
||||
void ConfigManager::set_callback_timeout(uint32_t timeout) { callback_timout_ = timeout; }
|
||||
|
||||
void ConfigManager::set_cache_host(std::string cache_host) { cache_host_ = cache_host; }
|
||||
void ConfigManager::set_cache_host(std::string cache_host) { cache_host_ = std::move(cache_host); }
|
||||
|
||||
void ConfigManager::set_cache_port(int32_t cache_port) { cache_port_ = cache_port; }
|
||||
|
||||
void ConfigManager::set_num_connections(int32_t num_connections) { num_connections_ = num_connections; }
|
||||
|
||||
void ConfigManager::set_prefetch_size(int32_t prefetch_size) { prefetch_size_ = prefetch_size; }
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -97,6 +97,14 @@ class ConfigManager {
|
|||
// @return The port of cache server
|
||||
int32_t cache_port() const { return cache_port_; }
|
||||
|
||||
/// getter function
|
||||
/// \return Number of tcp/ip connection
|
||||
int32_t num_connections() const { return num_connections_; }
|
||||
|
||||
/// getter function
|
||||
/// \return Prefetch size
|
||||
int32_t prefetch_size() const { return prefetch_size_; }
|
||||
|
||||
// setter function
|
||||
// @param rows_per_buffer - The setting to apply to the config
|
||||
void set_rows_per_buffer(int32_t rows_per_buffer);
|
||||
|
@ -121,6 +129,14 @@ class ConfigManager {
|
|||
// @param cache_port - The port of cache server
|
||||
void set_cache_port(int32_t cache_port);
|
||||
|
||||
/// setter function
|
||||
/// \param num_connections
|
||||
void set_num_connections(int32_t num_connections);
|
||||
|
||||
/// setter function
|
||||
/// \param prefetch_size
|
||||
void set_prefetch_size(int32_t prefetch_size);
|
||||
|
||||
uint32_t seed() const;
|
||||
|
||||
// setter function
|
||||
|
@ -153,6 +169,8 @@ class ConfigManager {
|
|||
uint32_t callback_timout_;
|
||||
std::string cache_host_;
|
||||
int32_t cache_port_;
|
||||
int32_t num_connections_;
|
||||
int32_t prefetch_size_;
|
||||
|
||||
// Private helper function that takes a nlohmann json format and populates the settings
|
||||
// @param j - The json nlohmann json info
|
||||
|
|
|
@ -71,6 +71,8 @@ constexpr uint32_t kCfgMonitorSamplingInterval = 10;
|
|||
constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds
|
||||
constexpr int32_t kCfgDefaultCachePort = 50052;
|
||||
constexpr char kCfgDefaultCacheHost[] = "127.0.0.1";
|
||||
constexpr int32_t kDftPrefetchSize = 20;
|
||||
constexpr int32_t kDftNumConnections = 12;
|
||||
|
||||
// Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h)
|
||||
constexpr uint8_t kCVInvalidType = 255;
|
||||
|
|
|
@ -79,6 +79,14 @@ class TensorRow {
|
|||
|
||||
const vector_type &getRow() const { return row_; }
|
||||
|
||||
int64_t SizeInBytes() const {
|
||||
size_t sz = 0;
|
||||
for (auto &it : row_) {
|
||||
sz += it->SizeInBytes();
|
||||
}
|
||||
return sz;
|
||||
}
|
||||
|
||||
// Wrapper functions to support vector operations
|
||||
void emplace_back(value_type t) { row_.emplace_back(t); }
|
||||
|
||||
|
|
|
@ -12,7 +12,9 @@ add_library(engine-cache-client OBJECT
|
|||
|
||||
if (ENABLE_CACHE)
|
||||
ms_grpc_generate(CACHE_GRPC_SRCS CACHE_GRPC_HDRS cache_grpc.proto)
|
||||
target_sources(engine-cache-client PUBLIC ${CACHE_GRPC_SRCS} cache_grpc_client.cc)
|
||||
target_sources(engine-cache-client PUBLIC ${CACHE_GRPC_SRCS}
|
||||
cache_grpc_client.cc
|
||||
cache_ipc.cc)
|
||||
|
||||
add_library(engine-cache-server OBJECT
|
||||
${CACHE_GRPC_SRCS}
|
||||
|
|
|
@ -37,12 +37,17 @@ int main(int argc, char **argv) {
|
|||
warningMsg += "WARNING:\n";
|
||||
warningMsg += "cache_admin and the cache server that it controls are currently only used for experimental research";
|
||||
warningMsg += " purposes at this time.\n";
|
||||
auto env_enable_cache = std::getenv("MS_ENABLE_CACHE");
|
||||
if (env_enable_cache == nullptr || strcmp(env_enable_cache, "TRUE") != 0) {
|
||||
// temporary disable cache feature in the current release
|
||||
warningMsg += "This command is currently disabled. Quitting.\n";
|
||||
std::cerr << warningMsg << std::endl;
|
||||
return 0;
|
||||
}
|
||||
warningMsg += "It is not intended for general availability yet as it may not be stable. Use it at your own risk.\n";
|
||||
|
||||
// A warning message until the code is mature enough.
|
||||
std::cerr << warningMsg << std::endl;
|
||||
// temporary disable cache feature in the current release
|
||||
return 0;
|
||||
|
||||
if (argc == 1) {
|
||||
args.Help();
|
||||
|
|
|
@ -19,9 +19,11 @@
|
|||
#include <sys/wait.h>
|
||||
#include <unistd.h>
|
||||
#include <cerrno>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/engine/cache/cache_request.h"
|
||||
#include "minddata/dataset/engine/cache/cache_client.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
|
@ -39,6 +41,7 @@ CacheAdminArgHandler::CacheAdminArgHandler()
|
|||
num_workers_(kDefaultNumWorkers),
|
||||
shm_mem_sz_(kDefaultSharedMemorySizeInGB),
|
||||
log_level_(kDefaultLogLevel),
|
||||
memory_cap_ratio_(kMemoryCapRatio),
|
||||
hostname_(kCfgDefaultCacheHost),
|
||||
spill_dir_(kDefaultSpillDir),
|
||||
command_id_(CommandId::kCmdUnknown) {
|
||||
|
@ -62,6 +65,9 @@ CacheAdminArgHandler::CacheAdminArgHandler()
|
|||
arg_map_["--shared_memory_size"] = ArgValue::kArgSharedMemorySize;
|
||||
arg_map_["-l"] = ArgValue::kArgLogLevel;
|
||||
arg_map_["--minloglevel"] = ArgValue::kArgLogLevel;
|
||||
arg_map_["-r"] = ArgValue::kArgMemoryCapRatio;
|
||||
arg_map_["--memory_cap_ratio"] = ArgValue::kArgMemoryCapRatio;
|
||||
arg_map_["--list_sessions"] = ArgValue::kArgListSessions;
|
||||
// Initialize argument tracker with false values
|
||||
for (int16_t i = 0; i < static_cast<int16_t>(ArgValue::kArgNumArgs); ++i) {
|
||||
ArgValue currAV = static_cast<ArgValue>(i);
|
||||
|
@ -69,6 +75,8 @@ CacheAdminArgHandler::CacheAdminArgHandler()
|
|||
}
|
||||
}
|
||||
|
||||
CacheAdminArgHandler::~CacheAdminArgHandler() = default;
|
||||
|
||||
Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream,
|
||||
CommandId command_id) {
|
||||
// Detect if the user tried to provide this argument more than once
|
||||
|
@ -102,7 +110,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std
|
|||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
}
|
||||
|
||||
// Now, attempt to convert the value into it's string format for output
|
||||
// Now, attempt to convert the value into it's numeric format for output
|
||||
try {
|
||||
*out_arg = std::stoul(value_as_string);
|
||||
} catch (const std::exception &e) {
|
||||
|
@ -140,7 +148,13 @@ Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg,
|
|||
// If there is no argument to get, such as the --start command, then out_arg will be a nullptr.
|
||||
if (out_arg != nullptr) {
|
||||
// Fetch the argument from the arg stream into a string
|
||||
if (arg_stream->rdbuf()->in_avail() != 0) {
|
||||
*arg_stream >> *out_arg;
|
||||
} else {
|
||||
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
}
|
||||
|
||||
if (out_arg->empty()) {
|
||||
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
|
@ -150,12 +164,62 @@ Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheAdminArgHandler::AssignArg(std::string option, float *out_arg, std::stringstream *arg_stream,
|
||||
CommandId command_id) {
|
||||
// Detect if the user tried to provide this argument more than once
|
||||
ArgValue selected_arg = arg_map_[option];
|
||||
if (used_args_[selected_arg]) {
|
||||
std::string err_msg = "The " + option + " argument was given more than once.";
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
}
|
||||
|
||||
// Flag that this arg is used now
|
||||
used_args_[selected_arg] = true;
|
||||
|
||||
// Some options are just arguments, for example "--hostname "127.0.0.1" is not a command, it's just an argument.
|
||||
// Other options are actual commands, for example "--start".
|
||||
// If this option is also a command, make sure there has not been multiple commands given before assigning it.
|
||||
if (command_id != CommandId::kCmdUnknown) {
|
||||
if (command_id_ != CommandId::kCmdUnknown) {
|
||||
std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option;
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
} else {
|
||||
command_id_ = command_id;
|
||||
}
|
||||
}
|
||||
|
||||
std::string value_as_string;
|
||||
|
||||
// Fetch the argument from the arg stream into a string
|
||||
*arg_stream >> value_as_string;
|
||||
if (value_as_string.empty()) {
|
||||
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
}
|
||||
|
||||
// Now, attempt to convert the value into it's string format for output
|
||||
try {
|
||||
*out_arg = std::stof(value_as_string, nullptr);
|
||||
} catch (const std::exception &e) {
|
||||
std::string err_msg = "Invalid numeric value: " + value_as_string;
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheAdminArgHandler::ParseArgStream(std::stringstream *arg_stream) {
|
||||
std::string tok;
|
||||
while (*arg_stream >> tok) {
|
||||
switch (arg_map_[tok]) {
|
||||
case ArgValue::kArgHost: {
|
||||
RETURN_IF_NOT_OK(AssignArg(tok, &hostname_, arg_stream));
|
||||
// Temporary sanity check. We only support localhost for now
|
||||
if (hostname_ != std::string(kCfgDefaultCacheHost)) {
|
||||
std::string err_msg =
|
||||
"Invalid host interface: " + hostname_ + ". Current limitation, only 127.0.0.1 can be used.";
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case ArgValue::kArgPort: {
|
||||
|
@ -203,6 +267,14 @@ Status CacheAdminArgHandler::ParseArgStream(std::stringstream *arg_stream) {
|
|||
RETURN_IF_NOT_OK(AssignArg(tok, &log_level_, arg_stream));
|
||||
break;
|
||||
}
|
||||
case ArgValue::kArgMemoryCapRatio: {
|
||||
RETURN_IF_NOT_OK(AssignArg(tok, &memory_cap_ratio_, arg_stream));
|
||||
break;
|
||||
}
|
||||
case ArgValue::kArgListSessions: {
|
||||
RETURN_IF_NOT_OK(AssignArg(tok, static_cast<std::string *>(nullptr), arg_stream, CommandId::kCmdListSessions));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
// Save space delimited trailing arguments
|
||||
trailing_args_ += (" " + tok);
|
||||
|
@ -232,9 +304,12 @@ Status CacheAdminArgHandler::Validate() {
|
|||
}
|
||||
|
||||
// Additional checks here
|
||||
if (num_workers_ < 1) return Status(StatusCode::kSyntaxError, "Number of workers must be positive value.");
|
||||
if (num_workers_ < 1 || num_workers_ > 100)
|
||||
return Status(StatusCode::kSyntaxError, "Number of workers must be in range of 1 and 100.");
|
||||
if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kSyntaxError, "Log level must be in range (0..3).");
|
||||
// port range check?
|
||||
if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1)
|
||||
return Status(StatusCode::kSyntaxError, "Memory cap ratio should be positive and no greater than 1");
|
||||
if (port_ < 1025 || port_ > 65535) return Status(StatusCode::kSyntaxError, "Port must be in range (1025..65535).");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -245,12 +320,9 @@ Status CacheAdminArgHandler::RunCommand() {
|
|||
Help();
|
||||
break;
|
||||
}
|
||||
case CommandId::kCmdStart: {
|
||||
RETURN_IF_NOT_OK(StartServer());
|
||||
break;
|
||||
}
|
||||
case CommandId::kCmdStart:
|
||||
case CommandId::kCmdStop: {
|
||||
RETURN_IF_NOT_OK(StopServer());
|
||||
RETURN_IF_NOT_OK(StartStopServer(command_id_));
|
||||
break;
|
||||
}
|
||||
case CommandId::kCmdGenerateSession: {
|
||||
|
@ -259,7 +331,7 @@ Status CacheAdminArgHandler::RunCommand() {
|
|||
auto rq = std::make_shared<GenerateSessionIdRequest>();
|
||||
RETURN_IF_NOT_OK(comm.HandleRequest(rq));
|
||||
RETURN_IF_NOT_OK(rq->Wait());
|
||||
std::cout << rq->GetSessionId() << std::endl;
|
||||
std::cout << "Session: " << rq->GetSessionId() << std::endl;
|
||||
break;
|
||||
}
|
||||
case CommandId::kCmdDestroySession: {
|
||||
|
@ -273,6 +345,39 @@ Status CacheAdminArgHandler::RunCommand() {
|
|||
std::cout << "Drop session successful" << std::endl;
|
||||
break;
|
||||
}
|
||||
case CommandId::kCmdListSessions: {
|
||||
CacheClientGreeter comm(hostname_, port_, 1);
|
||||
RETURN_IF_NOT_OK(comm.ServiceStart());
|
||||
auto rq = std::make_shared<ListSessionsRequest>();
|
||||
RETURN_IF_NOT_OK(comm.HandleRequest(rq));
|
||||
RETURN_IF_NOT_OK(rq->Wait());
|
||||
std::vector<SessionCacheInfo> session_info = rq->GetSessionCacheInfo();
|
||||
if (!session_info.empty()) {
|
||||
std::cout << std::setw(12) << "Session" << std::setw(12) << "Cache Id" << std::setw(12) << "Mem cached"
|
||||
<< std::setw(12) << "Disk cached" << std::setw(16) << "Avg cache size" << std::endl;
|
||||
for (auto curr_session : session_info) {
|
||||
std::string cache_id;
|
||||
std::string stat_mem_cached;
|
||||
std::string stat_disk_cached;
|
||||
std::string stat_avg_cached;
|
||||
int32_t crc = (curr_session.connection_id & 0x00000000FFFFFFFF);
|
||||
cache_id = (curr_session.connection_id == 0) ? "n/a" : std::to_string(crc);
|
||||
stat_mem_cached =
|
||||
(curr_session.stats.num_mem_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_mem_cached);
|
||||
stat_disk_cached =
|
||||
(curr_session.stats.num_disk_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_disk_cached);
|
||||
stat_avg_cached =
|
||||
(curr_session.stats.avg_cache_sz == 0) ? "n/a" : std::to_string(curr_session.stats.avg_cache_sz);
|
||||
|
||||
std::cout << std::setw(12) << curr_session.session_id << std::setw(12) << cache_id << std::setw(12)
|
||||
<< stat_mem_cached << std::setw(12) << stat_disk_cached << std::setw(16) << stat_avg_cached
|
||||
<< std::endl;
|
||||
}
|
||||
} else {
|
||||
std::cout << "No active sessions." << std::endl;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid cache admin command id.");
|
||||
break;
|
||||
|
@ -282,7 +387,7 @@ Status CacheAdminArgHandler::RunCommand() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheAdminArgHandler::StartServer() {
|
||||
Status CacheAdminArgHandler::StartStopServer(CommandId command_id) {
|
||||
// There currently does not exist any "install path" or method to identify which path the installed binaries will
|
||||
// exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the
|
||||
// cache_admin binary (this process).
|
||||
|
@ -324,7 +429,10 @@ Status CacheAdminArgHandler::StartServer() {
|
|||
close(fd[1]);
|
||||
dup2(fd[0], 0);
|
||||
close(fd[0]);
|
||||
wait(nullptr);
|
||||
int status;
|
||||
if (waitpid(pid, &status, 0) == -1) {
|
||||
RETURN_STATUS_UNEXPECTED("waitpid fails. errno = " + std::to_string(errno));
|
||||
}
|
||||
std::string msg;
|
||||
const int32_t buf_sz = 1024;
|
||||
msg.resize(buf_sz);
|
||||
|
@ -335,6 +443,13 @@ Status CacheAdminArgHandler::StartServer() {
|
|||
}
|
||||
msg.resize(n);
|
||||
std::cout << msg << std::endl;
|
||||
if (WIFEXITED(status)) {
|
||||
auto exit_status = WEXITSTATUS(status);
|
||||
if (exit_status) {
|
||||
std::string errMsg = "Child exit status " + std::to_string(exit_status);
|
||||
return Status(StatusCode::kUnexpectedError, errMsg);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
} else {
|
||||
// Child here ...
|
||||
|
@ -350,19 +465,29 @@ Status CacheAdminArgHandler::StartServer() {
|
|||
std::string shared_memory_string = std::to_string(shm_mem_sz_);
|
||||
std::string minloglevel_string = std::to_string(log_level_);
|
||||
std::string daemonize_string = "true";
|
||||
std::string memory_cap_ratio_string = std::to_string(memory_cap_ratio_);
|
||||
|
||||
char *argv[8];
|
||||
argv[0] = cache_server_binary.data(); // First arg is usually the binary name
|
||||
char *argv[9];
|
||||
if (command_id == CommandId::kCmdStart) {
|
||||
argv[0] = cache_server_binary.data();
|
||||
argv[1] = spill_dir_.data();
|
||||
argv[2] = workers_string.data();
|
||||
argv[3] = port_string.data();
|
||||
argv[4] = shared_memory_string.data();
|
||||
argv[5] = minloglevel_string.data();
|
||||
argv[6] = daemonize_string.data();
|
||||
argv[7] = nullptr;
|
||||
argv[7] = memory_cap_ratio_string.data();
|
||||
argv[8] = nullptr;
|
||||
} else {
|
||||
// We are doing a --stop. Change the name to '-' and we also need the port number.
|
||||
// The rest we don't need.
|
||||
argv[0] = std::string("-").data();
|
||||
argv[1] = port_string.data();
|
||||
argv[2] = nullptr;
|
||||
}
|
||||
|
||||
// Now exec the binary
|
||||
execv(argv[0], argv);
|
||||
execv(cache_server_binary.data(), argv);
|
||||
// If the exec was successful, this line will never be reached due to process image being replaced.
|
||||
// ..unless exec failed.
|
||||
std::string err_msg = "Failed to exec cache server: " + cache_server_binary;
|
||||
|
@ -371,16 +496,6 @@ Status CacheAdminArgHandler::StartServer() {
|
|||
}
|
||||
}
|
||||
|
||||
Status CacheAdminArgHandler::StopServer() {
|
||||
CacheClientGreeter comm(hostname_, port_, 1);
|
||||
RETURN_IF_NOT_OK(comm.ServiceStart());
|
||||
auto rq = std::make_shared<ShutdownRequest>();
|
||||
RETURN_IF_NOT_OK(comm.HandleRequest(rq));
|
||||
// We will ignore the rc because if the shutdown is successful, the server will not reply back.
|
||||
(void)rq->Wait();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void CacheAdminArgHandler::Help() {
|
||||
std::cerr << "Syntax:\n";
|
||||
std::cerr << " cache_admin [--start | --stop]\n";
|
||||
|
@ -390,8 +505,12 @@ void CacheAdminArgHandler::Help() {
|
|||
std::cerr << " [ [-d | --destroy_session] <session id> ]\n";
|
||||
std::cerr << " [ [-w | --workers] <number of workers> ]\n";
|
||||
std::cerr << " [ [-s | --spilldir] <spilling directory> ]\n";
|
||||
std::cerr << " [ [-m | --shared_memory_size] <shared memory size> ]\n";
|
||||
std::cerr << " [ [-l | --minloglevel] <log level> ]\n";
|
||||
std::cerr << " [ --list_sessions ]\n";
|
||||
// Do not expose these option to the user via help or documentation, but the options do exist to aid with
|
||||
// development and tuning.
|
||||
// std::cerr << " [ [-m | --shared_memory_size] <shared memory size> ]\n";
|
||||
// std::cerr << " [ [-r | --memory_cap_ratio] <float percent value>]\n";
|
||||
std::cerr << " [--help]" << std::endl;
|
||||
}
|
||||
} // namespace dataset
|
||||
|
|
|
@ -32,6 +32,7 @@ class CacheAdminArgHandler {
|
|||
static constexpr int32_t kDefaultNumWorkers = 32;
|
||||
static constexpr int32_t kDefaultSharedMemorySizeInGB = 4;
|
||||
static constexpr int32_t kDefaultLogLevel = 1;
|
||||
static constexpr float kMemoryCapRatio = 0.8;
|
||||
static const char kServerBinary[];
|
||||
static const char kDefaultSpillDir[];
|
||||
|
||||
|
@ -42,12 +43,13 @@ class CacheAdminArgHandler {
|
|||
kCmdStop = 2,
|
||||
kCmdGenerateSession = 3,
|
||||
kCmdDestroySession = 4,
|
||||
kCmdListSessions = 5,
|
||||
kCmdUnknown = 32767
|
||||
};
|
||||
|
||||
CacheAdminArgHandler();
|
||||
|
||||
~CacheAdminArgHandler() = default;
|
||||
virtual ~CacheAdminArgHandler();
|
||||
|
||||
Status ParseArgStream(std::stringstream *arg_stream);
|
||||
|
||||
|
@ -70,12 +72,12 @@ class CacheAdminArgHandler {
|
|||
kArgNumWorkers = 9,
|
||||
kArgSharedMemorySize = 10,
|
||||
kArgLogLevel = 11,
|
||||
kArgNumArgs = 12 // Must be the last position to provide a count
|
||||
kArgMemoryCapRatio = 12,
|
||||
kArgListSessions = 13,
|
||||
kArgNumArgs = 14 // Must be the last position to provide a count
|
||||
};
|
||||
|
||||
Status StartServer();
|
||||
|
||||
Status StopServer();
|
||||
Status StartStopServer(CommandId);
|
||||
|
||||
Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream,
|
||||
CommandId command_id = CommandId::kCmdUnknown);
|
||||
|
@ -83,6 +85,9 @@ class CacheAdminArgHandler {
|
|||
Status AssignArg(std::string option, std::string *out_arg, std::stringstream *arg_stream,
|
||||
CommandId command_id = CommandId::kCmdUnknown);
|
||||
|
||||
Status AssignArg(std::string option, float *out_arg, std::stringstream *arg_stream,
|
||||
CommandId command_id = CommandId::kCmdUnknown);
|
||||
|
||||
Status Validate();
|
||||
|
||||
CommandId command_id_;
|
||||
|
@ -90,6 +95,7 @@ class CacheAdminArgHandler {
|
|||
int32_t num_workers_;
|
||||
int32_t shm_mem_sz_;
|
||||
int32_t log_level_;
|
||||
float memory_cap_ratio_;
|
||||
session_id_type session_id_;
|
||||
std::string hostname_;
|
||||
std::string spill_dir_;
|
||||
|
|
|
@ -17,27 +17,19 @@
|
|||
#include "minddata/dataset/util/path.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB)
|
||||
: ptr_(nullptr), val_in_GB_(val_in_GB), port_(port), shmid_(-1) {}
|
||||
CachedSharedMemoryArena::~CachedSharedMemoryArena() {
|
||||
#if CACHE_LOCAL_CLIENT
|
||||
if (this->ptr_ != nullptr && this->ptr_ != reinterpret_cast<void *>(-1)) {
|
||||
shmdt(this->ptr_);
|
||||
CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB) : val_in_GB_(val_in_GB), port_(port) {
|
||||
// We create the shared memory and we will destroy it. All other client just detach only.
|
||||
shm_.RemoveResourcesOnExit();
|
||||
}
|
||||
this->ptr_ = nullptr;
|
||||
if (shmid_ != -1) {
|
||||
shmctl(shmid_, IPC_RMID, nullptr);
|
||||
CachedSharedMemoryArena::~CachedSharedMemoryArena() {
|
||||
// Also remove the path we use to generate ftok.
|
||||
Path p(PortToUnixSocketPath(port_));
|
||||
(void)p.Remove();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port,
|
||||
size_t val_in_GB) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
#if CACHE_LOCAL_CLIENT
|
||||
auto ba = new (std::nothrow) CachedSharedMemoryArena(port, val_in_GB);
|
||||
if (ba == nullptr) {
|
||||
return Status(StatusCode::kOutOfMemory);
|
||||
|
@ -46,26 +38,13 @@ Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryAr
|
|||
// the destructor of *out to deal.
|
||||
(*out).reset(ba);
|
||||
// Generate the ftok using a combination of port.
|
||||
int err;
|
||||
auto shm_key = PortToFtok(port, &err);
|
||||
if (shm_key == (key_t)-1) {
|
||||
std::string errMsg = "Ftok failed with errno " + std::to_string(err);
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP;
|
||||
SharedMemory::shm_key_t shm_key;
|
||||
RETURN_IF_NOT_OK(PortToFtok(port, &shm_key));
|
||||
ba->shm_.SetPublicKey(shm_key);
|
||||
// Value is in GB. Convert into bytes.
|
||||
int64_t sz = val_in_GB * 1073741824L;
|
||||
ba->shmid_ = shmget(shm_key, sz, IPC_CREAT | IPC_EXCL | access_mode);
|
||||
if (ba->shmid_) {
|
||||
ba->ptr_ = shmat(ba->shmid_, nullptr, 0);
|
||||
if (ba->ptr_ == reinterpret_cast<void *>(-1)) {
|
||||
RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno));
|
||||
}
|
||||
ba->impl_ = std::make_unique<ArenaImpl>(ba->ptr_, sz);
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Shared memory creation failed. Errno " + std::to_string(errno));
|
||||
}
|
||||
#endif
|
||||
RETURN_IF_NOT_OK(ba->shm_.Create(sz));
|
||||
ba->impl_ = std::make_unique<ArenaImpl>(ba->shm_.SharedMemoryBaseAddr(), sz);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <string>
|
||||
#include "minddata/dataset/util/arena.h"
|
||||
#include "minddata/dataset/engine/cache/cache_common.h"
|
||||
#include "minddata/dataset/engine/cache/cache_ipc.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// This is a derived class of Arena but resides in shared memory
|
||||
|
@ -73,10 +74,9 @@ class CachedSharedMemoryArena : public MemoryPool {
|
|||
|
||||
private:
|
||||
mutable std::mutex mux_;
|
||||
void *ptr_;
|
||||
int32_t val_in_GB_;
|
||||
int32_t port_;
|
||||
int shmid_;
|
||||
SharedMemory shm_;
|
||||
std::unique_ptr<ArenaImpl> impl_;
|
||||
/// Private constructor. Not to be called directly.
|
||||
CachedSharedMemoryArena(int32_t port, size_t val_in_GB);
|
||||
|
|
|
@ -24,26 +24,26 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
CacheClient::Builder::Builder()
|
||||
: session_id_(0), cache_mem_sz_(0), spill_(false), hostname_(""), port_(0), num_workers_(0), prefetch_size_(0) {
|
||||
: session_id_(0), cache_mem_sz_(0), spill_(false), hostname_(""), port_(0), num_connections_(0), prefetch_size_(0) {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
hostname_ = cfg->cache_host();
|
||||
port_ = cfg->cache_port();
|
||||
num_workers_ = cfg->num_parallel_workers();
|
||||
prefetch_size_ = 20; // rows_per_buf is too small (1 by default).
|
||||
num_connections_ = cfg->num_connections(); // number of async tcp/ip connections
|
||||
prefetch_size_ = cfg->prefetch_size(); // prefetch size
|
||||
}
|
||||
|
||||
Status CacheClient::Builder::Build(std::shared_ptr<CacheClient> *out) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*out =
|
||||
std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_, prefetch_size_);
|
||||
*out = std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_connections_,
|
||||
prefetch_size_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClient::Builder::SanityCheck() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_connections_ > 0, "rpc connections must be positive");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(port_ > 0, "port must be positive");
|
||||
|
@ -55,26 +55,32 @@ Status CacheClient::Builder::SanityCheck() {
|
|||
|
||||
// Constructor
|
||||
CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname,
|
||||
int32_t port, int32_t num_workers, int32_t prefetch_size)
|
||||
int32_t port, int32_t num_connections, int32_t prefetch_size)
|
||||
: server_connection_id_(0),
|
||||
cache_mem_sz_(cache_mem_sz),
|
||||
spill_(spill),
|
||||
local_bypass_(false),
|
||||
hostname_(std::move(hostname)),
|
||||
port_(port),
|
||||
num_workers_(num_workers),
|
||||
prefetch_size_(prefetch_size) {
|
||||
num_connections_(num_connections),
|
||||
prefetch_size_(prefetch_size),
|
||||
fetch_all_keys_(true) {
|
||||
cinfo_.set_session_id(session_id);
|
||||
comm_ = std::make_shared<CacheClientGreeter>(hostname_, port_, num_workers_);
|
||||
comm_ = std::make_shared<CacheClientGreeter>(hostname_, port_, num_connections_);
|
||||
}
|
||||
|
||||
CacheClient::~CacheClient() {
|
||||
cache_miss_keys_wp_.Set();
|
||||
(void)comm_->ServiceStop();
|
||||
}
|
||||
|
||||
// print method for display cache details
|
||||
void CacheClient::Print(std::ostream &out) const {
|
||||
out << " Session id: " << session_id() << "\n Cache crc: " << cinfo_.crc()
|
||||
<< "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << getCacheMemSz()
|
||||
<< "\n Spilling: " << std::boolalpha << isSpill() << "\n Hostname: " << getHostname()
|
||||
<< "\n Port: " << getPort() << "\n Number of rpc workers: " << getNumWorkers()
|
||||
<< "\n Prefetch size: " << getPrefetchSize() << "\n Local client support: " << std::boolalpha
|
||||
<< "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << GetCacheMemSz()
|
||||
<< "\n Spilling: " << std::boolalpha << isSpill() << "\n Hostname: " << GetHostname()
|
||||
<< "\n Port: " << GetPort() << "\n Number of rpc workers: " << GetNumConnections()
|
||||
<< "\n Prefetch size: " << GetPrefetchSize() << "\n Local client support: " << std::boolalpha
|
||||
<< SupportLocalClient();
|
||||
}
|
||||
|
||||
|
@ -199,14 +205,6 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClient::PurgeCache() {
|
||||
UniqueLock lck(&mux_);
|
||||
auto rq = std::make_shared<PurgeCacheRequest>(server_connection_id_);
|
||||
RETURN_IF_NOT_OK(PushRequest(rq));
|
||||
RETURN_IF_NOT_OK(rq->Wait());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClient::DestroyCache() {
|
||||
UniqueLock lck(&mux_);
|
||||
auto rq = std::make_shared<DestroyCacheRequest>(server_connection_id_);
|
||||
|
@ -253,5 +251,71 @@ Status CacheClient::BuildPhaseDone() const {
|
|||
}
|
||||
|
||||
Status CacheClient::PushRequest(std::shared_ptr<BaseRequest> rq) const { return comm_->HandleRequest(std::move(rq)); }
|
||||
|
||||
void CacheClient::ServerRunningOutOfResources() {
|
||||
bool expected = true;
|
||||
if (fetch_all_keys_.compare_exchange_strong(expected, false)) {
|
||||
Status rc;
|
||||
// Server runs out of memory or disk space to cache any more rows.
|
||||
// First of all, we will turn off the locking.
|
||||
auto toggle_write_mode_rq = std::make_shared<ToggleWriteModeRequest>(server_connection_id_, false);
|
||||
rc = PushRequest(toggle_write_mode_rq);
|
||||
if (rc.IsError()) {
|
||||
return;
|
||||
}
|
||||
// Wait until we can toggle the state of the server to non-locking
|
||||
rc = toggle_write_mode_rq->Wait();
|
||||
if (rc.IsError()) {
|
||||
return;
|
||||
}
|
||||
// Now we get a list of all the keys not cached at the server so
|
||||
// we can filter out at the prefetch level.
|
||||
auto cache_miss_rq = std::make_shared<GetCacheMissKeysRequest>(server_connection_id_);
|
||||
rc = PushRequest(cache_miss_rq);
|
||||
if (rc.IsError()) {
|
||||
return;
|
||||
}
|
||||
rc = cache_miss_rq->Wait();
|
||||
if (rc.IsError()) {
|
||||
return;
|
||||
}
|
||||
// We will get back a vector of row id between [min,max] that are absent in the server.
|
||||
auto &row_id_buf = cache_miss_rq->reply_.result();
|
||||
auto p = flatbuffers::GetRoot<TensorRowIds>(row_id_buf.data());
|
||||
std::vector<row_id_type> row_ids;
|
||||
auto sz = p->row_id()->size();
|
||||
row_ids.reserve(sz);
|
||||
for (auto i = 0; i < sz; ++i) {
|
||||
row_ids.push_back(p->row_id()->Get(i));
|
||||
}
|
||||
cache_miss_keys_ = std::make_unique<CacheMissKeys>(row_ids);
|
||||
// We are all set.
|
||||
cache_miss_keys_wp_.Set();
|
||||
}
|
||||
}
|
||||
|
||||
CacheClient::CacheMissKeys::CacheMissKeys(const std::vector<row_id_type> &v) {
|
||||
auto it = v.begin();
|
||||
min_ = *it;
|
||||
++it;
|
||||
max_ = *it;
|
||||
++it;
|
||||
while (it != v.end()) {
|
||||
gap_.insert(*it);
|
||||
++it;
|
||||
}
|
||||
MS_LOG(WARNING) << "# of cache miss keys between min(" << min_ << ") and max(" << max_ << ") is " << gap_.size();
|
||||
}
|
||||
|
||||
bool CacheClient::CacheMissKeys::KeyIsCacheMiss(row_id_type key) {
|
||||
if (key > max_ || key < min_) {
|
||||
return true;
|
||||
} else if (key == min_ || key == max_) {
|
||||
return false;
|
||||
} else {
|
||||
auto it = gap_.find(key);
|
||||
return it != gap_.end();
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,8 +16,13 @@
|
|||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_
|
||||
|
||||
#include <atomic>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
@ -31,6 +36,8 @@
|
|||
#endif
|
||||
#include "minddata/dataset/engine/data_buffer.h"
|
||||
#include "minddata/dataset/util/lock.h"
|
||||
#include "minddata/dataset/util/cond_var.h"
|
||||
#include "minddata/dataset/util/queue_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -89,10 +96,10 @@ class CacheClient {
|
|||
}
|
||||
|
||||
/// Setter function to set number of async rpc workers
|
||||
/// \param num_workers
|
||||
/// \param num_connections
|
||||
/// \return Builder object itself
|
||||
Builder &SetNumWorkers(int32_t num_workers) {
|
||||
num_workers_ = num_workers;
|
||||
Builder &SetNumConnections(int32_t num_connections) {
|
||||
num_connections_ = num_connections;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
@ -105,13 +112,13 @@ class CacheClient {
|
|||
}
|
||||
|
||||
/// Getter functions
|
||||
session_id_type getSessionId() const { return session_id_; }
|
||||
uint64_t getCacheMemSz() const { return cache_mem_sz_; }
|
||||
session_id_type GetSessionId() const { return session_id_; }
|
||||
uint64_t GetCacheMemSz() const { return cache_mem_sz_; }
|
||||
bool isSpill() const { return spill_; }
|
||||
const std::string &getHostname() const { return hostname_; }
|
||||
int32_t getPort() const { return port_; }
|
||||
int32_t getNumWorkers() const { return num_workers_; }
|
||||
int32_t getPrefetchSize() const { return prefetch_size_; }
|
||||
int32_t GetPort() const { return port_; }
|
||||
int32_t GetNumConnections() const { return num_connections_; }
|
||||
int32_t GetPrefetchSize() const { return prefetch_size_; }
|
||||
|
||||
Status SanityCheck();
|
||||
|
||||
|
@ -123,7 +130,7 @@ class CacheClient {
|
|||
bool spill_;
|
||||
std::string hostname_;
|
||||
int32_t port_;
|
||||
int32_t num_workers_;
|
||||
int32_t num_connections_;
|
||||
int32_t prefetch_size_;
|
||||
};
|
||||
|
||||
|
@ -132,10 +139,10 @@ class CacheClient {
|
|||
/// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited
|
||||
/// \param spill Spill to disk if out of memory
|
||||
CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, int32_t port,
|
||||
int32_t num_workers, int32_t prefetch_size);
|
||||
int32_t num_connections, int32_t prefetch_size);
|
||||
|
||||
/// \brief Destructor
|
||||
~CacheClient() { (void)comm_->ServiceStop(); }
|
||||
~CacheClient();
|
||||
|
||||
/// \brief Send a TensorRow to the cache server
|
||||
/// \param[in] row
|
||||
|
@ -161,10 +168,6 @@ class CacheClient {
|
|||
/// \return Status object
|
||||
Status CreateCache(uint32_t tree_crc, bool generate_id);
|
||||
|
||||
/// \brief Purge a cache. Cache can be reused after reset.
|
||||
/// \return Status object
|
||||
Status PurgeCache();
|
||||
|
||||
/// \brief Destroy a cache. Like Purge but the cache is deleted and can't be reused.
|
||||
/// \return Status object
|
||||
Status DestroyCache();
|
||||
|
@ -218,12 +221,31 @@ class CacheClient {
|
|||
|
||||
/// Getter functions
|
||||
session_id_type session_id() const { return cinfo_.session_id(); }
|
||||
uint64_t getCacheMemSz() const { return cache_mem_sz_; }
|
||||
uint64_t GetCacheMemSz() const { return cache_mem_sz_; }
|
||||
bool isSpill() const { return spill_; }
|
||||
const std::string &getHostname() const { return hostname_; }
|
||||
int32_t getPort() const { return port_; }
|
||||
int32_t getNumWorkers() const { return num_workers_; }
|
||||
int32_t getPrefetchSize() const { return prefetch_size_; }
|
||||
const std::string &GetHostname() const { return hostname_; }
|
||||
int32_t GetPort() const { return port_; }
|
||||
int32_t GetNumConnections() const { return num_connections_; }
|
||||
int32_t GetPrefetchSize() const { return prefetch_size_; }
|
||||
|
||||
/// MergeOp will notify us when the server can't cache any more rows.
|
||||
/// We will stop any attempt to fetch any rows that are most likely
|
||||
/// not present at the server.
|
||||
void ServerRunningOutOfResources();
|
||||
|
||||
/// \brief Check if a row is 100% cache miss at the server by checking the local information
|
||||
/// \param key row id to be test
|
||||
/// \return true if not at the server
|
||||
bool KeyIsCacheMiss(row_id_type key) {
|
||||
if (cache_miss_keys_) {
|
||||
// Make sure it is fully built even though the pointer is not null
|
||||
Status rc = cache_miss_keys_wp_.Wait();
|
||||
if (rc.IsOk()) {
|
||||
return cache_miss_keys_->KeyIsCacheMiss(key);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
mutable RWLock mux_;
|
||||
|
@ -240,9 +262,27 @@ class CacheClient {
|
|||
bool local_bypass_;
|
||||
std::string hostname_;
|
||||
int32_t port_;
|
||||
int32_t num_workers_;
|
||||
int32_t num_connections_;
|
||||
int32_t prefetch_size_;
|
||||
mutable std::shared_ptr<CacheClientGreeter> comm_;
|
||||
std::atomic<bool> fetch_all_keys_;
|
||||
WaitPost cache_miss_keys_wp_;
|
||||
/// A structure shared by all the prefetchers to know what keys are missing at the server.
|
||||
class CacheMissKeys {
|
||||
public:
|
||||
explicit CacheMissKeys(const std::vector<row_id_type> &v);
|
||||
~CacheMissKeys() = default;
|
||||
/// This checks if a key is missing.
|
||||
/// \param key
|
||||
/// \return true if definitely a key miss
|
||||
bool KeyIsCacheMiss(row_id_type key);
|
||||
|
||||
private:
|
||||
row_id_type min_;
|
||||
row_id_type max_;
|
||||
std::set<row_id_type> gap_;
|
||||
};
|
||||
std::unique_ptr<CacheMissKeys> cache_miss_keys_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,13 +25,6 @@
|
|||
#define CACHE_LOCAL_CLIENT 1
|
||||
#endif
|
||||
|
||||
#ifdef CACHE_LOCAL_CLIENT
|
||||
#include <sys/types.h>
|
||||
#include <sys/ipc.h>
|
||||
#include <sys/shm.h>
|
||||
#else
|
||||
typedef int key_t;
|
||||
#endif
|
||||
#ifdef ENABLE_CACHE
|
||||
#include <grpcpp/grpcpp.h>
|
||||
#endif
|
||||
|
@ -54,6 +47,8 @@ constexpr static uint32_t kLocalClientSupport = 1;
|
|||
/// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is
|
||||
/// inline in the protobuf. This also implies kLocalClientSupport is also true.
|
||||
constexpr static uint32_t kDataIsInSharedMemory = 2;
|
||||
/// \brief Size of each message used in message queue.
|
||||
constexpr static int32_t kSharedMessageSize = 2048;
|
||||
|
||||
/// \brief Convert a Status object into a protobuf
|
||||
/// \param rc[in] Status object
|
||||
|
@ -62,29 +57,10 @@ inline void Status2CacheReply(const Status &rc, CacheReply *reply) {
|
|||
reply->set_rc(static_cast<int32_t>(rc.get_code()));
|
||||
reply->set_msg(rc.ToString());
|
||||
}
|
||||
|
||||
/// \brief Generate the unix socket file we use on both client/server side given a tcp/ip port number
|
||||
/// \param port
|
||||
/// \return unix socket url
|
||||
inline std::string PortToUnixSocketPath(int port) { return "/tmp/cache_server_p" + std::to_string(port); }
|
||||
|
||||
/// \brief Generate a shared memory key using the tcp/ip port.
|
||||
/// \note It must be called after the cache server generates the unix socket or ftok will fail.
|
||||
/// \note Caller must check the return value. -1 means ftok failed.
|
||||
/// \param[in] port
|
||||
/// \param[out] err. If not null and ftok fails, this will contain the value of errno
|
||||
/// \return key
|
||||
inline key_t PortToFtok(int port, int *err) {
|
||||
key_t shmkey = -1;
|
||||
#ifdef CACHE_LOCAL_CLIENT
|
||||
const std::string unix_path = PortToUnixSocketPath(port);
|
||||
shmkey = ftok(unix_path.data(), 'a');
|
||||
if (err != nullptr && shmkey == (key_t)-1) {
|
||||
*err = errno;
|
||||
}
|
||||
#endif
|
||||
return shmkey;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_
|
||||
|
|
|
@ -17,34 +17,10 @@
|
|||
#include <chrono>
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
Status CacheClientRequestTag::MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq,
|
||||
std::unique_ptr<CacheClientRequestTag> &&tag) {
|
||||
// If there is anything extra we need to do before we send.
|
||||
RETURN_IF_NOT_OK(tag->base_rq_->Prepare());
|
||||
// One minute timeout
|
||||
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60);
|
||||
tag->ctx_.set_deadline(deadline);
|
||||
tag->rpc_ = stub->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, cq);
|
||||
tag->rpc_->StartCall();
|
||||
// Last step is we release the ownership and transfer it to the completion queue.
|
||||
// The memory will be released by WorkerEntry or by the destructor when we drain the queue
|
||||
auto ccReqTag = tag.release();
|
||||
ccReqTag->rpc_->Finish(&ccReqTag->base_rq_->reply_, &ccReqTag->rc_,
|
||||
ccReqTag); // inject this object into the completion queue
|
||||
return Status::OK();
|
||||
}
|
||||
CacheClientGreeter::~CacheClientGreeter() { (void)ServiceStop(); }
|
||||
|
||||
CacheClientGreeter::~CacheClientGreeter() {
|
||||
(void)ServiceStop();
|
||||
// Detach from shared memory if any
|
||||
if (shmat_addr_ != nullptr) {
|
||||
shmdt(shmat_addr_);
|
||||
shmat_addr_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers)
|
||||
: num_workers_(num_workers), shm_key_(-1), shm_id_(-1), shmat_addr_(nullptr) {
|
||||
CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_connections)
|
||||
: num_connections_(num_connections), request_cnt_(0) {
|
||||
grpc::ChannelArguments args;
|
||||
// We need to bump up the message size to unlimited. The default receiving
|
||||
// message limit is 4MB which is not big enough.
|
||||
|
@ -68,21 +44,11 @@ CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port
|
|||
Status CacheClientGreeter::AttachToSharedMemory(int32_t port, bool *local_bypass) {
|
||||
*local_bypass = false;
|
||||
#if CACHE_LOCAL_CLIENT
|
||||
int err;
|
||||
shm_key_ = PortToFtok(port, &err);
|
||||
if (shm_key_ == (key_t)-1) {
|
||||
std::string errMsg = "Ftok failed with errno " + std::to_string(err);
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
SharedMemory::shm_key_t shm_key;
|
||||
RETURN_IF_NOT_OK(PortToFtok(port, &shm_key));
|
||||
// Attach to the shared memory
|
||||
shm_id_ = shmget(shm_key_, 0, 0);
|
||||
if (shm_id_ == -1) {
|
||||
RETURN_STATUS_UNEXPECTED("Shmget failed. Errno " + std::to_string(errno));
|
||||
}
|
||||
shmat_addr_ = shmat(shm_id_, nullptr, 0);
|
||||
if (shmat_addr_ == reinterpret_cast<void *>(-1)) {
|
||||
RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno));
|
||||
}
|
||||
mem_.SetPublicKey(shm_key);
|
||||
RETURN_IF_NOT_OK(mem_.Attach());
|
||||
*local_bypass = true;
|
||||
#endif
|
||||
return Status::OK();
|
||||
|
@ -90,7 +56,7 @@ Status CacheClientGreeter::AttachToSharedMemory(int32_t port, bool *local_bypass
|
|||
|
||||
Status CacheClientGreeter::DoServiceStart() {
|
||||
RETURN_IF_NOT_OK(vg_.ServiceStart());
|
||||
RETURN_IF_NOT_OK(DispatchWorkers(num_workers_));
|
||||
RETURN_IF_NOT_OK(DispatchWorkers(num_connections_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -100,19 +66,40 @@ Status CacheClientGreeter::DoServiceStop() {
|
|||
// Shutdown the TaskGroup.
|
||||
vg_.interrupt_all();
|
||||
vg_.join_all(Task::WaitFlag::kNonBlocking);
|
||||
// Drain the queue
|
||||
// Drain the queue. We know how many requests we send out
|
||||
while (!req_.empty()) {
|
||||
bool success;
|
||||
void *tag;
|
||||
while (cq_.Next(&tag, &success)) {
|
||||
auto r = reinterpret_cast<CacheClientRequestTag *>(tag);
|
||||
delete r;
|
||||
req_.erase(r->seqNo_);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClientGreeter::HandleRequest(std::shared_ptr<BaseRequest> rq) {
|
||||
auto tag = std::make_unique<CacheClientRequestTag>(std::move(rq));
|
||||
return tag->MakeCall(stub_.get(), &cq_, std::move(tag));
|
||||
// If there is anything extra we need to do before we send.
|
||||
RETURN_IF_NOT_OK(rq->Prepare());
|
||||
auto seqNo = request_cnt_.fetch_add(1);
|
||||
auto tag = std::make_unique<CacheClientRequestTag>(std::move(rq), seqNo);
|
||||
// One minute timeout
|
||||
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60);
|
||||
tag->ctx_.set_deadline(deadline);
|
||||
tag->rpc_ = stub_->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, &cq_);
|
||||
tag->rpc_->StartCall();
|
||||
auto ccReqTag = tag.get();
|
||||
// Insert it into the map.
|
||||
{
|
||||
std::unique_lock<std::mutex> lck(mux_);
|
||||
auto r = req_.emplace(seqNo, std::move(tag));
|
||||
if (!r.second) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__);
|
||||
}
|
||||
}
|
||||
// Last step is to tag the request.
|
||||
ccReqTag->rpc_->Finish(&ccReqTag->base_rq_->reply_, &ccReqTag->rc_, ccReqTag);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClientGreeter::WorkerEntry() {
|
||||
|
@ -129,15 +116,26 @@ Status CacheClientGreeter::WorkerEntry() {
|
|||
auto &rc = rq->rc_;
|
||||
if (!rc.ok()) {
|
||||
auto error_code = rq->rc_.error_code();
|
||||
std::string errMsg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code);
|
||||
Status remote_rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
std::string err_msg;
|
||||
if (error_code == grpc::StatusCode::UNAVAILABLE) {
|
||||
err_msg =
|
||||
"Cache server is unreachable. Make sure the server is running. GRPC Code" + std::to_string(error_code);
|
||||
} else {
|
||||
err_msg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code);
|
||||
}
|
||||
Status remote_rc = Status(StatusCode::kNetWorkError, __LINE__, __FILE__, err_msg);
|
||||
Status2CacheReply(remote_rc, &rq->base_rq_->reply_);
|
||||
}
|
||||
// Notify the waiting thread.
|
||||
rq->Notify();
|
||||
}
|
||||
{
|
||||
// We can now free the memory
|
||||
delete rq;
|
||||
std::unique_lock<std::mutex> lck(mux_);
|
||||
auto seqNo = rq->seqNo_;
|
||||
auto n = req_.erase(seqNo);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(n == 1, "Sequence " + std::to_string(seqNo) + " not found");
|
||||
}
|
||||
} else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) {
|
||||
// If we are interrupted, exit. Otherwise wait again.
|
||||
RETURN_IF_INTERRUPTED();
|
||||
|
|
|
@ -16,10 +16,14 @@
|
|||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_
|
||||
|
||||
#include <atomic>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "minddata/dataset/engine/cache/cache_common.h"
|
||||
#include "minddata/dataset/engine/cache/cache_ipc.h"
|
||||
#include "minddata/dataset/util/service.h"
|
||||
#include "minddata/dataset/util/task_manager.h"
|
||||
namespace mindspore {
|
||||
|
@ -34,16 +38,10 @@ namespace dataset {
|
|||
class CacheClientRequestTag {
|
||||
public:
|
||||
friend class CacheClientGreeter;
|
||||
explicit CacheClientRequestTag(std::shared_ptr<BaseRequest> rq) : base_rq_(std::move(rq)) {}
|
||||
explicit CacheClientRequestTag(std::shared_ptr<BaseRequest> rq, int64_t seqNo)
|
||||
: base_rq_(std::move(rq)), seqNo_(seqNo) {}
|
||||
~CacheClientRequestTag() = default;
|
||||
|
||||
/// \brief Make a RPC call
|
||||
/// \param stub from CacheClientGreeter
|
||||
/// \param cq from CacheClientGreeter
|
||||
/// \return Status object
|
||||
static Status MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq,
|
||||
std::unique_ptr<CacheClientRequestTag> &&tag);
|
||||
|
||||
/// \brief Notify the client that a result has come back from the server
|
||||
void Notify() { base_rq_->wp_.Set(); }
|
||||
|
||||
|
@ -52,6 +50,7 @@ class CacheClientRequestTag {
|
|||
grpc::Status rc_;
|
||||
grpc::ClientContext ctx_;
|
||||
std::unique_ptr<grpc::ClientAsyncResponseReader<CacheReply>> rpc_;
|
||||
int64_t seqNo_;
|
||||
};
|
||||
|
||||
/// \brief A GRPC layer to convert BaseRequest into protobuf and send to the cache server using gRPC
|
||||
|
@ -60,7 +59,7 @@ class CacheClientGreeter : public Service {
|
|||
friend class CacheClient;
|
||||
|
||||
public:
|
||||
explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers);
|
||||
explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_connections);
|
||||
~CacheClientGreeter();
|
||||
|
||||
/// Override base Service class
|
||||
|
@ -85,17 +84,18 @@ class CacheClientGreeter : public Service {
|
|||
|
||||
/// \brief This returns where we attach to the shared memory.
|
||||
/// \return Base address of the shared memory.
|
||||
const void *SharedMemoryBaseAddr() const { return shmat_addr_; }
|
||||
const void *SharedMemoryBaseAddr() const { return mem_.SharedMemoryBaseAddr(); }
|
||||
|
||||
private:
|
||||
std::shared_ptr<grpc::Channel> channel_;
|
||||
std::unique_ptr<CacheServerGreeter::Stub> stub_;
|
||||
grpc::CompletionQueue cq_;
|
||||
TaskGroup vg_;
|
||||
int32_t num_workers_;
|
||||
key_t shm_key_;
|
||||
int32_t shm_id_;
|
||||
void *shmat_addr_;
|
||||
int32_t num_connections_;
|
||||
std::atomic<int64_t> request_cnt_;
|
||||
mutable std::mutex mux_;
|
||||
std::map<int64_t, std::unique_ptr<CacheClientRequestTag>> req_;
|
||||
SharedMemory mem_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -47,53 +47,10 @@ void CacheServerGreeterImpl::Shutdown() {
|
|||
|
||||
CacheServerGreeterImpl::~CacheServerGreeterImpl() { Shutdown(); }
|
||||
|
||||
Status CacheServerGreeterImpl::IpcResourceCleanup() {
|
||||
#if CACHE_LOCAL_CLIENT
|
||||
int err;
|
||||
auto shm_key = PortToFtok(port_, &err);
|
||||
// We are expecting the unix path doesn't exist.
|
||||
if (shm_key == (key_t)-1) {
|
||||
return Status::OK();
|
||||
}
|
||||
// Attach to the shared memory
|
||||
auto shm_id = shmget(shm_key, 0, 0);
|
||||
if (shm_id == -1) {
|
||||
return Status::OK();
|
||||
}
|
||||
struct shmid_ds ds {};
|
||||
auto inx = shmctl(shm_id, IPC_STAT, &ds);
|
||||
if (inx == -1) {
|
||||
std::string errMsg = "Unable to query shared memory with id " + std::to_string(shm_id);
|
||||
errMsg += "\nPlesae remove it manually using ipcrm -m command";
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
if (ds.shm_nattch == 0) {
|
||||
// Stale shared memory from last time.
|
||||
// Remove both the memory and the socket path
|
||||
inx = shmctl(shm_id, IPC_RMID, nullptr);
|
||||
if (inx == -1) {
|
||||
std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id);
|
||||
errMsg += ". Errno :" + std::to_string(errno);
|
||||
errMsg += "\nPlesae remove it manually using ipcrm -m command";
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
Path p(unix_socket_);
|
||||
(void)p.Remove();
|
||||
} else {
|
||||
// Server is already up.
|
||||
MS_LOG(ERROR) << "Cache server is already up and running";
|
||||
// We return a duplicate error. The main() will intercept
|
||||
// and output a proper message
|
||||
return Status(StatusCode::kDuplicateKey);
|
||||
}
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheServerGreeterImpl::Run() {
|
||||
// To listen on all interfaces, use 0.0.0.0
|
||||
// Use 127.0.0.1 if just locally on the same machine.
|
||||
std::string host("0.0.0.0"); // listen on all interfaces.
|
||||
// Future, allow the user to choose listening interface. For now, default to localhost
|
||||
std::string host("127.0.0.1");
|
||||
std::string server_address = host + ":" + std::to_string(port_);
|
||||
grpc::ServerBuilder builder;
|
||||
// Default message size for gRPC is 4MB. Increase it to 2g-1
|
||||
|
@ -101,9 +58,6 @@ Status CacheServerGreeterImpl::Run() {
|
|||
int port_tcpip = 0;
|
||||
#if CACHE_LOCAL_CLIENT
|
||||
int port_local = 0;
|
||||
// Check if we need to do clean up on the shared memory if the server
|
||||
// came down unexpectedly like SEGV
|
||||
RETURN_IF_NOT_OK(IpcResourceCleanup());
|
||||
// We also optimize on local clients on the same machine using unix socket
|
||||
builder.AddListeningPort("unix://" + unix_socket_, grpc::InsecureServerCredentials(), &port_local);
|
||||
#endif
|
||||
|
|
|
@ -41,7 +41,7 @@ class CacheServerRequest : public BaseRequest {
|
|||
st_(STATE::CREATE),
|
||||
responder_(&ctx_) {}
|
||||
|
||||
~CacheServerRequest() = default;
|
||||
~CacheServerRequest() override = default;
|
||||
|
||||
/// \brief Functor. Used mainly by CacheServerGreeterImpl class to tag each incoming request and this
|
||||
/// functor will translate each protobuf into some form understood by by CacheService class.
|
||||
|
@ -87,8 +87,6 @@ class CacheServerGreeterImpl final {
|
|||
|
||||
void Shutdown();
|
||||
|
||||
Status IpcResourceCleanup();
|
||||
|
||||
private:
|
||||
int32_t port_;
|
||||
size_t shm_pool_sz_in_gb_;
|
||||
|
|
|
@ -0,0 +1,163 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/cache/cache_ipc.h"
|
||||
#include <sys/stat.h>
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
Status PortToFtok(int port, SharedMemory::shm_key_t *out) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
key_t shmkey = -1;
|
||||
const std::string unix_path = PortToUnixSocketPath(port);
|
||||
shmkey = ftok(unix_path.data(), 'a');
|
||||
if (shmkey == (key_t)-1) {
|
||||
std::string errMsg = "Unable to create a ftok token. Errno = " + std::to_string(errno);
|
||||
return Status(errno == ENOENT ? StatusCode::kFileNotExist : StatusCode::kUnexpectedError, errMsg);
|
||||
}
|
||||
*out = shmkey;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
SharedMessage::~SharedMessage() {
|
||||
// Only remove the queue if we are asked to.
|
||||
if (remove_ipc_on_exit_ && msg_qid_ != -1) {
|
||||
// Remove the message que and never mind about the return code.
|
||||
(void)msgctl(msg_qid_, IPC_RMID, nullptr);
|
||||
msg_qid_ = -1;
|
||||
}
|
||||
}
|
||||
|
||||
Status SharedMessage::Create() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ == -1, "Message queue already created");
|
||||
auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP;
|
||||
msg_qid_ = msgget(IPC_PRIVATE, IPC_CREAT | IPC_EXCL | access_mode);
|
||||
if (msg_qid_ == -1) {
|
||||
std::string errMsg = "Unable to create a message queue. Errno = " + std::to_string(errno);
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SharedMessage::SendStatus(const Status &rc) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id");
|
||||
StatusMsgBuf msg{
|
||||
1,
|
||||
};
|
||||
msg.body.status.err_code = static_cast<int32_t>(rc.get_code());
|
||||
auto err = memcpy_s(msg.body.status.err_msg, kSharedMessageSize, rc.ToString().data(), rc.ToString().size());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(err == EOK, "memcpy_s failed. err = " + std::to_string(err));
|
||||
msg.body.status.err_msg[rc.ToString().size()] = '\0';
|
||||
err = msgsnd(msg_qid_, reinterpret_cast<void *>(&msg), sizeof(msg.body.status), IPC_NOWAIT);
|
||||
if (err == -1) {
|
||||
std::string errMsg = "Failed to call msgsnd. Errno = " + std::to_string(errno);
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SharedMessage::ReceiveStatus(Status *rc) {
|
||||
RETURN_UNEXPECTED_IF_NULL(rc);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id");
|
||||
struct StatusMsgBuf msg {};
|
||||
auto err = msgrcv(msg_qid_, reinterpret_cast<void *>(&msg), sizeof(msg.body.status), 0, MSG_NOERROR);
|
||||
if (err == -1) {
|
||||
std::string errMsg = "Failed to call msgrcv. Errno = " + std::to_string(errno);
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
|
||||
Status rc_recv(static_cast<StatusCode>(msg.body.status.err_code), msg.body.status.err_msg);
|
||||
*rc = std::move(rc_recv);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
SharedMemory::~SharedMemory() {
|
||||
if (shmat_addr_) {
|
||||
(void)Detach();
|
||||
}
|
||||
if (remove_ipc_on_exit_ && shm_id_ != -1) {
|
||||
// Remove the shared memory and never mind about the return code.
|
||||
Status rc = Destroy();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << rc.ToString();
|
||||
}
|
||||
}
|
||||
shm_id_ = -1;
|
||||
shmat_addr_ = nullptr;
|
||||
}
|
||||
|
||||
Status SharedMemory::Create(int64_t sz) {
|
||||
auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP;
|
||||
shm_id_ = shmget(shm_key_, sz, IPC_CREAT | IPC_EXCL | access_mode);
|
||||
if (shm_id_ == -1) {
|
||||
RETURN_STATUS_UNEXPECTED("Shared memory creation failed. Errno " + std::to_string(errno));
|
||||
} else {
|
||||
shmat_addr_ = shmat(shm_id_, nullptr, 0);
|
||||
if (shmat_addr_ == reinterpret_cast<void *>(-1)) {
|
||||
RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SharedMemory::Attach() {
|
||||
shm_id_ = shmget(shm_key_, 0, 0);
|
||||
if (shm_id_ == -1) {
|
||||
RETURN_STATUS_UNEXPECTED("Shmget failed. Errno " + std::to_string(errno));
|
||||
}
|
||||
shmat_addr_ = shmat(shm_id_, nullptr, 0);
|
||||
if (shmat_addr_ == reinterpret_cast<void *>(-1)) {
|
||||
RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SharedMemory::Detach() {
|
||||
if (shmat_addr_) {
|
||||
auto err = shmdt(shmat_addr_);
|
||||
if (err == -1) {
|
||||
RETURN_STATUS_UNEXPECTED("Shared memory detach failed. Errno " + std::to_string(errno));
|
||||
}
|
||||
}
|
||||
shmat_addr_ = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SharedMemory::Destroy() {
|
||||
// Remove the shared memory and never mind about the return code.
|
||||
auto err = shmctl(shm_id_, IPC_RMID, nullptr);
|
||||
if (err == -1) {
|
||||
std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id_);
|
||||
errMsg += ". Errno :" + std::to_string(errno);
|
||||
errMsg += "\nPlesae remove it manually using ipcrm -m command";
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SharedMemory::GetNumAttached(int32_t *num) {
|
||||
RETURN_UNEXPECTED_IF_NULL(num);
|
||||
struct shmid_ds ds {};
|
||||
auto err = shmctl(shm_id_, IPC_STAT, &ds);
|
||||
if (err == -1) {
|
||||
std::string errMsg = "Unable to query shared memory with id " + std::to_string(shm_id_);
|
||||
errMsg += "\nPlease remove it manually using ipcrm -m command";
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
*num = ds.shm_nattch;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,207 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_IPC_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_IPC_H_
|
||||
|
||||
#include <sys/types.h>
|
||||
#include <sys/ipc.h>
|
||||
#include <sys/shm.h>
|
||||
#include <sys/msg.h>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "minddata/dataset/engine/cache/cache_common.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// A message queue structure between the parent and the child process
|
||||
struct StatusMsgBuf {
|
||||
int64_t mtype;
|
||||
union {
|
||||
char mtext[1];
|
||||
struct {
|
||||
int32_t err_code;
|
||||
char err_msg[kSharedMessageSize];
|
||||
} status;
|
||||
} body;
|
||||
};
|
||||
|
||||
class BaseIPC {
|
||||
public:
|
||||
BaseIPC() : remove_ipc_on_exit_(false) {}
|
||||
virtual ~BaseIPC() {}
|
||||
/// Indicate if we should remove the ipc resource on exit. Usually this is done by parent process.
|
||||
void RemoveResourcesOnExit() { remove_ipc_on_exit_ = true; }
|
||||
/// Copy constructors
|
||||
BaseIPC(const BaseIPC &rhs) : remove_ipc_on_exit_(false) {}
|
||||
BaseIPC &operator=(const BaseIPC &rhs) {
|
||||
if (&rhs != this) {
|
||||
remove_ipc_on_exit_ = false;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
/// Move constructors
|
||||
BaseIPC(BaseIPC &&rhs) noexcept : remove_ipc_on_exit_(rhs.remove_ipc_on_exit_) { rhs.remove_ipc_on_exit_ = false; }
|
||||
BaseIPC &operator=(BaseIPC &&rhs) noexcept {
|
||||
if (&rhs != this) {
|
||||
remove_ipc_on_exit_ = rhs.remove_ipc_on_exit_;
|
||||
rhs.remove_ipc_on_exit_ = false;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
protected:
|
||||
bool remove_ipc_on_exit_;
|
||||
};
|
||||
|
||||
/// \brief This wraps a shared message for the communication between processes. It is used primarily
|
||||
/// for starting and stopping a server.
|
||||
class SharedMessage : public BaseIPC {
|
||||
public:
|
||||
using queue_id_t = int;
|
||||
SharedMessage() : msg_qid_(-1) {}
|
||||
explicit SharedMessage(queue_id_t qid) : msg_qid_(qid) {}
|
||||
~SharedMessage() override;
|
||||
|
||||
/// Copy constructors
|
||||
SharedMessage(const SharedMessage &rhs) : BaseIPC(rhs), msg_qid_(rhs.msg_qid_) {}
|
||||
SharedMessage &operator=(const SharedMessage &rhs) {
|
||||
if (&rhs != this) {
|
||||
msg_qid_ = rhs.msg_qid_;
|
||||
BaseIPC::operator=(rhs);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
/// Move constructors
|
||||
SharedMessage(SharedMessage &&rhs) noexcept : BaseIPC(std::move(rhs)) {
|
||||
msg_qid_ = rhs.msg_qid_;
|
||||
rhs.msg_qid_ = -1;
|
||||
}
|
||||
SharedMessage &operator=(SharedMessage &&rhs) noexcept {
|
||||
if (&rhs != this) {
|
||||
msg_qid_ = rhs.msg_qid_;
|
||||
rhs.msg_qid_ = -1;
|
||||
BaseIPC::operator=(std::move(rhs));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Return the private id
|
||||
queue_id_t GetMsgQueueId() const { return msg_qid_; }
|
||||
|
||||
/// \brief Create a private message queue
|
||||
Status Create();
|
||||
|
||||
/// Send a Status object
|
||||
Status SendStatus(const Status &rc);
|
||||
|
||||
/// Retrieve a Status object
|
||||
Status ReceiveStatus(Status *rc);
|
||||
|
||||
private:
|
||||
queue_id_t msg_qid_;
|
||||
};
|
||||
|
||||
/// \brief This wraps a shared memory for the communication between processes. It is used primarily
|
||||
/// for transporting large tensor rows.
|
||||
class SharedMemory : public BaseIPC {
|
||||
public:
|
||||
using shm_key_t = int;
|
||||
using shm_id_t = int;
|
||||
SharedMemory() : shm_id_(-1), shm_key_(-1), shmat_addr_(nullptr) {}
|
||||
explicit SharedMemory(shm_key_t public_key) : shm_id_(-1), shm_key_(public_key), shmat_addr_(nullptr) {}
|
||||
~SharedMemory() override;
|
||||
/// Copy constructors
|
||||
SharedMemory(const SharedMemory &rhs)
|
||||
: BaseIPC(rhs), shm_id_(rhs.shm_id_), shm_key_(rhs.shm_key_), shmat_addr_(rhs.shmat_addr_) {}
|
||||
SharedMemory &operator=(const SharedMemory &rhs) {
|
||||
if (&rhs != this) {
|
||||
shm_id_ = rhs.shm_id_;
|
||||
shm_key_ = rhs.shm_key_;
|
||||
shmat_addr_ = rhs.shmat_addr_;
|
||||
BaseIPC::operator=(rhs);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
/// Move constructors
|
||||
SharedMemory(SharedMemory &&rhs) noexcept : BaseIPC(std::move(rhs)) {
|
||||
shm_id_ = rhs.shm_id_;
|
||||
shm_key_ = rhs.shm_key_;
|
||||
shmat_addr_ = rhs.shmat_addr_;
|
||||
rhs.shm_id_ = -1;
|
||||
rhs.shm_key_ = -1;
|
||||
rhs.shmat_addr_ = nullptr;
|
||||
}
|
||||
SharedMemory &operator=(SharedMemory &&rhs) noexcept {
|
||||
if (&rhs != this) {
|
||||
shm_id_ = rhs.shm_id_;
|
||||
shm_key_ = rhs.shm_key_;
|
||||
shmat_addr_ = rhs.shmat_addr_;
|
||||
rhs.shm_id_ = -1;
|
||||
rhs.shm_key_ = -1;
|
||||
rhs.shmat_addr_ = nullptr;
|
||||
BaseIPC::operator=(std::move(rhs));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
/// \brief Set the public key
|
||||
void SetPublicKey(key_t public_key) { shm_key_ = public_key; }
|
||||
|
||||
/// \brief This returns where we attach to the shared memory.
|
||||
/// \return Base address of the shared memory.
|
||||
const void *SharedMemoryBaseAddr() const { return shmat_addr_; }
|
||||
void *SharedMemoryBaseAddr() { return shmat_addr_; }
|
||||
|
||||
/// \brief Attach to shared memory
|
||||
/// \return Status object
|
||||
Status Attach();
|
||||
|
||||
/// Detach from shared memory
|
||||
/// \return Status object
|
||||
Status Detach();
|
||||
|
||||
/// Create shared memory
|
||||
/// \return Status object
|
||||
Status Create(int64_t sz);
|
||||
|
||||
/// Destroy shared memory
|
||||
/// \return Status object
|
||||
Status Destroy();
|
||||
|
||||
/// \brief Return the shared memory id
|
||||
shm_id_t GetSharedMemoryId() const { return shm_id_; }
|
||||
|
||||
/// \brief Get number of processes attached to the shared memory
|
||||
/// \return Status object
|
||||
Status GetNumAttached(int32_t *num);
|
||||
|
||||
private:
|
||||
shm_id_t shm_id_;
|
||||
shm_key_t shm_key_;
|
||||
void *shmat_addr_;
|
||||
};
|
||||
|
||||
/// \brief Generate a shared memory key using the tcp/ip port.
|
||||
/// \note It must be called after the cache server generates the unix socket or ftok will fail.
|
||||
/// \note Caller must check the return value. -1 means ftok failed.
|
||||
/// \param[in] port
|
||||
/// \param[out] err. If not null and ftok fails, this will contain the value of errno
|
||||
/// \return key
|
||||
Status PortToFtok(int port, SharedMemory::shm_key_t *);
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_IPC_H_
|
|
@ -21,24 +21,136 @@
|
|||
#include <glog/logging.h>
|
||||
#endif
|
||||
#include <cstdlib>
|
||||
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
#include "minddata/dataset/engine/cache/cache_common.h"
|
||||
#include "minddata/dataset/engine/cache/cache_ipc.h"
|
||||
namespace ds = mindspore::dataset;
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
ds::Status rc;
|
||||
ds::CacheServer::Builder builder;
|
||||
|
||||
// This executable is not to be called directly, and should be invoked by cache_admin executable.
|
||||
if (argc != 7) {
|
||||
rc = ds::Status(ds::StatusCode::kSyntaxError);
|
||||
std::cerr << rc.ToString() << std::endl;
|
||||
return static_cast<int>(rc.get_code());
|
||||
/// Send a synchronous command to the local server using tcp/ip.
|
||||
/// We aren't using any client code because this binary is not necessarily linked with the client library.
|
||||
/// So just using grpc call directly.
|
||||
/// \param port tcp/ip port to use
|
||||
/// \param type Type of command.
|
||||
/// \param out grpc result
|
||||
/// \return Status object
|
||||
ds::Status SendSyncCommand(int32_t port, ds::BaseRequest::RequestType type, ds::CacheRequest *rq, ds::CacheReply *reply,
|
||||
grpc::Status *out) {
|
||||
if (rq == nullptr) {
|
||||
return ds::Status(ds::StatusCode::kUnexpectedError, "pointer rq is null");
|
||||
}
|
||||
if (reply == nullptr) {
|
||||
return ds::Status(ds::StatusCode::kUnexpectedError, "pointer reply is null");
|
||||
}
|
||||
if (out == nullptr) {
|
||||
return ds::Status(ds::StatusCode::kUnexpectedError, "pointer out is null");
|
||||
}
|
||||
const std::string hostname = "127.0.0.1";
|
||||
auto unix_socket = ds::PortToUnixSocketPath(port);
|
||||
#if CACHE_LOCAL_CLIENT
|
||||
const std::string target = "unix://" + unix_socket;
|
||||
#else
|
||||
const std::string target = hostname + ":" + std::to_string(port);
|
||||
#endif
|
||||
try {
|
||||
rq->set_type(static_cast<int16_t>(type));
|
||||
grpc::ChannelArguments args;
|
||||
grpc::ClientContext ctx;
|
||||
grpc::CompletionQueue cq;
|
||||
// Standard async rpc call
|
||||
std::shared_ptr<grpc::Channel> channel =
|
||||
grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args);
|
||||
std::unique_ptr<ds::CacheServerGreeter::Stub> stub = ds::CacheServerGreeter::NewStub(channel);
|
||||
std::unique_ptr<grpc::ClientAsyncResponseReader<ds::CacheReply>> rpc =
|
||||
stub->PrepareAsyncCacheServerRequest(&ctx, *rq, &cq);
|
||||
rpc->StartCall();
|
||||
// We need to pass a tag. But since this is the only request in the completion queue and so we
|
||||
// just pass a dummy
|
||||
int64_t dummy;
|
||||
void *tag;
|
||||
bool success;
|
||||
rpc->Finish(reply, out, &dummy);
|
||||
// Now we wait on the completion queue synchronously.
|
||||
auto r = cq.Next(&tag, &success);
|
||||
if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) {
|
||||
if (!success || tag != &dummy) {
|
||||
std::string errMsg = "Unexpected programming error ";
|
||||
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
}
|
||||
if (out->ok()) {
|
||||
return ds::Status(static_cast<ds::StatusCode>(reply->rc()), reply->msg());
|
||||
} else {
|
||||
auto error_code = out->error_code();
|
||||
std::string errMsg = out->error_message() + ". GRPC Code " + std::to_string(error_code);
|
||||
return ds::Status(ds::StatusCode::kNetWorkError, errMsg);
|
||||
}
|
||||
} else {
|
||||
std::string errMsg = "Unexpected queue rc = " + std::to_string(r);
|
||||
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
/// Stop the server
|
||||
/// \param argv
|
||||
/// \return Status object
|
||||
ds::Status StopServer(int argc, char **argv) {
|
||||
ds::Status rc;
|
||||
ds::CacheServer::Builder builder;
|
||||
std::string errMsg;
|
||||
if (argc != 2) {
|
||||
return ds::Status(ds::StatusCode::kSyntaxError);
|
||||
}
|
||||
int32_t port = strtol(argv[1], nullptr, 10);
|
||||
// We will go through the builder to do some snaity check. We only need the port number
|
||||
// to shut down the server. Null the root directory as we don't trigger the sanity code to write out anything
|
||||
// to the spill directory.
|
||||
builder.SetPort(port).SetRootDirectory("");
|
||||
// Part of the sanity check is check the shared memory. If the server is up and running, we expect
|
||||
// the return code is kDuplicate.
|
||||
rc = builder.SanityCheck();
|
||||
if (rc.IsOk()) {
|
||||
errMsg = "Server is not up or has been shutdown already.";
|
||||
return ds::Status(ds::StatusCode::kUnexpectedError, errMsg);
|
||||
} else if (rc.get_code() != ds::StatusCode::kDuplicateKey) {
|
||||
// Not OK, and no duplicate, just return the rc whatever it is.
|
||||
return rc;
|
||||
} else {
|
||||
// Now we get some work to do. We will send a tcp/ip request to the given port.
|
||||
// This binary is not linked with client side of code, so we will just call grpc directly.
|
||||
ds::CacheRequest rq;
|
||||
ds::CacheReply reply;
|
||||
grpc::Status grpc_rc;
|
||||
rc = SendSyncCommand(port, ds::BaseRequest::RequestType::kStopService, &rq, &reply, &grpc_rc);
|
||||
// The request is like a self destruct message, the server will not send anything back and
|
||||
// shutdown all incoming request. So we should expect some unexpected network error if
|
||||
// all goes well and we expect to GRPC code 14.
|
||||
auto err_code = grpc_rc.error_code();
|
||||
if (rc.get_code() != ds::StatusCode::kNetWorkError || err_code != grpc::StatusCode::UNAVAILABLE) {
|
||||
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__);
|
||||
}
|
||||
}
|
||||
return ds::Status::OK();
|
||||
}
|
||||
|
||||
/// Start the server
|
||||
/// \param argv
|
||||
/// \return Status object
|
||||
ds::Status StartServer(int argc, char **argv) {
|
||||
ds::Status rc;
|
||||
ds::CacheServer::Builder builder;
|
||||
if (argc != 8) {
|
||||
return ds::Status(ds::StatusCode::kSyntaxError);
|
||||
}
|
||||
|
||||
int32_t port = strtol(argv[3], nullptr, 10);
|
||||
builder.SetRootDirectory(argv[1])
|
||||
.SetNumWorkers(strtol(argv[2], nullptr, 10))
|
||||
.SetPort(strtol(argv[3], nullptr, 10))
|
||||
.SetSharedMemorySizeInGB(strtol(argv[4], nullptr, 10));
|
||||
.SetPort(port)
|
||||
.SetSharedMemorySizeInGB(strtol(argv[4], nullptr, 10))
|
||||
.SetMemoryCapRatio(strtof(argv[7], nullptr));
|
||||
|
||||
#ifdef USE_GLOG
|
||||
FLAGS_minloglevel = strtol(argv[5], nullptr, 10);
|
||||
|
@ -52,36 +164,42 @@ int main(int argc, char **argv) {
|
|||
// is called. This is a standard procedure for daemonize a process on unix.
|
||||
if (chdir("/") == -1) {
|
||||
std::string errMsg = "Unable to change directory to /. Errno = " + std::to_string(errno);
|
||||
std::cerr << errMsg << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Simple check of the parameters before we move on.
|
||||
rc = builder.SanityCheck();
|
||||
if (rc.IsError()) {
|
||||
std::cerr << rc.ToString() << std::endl;
|
||||
return static_cast<int>(rc.get_code());
|
||||
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
}
|
||||
|
||||
// A message queue for communication between parent and child (if we fork).
|
||||
ds::SharedMessage msg;
|
||||
if (daemonize) {
|
||||
#ifdef USE_GLOG
|
||||
FLAGS_log_dir = "/tmp";
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
#endif
|
||||
|
||||
if (daemonize) {
|
||||
// fork the child process to become the daemon
|
||||
rc = msg.Create();
|
||||
if (rc.IsError()) {
|
||||
return rc;
|
||||
}
|
||||
pid_t pid = fork();
|
||||
// failed to fork
|
||||
if (pid < 0) {
|
||||
std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno);
|
||||
std::cerr << err_msg << std::endl;
|
||||
return errno;
|
||||
std::string errMsg = "Failed to fork process for cache server. Errno = " + std::to_string(errno);
|
||||
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else if (pid > 0) {
|
||||
// Parent
|
||||
// Parent and will be responsible for remove the queue on exit.
|
||||
msg.RemoveResourcesOnExit();
|
||||
// Sleep one second and we attach to the msg que
|
||||
std::this_thread::sleep_for(std::chrono::seconds(1));
|
||||
ds::Status child_rc;
|
||||
rc = msg.ReceiveStatus(&child_rc);
|
||||
if (rc.IsError()) {
|
||||
return rc;
|
||||
}
|
||||
if (child_rc.IsError()) {
|
||||
return child_rc;
|
||||
}
|
||||
std::cerr << "cache server daemon process has been created as process id: " << pid
|
||||
<< "\nCheck log file for any start up error" << std::endl;
|
||||
signal(SIGCHLD, SIG_IGN); // ignore sig child signal.
|
||||
return 0;
|
||||
return ds::Status::OK();
|
||||
} else {
|
||||
// Child process will continue from here if daemonize and parent has already exited.
|
||||
// If we are running in the foreground, none of the code in block below will be run.
|
||||
|
@ -89,8 +207,8 @@ int main(int argc, char **argv) {
|
|||
umask(0);
|
||||
sid = setsid();
|
||||
if (sid < 0) {
|
||||
MS_LOG(ERROR) << "Failed to setsid(). Errno = " << std::to_string(errno);
|
||||
return errno;
|
||||
std::string errMsg = "Failed to setsid(). Errno = " + std::to_string(errno);
|
||||
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
}
|
||||
close(0);
|
||||
close(1);
|
||||
|
@ -100,22 +218,36 @@ int main(int argc, char **argv) {
|
|||
|
||||
// Dump the summary
|
||||
MS_LOG(INFO) << builder << std::endl;
|
||||
// Create the instance with some sanity checks built in
|
||||
rc = builder.Build();
|
||||
if (rc.IsOk()) {
|
||||
// If all goes well, kick off the threads. Loop forever and never return unless error.
|
||||
ds::CacheServer &cs = ds::CacheServer::GetInstance();
|
||||
// Kick off the threads. Loop forever and never return unless error.
|
||||
rc = cs.Run();
|
||||
if (rc.get_code() == ds::StatusCode::kDuplicateKey) {
|
||||
std::string errMsg = "Server is already started";
|
||||
MS_LOG(ERROR) << errMsg;
|
||||
std::cerr << errMsg << std::endl;
|
||||
return 0;
|
||||
rc = cs.Run(msg.GetMsgQueueId());
|
||||
} else if (daemonize) {
|
||||
// If we didn't pass the sanity check to at least create the instance, use
|
||||
// the message queue to return the error message if this is the child daemon.
|
||||
return msg.SendStatus(rc);
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
ds::Status rc;
|
||||
ds::CacheServer::Builder builder;
|
||||
|
||||
// This executable is not to be called directly, and should be invoked by cache_admin executable.
|
||||
if (strcmp(argv[0], "-") == 0) {
|
||||
rc = StopServer(argc, argv);
|
||||
} else {
|
||||
rc = StartServer(argc, argv);
|
||||
}
|
||||
// Check result
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << rc.ToString();
|
||||
std::cerr << rc.ToString() << std::endl;
|
||||
return static_cast<int>(rc.get_code());
|
||||
auto errCode = rc.get_code();
|
||||
auto errMsg = rc.ToString();
|
||||
std::cerr << errMsg << std::endl;
|
||||
return static_cast<int>(errCode);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -250,5 +250,27 @@ Status GetStatRequest::PostReply() {
|
|||
stat_.cache_service_state = msg->state();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ListSessionsRequest::PostReply() {
|
||||
auto *msg = flatbuffers::GetRoot<ListSessionsMsg>(reply_.result().data());
|
||||
auto session_vector = msg->sessions();
|
||||
for (auto i = 0; i < session_vector->size(); ++i) {
|
||||
SessionCacheInfo current_info;
|
||||
CacheServiceStat stats;
|
||||
auto current_session_info = session_vector->Get(i);
|
||||
current_info.session_id = current_session_info->session_id();
|
||||
current_info.connection_id = current_session_info->connection_id();
|
||||
stats.num_mem_cached = current_session_info->stats()->num_mem_cached();
|
||||
stats.num_disk_cached = current_session_info->stats()->num_disk_cached();
|
||||
stats.avg_cache_sz = current_session_info->stats()->avg_cache_sz();
|
||||
stats.min_row_id = current_session_info->stats()->min_row_id();
|
||||
stats.max_row_id = current_session_info->stats()->max_row_id();
|
||||
stats.cache_service_state = current_session_info->stats()->state();
|
||||
current_info.stats = stats; // fixed length struct. = operator is safe
|
||||
session_info_list_.push_back(current_info);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -46,6 +46,13 @@ struct CacheServiceStat {
|
|||
int8_t cache_service_state;
|
||||
};
|
||||
|
||||
/// \brief Info structure ListSessionsRequest
|
||||
struct SessionCacheInfo {
|
||||
session_id_type session_id;
|
||||
connection_id_type connection_id;
|
||||
CacheServiceStat stats;
|
||||
};
|
||||
|
||||
/// \brief CacheClient communicates with CacheServer using Requests.
|
||||
class BaseRequest {
|
||||
public:
|
||||
|
@ -54,7 +61,7 @@ class BaseRequest {
|
|||
kCacheRow = 0,
|
||||
kBatchFetchRows = 1,
|
||||
kCreateCache = 2,
|
||||
kPurgeCache = 3,
|
||||
kGetCacheMissKeys = 3,
|
||||
kDestroyCache = 4,
|
||||
kGetStat = 5,
|
||||
kCacheSchema = 6,
|
||||
|
@ -65,6 +72,9 @@ class BaseRequest {
|
|||
kAllocateSharedBlock = 11,
|
||||
kFreeSharedBlock = 12,
|
||||
kStopService = 13,
|
||||
kHeartBeat = 14,
|
||||
kToggleWriteMode = 15,
|
||||
kListSessions = 16,
|
||||
// Add new request before it.
|
||||
kRequestUnknown = 32767
|
||||
};
|
||||
|
@ -73,6 +83,7 @@ class BaseRequest {
|
|||
friend class CacheServerRequest;
|
||||
friend class CacheClientGreeter;
|
||||
friend class CacheClientRequestTag;
|
||||
friend class CacheClient;
|
||||
|
||||
/// \brief Base class of a cache server request
|
||||
/// \param type Type of the request
|
||||
|
@ -119,7 +130,7 @@ class FreeSharedBlockRequest : public BaseRequest {
|
|||
rq_.set_connection_id(connection_id);
|
||||
rq_.add_buf_data(std::to_string(addr));
|
||||
}
|
||||
~FreeSharedBlockRequest() = default;
|
||||
~FreeSharedBlockRequest() override = default;
|
||||
};
|
||||
|
||||
/// \brief Request to cache a single TensorRow
|
||||
|
@ -136,7 +147,7 @@ class CacheRowRequest : public BaseRequest {
|
|||
rq_.set_connection_id(connection_id);
|
||||
rq_.add_buf_data(cookie);
|
||||
}
|
||||
~CacheRowRequest() = default;
|
||||
~CacheRowRequest() override = default;
|
||||
|
||||
/// \brief Serialize a TensorRow for streaming to the cache server
|
||||
/// \param row TensorRow
|
||||
|
@ -183,7 +194,7 @@ class BatchFetchRequest : public BaseRequest {
|
|||
friend class CacheServer;
|
||||
friend class CacheService;
|
||||
BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id, bool local_bypass);
|
||||
~BatchFetchRequest() = default;
|
||||
~BatchFetchRequest() override = default;
|
||||
Status RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr);
|
||||
|
||||
private:
|
||||
|
@ -203,7 +214,7 @@ class CreateCacheRequest : public BaseRequest {
|
|||
/// \param flag Attributes of the cache.
|
||||
explicit CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz,
|
||||
CreateCacheFlag flag = CreateCacheFlag::kNone);
|
||||
~CreateCacheRequest() = default;
|
||||
~CreateCacheRequest() override = default;
|
||||
void ParseResult(connection_id_type *id, std::string *out) {
|
||||
auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data());
|
||||
*id = p->connection_id();
|
||||
|
@ -218,14 +229,15 @@ class CreateCacheRequest : public BaseRequest {
|
|||
CreateCacheFlag flag_;
|
||||
};
|
||||
|
||||
/// \brief Request to purge a cache.
|
||||
class PurgeCacheRequest : public BaseRequest {
|
||||
/// \brief Request to get all the keys not present at the server.
|
||||
/// \note Only applicable to mappable case
|
||||
class GetCacheMissKeysRequest : public BaseRequest {
|
||||
public:
|
||||
friend class CacheServer;
|
||||
explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kPurgeCache) {
|
||||
explicit GetCacheMissKeysRequest(connection_id_type connection_id) : BaseRequest(RequestType::kGetCacheMissKeys) {
|
||||
rq_.set_connection_id(connection_id);
|
||||
}
|
||||
~PurgeCacheRequest() = default;
|
||||
~GetCacheMissKeysRequest() override = default;
|
||||
};
|
||||
|
||||
/// \brief Request to destroy a cache
|
||||
|
@ -235,7 +247,7 @@ class DestroyCacheRequest : public BaseRequest {
|
|||
explicit DestroyCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kDestroyCache) {
|
||||
rq_.set_connection_id(connection_id);
|
||||
}
|
||||
~DestroyCacheRequest() = default;
|
||||
~DestroyCacheRequest() override = default;
|
||||
};
|
||||
|
||||
/// \brief Obtain the statistics of the current connection
|
||||
|
@ -247,7 +259,7 @@ class GetStatRequest : public BaseRequest {
|
|||
rq_.set_connection_id(connection_id);
|
||||
}
|
||||
|
||||
~GetStatRequest() = default;
|
||||
~GetStatRequest() override = default;
|
||||
|
||||
/// \brief Override base function to process the result.
|
||||
Status PostReply() override;
|
||||
|
@ -269,7 +281,7 @@ class CacheSchemaRequest : public BaseRequest {
|
|||
explicit CacheSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kCacheSchema) {
|
||||
rq_.set_connection_id(connection_id);
|
||||
}
|
||||
~CacheSchemaRequest() = default;
|
||||
~CacheSchemaRequest() override = default;
|
||||
|
||||
Status SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map);
|
||||
};
|
||||
|
@ -281,7 +293,7 @@ class FetchSchemaRequest : public BaseRequest {
|
|||
explicit FetchSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kFetchSchema) {
|
||||
rq_.set_connection_id(connection_id);
|
||||
}
|
||||
~FetchSchemaRequest() = default;
|
||||
~FetchSchemaRequest() override = default;
|
||||
|
||||
Status PostReply() override;
|
||||
|
||||
|
@ -300,7 +312,7 @@ class BuildPhaseDoneRequest : public BaseRequest {
|
|||
rq_.set_connection_id(connection_id);
|
||||
rq_.add_buf_data(cookie_);
|
||||
}
|
||||
~BuildPhaseDoneRequest() = default;
|
||||
~BuildPhaseDoneRequest() override = default;
|
||||
|
||||
private:
|
||||
std::string cookie_;
|
||||
|
@ -313,7 +325,7 @@ class DropSessionRequest : public BaseRequest {
|
|||
explicit DropSessionRequest(const CacheClientInfo &cinfo) : BaseRequest(RequestType::kDropSession) {
|
||||
rq_.mutable_connection_info()->operator=(cinfo);
|
||||
}
|
||||
~DropSessionRequest() = default;
|
||||
~DropSessionRequest() override = default;
|
||||
};
|
||||
|
||||
class GenerateSessionIdRequest : public BaseRequest {
|
||||
|
@ -325,11 +337,36 @@ class GenerateSessionIdRequest : public BaseRequest {
|
|||
rq_.set_connection_id(0);
|
||||
}
|
||||
|
||||
~GenerateSessionIdRequest() = default;
|
||||
~GenerateSessionIdRequest() override = default;
|
||||
|
||||
session_id_type GetSessionId() { return atoi(reply_.result().data()); }
|
||||
};
|
||||
|
||||
class ListSessionsRequest : public BaseRequest {
|
||||
public:
|
||||
friend class CacheServer;
|
||||
ListSessionsRequest() : BaseRequest(RequestType::kListSessions) {
|
||||
// This request is not specific to any cache or session
|
||||
rq_.set_connection_id(0);
|
||||
}
|
||||
|
||||
~ListSessionsRequest() override = default;
|
||||
|
||||
/// \brief Override base function to process the result.
|
||||
Status PostReply() override;
|
||||
|
||||
void GetSessionCacheInfo(std::vector<SessionCacheInfo> *info) {
|
||||
if (info != nullptr) {
|
||||
(*info) = session_info_list_;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<SessionCacheInfo> GetSessionCacheInfo() { return session_info_list_; }
|
||||
|
||||
private:
|
||||
std::vector<SessionCacheInfo> session_info_list_;
|
||||
};
|
||||
|
||||
class AllocateSharedBlockRequest : public BaseRequest {
|
||||
public:
|
||||
friend class CacheServer;
|
||||
|
@ -338,7 +375,7 @@ class AllocateSharedBlockRequest : public BaseRequest {
|
|||
rq_.set_connection_id(connection_id);
|
||||
rq_.add_buf_data(std::to_string(requestedSz));
|
||||
}
|
||||
~AllocateSharedBlockRequest() = default;
|
||||
~AllocateSharedBlockRequest() override = default;
|
||||
|
||||
/// \brief On return from the server, we get the (relative) address where
|
||||
/// the free block is located.
|
||||
|
@ -349,11 +386,15 @@ class AllocateSharedBlockRequest : public BaseRequest {
|
|||
}
|
||||
};
|
||||
|
||||
class ShutdownRequest : public BaseRequest {
|
||||
class ToggleWriteModeRequest : public BaseRequest {
|
||||
public:
|
||||
friend class CacheServer;
|
||||
ShutdownRequest() : BaseRequest(RequestType::kStopService) {}
|
||||
~ShutdownRequest() = default;
|
||||
explicit ToggleWriteModeRequest(connection_id_type connection_id, bool on_off)
|
||||
: BaseRequest(RequestType::kToggleWriteMode) {
|
||||
rq_.set_connection_id(connection_id);
|
||||
rq_.add_buf_data(on_off ? "on" : "off");
|
||||
}
|
||||
~ToggleWriteModeRequest() override = default;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <functional>
|
||||
#include <limits>
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/engine/cache/cache_ipc.h"
|
||||
#include "minddata/dataset/engine/cache/cache_service.h"
|
||||
#include "minddata/dataset/engine/cache/cache_request.h"
|
||||
#include "minddata/dataset/util/bit.h"
|
||||
|
@ -107,6 +108,8 @@ Status CacheServer::DoServiceStop() {
|
|||
// First stop all the threads.
|
||||
RETURN_IF_NOT_OK(vg_.ServiceStop());
|
||||
// Clean up all the caches if any.
|
||||
// Dump a message how much memory we have consumed in total.
|
||||
MS_LOG(INFO) << "Memory usage for the current server: " << GetMemoryUsage() << " bytes.";
|
||||
UniqueLock lck(&rwLock_);
|
||||
auto it = all_caches_.begin();
|
||||
while (it != all_caches_.end()) {
|
||||
|
@ -121,7 +124,6 @@ Status CacheServer::DoServiceStop() {
|
|||
}
|
||||
|
||||
CacheService *CacheServer::GetService(connection_id_type id) const {
|
||||
SharedLock lck(&rwLock_);
|
||||
auto it = all_caches_.find(id);
|
||||
if (it != all_caches_.end()) {
|
||||
return it->second.get();
|
||||
|
@ -134,6 +136,16 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
|
|||
std::string cookie;
|
||||
auto session_id = rq->connection_info().session_id();
|
||||
auto crc = rq->connection_info().crc();
|
||||
|
||||
// Before allowing the creation, make sure the session had already been created by the user
|
||||
// Our intention is to add this cache to the active sessions list so leave the list locked during
|
||||
// this entire function.
|
||||
UniqueLock lock(&sessions_lock_);
|
||||
auto session_it = active_sessions_.find(session_id);
|
||||
if (session_it == active_sessions_.end()) {
|
||||
RETURN_STATUS_UNEXPECTED("A cache creation has been requested but the session was not found!");
|
||||
}
|
||||
|
||||
// We concat both numbers to form the internal connection id.
|
||||
auto connection_id = GetConnectionID(session_id, crc);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing info to create cache");
|
||||
|
@ -172,10 +184,15 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
|
|||
} catch (const std::bad_alloc &e) {
|
||||
return Status(StatusCode::kOutOfMemory);
|
||||
}
|
||||
// Add the cache into the active session tracking.
|
||||
// We have already validated that the session exists and that this is a new cache created.
|
||||
session_it->second.insert(connection_id);
|
||||
|
||||
} else {
|
||||
duplicate = true;
|
||||
MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service";
|
||||
}
|
||||
|
||||
off_cookie = fbb.CreateString(cookie);
|
||||
CreateCacheReplyMsgBuilder bld(fbb);
|
||||
bld.add_connection_id(connection_id);
|
||||
|
@ -183,19 +200,18 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
|
|||
auto off = bld.Finish();
|
||||
fbb.Finish(off);
|
||||
reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
|
||||
// Track the history of all the sessions that we have created so far.
|
||||
history_sessions_.insert(session_id);
|
||||
// We can return OK but we will return a duplicate key so user can act accordingly to either ignore it
|
||||
// treat it as OK.
|
||||
return duplicate ? Status(StatusCode::kDuplicateKey) : Status::OK();
|
||||
}
|
||||
|
||||
Status CacheServer::DestroyCache(CacheService *cs, CacheRequest *rq) {
|
||||
Status CacheServer::DestroyCache(CacheRequest *rq) {
|
||||
// We need a strong lock to protect the map.
|
||||
UniqueLock lck(&rwLock_);
|
||||
auto id = rq->connection_id();
|
||||
CacheService *cs = GetService(id);
|
||||
// it is already destroyed. Ignore it.
|
||||
if (cs != nullptr) {
|
||||
auto id = rq->connection_id();
|
||||
MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id);
|
||||
// std::map will invoke the destructor of CacheService. So we don't need to do anything here.
|
||||
auto n = all_caches_.erase(id);
|
||||
|
@ -204,11 +220,34 @@ Status CacheServer::DestroyCache(CacheService *cs, CacheRequest *rq) {
|
|||
MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service";
|
||||
}
|
||||
}
|
||||
|
||||
// Now that this cache is removed, we need to also remove it's connection id from active session tracking
|
||||
auto session_id = GetSessionID(id);
|
||||
UniqueLock sess_lck(&sessions_lock_);
|
||||
|
||||
auto it = active_sessions_.find(session_id);
|
||||
if (it == active_sessions_.end()) {
|
||||
// The session was not found in the active sessions
|
||||
RETURN_STATUS_UNEXPECTED("A destroy cache request has been completed but it had a stale session id!");
|
||||
}
|
||||
|
||||
auto connection_it = it->second.find(id);
|
||||
if (connection_it == it->second.end()) {
|
||||
RETURN_STATUS_UNEXPECTED("A destroy cache request could not find the connection in the activate sessions!");
|
||||
}
|
||||
|
||||
// remove that connection id from the set
|
||||
it->second.erase(connection_it);
|
||||
MS_LOG(INFO) << "Destroyed cache " << id << " and removed from active session " << session_id;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
inline Status CacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) {
|
||||
Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) {
|
||||
auto connection_id = rq->connection_id();
|
||||
// Hold the shared lock to prevent the cache from being dropped.
|
||||
SharedLock lck(&rwLock_);
|
||||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
|
@ -236,8 +275,11 @@ inline Status CacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheServer::FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) {
|
||||
Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
|
||||
auto connection_id = rq->connection_id();
|
||||
// Hold the shared lock to prevent the cache from being dropped.
|
||||
SharedLock lck(&rwLock_);
|
||||
CacheService *cs = GetService(connection_id);
|
||||
auto shared_pool = comm_layer_->GetSharedMemoryPool();
|
||||
auto *base = shared_pool->SharedMemoryBaseAddr();
|
||||
// Ensure we got 3 pieces of data coming in
|
||||
|
@ -270,8 +312,11 @@ Status CacheServer::FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply
|
|||
return rc;
|
||||
}
|
||||
|
||||
Status CacheServer::BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheReply *reply) {
|
||||
Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
|
||||
auto connection_id = rq->connection_id();
|
||||
// Hold the shared lock to prevent the cache from being dropped.
|
||||
SharedLock lck(&rwLock_);
|
||||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
|
@ -325,8 +370,11 @@ Status CacheServer::BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheRepl
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
inline Status GetStat(CacheService *cs, CacheRequest *rq, CacheReply *reply) {
|
||||
Status CacheServer::GetStat(CacheRequest *rq, CacheReply *reply) {
|
||||
auto connection_id = rq->connection_id();
|
||||
// Hold the shared lock to prevent the cache from being dropped.
|
||||
SharedLock lck(&rwLock_);
|
||||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
|
@ -338,8 +386,8 @@ inline Status GetStat(CacheService *cs, CacheRequest *rq, CacheReply *reply) {
|
|||
bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached);
|
||||
bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached);
|
||||
bld.add_avg_cache_sz(svc_stat.stat_.average_cache_sz);
|
||||
bld.add_max_row_id(svc_stat.max_);
|
||||
bld.add_min_row_id(svc_stat.min_);
|
||||
bld.add_max_row_id(svc_stat.stat_.max_key);
|
||||
bld.add_min_row_id(svc_stat.stat_.min_key);
|
||||
bld.add_state(svc_stat.state_);
|
||||
auto offset = bld.Finish();
|
||||
fbb.Finish(offset);
|
||||
|
@ -348,8 +396,11 @@ inline Status GetStat(CacheService *cs, CacheRequest *rq, CacheReply *reply) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
inline Status CacheSchema(CacheService *cs, CacheRequest *rq) {
|
||||
Status CacheServer::CacheSchema(CacheRequest *rq) {
|
||||
auto connection_id = rq->connection_id();
|
||||
// Hold the shared lock to prevent the cache from being dropped.
|
||||
SharedLock lck(&rwLock_);
|
||||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
|
@ -361,8 +412,11 @@ inline Status CacheSchema(CacheService *cs, CacheRequest *rq) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
inline Status FetchSchema(CacheService *cs, CacheRequest *rq, CacheReply *reply) {
|
||||
Status CacheServer::FetchSchema(CacheRequest *rq, CacheReply *reply) {
|
||||
auto connection_id = rq->connection_id();
|
||||
// Hold the shared lock to prevent the cache from being dropped.
|
||||
SharedLock lck(&rwLock_);
|
||||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
|
@ -377,8 +431,11 @@ inline Status FetchSchema(CacheService *cs, CacheRequest *rq, CacheReply *reply)
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
inline Status BuildPhaseDone(CacheService *cs, CacheRequest *rq) {
|
||||
Status CacheServer::BuildPhaseDone(CacheRequest *rq) {
|
||||
auto connection_id = rq->connection_id();
|
||||
// Hold the shared lock to prevent the cache from being dropped.
|
||||
SharedLock lck(&rwLock_);
|
||||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
|
@ -396,15 +453,24 @@ inline Status BuildPhaseDone(CacheService *cs, CacheRequest *rq) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheServer::PurgeCache(CacheService *cs) {
|
||||
Status CacheServer::GetCacheMissKeys(CacheRequest *rq, CacheReply *reply) {
|
||||
auto connection_id = rq->connection_id();
|
||||
// Hold the shared lock to prevent the cache from being dropped.
|
||||
SharedLock lck(&rwLock_);
|
||||
// If shutdown in progress, ignore the command.
|
||||
if (global_shutdown_) {
|
||||
return Status::OK();
|
||||
}
|
||||
// it is already purged. Ignore it.
|
||||
if (cs != nullptr) {
|
||||
RETURN_IF_NOT_OK(cs->Purge());
|
||||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else {
|
||||
std::vector<row_id_type> gap;
|
||||
RETURN_IF_NOT_OK(cs->FindKeysMiss(&gap));
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto off_t = fbb.CreateVector(gap);
|
||||
TensorRowIdsBuilder bld(fbb);
|
||||
bld.add_row_id(off_t);
|
||||
auto off = bld.Finish();
|
||||
fbb.Finish(off);
|
||||
reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -414,6 +480,72 @@ inline Status GenerateClientSessionID(session_id_type session_id, CacheReply *re
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheServer::ToggleWriteMode(CacheRequest *rq) {
|
||||
auto connection_id = rq->connection_id();
|
||||
// Hold the shared lock to prevent the cache from being dropped.
|
||||
SharedLock lck(&rwLock_);
|
||||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else {
|
||||
// First piece of data is the on/off flag
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing action flag");
|
||||
const auto &action = rq->buf_data(0);
|
||||
bool on_off = false;
|
||||
if (strcmp(action.data(), "on") == 0) {
|
||||
on_off = true;
|
||||
} else if (strcmp(action.data(), "off") == 0) {
|
||||
on_off = false;
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Unknown request: " + action);
|
||||
}
|
||||
RETURN_IF_NOT_OK(cs->ToggleWriteMode(on_off));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheServer::ListSessions(CacheReply *reply) {
|
||||
SharedLock lck(&sessions_lock_);
|
||||
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
std::vector<flatbuffers::Offset<ListSessionMsg>> session_msgs_vector;
|
||||
for (auto it = active_sessions_.begin(); it != active_sessions_.end(); it++) {
|
||||
session_id_type current_session_id = it->first;
|
||||
// Loop over each cache inside this session
|
||||
if (!it->second.empty()) {
|
||||
for (auto current_conn_id : it->second) {
|
||||
CacheService *cs = GetService(current_conn_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(current_conn_id) + " not found during list sessions.";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else {
|
||||
CacheService::ServiceStat svc_stat;
|
||||
RETURN_IF_NOT_OK(cs->GetStat(&svc_stat));
|
||||
auto current_stats = CreateServiceStatMsg(fbb, svc_stat.stat_.num_mem_cached, svc_stat.stat_.num_disk_cached,
|
||||
svc_stat.stat_.average_cache_sz, svc_stat.stat_.min_key,
|
||||
svc_stat.stat_.max_key, svc_stat.state_);
|
||||
auto current_session_info = CreateListSessionMsg(fbb, current_session_id, current_conn_id, current_stats);
|
||||
session_msgs_vector.push_back(current_session_info);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If there is no cache created yet, assign a connection id of 0 along with empty stats
|
||||
auto current_stats = CreateServiceStatMsg(fbb, 0, 0, 0, 0, 0, 0);
|
||||
auto current_session_info = CreateListSessionMsg(fbb, current_session_id, 0, current_stats);
|
||||
session_msgs_vector.push_back(current_session_info);
|
||||
}
|
||||
}
|
||||
auto session_msgs = fbb.CreateVector(session_msgs_vector);
|
||||
ListSessionsMsgBuilder s_builder(fbb);
|
||||
s_builder.add_sessions(session_msgs);
|
||||
auto offset = s_builder.Finish();
|
||||
fbb.Finish(offset);
|
||||
reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief This is the main loop the cache server thread(s) are running.
|
||||
/// Each thread will pop a request and send the result back to the client using grpc
|
||||
/// \return
|
||||
|
@ -426,12 +558,6 @@ Status CacheServer::ServerRequest(int32_t worker_id) {
|
|||
RETURN_IF_NOT_OK(my_que->PopFront(&cache_req));
|
||||
auto &rq = cache_req->rq_;
|
||||
auto &reply = cache_req->reply_;
|
||||
CacheService *cs = nullptr;
|
||||
// Request comes in roughly two sets. One set is at the cache level with a connection id.
|
||||
// The other set is working at a high level and without a connection id
|
||||
if (!rq.has_connection_info()) {
|
||||
cs = GetService(rq.connection_id());
|
||||
}
|
||||
// Except for creating a new session, we expect cs is not null.
|
||||
switch (cache_req->type_) {
|
||||
case BaseRequest::RequestType::kCacheRow: {
|
||||
|
@ -439,42 +565,42 @@ Status CacheServer::ServerRequest(int32_t worker_id) {
|
|||
// call the appropriate method.
|
||||
auto flag = rq.flag();
|
||||
if (BitTest(flag, kDataIsInSharedMemory)) {
|
||||
cache_req->rc_ = FastCacheRow(cs, &rq, &reply);
|
||||
cache_req->rc_ = FastCacheRow(&rq, &reply);
|
||||
} else {
|
||||
cache_req->rc_ = CacheRow(cs, &rq, &reply);
|
||||
cache_req->rc_ = CacheRow(&rq, &reply);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kBatchFetchRows: {
|
||||
cache_req->rc_ = BatchFetchRows(cs, &rq, &reply);
|
||||
cache_req->rc_ = BatchFetchRows(&rq, &reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kCreateCache: {
|
||||
cache_req->rc_ = CreateService(&rq, &reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kPurgeCache: {
|
||||
cache_req->rc_ = PurgeCache(cs);
|
||||
case BaseRequest::RequestType::kGetCacheMissKeys: {
|
||||
cache_req->rc_ = GetCacheMissKeys(&rq, &reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kDestroyCache: {
|
||||
cache_req->rc_ = DestroyCache(cs, &rq);
|
||||
cache_req->rc_ = DestroyCache(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kGetStat: {
|
||||
cache_req->rc_ = GetStat(cs, &rq, &reply);
|
||||
cache_req->rc_ = GetStat(&rq, &reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kCacheSchema: {
|
||||
cache_req->rc_ = CacheSchema(cs, &rq);
|
||||
cache_req->rc_ = CacheSchema(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kFetchSchema: {
|
||||
cache_req->rc_ = FetchSchema(cs, &rq, &reply);
|
||||
cache_req->rc_ = FetchSchema(&rq, &reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kBuildPhaseDone: {
|
||||
cache_req->rc_ = BuildPhaseDone(cs, &rq);
|
||||
cache_req->rc_ = BuildPhaseDone(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kDropSession: {
|
||||
|
@ -498,6 +624,18 @@ Status CacheServer::ServerRequest(int32_t worker_id) {
|
|||
cache_req->rc_ = GlobalShutdown();
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kHeartBeat: {
|
||||
cache_req->rc_ = Status::OK();
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kToggleWriteMode: {
|
||||
cache_req->rc_ = ToggleWriteMode(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kListSessions: {
|
||||
cache_req->rc_ = ListSessions(&reply);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
std::string errMsg("Unknown request type : ");
|
||||
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
|
||||
|
@ -526,18 +664,32 @@ session_id_type CacheServer::GetSessionID(connection_id_type connection_id) cons
|
|||
}
|
||||
|
||||
CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port,
|
||||
int32_t shared_meory_sz_in_gb)
|
||||
int32_t shared_meory_sz_in_gb, float memory_cap_ratio)
|
||||
: top_(spill_path),
|
||||
num_workers_(num_workers),
|
||||
port_(port),
|
||||
shared_memory_sz_in_gb_(shared_meory_sz_in_gb),
|
||||
global_shutdown_(false) {}
|
||||
global_shutdown_(false),
|
||||
memory_cap_ratio_(memory_cap_ratio),
|
||||
cur_mem_usage_(0) {
|
||||
memory_cap_ = CacheServer::GetTotalSystemMemory() * memory_cap_ratio_;
|
||||
}
|
||||
|
||||
Status CacheServer::Run() {
|
||||
RETURN_IF_NOT_OK(ServiceStart());
|
||||
Status CacheServer::Run(int msg_qid) {
|
||||
Status rc = ServiceStart();
|
||||
// If there is a message que, return the status now before we call join_all which will never return
|
||||
if (msg_qid != -1) {
|
||||
SharedMessage msg(msg_qid);
|
||||
RETURN_IF_NOT_OK(msg.SendStatus(rc));
|
||||
}
|
||||
if (rc.IsError()) {
|
||||
return rc;
|
||||
}
|
||||
// This is called by the main function and we shouldn't exit. Otherwise the main thread
|
||||
// will just shutdown. So we will call some function that never return unless error.
|
||||
// One good case will be simply to wait for all threads to return.
|
||||
// note that after we have sent the initial status using the msg_qid, parent process will exit and
|
||||
// remove it. So we can't use it again.
|
||||
RETURN_IF_NOT_OK(vg_.join_all(Task::WaitFlag::kBlocking));
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -567,32 +719,51 @@ Status CacheServer::ReturnRequestTag(CacheServerRequest *p) {
|
|||
Status CacheServer::DestroySession(CacheRequest *rq) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing session id");
|
||||
auto drop_session_id = rq->connection_info().session_id();
|
||||
UniqueLock lck(&rwLock_);
|
||||
for (auto &cs : all_caches_) {
|
||||
auto connection_id = cs.first;
|
||||
auto session_id = GetSessionID(connection_id);
|
||||
// We can just call DestroyCache() but we are holding a lock already. Doing so will cause deadlock.
|
||||
// So we will just manually do it.
|
||||
if (session_id == drop_session_id) {
|
||||
// std::map will invoke the destructor of CacheService. So we don't need to do anything here.
|
||||
auto n = all_caches_.erase(connection_id);
|
||||
MS_LOG(INFO) << "Destroy " << n << " copies of cache with id " << connection_id;
|
||||
|
||||
UniqueLock lck(&sessions_lock_);
|
||||
|
||||
// First validate that this session exists
|
||||
auto it = active_sessions_.find(drop_session_id);
|
||||
if (it == active_sessions_.end()) {
|
||||
RETURN_STATUS_UNEXPECTED("A destroy session command has been requested but the session was not found!");
|
||||
}
|
||||
|
||||
// Iterate over the set of connection id's for this session that we're dropping and erase each one.
|
||||
{
|
||||
UniqueLock rwlck(&rwLock_);
|
||||
for (auto drop_connection_id : it->second) {
|
||||
auto cache_drop_it = all_caches_.find(drop_connection_id);
|
||||
if (cache_drop_it == all_caches_.end()) {
|
||||
RETURN_STATUS_UNEXPECTED("active session tracking had stale or incorrect cache entry.");
|
||||
}
|
||||
all_caches_.erase(cache_drop_it);
|
||||
MS_LOG(INFO) << "Session destroy: Destroy cache with id " << drop_connection_id;
|
||||
// **Do not bother to remove the cache connection id from the active session because we will soon remove the
|
||||
// entire session.
|
||||
}
|
||||
}
|
||||
|
||||
// Finally remove the session itself
|
||||
active_sessions_.erase(it);
|
||||
MS_LOG(INFO) << "Session destroyed with id " << drop_session_id;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
session_id_type CacheServer::GenerateSessionID() const {
|
||||
SharedLock lock(&rwLock_);
|
||||
session_id_type CacheServer::GenerateSessionID() {
|
||||
UniqueLock lock(&sessions_lock_);
|
||||
auto mt = GetRandomDevice();
|
||||
std::uniform_int_distribution<session_id_type> distribution(0, std::numeric_limits<session_id_type>::max());
|
||||
session_id_type session_id;
|
||||
bool duplicate = false;
|
||||
do {
|
||||
session_id = distribution(mt);
|
||||
auto it = history_sessions_.find(session_id);
|
||||
duplicate = (it != history_sessions_.end());
|
||||
auto it = active_sessions_.find(session_id);
|
||||
duplicate = (it != active_sessions_.end());
|
||||
} while (duplicate);
|
||||
|
||||
// Add this session to our tracking of active sessions with initialized empty set of connections.
|
||||
active_sessions_[session_id] = std::set<connection_id_type>();
|
||||
return session_id;
|
||||
}
|
||||
|
||||
|
@ -637,19 +808,59 @@ Status CacheServer::GlobalShutdown() {
|
|||
vg_.interrupt_all();
|
||||
// The next thing to do drop all the caches.
|
||||
UniqueLock lck(&rwLock_);
|
||||
for (auto &it : all_caches_) {
|
||||
auto id = it.first;
|
||||
for (auto it = all_caches_.begin(); it != all_caches_.end();) {
|
||||
auto id = it->first;
|
||||
MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id);
|
||||
// Wait for all outstanding work to be finished.
|
||||
auto &cs = it.second;
|
||||
auto &cs = it->second;
|
||||
UniqueLock cs_lock(&cs->rw_lock_);
|
||||
// std::map will invoke the destructor of CacheService. So we don't need to do anything here.
|
||||
(void)all_caches_.erase(id);
|
||||
it = all_caches_.erase(it);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64_t CacheServer::GetTotalSystemMemory() {
|
||||
auto pages = sysconf(_SC_PHYS_PAGES);
|
||||
auto page_size = sysconf(_SC_PAGE_SIZE);
|
||||
auto total = static_cast<int64_t>(pages) * static_cast<int64_t>(page_size);
|
||||
MS_LOG(INFO) << "Total physical RAM in bytes: " << total;
|
||||
return total;
|
||||
}
|
||||
|
||||
Status CacheServer::Builder::IpcResourceCleanup() {
|
||||
Status rc;
|
||||
SharedMemory::shm_key_t shm_key;
|
||||
auto unix_socket = PortToUnixSocketPath(port_);
|
||||
rc = PortToFtok(port_, &shm_key);
|
||||
// We are expecting the unix path doesn't exist.
|
||||
if (rc.IsError()) {
|
||||
return Status::OK();
|
||||
}
|
||||
// Attach to the shared memory which we expect don't exist
|
||||
SharedMemory mem(shm_key);
|
||||
rc = mem.Attach();
|
||||
if (rc.IsError()) {
|
||||
return Status::OK();
|
||||
}
|
||||
int32_t num_attached;
|
||||
RETURN_IF_NOT_OK(mem.GetNumAttached(&num_attached));
|
||||
if (num_attached == 0) {
|
||||
// Stale shared memory from last time.
|
||||
// Remove both the memory and the socket path
|
||||
RETURN_IF_NOT_OK(mem.Destroy());
|
||||
Path p(unix_socket);
|
||||
(void)p.Remove();
|
||||
} else {
|
||||
// Server is already up.
|
||||
std::string errMsg = "Cache server is already up and running";
|
||||
// We return a duplicate error. The main() will intercept
|
||||
// and output a proper message
|
||||
return Status(StatusCode::kDuplicateKey, errMsg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheServer::Builder::SanityCheck() {
|
||||
if (shared_memory_sz_in_gb_ <= 0) {
|
||||
RETURN_STATUS_UNEXPECTED("Shared memory size (in GB unit) must be positive");
|
||||
|
@ -673,6 +884,12 @@ Status CacheServer::Builder::SanityCheck() {
|
|||
RETURN_STATUS_UNEXPECTED("Spilling directory is not writable\n" + rc.ToString());
|
||||
}
|
||||
}
|
||||
if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1) {
|
||||
RETURN_STATUS_UNEXPECTED("Memory cap ratio should be positive and no greater than 1");
|
||||
}
|
||||
|
||||
// Check if the shared memory.
|
||||
RETURN_IF_NOT_OK(IpcResourceCleanup());
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_
|
||||
|
||||
#include <string.h>
|
||||
#include <unistd.h>
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
|
@ -47,15 +49,16 @@ class CacheServer : public Service {
|
|||
using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>;
|
||||
class Builder {
|
||||
public:
|
||||
Builder() : top_("/tmp"), num_workers_(32), port_(50052), shared_memory_sz_in_gb_(4) {}
|
||||
Builder() : top_("/tmp"), num_workers_(32), port_(50052), shared_memory_sz_in_gb_(4), memory_cap_ratio_(0.8) {}
|
||||
|
||||
~Builder() = default;
|
||||
|
||||
/// \brief Getter functions
|
||||
const std::string &getTop() const { return top_; }
|
||||
int32_t getNumWorkers() const { return num_workers_; }
|
||||
int32_t getPort() const { return port_; }
|
||||
int32_t getSharedMemorySzInGb() const { return shared_memory_sz_in_gb_; }
|
||||
const std::string &GetTop() const { return top_; }
|
||||
int32_t GetNumWorkers() const { return num_workers_; }
|
||||
int32_t GetPort() const { return port_; }
|
||||
int32_t GetSharedMemorySzInGb() const { return shared_memory_sz_in_gb_; }
|
||||
float GetMemoryCapRatio() const { return memory_cap_ratio_; }
|
||||
|
||||
Builder &SetRootDirectory(std::string root) {
|
||||
top_ = std::move(root);
|
||||
|
@ -73,15 +76,20 @@ class CacheServer : public Service {
|
|||
shared_memory_sz_in_gb_ = sz;
|
||||
return *this;
|
||||
}
|
||||
Builder &SetMemoryCapRatio(float ratio) {
|
||||
memory_cap_ratio_ = ratio;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Status SanityCheck();
|
||||
|
||||
void Print(std::ostream &out) const {
|
||||
out << "Summary of the cache server configuration\n"
|
||||
<< "Spill directory: " << getTop() << "\n"
|
||||
<< "Number of parallel workers: " << getNumWorkers() << "\n"
|
||||
<< "Tcp/ip port: " << getPort() << "\n"
|
||||
<< "Shared memory size (in GB): " << getSharedMemorySzInGb();
|
||||
<< "Spill directory: " << GetTop() << "\n"
|
||||
<< "Number of parallel workers: " << GetNumWorkers() << "\n"
|
||||
<< "Tcp/ip port: " << GetPort() << "\n"
|
||||
<< "Shared memory size (in GB): " << GetSharedMemorySzInGb() << "\n"
|
||||
<< "Memory cap ratio: " << GetMemoryCapRatio();
|
||||
}
|
||||
|
||||
friend std::ostream &operator<<(std::ostream &out, const Builder &bld) {
|
||||
|
@ -93,7 +101,8 @@ class CacheServer : public Service {
|
|||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
// We need to bring up the Task Manager by bringing up the Services singleton.
|
||||
RETURN_IF_NOT_OK(Services::CreateInstance());
|
||||
RETURN_IF_NOT_OK(CacheServer::CreateInstance(top_, num_workers_, port_, shared_memory_sz_in_gb_));
|
||||
RETURN_IF_NOT_OK(
|
||||
CacheServer::CreateInstance(top_, num_workers_, port_, shared_memory_sz_in_gb_, memory_cap_ratio_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -102,20 +111,27 @@ class CacheServer : public Service {
|
|||
int32_t num_workers_;
|
||||
int32_t port_;
|
||||
int32_t shared_memory_sz_in_gb_;
|
||||
float memory_cap_ratio_;
|
||||
|
||||
/// \brief Sanity checks on the shared memory.
|
||||
/// \return Status object
|
||||
Status IpcResourceCleanup();
|
||||
};
|
||||
|
||||
CacheServer(const CacheServer &) = delete;
|
||||
CacheServer &operator=(const CacheServer &) = delete;
|
||||
CacheServer(CacheServer &&) = delete;
|
||||
CacheServer &operator=(CacheServer &) = delete;
|
||||
Status DoServiceStart() override;
|
||||
Status DoServiceStop() override;
|
||||
~CacheServer() { (void)ServiceStop(); }
|
||||
~CacheServer() override { (void)ServiceStop(); }
|
||||
|
||||
static Status CreateInstance(const std::string &spill_path, int32_t num_workers, int32_t port,
|
||||
int32_t shared_memory_sz) {
|
||||
int32_t shared_memory_sz, float memory_cap_ratio) {
|
||||
std::call_once(init_instance_flag_, [&]() -> Status {
|
||||
auto &svcManager = Services::GetInstance();
|
||||
RETURN_IF_NOT_OK(svcManager.AddHook(&instance_, spill_path, num_workers, port, shared_memory_sz));
|
||||
auto &SvcManager = Services::GetInstance();
|
||||
RETURN_IF_NOT_OK(
|
||||
SvcManager.AddHook(&instance_, spill_path, num_workers, port, shared_memory_sz, memory_cap_ratio));
|
||||
return Status::OK();
|
||||
});
|
||||
return Status::OK();
|
||||
|
@ -133,7 +149,7 @@ class CacheServer : public Service {
|
|||
}
|
||||
|
||||
/// \\brief Kick off server threads. Never return unless error out.
|
||||
Status Run();
|
||||
Status Run(SharedMessage::queue_id_t msg_qid);
|
||||
|
||||
/// \brief Get a free tag
|
||||
/// \param q[in] pointer to a pointer to a CacheServerRequest
|
||||
|
@ -145,13 +161,35 @@ class CacheServer : public Service {
|
|||
/// \return Status object
|
||||
static Status ReturnRequestTag(CacheServerRequest *p);
|
||||
|
||||
/// \brief This returns the size (in bytes) of the physical RAM on the machine.
|
||||
/// \return the size (in bytes) of the physical RAM on the machine.
|
||||
static int64_t GetTotalSystemMemory();
|
||||
|
||||
/// \brief Internally this is how much we will try to use without exceeding the limit
|
||||
/// \return Internal cap maximum
|
||||
int64_t GetAvailableSystemMemory() { return memory_cap_; }
|
||||
|
||||
/// \brief Find out the current memory usage
|
||||
int64_t GetMemoryUsage() { return cur_mem_usage_; }
|
||||
|
||||
/// \brief This updates our current memory usage.
|
||||
enum MemUsageOp : int8_t { kAllocate = 1, kFree = 2 };
|
||||
void UpdateMemoryUsage(int64_t sz, MemUsageOp op) {
|
||||
if (op == MemUsageOp::kAllocate) {
|
||||
cur_mem_usage_ += sz;
|
||||
} else {
|
||||
cur_mem_usage_ -= sz;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static std::once_flag init_instance_flag_;
|
||||
static CacheServer *instance_;
|
||||
mutable RWLock rwLock_;
|
||||
mutable RWLock sessions_lock_;
|
||||
std::string top_;
|
||||
cache_index all_caches_;
|
||||
std::set<session_id_type> history_sessions_;
|
||||
std::map<session_id_type, std::set<connection_id_type>> active_sessions_;
|
||||
std::shared_ptr<QueueList<CacheServerRequest *>> cache_q_;
|
||||
std::shared_ptr<QueueList<CacheServerRequest *>> free_list_;
|
||||
std::vector<std::unique_ptr<MemGuard<CacheServerRequest, Allocator<CacheServerRequest>>>> tag_;
|
||||
|
@ -162,11 +200,15 @@ class CacheServer : public Service {
|
|||
int32_t port_;
|
||||
int32_t shared_memory_sz_in_gb_;
|
||||
std::atomic<bool> global_shutdown_;
|
||||
float memory_cap_ratio_;
|
||||
int64_t memory_cap_;
|
||||
std::atomic<int64_t> cur_mem_usage_;
|
||||
|
||||
/// \brief Constructor
|
||||
/// \param spill_path Top directory for spilling buffers to.
|
||||
/// \param num_workers Number of threads for handling requests.
|
||||
explicit CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, int32_t share_memory_sz_in_gb);
|
||||
explicit CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, int32_t share_memory_sz_in_gb,
|
||||
float memory_cap_ratio);
|
||||
|
||||
/// \brief Locate a cache service from connection id.
|
||||
/// \return Pointer to cache service. Null if not found
|
||||
|
@ -179,11 +221,9 @@ class CacheServer : public Service {
|
|||
Status CreateService(CacheRequest *rq, CacheReply *reply);
|
||||
|
||||
/// \brief Destroy a cache service
|
||||
/// \param cs
|
||||
/// \param rq
|
||||
/// \return
|
||||
Status DestroyCache(CacheService *cs, CacheRequest *rq);
|
||||
Status PurgeCache(CacheService *cs);
|
||||
/// \return Status object
|
||||
Status DestroyCache(CacheRequest *rq);
|
||||
|
||||
/// \brief Entry point for all internal server threads.
|
||||
Status ServerRequest(int32_t worker_id);
|
||||
|
@ -207,7 +247,7 @@ class CacheServer : public Service {
|
|||
|
||||
/// \brief Generate a session ID for the client
|
||||
/// \return Session ID
|
||||
session_id_type GenerateSessionID() const;
|
||||
session_id_type GenerateSessionID();
|
||||
|
||||
/// \brief Handle kAllocateSharedBlock request
|
||||
/// \param rq CacheRequest
|
||||
|
@ -220,20 +260,55 @@ class CacheServer : public Service {
|
|||
/// \return Status object
|
||||
Status FreeSharedMemory(CacheRequest *rq);
|
||||
|
||||
/// \brief Handle kFastCacheRow request
|
||||
/// \brief Handle CacheRow request
|
||||
/// \note There are two different implementation depends if shared memory is used for transportation.
|
||||
/// \return Status object
|
||||
Status FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply);
|
||||
Status FastCacheRow(CacheRequest *rq, CacheReply *reply);
|
||||
Status CacheRow(CacheRequest *rq, CacheReply *reply);
|
||||
|
||||
/// \brief Internal function to do row batch fetch
|
||||
/// \param cs CacheService
|
||||
/// \param rq Request
|
||||
/// \param reply Reply
|
||||
/// \return
|
||||
Status BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheReply *reply);
|
||||
/// \return Status object
|
||||
Status BatchFetchRows(CacheRequest *rq, CacheReply *reply);
|
||||
|
||||
/// \brief Internal function to get statistics
|
||||
/// \param rq
|
||||
/// \param reply
|
||||
/// \return Status object
|
||||
Status GetStat(CacheRequest *rq, CacheReply *reply);
|
||||
|
||||
/// \brief Cache a schema request
|
||||
/// \param rq
|
||||
/// \return Status object
|
||||
Status CacheSchema(CacheRequest *rq);
|
||||
|
||||
/// \brief Fetch a schema request
|
||||
/// \param rq
|
||||
/// \param reply
|
||||
/// \return Status object
|
||||
Status FetchSchema(CacheRequest *rq, CacheReply *reply);
|
||||
|
||||
/// \brief Mark Build phase done (for non-mappable case)
|
||||
/// \param rq
|
||||
/// \return Status object
|
||||
Status BuildPhaseDone(CacheRequest *rq);
|
||||
|
||||
/// \brief A proper shutdown of the server
|
||||
/// \return Status object
|
||||
Status GlobalShutdown();
|
||||
|
||||
/// \brief Find keys that will be cache miss
|
||||
/// \return Status object
|
||||
Status GetCacheMissKeys(CacheRequest *rq, CacheReply *reply);
|
||||
|
||||
/// \brief Toggle write mode for a service
|
||||
Status ToggleWriteMode(CacheRequest *rq);
|
||||
|
||||
/// \brief List the sessions and their caches
|
||||
/// \param reply
|
||||
/// \return Status object
|
||||
Status ListSessions(CacheReply *reply);
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/cache/cache_service.h"
|
||||
#include "minddata/dataset/engine/cache/cache_server.h"
|
||||
#include "minddata/dataset/util/slice.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -22,42 +23,62 @@ CacheService::CacheService(uint64_t mem_sz, const std::string &root, bool genera
|
|||
: root_(root),
|
||||
cache_mem_sz_(mem_sz),
|
||||
cp_(nullptr),
|
||||
map_(nullptr),
|
||||
next_id_(0),
|
||||
generate_id_(generate_id),
|
||||
schema_key_(-1),
|
||||
st_(generate_id ? State::kBuildPhase : State::kNone) {}
|
||||
st_(generate_id ? State::kBuildPhase : State::kNone),
|
||||
cur_mem_usage_(0),
|
||||
cur_disk_usage_(0) {}
|
||||
|
||||
CacheService::~CacheService() { (void)ServiceStop(); }
|
||||
|
||||
bool CacheService::UseArena() {
|
||||
// If fixed size, use Arena instead of the pool from global context.
|
||||
return (cache_mem_sz_ > 0);
|
||||
}
|
||||
|
||||
Status CacheService::DoServiceStart() {
|
||||
std::shared_ptr<MemoryPool> mp_;
|
||||
CacheServer &cs = CacheServer::GetInstance();
|
||||
if (UseArena()) {
|
||||
auto avail_mem = cs.GetAvailableSystemMemory() / 1048576L;
|
||||
if (cache_mem_sz_ > avail_mem) {
|
||||
// Output a warning that we use more than recommended. If we fail to allocate, we will fail anyway.
|
||||
MS_LOG(WARNING) << "Requesting cache size " << cache_mem_sz_ << " MB while available system memory " << avail_mem
|
||||
<< " MB";
|
||||
}
|
||||
// Create a fixed size arena based on the parameter.
|
||||
std::shared_ptr<Arena> arena;
|
||||
RETURN_IF_NOT_OK(Arena::CreateArena(&arena, cache_mem_sz_));
|
||||
mp_ = std::move(arena);
|
||||
// update the global usage only.
|
||||
cs.UpdateMemoryUsage(cache_mem_sz_ * 1048576L, CacheServer::MemUsageOp::kAllocate);
|
||||
} else {
|
||||
// Unlimited size. Simply use a system pool. Another choice is CircularPool.
|
||||
mp_ = std::make_shared<SystemPool>();
|
||||
}
|
||||
// Put together a CachePool for backing up the Tensor
|
||||
cp_ = std::make_shared<CachePool>(CachePool::value_allocator(mp_), root_);
|
||||
cp_ = std::make_shared<CachePool>(CachePool::value_allocator(mp_), UseArena(), root_);
|
||||
RETURN_IF_NOT_OK(cp_->ServiceStart());
|
||||
// Set up the B+ tree as well. But use the system pool instead.
|
||||
map_ = std::make_shared<row_map>();
|
||||
// Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name.
|
||||
cookie_ = cp_->MyName();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheService::DoServiceStop() {
|
||||
if (cp_ != nullptr) {
|
||||
RETURN_IF_NOT_OK(cp_->ServiceStop());
|
||||
}
|
||||
CacheServer &cs = CacheServer::GetInstance();
|
||||
if (UseArena()) {
|
||||
cs.UpdateMemoryUsage(cache_mem_sz_ * 1048576L, CacheServer::MemUsageOp::kFree);
|
||||
} else {
|
||||
MS_LOG(INFO) << "Memory/disk usage for the current service: " << GetMemoryUsage() << " bytes and " << GetDiskUsage()
|
||||
<< " bytes.";
|
||||
cs.UpdateMemoryUsage(GetMemoryUsage(), CacheServer::MemUsageOp::kFree);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated) {
|
||||
SharedLock rw(&rw_lock_);
|
||||
RETURN_UNEXPECTED_IF_NULL(row_id_generated);
|
||||
|
@ -66,6 +87,11 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
|
|||
// allow other to cache more rows.
|
||||
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
|
||||
}
|
||||
if (st_ == State::kNoLocking) {
|
||||
// We ignore write this request once we turn off locking on the B+ tree. So we will just
|
||||
// return out of memory from now on.
|
||||
return Status(StatusCode::kOutOfMemory);
|
||||
}
|
||||
try {
|
||||
// The first buffer is a flatbuffer which describes the rest of the buffers follow
|
||||
auto fb = buf.front();
|
||||
|
@ -86,6 +112,7 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
|
|||
*row_id_generated = msg->row_id();
|
||||
}
|
||||
auto size_of_this = msg->size_of_this();
|
||||
size_t total_sz = size_of_this;
|
||||
auto column_hdr = msg->column();
|
||||
// Number of tensor buffer should match the number of columns plus one.
|
||||
if (buf.size() != column_hdr->size() + 1) {
|
||||
|
@ -99,16 +126,28 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
|
|||
all_data.emplace_back(fb, size_of_this);
|
||||
for (auto i = 0; i < column_hdr->size(); ++i) {
|
||||
all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i));
|
||||
total_sz += msg->data_sz()->Get(i);
|
||||
}
|
||||
// Now we cache the flat buffer.
|
||||
CachePool::key_type key;
|
||||
RETURN_IF_NOT_OK(cp_->Insert(all_data, &key));
|
||||
Status rc = map_->DoInsert(*row_id_generated, key);
|
||||
// Now we cache the buffer. If we are using Arena which has a fixed cap, then just do it.
|
||||
// Otherwise, we check how much (globally) how much we use and may simply spill to disk
|
||||
// directly.
|
||||
CacheServer &cs = CacheServer::GetInstance();
|
||||
bool write_to_disk_directly = UseArena() ? false : (total_sz + cs.GetMemoryUsage()) > cs.GetAvailableSystemMemory();
|
||||
Status rc = cp_->Insert(*row_id_generated, all_data, write_to_disk_directly);
|
||||
if (rc == Status(StatusCode::kDuplicateKey)) {
|
||||
MS_LOG(DEBUG) << "Ignoring duplicate key.";
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(rc);
|
||||
}
|
||||
// All good, then update the memory usage local and global (if not using arena)
|
||||
if (write_to_disk_directly) {
|
||||
cur_disk_usage_ += total_sz;
|
||||
} else {
|
||||
cur_mem_usage_ += total_sz;
|
||||
if (!UseArena()) {
|
||||
cs.UpdateMemoryUsage(total_sz, CacheServer::MemUsageOp::kAllocate);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
} catch (const std::exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED(e.what());
|
||||
|
@ -123,6 +162,11 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_
|
|||
// allow other to cache more rows.
|
||||
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
|
||||
}
|
||||
if (st_ == State::kNoLocking) {
|
||||
// We ignore write this request once we turn off locking on the B+ tree. So we will just
|
||||
// return out of memory from now on.
|
||||
return Status(StatusCode::kOutOfMemory);
|
||||
}
|
||||
try {
|
||||
// If we don't need to generate id, we need to find it from the buffer.
|
||||
if (generate_id_) {
|
||||
|
@ -139,20 +183,33 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_
|
|||
}
|
||||
*row_id_generated = msg->row_id();
|
||||
}
|
||||
// Now we cache the flat buffer.
|
||||
CachePool::key_type key;
|
||||
RETURN_IF_NOT_OK(cp_->Insert({src}, &key));
|
||||
Status rc = map_->DoInsert(*row_id_generated, key);
|
||||
// Now we cache the buffer. If we are using Arena which has a fixed cap, then just do it.
|
||||
// Otherwise, we check how much (globally) how much we use and may simply spill to disk
|
||||
// directly.
|
||||
auto total_sz = src.GetSize();
|
||||
CacheServer &cs = CacheServer::GetInstance();
|
||||
bool write_to_disk_directly = UseArena() ? false : (total_sz + cs.GetMemoryUsage()) > cs.GetAvailableSystemMemory();
|
||||
Status rc = cp_->Insert(*row_id_generated, {src}, write_to_disk_directly);
|
||||
if (rc == Status(StatusCode::kDuplicateKey)) {
|
||||
MS_LOG(DEBUG) << "Ignoring duplicate key.";
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(rc);
|
||||
}
|
||||
// All good, then update the memory usage local and global (if not using arena)
|
||||
if (write_to_disk_directly) {
|
||||
cur_disk_usage_ += total_sz;
|
||||
} else {
|
||||
cur_mem_usage_ += total_sz;
|
||||
if (!UseArena()) {
|
||||
cs.UpdateMemoryUsage(total_sz, CacheServer::MemUsageOp::kAllocate);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
} catch (const std::exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &out, const CacheService &cs) {
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nCache memory size: " << cs.cache_mem_sz_;
|
||||
|
@ -164,34 +221,29 @@ std::ostream &operator<<(std::ostream &out, const CacheService &cs) {
|
|||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
Path CacheService::GetSpillPath() const { return cp_->GetSpillPath(); }
|
||||
Status CacheService::Purge() {
|
||||
// First we must lock exclusively. No one else can cache/restore anything.
|
||||
UniqueLock rw(&rw_lock_);
|
||||
RETURN_IF_NOT_OK(cp_->ServiceStop());
|
||||
auto new_map = std::make_shared<row_map>();
|
||||
map_.reset();
|
||||
map_ = std::move(new_map);
|
||||
next_id_ = 0;
|
||||
RETURN_IF_NOT_OK(cp_->ServiceStart());
|
||||
|
||||
Status CacheService::FindKeysMiss(std::vector<row_id_type> *out) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
std::unique_lock<std::mutex> lock(get_key_miss_mux_);
|
||||
if (key_miss_results_ == nullptr) {
|
||||
// Just do it once.
|
||||
key_miss_results_ = std::make_shared<std::vector<row_id_type>>();
|
||||
auto stat = cp_->GetStat(true);
|
||||
key_miss_results_->push_back(stat.min_key);
|
||||
key_miss_results_->push_back(stat.max_key);
|
||||
key_miss_results_->insert(key_miss_results_->end(), stat.gap.begin(), stat.gap.end());
|
||||
}
|
||||
out->insert(out->end(), key_miss_results_->begin(), key_miss_results_->end());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheService::GetStat(CacheService::ServiceStat *out) {
|
||||
SharedLock rw(&rw_lock_);
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
if (st_ == State::kNone || st_ == State::kFetchPhase) {
|
||||
out->stat_ = cp_->GetStat();
|
||||
out->state_ = static_cast<ServiceStat::state_type>(st_);
|
||||
auto it = map_->begin();
|
||||
if (it != map_->end()) {
|
||||
out->min_ = it.key();
|
||||
auto end_it = map_->end();
|
||||
--end_it;
|
||||
out->max_ = end_it.key();
|
||||
}
|
||||
} else {
|
||||
out->state_ = static_cast<ServiceStat::state_type>(st_);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -204,19 +256,12 @@ Status CacheService::PreBatchFetch(const std::vector<row_id_type> &v, std::vecto
|
|||
*mem_sz = (num_elements + 1) * sizeof(int64_t);
|
||||
(*out).reserve(num_elements);
|
||||
for (auto row_id : v) {
|
||||
auto r = map_->Search(row_id);
|
||||
if (r.second) {
|
||||
auto &it = r.first;
|
||||
CachePool::key_type key = it.value();
|
||||
auto sz = cp_->GetSize(key);
|
||||
if (sz == 0) {
|
||||
std::string errMsg = "Key not found: ";
|
||||
errMsg += std::to_string(key);
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
(*out).emplace_back(key, sz);
|
||||
auto sz = cp_->GetSize(row_id);
|
||||
if (sz > 0) {
|
||||
(*out).emplace_back(row_id, sz);
|
||||
(*mem_sz) += sz;
|
||||
} else {
|
||||
// key not found
|
||||
(*out).emplace_back(-1, 0);
|
||||
}
|
||||
}
|
||||
|
@ -252,27 +297,19 @@ Status CacheService::BatchFetch(const std::vector<row_id_type> &v, const std::ve
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheService::CacheSchema(const void *buf, int64_t len) {
|
||||
SharedLock rw(&rw_lock_);
|
||||
if (st_ == State::kFetchPhase) {
|
||||
// For this kind of cache service, once we are done with the build phase into fetch phase, we can't
|
||||
// allow other to cache more rows.
|
||||
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
|
||||
}
|
||||
// This is a special request and we need to remember where we store it.
|
||||
UniqueLock rw(&rw_lock_);
|
||||
// In case we are calling the same function from multiple threads, only
|
||||
// the first one is considered. Rest is ignored.
|
||||
CachePool::key_type cur_key = schema_key_;
|
||||
CachePool::key_type key;
|
||||
if (cur_key < 0) {
|
||||
RETURN_IF_NOT_OK(cp_->Insert({ReadableSlice(buf, len)}, &key));
|
||||
auto result = std::atomic_compare_exchange_strong(&schema_key_, &cur_key, key);
|
||||
MS_LOG(DEBUG) << "Caching Schema. Result = " << result;
|
||||
if (schema_.empty()) {
|
||||
schema_.assign(static_cast<const char *>(buf), len);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Caching Schema already done";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheService::FetchSchema(std::string *out) const {
|
||||
SharedLock rw(&rw_lock_);
|
||||
if (st_ == State::kBuildPhase) {
|
||||
|
@ -283,32 +320,44 @@ Status CacheService::FetchSchema(std::string *out) const {
|
|||
// We are going to use std::string to allocate and hold the result which will be eventually
|
||||
// 'moved' to the protobuf message (which underneath is also a std::string) for the purpose
|
||||
// to minimize memory copy.
|
||||
std::string mem;
|
||||
if (schema_key_ >= 0) {
|
||||
auto len = cp_->GetSize(schema_key_);
|
||||
try {
|
||||
mem.resize(len);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= len, "Programming error");
|
||||
} catch (const std::bad_alloc &e) {
|
||||
return Status(StatusCode::kOutOfMemory);
|
||||
}
|
||||
auto slice = WritableSlice(mem.data(), len);
|
||||
RETURN_IF_NOT_OK(cp_->Read(schema_key_, &slice));
|
||||
std::string mem(schema_);
|
||||
if (!mem.empty()) {
|
||||
*out = std::move(mem);
|
||||
} else {
|
||||
return Status(StatusCode::kFileNotExist, __LINE__, __FILE__, "No schema has been cached");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheService::BuildPhaseDone() {
|
||||
if (HasBuildPhase()) {
|
||||
// Exclusive lock to switch phase
|
||||
UniqueLock rw(&rw_lock_);
|
||||
st_ = State::kFetchPhase;
|
||||
cp_->SetLocking(false);
|
||||
return Status::OK();
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase");
|
||||
}
|
||||
}
|
||||
|
||||
Status CacheService::ToggleWriteMode(bool on_off) {
|
||||
UniqueLock rw(&rw_lock_);
|
||||
if (HasBuildPhase()) {
|
||||
RETURN_STATUS_UNEXPECTED("Not applicable to non-mappable dataset");
|
||||
} else {
|
||||
// If we stop accepting write request, we turn off locking for the
|
||||
// underlying B+ tree. All future write request we will return kOutOfMemory.
|
||||
if (st_ == State::kNone && !on_off) {
|
||||
st_ = State::kNoLocking;
|
||||
cp_->SetLocking(on_off);
|
||||
MS_LOG(WARNING) << "Locking mode is switched off.";
|
||||
} else if (st_ == State::kNoLocking && on_off) {
|
||||
st_ = State::kNone;
|
||||
cp_->SetLocking(on_off);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
@ -44,9 +45,8 @@ using key_size_pair = std::pair<CachePool::key_type, size_t>;
|
|||
class CacheService : public Service {
|
||||
public:
|
||||
friend class CacheServer;
|
||||
using row_map = BPlusTree<row_id_type, CachePool::key_type>;
|
||||
|
||||
enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase };
|
||||
enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase, kNoLocking };
|
||||
|
||||
/// \brief Constructor
|
||||
/// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited
|
||||
|
@ -97,11 +97,9 @@ class CacheService : public Service {
|
|||
class ServiceStat {
|
||||
public:
|
||||
using state_type = std::underlying_type<State>::type;
|
||||
ServiceStat() : min_(0), max_(0), state_(0) {}
|
||||
ServiceStat() : state_(0) {}
|
||||
~ServiceStat() = default;
|
||||
CachePool::CacheStat stat_{};
|
||||
row_id_type min_;
|
||||
row_id_type max_;
|
||||
state_type state_;
|
||||
};
|
||||
/// \brief Statistics for the current service
|
||||
|
@ -117,9 +115,9 @@ class CacheService : public Service {
|
|||
/// \param out A contiguous memory that contains the serialized form of schema.
|
||||
/// \return Status object
|
||||
Status FetchSchema(std::string *out) const;
|
||||
/// \brief Purge the content of a cache
|
||||
/// \brief Return a set of keys that are definitely cache miss
|
||||
/// \return Status object
|
||||
Status Purge();
|
||||
Status FindKeysMiss(std::vector<row_id_type> *out);
|
||||
/// \brief Overload the << operator to print a cache service
|
||||
/// \param out std::ostream
|
||||
/// \param cs A cache service
|
||||
|
@ -136,19 +134,33 @@ class CacheService : public Service {
|
|||
/// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call.
|
||||
/// \return Status object
|
||||
Status BuildPhaseDone();
|
||||
/// \brief Find out the current memory usage
|
||||
int64_t GetMemoryUsage() { return cur_mem_usage_; }
|
||||
/// \brief Find out the current disk usage
|
||||
int64_t GetDiskUsage() { return cur_disk_usage_; }
|
||||
/// \brief For kToggleWriteMode request
|
||||
Status ToggleWriteMode(bool on_off);
|
||||
|
||||
private:
|
||||
mutable RWLock rw_lock_;
|
||||
std::string root_;
|
||||
uint64_t cache_mem_sz_;
|
||||
std::shared_ptr<CachePool> cp_;
|
||||
std::shared_ptr<row_map> map_;
|
||||
std::atomic<row_id_type> next_id_;
|
||||
bool generate_id_;
|
||||
std::atomic<CachePool::key_type> schema_key_;
|
||||
std::string cookie_;
|
||||
State st_;
|
||||
|
||||
std::string schema_;
|
||||
// If we use an Arena, cur_disk_usage is always 0 as we don't know how CachePool manages it.
|
||||
// Otherwise we track how much is in memory and how much is on disk (if root_ is not empty).
|
||||
// We use them to control when we should stop caching in memory in the case when there is no
|
||||
// Arena.
|
||||
std::atomic<int64_t> cur_mem_usage_;
|
||||
std::atomic<int64_t> cur_disk_usage_;
|
||||
// We also cache the result from calling FindKeysMiss because it is expensive. Besides user make
|
||||
// this request after we hit memory full or disk full. So the result is unlikely to change.
|
||||
std::mutex get_key_miss_mux_;
|
||||
std::shared_ptr<std::vector<row_id_type>> key_miss_results_;
|
||||
/// \brief Private function to generate a row id
|
||||
/// \return Row id assigned.
|
||||
row_id_type GetNextRowId() { return next_id_.fetch_add(1); }
|
||||
|
|
|
@ -92,3 +92,13 @@ table CreateCacheReplyMsg {
|
|||
connection_id:int64;
|
||||
cookie:string;
|
||||
}
|
||||
|
||||
table ListSessionMsg {
|
||||
session_id:uint32;
|
||||
connection_id:uint64;
|
||||
stats:ServiceStatMsg;
|
||||
}
|
||||
|
||||
table ListSessionsMsg {
|
||||
sessions:[ListSessionMsg];
|
||||
}
|
||||
|
|
|
@ -53,22 +53,35 @@ CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t row
|
|||
num_cache_miss_(0),
|
||||
cache_client_(std::move(cache_client)),
|
||||
rows_per_buffer_(rows_per_buf),
|
||||
// We can cause deadlock if this internal Connector size is too small.
|
||||
keys_miss_(num_workers_, 1, connector_capacity_),
|
||||
prefetch_size_(cache_client_->getPrefetchSize()) {
|
||||
prefetch_size_(rows_per_buffer_),
|
||||
num_prefetchers_(num_workers_) {
|
||||
// Adjust the prefetch size based on the number of workers.
|
||||
auto prefetch_sz_per_thread = cache_client_->GetPrefetchSize() / num_prefetchers_;
|
||||
if (prefetch_size_ < prefetch_sz_per_thread) {
|
||||
prefetch_size_ = prefetch_sz_per_thread;
|
||||
MS_LOG(DEBUG) << "Per worker prefetch size : " << prefetch_size_;
|
||||
}
|
||||
io_block_queues_.Init(num_workers, op_connector_size);
|
||||
prefetch_queues_.Init(num_workers, op_connector_size);
|
||||
sampler_queue_ = std::make_unique<Queue<std::shared_ptr<Tensor>>>(op_connector_size);
|
||||
prefetch_queues_.Init(num_prefetchers_, op_connector_size);
|
||||
// We can cause deadlock if this internal Connector size is too small.
|
||||
keys_miss_ = std::make_unique<Connector<std::vector<row_id_type>>>(num_prefetchers_, 1, connector_capacity_);
|
||||
}
|
||||
// Common function to fetch samples from the sampler and send them using the io_block_queues to
|
||||
// the parallel workers
|
||||
Status CacheBase::FetchSamplesToWorkers() {
|
||||
int64_t buf_cnt = 0;
|
||||
int64_t wait_cnt = 0;
|
||||
int64_t prefetch_cnt = 0;
|
||||
// Kick off several threads which will prefetch prefetch_size_ rows in advance. The rows_per_buffers_
|
||||
// is too small (1 by default) and won't help performance.
|
||||
RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Dispatcher", std::bind(&CacheBase::Dispatcher, this)));
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1)));
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_prefetchers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1)));
|
||||
auto send_to_que = [](QueueList<std::unique_ptr<IOBlock>> &qList, int32_t worker_id,
|
||||
std::vector<row_id_type> &keys) -> Status {
|
||||
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
|
||||
RETURN_IF_NOT_OK(qList[worker_id]->Add(std::move(blk)));
|
||||
return Status::OK();
|
||||
};
|
||||
// Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them
|
||||
// to the WorkerEntry.
|
||||
do {
|
||||
|
@ -82,33 +95,54 @@ Status CacheBase::FetchSamplesToWorkers() {
|
|||
++wait_cnt;
|
||||
std::vector<row_id_type> keys;
|
||||
keys.reserve(rows_per_buffer_);
|
||||
std::vector<row_id_type> prefetch_keys;
|
||||
prefetch_keys.reserve(prefetch_size_);
|
||||
std::unique_ptr<DataBuffer> sampler_buffer;
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
while (!sampler_buffer->eoe()) {
|
||||
TensorRow sample_row;
|
||||
RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row));
|
||||
std::shared_ptr<Tensor> sample_ids = sample_row[0];
|
||||
// Send the sampler tensor to other thread for prefetching. We are using shared pointer so it
|
||||
// won't go out scope until it is really not in use.
|
||||
RETURN_IF_NOT_OK(sampler_queue_->Add(sample_ids));
|
||||
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) {
|
||||
keys.push_back(*itr);
|
||||
++row_cnt_;
|
||||
if (row_cnt_ % rows_per_buffer_ == 0) {
|
||||
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
|
||||
RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
|
||||
prefetch_keys.push_back(*itr);
|
||||
// Batch enough rows for performance reason.
|
||||
if (row_cnt_ % prefetch_size_ == 0) {
|
||||
RETURN_IF_NOT_OK(send_to_que(prefetch_queues_, prefetch_cnt++ % num_prefetchers_, prefetch_keys));
|
||||
// Now we tell the WorkerEntry to wait for them to come back. If prefetch_size_ is a multiple
|
||||
// of rows_per_buffer_, the keys vector will always be empty. But it can be partially filled.
|
||||
// The only requirement we set up is rows_per_buffer_ is less than or equal to prefetch_size_.
|
||||
for (auto row_id : prefetch_keys) {
|
||||
keys.push_back(row_id);
|
||||
if (keys.size() == rows_per_buffer_) {
|
||||
RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys));
|
||||
keys.clear();
|
||||
}
|
||||
}
|
||||
prefetch_keys.clear();
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
|
||||
}
|
||||
// Deal with any partial keys left.
|
||||
if (!prefetch_keys.empty()) {
|
||||
RETURN_IF_NOT_OK(send_to_que(prefetch_queues_, prefetch_cnt++ % num_prefetchers_, prefetch_keys));
|
||||
for (auto row_id : prefetch_keys) {
|
||||
keys.push_back(row_id);
|
||||
if (keys.size() == rows_per_buffer_) {
|
||||
RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys));
|
||||
keys.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!keys.empty()) {
|
||||
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
|
||||
RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
|
||||
RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys));
|
||||
}
|
||||
// send the eoe
|
||||
RETURN_IF_NOT_OK(
|
||||
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
|
||||
RETURN_IF_NOT_OK(prefetch_queues_[(prefetch_cnt++) % num_prefetchers_]->Add(
|
||||
std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
|
||||
// If repeat but the not last repeat, wait for reset.
|
||||
if (!IsLastIteration()) {
|
||||
MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << wait_cnt << " Buffer sent " << buf_cnt;
|
||||
|
@ -123,8 +157,6 @@ Status CacheBase::FetchSamplesToWorkers() {
|
|||
RETURN_IF_NOT_OK(
|
||||
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof)));
|
||||
// Shutdown threads
|
||||
std::shared_ptr<Tensor> empty;
|
||||
RETURN_IF_NOT_OK(sampler_queue_->Add(std::move(empty)));
|
||||
for (int32_t i = 0; i < num_workers_; i++) {
|
||||
RETURN_IF_NOT_OK(
|
||||
io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
|
||||
|
@ -145,13 +177,6 @@ Status CacheBase::FetchFromCache(int32_t worker_id) {
|
|||
if (blk->eof()) {
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)));
|
||||
} else if (blk->eoe()) {
|
||||
if (AllowCacheMiss()) {
|
||||
// This code path is for CacheLookupOp acting as a sampler. If we get a eoe from
|
||||
// a sampler, send a eoe to physical leaf op as well.
|
||||
std::vector<row_id_type> eoe;
|
||||
eoe.push_back(eoe_row_id);
|
||||
RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, eoe));
|
||||
}
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)));
|
||||
} else {
|
||||
std::vector<int64_t> keys;
|
||||
|
@ -162,22 +187,21 @@ Status CacheBase::FetchFromCache(int32_t worker_id) {
|
|||
}
|
||||
std::unique_ptr<DataBuffer> db = std::make_unique<DataBuffer>(buffer_id, DataBuffer::kDeBFlagNone);
|
||||
std::unique_ptr<TensorQTable> que = std::make_unique<TensorQTable>();
|
||||
std::vector<row_id_type> cache_miss;
|
||||
cache_miss.reserve(keys.size());
|
||||
for (auto row_id : keys) {
|
||||
TensorRow row;
|
||||
// Block until the row shows up in the pool.
|
||||
RETURN_IF_NOT_OK(prefetch_.PopFront(row_id, &row));
|
||||
RETURN_IF_NOT_OK(GetPrefetchRow(row_id, &row));
|
||||
if (row.empty()) {
|
||||
cache_miss.push_back(row_id);
|
||||
if (AllowCacheMiss()) {
|
||||
++num_cache_miss_;
|
||||
} else {
|
||||
std::string errMsg = "Row id " + std::to_string(row_id) + " not found.";
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
}
|
||||
que->push_back(std::move(row));
|
||||
}
|
||||
db->set_tensor_table(std::move(que));
|
||||
if (AllowCacheMiss()) {
|
||||
// Because of the way connector works, we push unconditionally even cache_miss can be empty.
|
||||
RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, cache_miss));
|
||||
}
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db)));
|
||||
buffer_id += num_workers_;
|
||||
}
|
||||
|
@ -189,7 +213,6 @@ Status CacheBase::RegisterResources() {
|
|||
RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(sampler_queue_->Register(tree_->AllTasks()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -208,73 +231,97 @@ Status CacheBase::UpdateColumnMapFromCache() {
|
|||
return rc;
|
||||
}
|
||||
|
||||
Status CacheBase::Dispatcher() {
|
||||
TaskManager::FindMe()->Post();
|
||||
int64_t buf_cnt = 0;
|
||||
int64_t num_row = 0;
|
||||
std::vector<row_id_type> keys;
|
||||
keys.reserve(prefetch_size_);
|
||||
do {
|
||||
keys.clear();
|
||||
std::shared_ptr<Tensor> sample_ids;
|
||||
RETURN_IF_NOT_OK(sampler_queue_->PopFront(&sample_ids));
|
||||
if (sample_ids == nullptr) {
|
||||
// A null shared pointer signal times to quit.
|
||||
// Also signal all prefetchers to quit.
|
||||
for (int32_t i = 0; i < num_workers_; i++) {
|
||||
RETURN_IF_NOT_OK(
|
||||
prefetch_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
|
||||
}
|
||||
break;
|
||||
}
|
||||
// Now we distribute the sampler ids to each prefetcher according to the prefetch size.
|
||||
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) {
|
||||
keys.push_back(*itr);
|
||||
++num_row;
|
||||
if (num_row % prefetch_size_ == 0) {
|
||||
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
|
||||
RETURN_IF_NOT_OK(prefetch_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
|
||||
keys.clear();
|
||||
}
|
||||
}
|
||||
// Send the remaining sample id
|
||||
if (!keys.empty()) {
|
||||
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
|
||||
RETURN_IF_NOT_OK(prefetch_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
|
||||
}
|
||||
} while (true);
|
||||
Status CacheBase::GetPrefetchRow(row_id_type row_id, TensorRow *out) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(row_id >= 0, "Expect positive row id");
|
||||
RETURN_IF_NOT_OK(prefetch_.PopFront(row_id, out));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheBase::PrefetchRows(const std::vector<row_id_type> &keys, std::vector<row_id_type> *cache_miss) {
|
||||
RETURN_UNEXPECTED_IF_NULL(cache_miss);
|
||||
std::vector<row_id_type> prefetch_keys;
|
||||
prefetch_keys.reserve(keys.size());
|
||||
|
||||
// Filter out all those keys that unlikely we will find at the server
|
||||
for (auto row_id : keys) {
|
||||
if (cache_client_->KeyIsCacheMiss(row_id)) {
|
||||
// Just put an empty row in the cache.
|
||||
TensorRow row;
|
||||
row.setId(row_id);
|
||||
RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
|
||||
cache_miss->push_back(row_id);
|
||||
} else {
|
||||
prefetch_keys.push_back(row_id);
|
||||
}
|
||||
}
|
||||
// Early exit if nothing to fetch
|
||||
if (prefetch_keys.empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
// Get the rows from the server
|
||||
TensorTable ttbl;
|
||||
Status rc = cache_client_->GetRows(prefetch_keys, &ttbl);
|
||||
if (rc.IsOk()) {
|
||||
auto row_it = ttbl.begin();
|
||||
for (auto row_id : prefetch_keys) {
|
||||
auto &row = *row_it;
|
||||
if (row.empty()) {
|
||||
cache_miss->push_back(row_id);
|
||||
}
|
||||
// Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row
|
||||
RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
|
||||
++row_it;
|
||||
}
|
||||
} else {
|
||||
// In case any thread is waiting for the rows to come back and blocked on a semaphore,
|
||||
// we will put an empty row in the local cache.
|
||||
for (auto row_id : prefetch_keys) {
|
||||
TensorRow row;
|
||||
row.setId(row_id);
|
||||
RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
|
||||
cache_miss->push_back(row_id);
|
||||
}
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
Status CacheBase::Prefetcher(int32_t worker_id) {
|
||||
TaskManager::FindMe()->Post();
|
||||
std::vector<row_id_type> prefetch_keys;
|
||||
prefetch_keys.reserve(prefetch_size_);
|
||||
std::vector<row_id_type> cache_miss;
|
||||
cache_miss.reserve(prefetch_size_);
|
||||
do {
|
||||
prefetch_keys.clear();
|
||||
cache_miss.clear();
|
||||
std::unique_ptr<IOBlock> blk;
|
||||
RETURN_IF_NOT_OK(prefetch_queues_[worker_id]->PopFront(&blk));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!blk->eof(), "Expect eoe or a regular io block");
|
||||
if (!blk->eoe()) {
|
||||
RETURN_IF_NOT_OK(blk->GetKeys(&prefetch_keys));
|
||||
if (prefetch_keys.empty()) {
|
||||
// Empty keys mean time to quit.
|
||||
break;
|
||||
Status rc;
|
||||
const int32_t max_retries = 5;
|
||||
int32_t retry_count = 0;
|
||||
do {
|
||||
rc = PrefetchRows(prefetch_keys, &cache_miss);
|
||||
if (rc.IsNetWorkError() && retry_count < max_retries) {
|
||||
// If we get some network error, we will attempt some retries
|
||||
retry_count++;
|
||||
} else if (rc.IsError()) {
|
||||
return rc;
|
||||
}
|
||||
TensorTable ttbl;
|
||||
RETURN_IF_NOT_OK(cache_client_->GetRows(prefetch_keys, &ttbl));
|
||||
auto row_it = ttbl.begin();
|
||||
for (auto row_id : prefetch_keys) {
|
||||
auto &row = *row_it;
|
||||
if (row.empty()) {
|
||||
if (AllowCacheMiss()) {
|
||||
++num_cache_miss_;
|
||||
} while (rc.IsNetWorkError());
|
||||
} else {
|
||||
std::string errMsg = "Row id " + std::to_string(row_id) + " not found.";
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
if (AllowCacheMiss()) {
|
||||
// This code path is for CacheLookupOp acting as a sampler. If we get a eoe from
|
||||
// a sampler, send a eoe to physical leaf op as well.
|
||||
cache_miss.push_back(eoe_row_id);
|
||||
}
|
||||
}
|
||||
// Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row
|
||||
RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
|
||||
++row_it;
|
||||
if (AllowCacheMiss()) {
|
||||
// Because of the way connector works, we push unconditionally even cache_miss can be empty.
|
||||
RETURN_IF_NOT_OK(keys_miss_->Push(worker_id, cache_miss));
|
||||
}
|
||||
} while (true);
|
||||
return Status::OK();
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/engine/connector.h"
|
||||
#include "minddata/dataset/engine/cache/cache_client.h"
|
||||
#include "minddata/dataset/engine/cache/cache_service.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
|
@ -90,8 +91,7 @@ class CacheBase : public ParallelOp {
|
|||
std::shared_ptr<CacheClient> cache_client_;
|
||||
WaitPost epoch_sync_;
|
||||
int32_t rows_per_buffer_;
|
||||
Connector<std::vector<row_id_type>> keys_miss_;
|
||||
QueueMap<row_id_type, TensorRow> prefetch_;
|
||||
std::unique_ptr<Connector<std::vector<row_id_type>>> keys_miss_;
|
||||
|
||||
/// \brief Common function to register resources for interrupt
|
||||
/// \note Derived should override this function for extra resources to be registered
|
||||
|
@ -111,13 +111,16 @@ class CacheBase : public ParallelOp {
|
|||
constexpr static int32_t connector_capacity_ = 1024;
|
||||
int32_t prefetch_size_;
|
||||
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
|
||||
int32_t num_prefetchers_;
|
||||
QueueList<std::unique_ptr<IOBlock>> prefetch_queues_;
|
||||
std::unique_ptr<Queue<std::shared_ptr<Tensor>>> sampler_queue_;
|
||||
QueueMap<row_id_type, TensorRow> prefetch_;
|
||||
|
||||
Status Dispatcher();
|
||||
/// \brief Prefetcher. It prefetch the rows from cache server
|
||||
/// \return Status object.
|
||||
Status Prefetcher(int32_t worker_id);
|
||||
/// \brief Functions used by prefetcher and WorkerEntry
|
||||
Status PrefetchRows(const std::vector<row_id_type> &keys, std::vector<row_id_type> *cache_miss);
|
||||
Status GetPrefetchRow(row_id_type row_id, TensorRow *out);
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -87,10 +87,10 @@ Status CacheLookupOp::InitSampler() { return Sampler::InitSampler(); }
|
|||
void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); }
|
||||
Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
std::vector<row_id_type> cache_miss;
|
||||
RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss));
|
||||
RETURN_IF_NOT_OK(keys_miss_->Pop(0, &cache_miss));
|
||||
// Ignore the case we have no cache miss, we can't return empty samples.
|
||||
while (cache_miss.empty()) {
|
||||
RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss));
|
||||
RETURN_IF_NOT_OK(keys_miss_->Pop(0, &cache_miss));
|
||||
}
|
||||
// Special code for eoe
|
||||
if (cache_miss.at(0) == eoe_row_id) {
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/util/system_pool.h"
|
||||
#include "minddata/dataset/util/task_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -48,7 +49,8 @@ CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t
|
|||
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler)
|
||||
: ParallelOp(numWorkers, opConnectorSize, sampler),
|
||||
num_cleaners_(numCleaners),
|
||||
cache_client_(std::move(cache_client)) {}
|
||||
cache_client_(std::move(cache_client)),
|
||||
cache_missing_rows_(true) {}
|
||||
|
||||
Status CacheMergeOp::operator()() {
|
||||
// A queue of row id to let cleaner send cache miss rows to the cache server
|
||||
|
@ -129,6 +131,7 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
|
|||
std::string errMsg = "Expect positive row id: " + std::to_string(row_id);
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
if (cache_missing_rows_) {
|
||||
// Technically number of this row shows up in the cache miss stream is equal to the number
|
||||
// of P() call. However the cleaner wants it too. So we need an extra copy.
|
||||
TensorRowCacheRequest *rq;
|
||||
|
@ -142,6 +145,7 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
|
|||
RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id));
|
||||
}
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(cache_miss_.Add(row_id, std::move(row)));
|
||||
}
|
||||
}
|
||||
|
@ -168,15 +172,20 @@ Status CacheMergeOp::Cleaner() {
|
|||
Status rc = rq->CheckCacheResult();
|
||||
if (rc.IsError()) {
|
||||
// If interrupt, time to quit.
|
||||
if (rc.get_code() == StatusCode::kInterrupted) {
|
||||
if (rc.IsInterrupted()) {
|
||||
return Status::OK();
|
||||
}
|
||||
} else if (rc.IsOutofMemory() || rc.IsNoSpace()) {
|
||||
// The server is hitting some limit and we will turn off caching from now on.
|
||||
cache_missing_rows_ = false;
|
||||
cache_client_->ServerRunningOutOfResources();
|
||||
} else {
|
||||
MS_LOG(INFO) << "Cache row not successful: " << rc.ToString();
|
||||
// Bad rc should not bring down the pipeline. We will simply continue and
|
||||
// change the state back to empty. We don't need a CAS from CLEAN back to EMPTY.
|
||||
rq->SetState(TensorRowCacheRequest::State::kEmpty);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -253,7 +262,7 @@ Status CacheMergeOp::Accept(NodePass *p, bool *modified) {
|
|||
Status CacheMergeOp::EoeReceived(int32_t worker_id) {
|
||||
// If we are in a repeat path, send the eoe up.
|
||||
// Otherwise ignore it.
|
||||
if (op_total_repeats_ > 1) {
|
||||
if (op_total_repeats_ != 1) {
|
||||
return DatasetOp::EoeReceived(worker_id);
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -281,7 +290,7 @@ Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowCacheReque
|
|||
*out = it->second.GetMutablePointer();
|
||||
} else {
|
||||
// We will create a new one.
|
||||
auto alloc = Services::GetAllocator<TensorRowCacheRequest>();
|
||||
auto alloc = SystemPool::GetAllocator<TensorRowCacheRequest>();
|
||||
auto r = io_request_.emplace(row_id, MemGuard<TensorRowCacheRequest, Allocator<TensorRowCacheRequest>>(alloc));
|
||||
if (r.second) {
|
||||
auto &mem = r.first->second;
|
||||
|
|
|
@ -202,6 +202,7 @@ class CacheMergeOp : public ParallelOp {
|
|||
std::unique_ptr<Queue<row_id_type>> io_que_;
|
||||
std::shared_ptr<CacheClient> cache_client_;
|
||||
int32_t num_cleaners_;
|
||||
std::atomic<bool> cache_missing_rows_;
|
||||
|
||||
/// \brief Locate the cache request from the io_request_ map
|
||||
/// \param row_id
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
|
@ -64,7 +65,7 @@ Status CacheOp::Builder::Build(std::shared_ptr<CacheOp> *ptr) {
|
|||
// Constructor of CacheOp
|
||||
CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
|
||||
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
|
||||
: CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler),
|
||||
: CacheBase(num_workers, op_connector_size, rows_per_buf, std::move(cache_client), std::move(sampler)),
|
||||
num_guys_in_(0),
|
||||
phase_(Phase::kBuildPhase) {}
|
||||
|
||||
|
@ -174,7 +175,7 @@ Status CacheOp::WorkerEntry(int32_t worker_id) {
|
|||
Status CacheOp::RegisterResources() {
|
||||
RETURN_IF_NOT_OK(CacheBase::RegisterResources());
|
||||
RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(keys_miss_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(keys_miss_->Register(tree_->AllTasks()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/engine/data_buffer.h"
|
||||
#include "minddata/dataset/engine/datasetops/concat_op.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/engine/db_connector.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
|
||||
|
@ -188,5 +189,11 @@ Status ConcatOp::ComputeColMap() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor pre-accept method for NodePass
|
||||
Status ConcatOp::PreAccept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->PreRunOnNode(shared_from_base<ConcatOp>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -105,6 +105,12 @@ class ConcatOp : public PipelineOp {
|
|||
// @return - Status
|
||||
Status ComputeColMap() override;
|
||||
|
||||
/// \brief Base-class override for NodePass pre-visit acceptor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status PreAccept(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
Status Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf);
|
||||
|
||||
|
|
|
@ -243,6 +243,7 @@ void DatasetOp::Print(std::ostream &out, bool show_all) const {
|
|||
out << "\nConnector queue size : " << oc_queue_size_ << "\nTotal repeats : " << op_total_repeats_
|
||||
<< "\nNumber repeats per epoch : " << op_num_repeats_per_epoch_;
|
||||
if (sampler_) {
|
||||
out << "\nSampler:\n";
|
||||
sampler_->Print(out, show_all);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -268,5 +268,11 @@ Status FilterOp::Accept(NodePass *p, bool *modified) {
|
|||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(shared_from_base<FilterOp>(), modified);
|
||||
}
|
||||
|
||||
// Visitor pre-accept method for NodePass
|
||||
Status FilterOp::PreAccept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->PreRunOnNode(shared_from_base<FilterOp>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -121,6 +121,12 @@ class FilterOp : public ParallelOp {
|
|||
// @param show_all A bool to control if you want to show all info or just a summary.
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
/// \brief Base-class override for NodePass pre-visit acceptor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status PreAccept(NodePass *p, bool *modified) override;
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
|
|
|
@ -458,6 +458,12 @@ Status MapOp::Accept(NodePass *p, bool *modified) {
|
|||
return p->RunOnNode(shared_from_base<MapOp>(), modified);
|
||||
}
|
||||
|
||||
// Visitor pre-accept method for NodePass
|
||||
Status MapOp::PreAccept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->PreRunOnNode(shared_from_base<MapOp>(), modified);
|
||||
}
|
||||
|
||||
Status MapOp::WaitForWorkers() {
|
||||
// reset num_paused workers to 0
|
||||
num_workers_paused_ = 0;
|
||||
|
|
|
@ -177,10 +177,16 @@ class MapOp : public ParallelOp {
|
|||
// @return the number of threads consuming data from previous op's output Connector.
|
||||
int32_t num_consumers() const override;
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
// @return - Status of the node visit.
|
||||
/// \brief Base-class override for NodePass pre-visit acceptor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status PreAccept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for NodePass visitor acceptor.
|
||||
/// \param[in] p Pointer to the NodePass to be accepted.
|
||||
/// \param[out] modified Whether this node visit modified the pipeline.
|
||||
/// \return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
// Op name getter
|
||||
|
|
|
@ -52,9 +52,9 @@ Status ParallelOp::CreateWorkerConnector(int32_t worker_connector_size) {
|
|||
void ParallelOp::Print(std::ostream &out, bool show_all) const {
|
||||
// Summary 1-liner print
|
||||
if (!show_all) {
|
||||
out << " [workers: " << num_workers_ << "]";
|
||||
// Call super class printer
|
||||
DatasetOp::Print(out, show_all);
|
||||
out << " [workers: " << num_workers_ << "]";
|
||||
} else {
|
||||
// Detailed print
|
||||
DatasetOp::Print(out, show_all);
|
||||
|
|
|
@ -27,14 +27,14 @@ PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampl
|
|||
void PipelineOp::Print(std::ostream &out, bool show_all) const {
|
||||
// Summary 1-liner print
|
||||
if (!show_all) {
|
||||
// Call super class printer
|
||||
DatasetOp::Print(out, show_all);
|
||||
out << " [workers: ";
|
||||
if (this->inlined()) {
|
||||
out << "0 (inlined)]";
|
||||
} else {
|
||||
out << "1]"; // Pipeline ops only have 1 worker
|
||||
}
|
||||
// Call super class printer
|
||||
DatasetOp::Print(out, show_all);
|
||||
} else {
|
||||
// Detailed print
|
||||
DatasetOp::Print(out, show_all);
|
||||
|
|
|
@ -235,6 +235,12 @@ Status ZipOp::EoeReceived(int32_t) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor pre-accept method for NodePass
|
||||
Status ZipOp::PreAccept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->PreRunOnNode(shared_from_base<ZipOp>(), modified);
|
||||
}
|
||||
|
||||
// Visitor accept method for NodePass
|
||||
Status ZipOp::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
|
|
|
@ -104,10 +104,16 @@ class ZipOp : public PipelineOp {
|
|||
// @return Status - The error code return
|
||||
Status operator()() override;
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
// @param modified - Whether this node visit modified the pipeline.
|
||||
// @return - Status of the node visit.
|
||||
/// \brief Base-class override for NodePass pre-visit acceptor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status PreAccept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for NodePass visitor acceptor.
|
||||
/// \param[in] p Pointer to the NodePass to be accepted.
|
||||
/// \param[out] modified Whether this node visit modified the pipeline.
|
||||
/// \return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
// Op name getter
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
|
||||
#include "minddata/dataset/engine/opt/post/repeat_pass.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/opt/pre/cache_error_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h"
|
||||
#include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
|
||||
#include "minddata/dataset/engine/perf/profiling.h"
|
||||
|
@ -235,6 +236,7 @@ Status ExecutionTree::PrepareTreePreAction() {
|
|||
std::vector<std::unique_ptr<Pass>> pre_actions;
|
||||
// Construct pre actions
|
||||
MS_LOG(INFO) << "Running pre pass loops.";
|
||||
pre_actions.push_back(std::make_unique<CacheErrorPass>());
|
||||
pre_actions.push_back(std::make_unique<EpochInjectionPass>());
|
||||
pre_actions.push_back(std::make_unique<RemovalPass>());
|
||||
#ifndef ENABLE_ANDROID
|
||||
|
|
|
@ -3,6 +3,7 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE
|
|||
add_library(engine-opt OBJECT
|
||||
pass.cc
|
||||
post/repeat_pass.cc
|
||||
pre/cache_error_pass.cc
|
||||
pre/cache_transform_pass.cc
|
||||
pre/epoch_injection_pass.cc
|
||||
pre/removal_pass.cc
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/datasetops/concat_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/dataset_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
|
||||
|
@ -143,40 +144,6 @@ Status NodePass::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
|
|||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status NodePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
Status NodePass::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
#endif
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
|
@ -207,13 +174,6 @@ Status NodePass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) {
|
|||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status NodePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
#endif
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
|
@ -239,18 +199,6 @@ Status NodePass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
|||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status NodePass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
#endif
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
|
@ -261,18 +209,6 @@ Status NodePass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
|||
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status NodePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
#endif
|
||||
|
||||
Status NodePass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
|
@ -283,12 +219,88 @@ Status NodePass::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified
|
|||
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status NodePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
Status NodePass::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
|
||||
}
|
||||
#endif
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,18 +37,6 @@ class SkipOp;
|
|||
|
||||
class ShuffleOp;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
class MindRecordOp;
|
||||
|
||||
class TFReaderOp;
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
class FilterOp;
|
||||
|
||||
class GeneratorOp;
|
||||
#endif
|
||||
|
||||
class AlbumOp;
|
||||
|
||||
class RandomDataOp;
|
||||
|
@ -63,10 +51,6 @@ class DeviceQueueOp;
|
|||
|
||||
class ImageFolderOp;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
class CacheOp;
|
||||
#endif
|
||||
|
||||
class MnistOp;
|
||||
|
||||
class ManifestOp;
|
||||
|
@ -79,20 +63,32 @@ class CocoOp;
|
|||
|
||||
class CelebAOp;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
class CacheMergeOp;
|
||||
|
||||
class CacheLookupOp;
|
||||
#endif
|
||||
|
||||
class EpochCtrlOp;
|
||||
|
||||
class BuildVocabOp;
|
||||
|
||||
class ConcatOp;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
class MindRecordOp;
|
||||
|
||||
class TFReaderOp;
|
||||
|
||||
class CacheOp;
|
||||
|
||||
class CacheMergeOp;
|
||||
|
||||
class CacheLookupOp;
|
||||
|
||||
class BuildSentencePieceVocabOp;
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
class FilterOp;
|
||||
|
||||
class GeneratorOp;
|
||||
#endif
|
||||
|
||||
// The base class Pass is the basic unit of tree transformation.
|
||||
// The actual implementation of the passes will be derived from here.
|
||||
class Pass : public std::enable_shared_from_this<Pass> {
|
||||
|
@ -168,22 +164,6 @@ class NodePass : public Pass {
|
|||
|
||||
virtual Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified);
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified);
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
virtual Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified);
|
||||
#endif
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified);
|
||||
|
@ -194,10 +174,6 @@ class NodePass : public Pass {
|
|||
|
||||
virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified);
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
|
||||
#endif
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified);
|
||||
|
@ -210,32 +186,50 @@ class NodePass : public Pass {
|
|||
|
||||
virtual Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified);
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
virtual Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified);
|
||||
#endif
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified);
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
|
||||
#endif
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified);
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
|
||||
#endif
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified);
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified);
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
virtual Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified);
|
||||
#endif
|
||||
|
||||
private:
|
||||
// Helper function to perform DFS visit
|
||||
Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified);
|
||||
|
|
|
@ -225,13 +225,17 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) {
|
|||
|
||||
// Turns off the tracking for operations under merge op
|
||||
Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
|
||||
// If there was not any repeat in the merge cache miss leg, then the cache_lookup
|
||||
// would not have been consumed yet. In that case, we need to set its total repeats for it.
|
||||
if (cache_lookup_) {
|
||||
cache_lookup_->set_total_repeats(num_repeats_);
|
||||
cache_lookup_->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
||||
}
|
||||
// Setting the flag is needed since we didn't call the base class DatasetOp version
|
||||
if (is_repeated_) {
|
||||
// If there was not any repeat in the merge cache miss leg, then the cache_lookup
|
||||
// would not have been consumed yet. In that case, we need to assign it to the upper repeat eoe stack
|
||||
if (cache_lookup_) {
|
||||
cache_lookup_->set_total_repeats(num_repeats_);
|
||||
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
||||
AddToEOEOpStack(std::move(cache_lookup_));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/zip_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
|
||||
#include "minddata/dataset/engine/opt/pre/cache_error_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// Constructor
|
||||
CacheErrorPass::CacheErrorPass() : is_cached_(false) {}
|
||||
|
||||
// Identifies the subtree below this node as being cached
|
||||
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||
// Turn on the flag that we're under a merge op
|
||||
is_cached_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns an error if ZipOp exists under a cache
|
||||
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified) {
|
||||
if (is_cached_) {
|
||||
RETURN_STATUS_UNEXPECTED("ZipOp is currently not supported as a descendant operator under a cache.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns an error if MapOp with non-deterministic TensorOps exists under a cache
|
||||
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
|
||||
if (is_cached_) {
|
||||
auto tfuncs = node->TFuncs();
|
||||
for (size_t i = 0; i < tfuncs.size(); i++) {
|
||||
if (!tfuncs[i]->Deterministic()) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"MapOp with non-deterministic TensorOps is currently not supported as a descendant of cache.");
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns an error if ConcatOp exists under a cache
|
||||
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) {
|
||||
if (is_cached_) {
|
||||
RETURN_STATUS_UNEXPECTED("ConcatOp is currently not supported as a descendant operator under a cache.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
// Returns an error if FilterOp exists under a cache
|
||||
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) {
|
||||
if (is_cached_) {
|
||||
RETURN_STATUS_UNEXPECTED("FilterOp is currently not supported as a descendant operator under a cache.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,76 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_ERROR_PASS_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_ERROR_PASS_
|
||||
|
||||
#include <memory>
|
||||
#include <stack>
|
||||
#include <utility>
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
/// \class CacheErrorPass cache_error_pass.h
|
||||
/// \brief This is a NodePass who's job is to catch invalid tree configurations related to cache and generate failures.
|
||||
class CacheErrorPass : public NodePass {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
CacheErrorPass();
|
||||
|
||||
/// \brief Destructor
|
||||
~CacheErrorPass() = default;
|
||||
|
||||
/// \brief Identifies the subtree below this node as being cached
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Returns an error if ZipOp exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Returns an error if MapOp with non-deterministic TensorOps exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Returns an error if ConcatOp exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override;
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
/// \brief Returns an error if FilterOp exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override;
|
||||
#endif
|
||||
|
||||
private:
|
||||
bool is_cached_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_POST_CACHE_ERROR_PASS_
|
|
@ -155,50 +155,77 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ImageFolderOp> n
|
|||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
if (is_caching_) {
|
||||
RETURN_STATUS_UNEXPECTED("There is currently no support for AlbumOp under cache.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
if (is_caching_) {
|
||||
RETURN_STATUS_UNEXPECTED("There is currently no support for MnistOp under cache.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
if (is_caching_) {
|
||||
RETURN_STATUS_UNEXPECTED("There is currently no support for CifarOp under cache.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
if (is_caching_) {
|
||||
RETURN_STATUS_UNEXPECTED("There is currently no support for CocoOp under cache.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
if (is_caching_) {
|
||||
RETURN_STATUS_UNEXPECTED("There is currently no support for CelebAOp under cache.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
if (is_caching_) {
|
||||
RETURN_STATUS_UNEXPECTED("There is currently no support for MindRecordOp under cache.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
if (is_caching_) {
|
||||
RETURN_STATUS_UNEXPECTED("There is currently no support for GeneratorOp under cache.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
if (is_caching_) {
|
||||
RETURN_STATUS_UNEXPECTED("There is currently no support for ManifestOp under cache.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Perform leaf node cache transform identification
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
if (is_caching_) {
|
||||
RETURN_STATUS_UNEXPECTED("There is currently no support for VOCOp under cache.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
@ -40,13 +40,6 @@ Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildSe
|
|||
injection_point_ = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Temporary code to prevent the injection of epoch control when cache op is present
|
||||
// Remove this code in cache op phase 2
|
||||
Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||
injection_point_ = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
Status EpochInjectionPass::InjectionFinder::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) {
|
||||
|
|
|
@ -54,13 +54,6 @@ class EpochInjectionPass : public TreePass {
|
|||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Temporary code to prevent the injection of epoch control when cache op is present.
|
||||
/// Remove this code in cache op phase 2
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||
#endif
|
||||
|
||||
/// \brief Register the DeviceQueueOp for further action.
|
||||
|
|
|
@ -62,6 +62,7 @@ Status RandomApplyOp::Compute(const TensorRow &input, TensorRow *output) {
|
|||
RandomApplyOp::RandomApplyOp(double prob, const std::vector<std::shared_ptr<TensorOp>> &ops)
|
||||
: prob_(prob), gen_(GetSeed()), rand_double_(0, 1) {
|
||||
compose_ = std::make_unique<ComposeOp>(ops);
|
||||
is_deterministic_ = false;
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -92,6 +92,7 @@ RandomChoiceOp::RandomChoiceOp(const std::vector<std::shared_ptr<TensorOp>> &ops
|
|||
} else if (ops_.size() == 1) {
|
||||
MS_LOG(WARNING) << "op_list has only 1 op, this op would be picked every time.";
|
||||
}
|
||||
is_deterministic_ = false;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -44,6 +44,7 @@ RandomAffineOp::RandomAffineOp(std::vector<float_t> degrees, std::vector<float_t
|
|||
interpolation_ = interpolation;
|
||||
fill_value_ = fill_value;
|
||||
rnd_.seed(GetSeed());
|
||||
is_deterministic_ = false;
|
||||
}
|
||||
|
||||
Status RandomAffineOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
|
|
|
@ -36,6 +36,7 @@ RandomColorAdjustOp::RandomColorAdjustOp(float s_bright_factor, float e_bright_f
|
|||
hue_factor_start_(s_hue_factor),
|
||||
hue_factor_end_(e_hue_factor) {
|
||||
rnd_.seed(GetSeed());
|
||||
is_deterministic_ = false;
|
||||
}
|
||||
|
||||
Status RandomColorAdjustOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
|
|
|
@ -19,7 +19,9 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
RandomColorOp::RandomColorOp(float t_lb, float t_ub) : rnd_(GetSeed()), dist_(t_lb, t_ub), t_lb_(t_lb), t_ub_(t_ub) {}
|
||||
RandomColorOp::RandomColorOp(float t_lb, float t_ub) : rnd_(GetSeed()), dist_(t_lb, t_ub), t_lb_(t_lb), t_ub_(t_ub) {
|
||||
is_deterministic_ = false;
|
||||
}
|
||||
|
||||
Status RandomColorOp::Compute(const std::shared_ptr<Tensor> &in, std::shared_ptr<Tensor> *out) {
|
||||
IO_CHECK(in, out);
|
||||
|
|
|
@ -41,6 +41,7 @@ RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t targ
|
|||
aspect_ub_(aspect_ub),
|
||||
max_iter_(max_iter) {
|
||||
rnd_.seed(GetSeed());
|
||||
is_deterministic_ = false;
|
||||
}
|
||||
|
||||
Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
|
|
|
@ -46,6 +46,7 @@ RandomCropOp::RandomCropOp(int32_t crop_height, int32_t crop_width, int32_t pad_
|
|||
fill_g_(fill_g),
|
||||
fill_b_(fill_b) {
|
||||
rnd_.seed(GetSeed());
|
||||
is_deterministic_ = false;
|
||||
}
|
||||
|
||||
Status RandomCropOp::ImagePadding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *pad_image,
|
||||
|
|
|
@ -33,6 +33,7 @@ class RandomHorizontalFlipOp : public TensorOp {
|
|||
static const float kDefProbability;
|
||||
|
||||
explicit RandomHorizontalFlipOp(float probability = kDefProbability) : distribution_(probability) {
|
||||
is_deterministic_ = false;
|
||||
rnd_.seed(GetSeed());
|
||||
}
|
||||
|
||||
|
|
|
@ -35,6 +35,7 @@ class RandomHorizontalFlipWithBBoxOp : public TensorOp {
|
|||
|
||||
explicit RandomHorizontalFlipWithBBoxOp(float probability = kDefProbability) : distribution_(probability) {
|
||||
rnd_.seed(GetSeed());
|
||||
is_deterministic_ = false;
|
||||
}
|
||||
|
||||
~RandomHorizontalFlipWithBBoxOp() override = default;
|
||||
|
|
|
@ -29,6 +29,7 @@ const std::vector<uint8_t> RandomPosterizeOp::kBitRange = {4, 8};
|
|||
RandomPosterizeOp::RandomPosterizeOp(const std::vector<uint8_t> &bit_range)
|
||||
: PosterizeOp(bit_range[0]), bit_range_(bit_range) {
|
||||
rnd_.seed(GetSeed());
|
||||
is_deterministic_ = false;
|
||||
}
|
||||
|
||||
Status RandomPosterizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
|
|
|
@ -35,6 +35,7 @@ class RandomResizeOp : public ResizeOp {
|
|||
|
||||
explicit RandomResizeOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeOp(size_1, size_2) {
|
||||
random_generator_.seed(GetSeed());
|
||||
is_deterministic_ = false;
|
||||
}
|
||||
|
||||
~RandomResizeOp() = default;
|
||||
|
|
|
@ -36,6 +36,7 @@ class RandomResizeWithBBoxOp : public ResizeWithBBoxOp {
|
|||
static const int32_t kDefTargetWidth;
|
||||
explicit RandomResizeWithBBoxOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeWithBBoxOp(size_1, size_2) {
|
||||
random_generator_.seed(GetSeed());
|
||||
is_deterministic_ = false;
|
||||
}
|
||||
|
||||
~RandomResizeWithBBoxOp() = default;
|
||||
|
|
|
@ -46,6 +46,7 @@ RandomRotationOp::RandomRotationOp(float start_degree, float end_degree, float c
|
|||
fill_g_(fill_g),
|
||||
fill_b_(fill_b) {
|
||||
rnd_.seed(GetSeed());
|
||||
is_deterministic_ = false;
|
||||
}
|
||||
|
||||
// main function call for random rotation : Generate the random degrees
|
||||
|
|
|
@ -90,6 +90,7 @@ RandomSelectSubpolicyOp::RandomSelectSubpolicyOp(const std::vector<Subpolicy> &p
|
|||
if (policy_.empty()) {
|
||||
MS_LOG(ERROR) << "policy in RandomSelectSubpolicyOp is empty.";
|
||||
}
|
||||
is_deterministic_ = false;
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -31,6 +31,7 @@ const float RandomSharpnessOp::kDefEndDegree = 1.9;
|
|||
RandomSharpnessOp::RandomSharpnessOp(float start_degree, float end_degree)
|
||||
: start_degree_(start_degree), end_degree_(end_degree) {
|
||||
rnd_.seed(GetSeed());
|
||||
is_deterministic_ = false;
|
||||
}
|
||||
|
||||
/// main function call for random sharpness : Generate the random degrees
|
||||
|
|
|
@ -32,7 +32,10 @@ namespace dataset {
|
|||
class RandomSolarizeOp : public SolarizeOp {
|
||||
public:
|
||||
// Pick a random threshold value to solarize the image with
|
||||
explicit RandomSolarizeOp(std::vector<uint8_t> threshold = {0, 255}) : threshold_(threshold) { rnd_.seed(GetSeed()); }
|
||||
explicit RandomSolarizeOp(std::vector<uint8_t> threshold = {0, 255}) : threshold_(threshold) {
|
||||
rnd_.seed(GetSeed());
|
||||
is_deterministic_ = false;
|
||||
}
|
||||
|
||||
~RandomSolarizeOp() = default;
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ class RandomVerticalFlipOp : public TensorOp {
|
|||
|
||||
explicit RandomVerticalFlipOp(float probability = kDefProbability) : distribution_(probability) {
|
||||
rnd_.seed(GetSeed());
|
||||
is_deterministic_ = false;
|
||||
}
|
||||
|
||||
~RandomVerticalFlipOp() override = default;
|
||||
|
|
|
@ -34,6 +34,7 @@ class RandomVerticalFlipWithBBoxOp : public TensorOp {
|
|||
// @param probability: Probablity of Image flipping, 0.5 by default
|
||||
explicit RandomVerticalFlipWithBBoxOp(float probability = kDefProbability) : distribution_(probability) {
|
||||
rnd_.seed(GetSeed());
|
||||
is_deterministic_ = false;
|
||||
}
|
||||
|
||||
~RandomVerticalFlipWithBBoxOp() override = default;
|
||||
|
|
|
@ -168,6 +168,10 @@ class TensorOp {
|
|||
// @return true/false
|
||||
bool OneToOne() { return NumInput() == 1 && NumOutput() == 1; }
|
||||
|
||||
// Returns true oif the TensorOp produces deterministic result.
|
||||
// @return true/false
|
||||
bool Deterministic() { return is_deterministic_; }
|
||||
|
||||
// Function to determine the number of inputs the TensorOp can take. 0: means undefined.
|
||||
// @return uint32_t
|
||||
virtual uint32_t NumInput() { return 1; }
|
||||
|
@ -191,6 +195,9 @@ class TensorOp {
|
|||
virtual Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs);
|
||||
|
||||
virtual std::string Name() const = 0;
|
||||
|
||||
protected:
|
||||
bool is_deterministic_{true};
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -88,21 +88,21 @@ class Allocator {
|
|||
std::shared_ptr<MemoryPool> pool_;
|
||||
};
|
||||
/// \brief It is a wrapper of unique_ptr with a custom Allocator class defined above
|
||||
template <typename T, typename... Args>
|
||||
Status MakeUnique(std::unique_ptr<T[], std::function<void(T *)>> *out, Allocator<T> alloc, size_t n, Args &&... args) {
|
||||
template <typename T, typename C = std::allocator<T>, typename... Args>
|
||||
Status MakeUnique(std::unique_ptr<T[], std::function<void(T *)>> *out, C alloc, size_t n, Args &&... args) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(n > 0, "size must be positive");
|
||||
try {
|
||||
T *data = alloc.allocate(n);
|
||||
if (!std::is_arithmetic<T>::value) {
|
||||
for (auto i = 0; i < n; i++) {
|
||||
std::allocator_traits<Allocator<T>>::construct(alloc, &(data[i]), std::forward<Args>(args)...);
|
||||
std::allocator_traits<C>::construct(alloc, &(data[i]), std::forward<Args>(args)...);
|
||||
}
|
||||
}
|
||||
auto deleter = [](T *p, Allocator<T> f_alloc, size_t f_n) {
|
||||
auto deleter = [](T *p, C f_alloc, size_t f_n) {
|
||||
if (!std::is_arithmetic<T>::value && std::is_destructible<T>::value) {
|
||||
for (auto i = 0; i < f_n; ++i) {
|
||||
std::allocator_traits<Allocator<T>>::destroy(f_alloc, &p[i]);
|
||||
std::allocator_traits<C>::destroy(f_alloc, &p[i]);
|
||||
}
|
||||
}
|
||||
f_alloc.deallocate(p, f_n);
|
||||
|
@ -129,7 +129,7 @@ class MemGuard {
|
|||
MemGuard(const MemGuard &) = delete;
|
||||
MemGuard &operator=(const MemGuard &) = delete;
|
||||
// On the other hand, We can support move constructor
|
||||
MemGuard(MemGuard &&lhs) noexcept : alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)), n_(lhs.n_) {}
|
||||
MemGuard(MemGuard &&lhs) noexcept : n_(lhs.n_), alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)) {}
|
||||
MemGuard &operator=(MemGuard &&lhs) noexcept {
|
||||
if (this != &lhs) {
|
||||
this->deallocate();
|
||||
|
|
|
@ -37,7 +37,8 @@ struct MemHdr {
|
|||
ArenaImpl::ArenaImpl(void *ptr, size_t sz) : size_in_bytes_(sz), ptr_(ptr) {
|
||||
// Divide the memory into blocks. Ignore the last partial block.
|
||||
uint64_t num_blks = size_in_bytes_ / ARENA_BLK_SZ;
|
||||
MS_LOG(DEBUG) << "Size of memory pool is " << num_blks << ", number of blocks of size is " << ARENA_BLK_SZ << ".";
|
||||
MS_LOG(DEBUG) << "Arena memory pool is created. Number of blocks : " << num_blks << ". Block size : " << ARENA_BLK_SZ
|
||||
<< ".";
|
||||
tr_.Insert(0, num_blks);
|
||||
}
|
||||
|
||||
|
@ -233,9 +234,9 @@ std::ostream &operator<<(std::ostream &os, const ArenaImpl &s) {
|
|||
|
||||
Status Arena::Init() {
|
||||
try {
|
||||
auto sz = size_in_MB_ * 1048576L;
|
||||
mem_ = std::make_unique<uint8_t[]>(sz);
|
||||
impl_ = std::make_unique<ArenaImpl>(mem_.get(), sz);
|
||||
int64_t sz = size_in_MB_ * 1048576L;
|
||||
RETURN_IF_NOT_OK(mem_.allocate(sz));
|
||||
impl_ = std::make_unique<ArenaImpl>(mem_.GetMutablePointer(), sz);
|
||||
} catch (std::bad_alloc &e) {
|
||||
return Status(StatusCode::kOutOfMemory);
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <utility>
|
||||
#include "minddata/dataset/util/allocator.h"
|
||||
#include "minddata/dataset/util/memory_pool.h"
|
||||
#include "minddata/dataset/util/treap.h"
|
||||
|
||||
|
@ -140,7 +141,7 @@ class Arena : public MemoryPool {
|
|||
protected:
|
||||
mutable std::mutex mux_;
|
||||
std::unique_ptr<ArenaImpl> impl_;
|
||||
std::unique_ptr<uint8_t[]> mem_;
|
||||
MemGuard<uint8_t> mem_;
|
||||
size_t size_in_MB_;
|
||||
|
||||
explicit Arena(size_t val_in_MB = 4096);
|
||||
|
|
|
@ -131,6 +131,23 @@ class BPlusTree {
|
|||
tree_stats() : size_(0), leaves_(0), inner_nodes_(0), level_(0) {}
|
||||
};
|
||||
|
||||
/// \brief Statistics functions
|
||||
/// \return Return the height of the tree
|
||||
auto GetHeight() const { return empty() ? 0 : stats_.level_ + 1; }
|
||||
/// \return Order of the B+ tree
|
||||
auto GetOrder() const { return traits::kLeafSlots; }
|
||||
/// \return Number of leaves nodes
|
||||
auto GetNumLeaves() const { return stats_.leaves_; }
|
||||
/// \return Number of inner nodes
|
||||
auto GetNumInnerNodes() const { return stats_.inner_nodes_; }
|
||||
|
||||
/// \brief Toggle locking
|
||||
/// \note Once locking is off. It is user's responsibility to ensure concurrency
|
||||
void SetLocking(bool on_off) {
|
||||
UniqueLock lck(&rw_lock_);
|
||||
acquire_lock_ = on_off;
|
||||
}
|
||||
|
||||
private:
|
||||
// Abstract class of a node (leaf or inner)
|
||||
class BaseNode {
|
||||
|
@ -288,6 +305,17 @@ class BPlusTree {
|
|||
key_compare key_less_;
|
||||
// Stat
|
||||
tree_stats stats_;
|
||||
// lock mode
|
||||
bool acquire_lock_;
|
||||
|
||||
void Init() {
|
||||
typename LeafNode::alloc_type alloc(alloc_);
|
||||
auto *p = alloc.allocate(1);
|
||||
root_ = new (p) LeafNode(alloc_);
|
||||
all_.Prepend(p);
|
||||
leaf_nodes_.Append(p);
|
||||
stats_.leaves_++;
|
||||
}
|
||||
|
||||
bool LessThan(const key_type &a, const key_type &b) const { return key_less_(a, b); }
|
||||
|
||||
|
@ -350,11 +378,11 @@ class BPlusTree {
|
|||
|
||||
~Iterator();
|
||||
|
||||
explicit Iterator(const Iterator &);
|
||||
Iterator(const Iterator &);
|
||||
|
||||
Iterator &operator=(const Iterator &lhs);
|
||||
|
||||
explicit Iterator(Iterator &&);
|
||||
Iterator(Iterator &&) noexcept;
|
||||
|
||||
Iterator &operator=(Iterator &&lhs);
|
||||
|
||||
|
@ -399,11 +427,11 @@ class BPlusTree {
|
|||
ConstIterator(const LeafNode *leaf, slot_type slot, bool locked = false)
|
||||
: cur_(leaf), slot_(slot), locked_(locked) {}
|
||||
|
||||
explicit ConstIterator(const ConstIterator &);
|
||||
ConstIterator(const ConstIterator &);
|
||||
|
||||
ConstIterator &operator=(const ConstIterator &lhs);
|
||||
|
||||
explicit ConstIterator(ConstIterator &&);
|
||||
ConstIterator(ConstIterator &&) noexcept;
|
||||
|
||||
ConstIterator &operator=(ConstIterator &&lhs);
|
||||
|
||||
|
|
|
@ -413,11 +413,16 @@ typename BPlusTree<K, V, A, C, T>::IndexRc BPlusTree<K, V, A, C, T>::Locate(RWLo
|
|||
}
|
||||
|
||||
template <typename K, typename V, typename A, typename C, typename T>
|
||||
BPlusTree<K, V, A, C, T>::BPlusTree() : leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr) {}
|
||||
BPlusTree<K, V, A, C, T>::BPlusTree()
|
||||
: leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr), acquire_lock_(true) {
|
||||
Init();
|
||||
}
|
||||
|
||||
template <typename K, typename V, typename A, typename C, typename T>
|
||||
BPlusTree<K, V, A, C, T>::BPlusTree(const Allocator<V> &alloc)
|
||||
: alloc_(alloc), leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr) {}
|
||||
: alloc_(alloc), leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr), acquire_lock_(true) {
|
||||
Init();
|
||||
}
|
||||
|
||||
template <typename K, typename V, typename A, typename C, typename T>
|
||||
BPlusTree<K, V, A, C, T>::~BPlusTree() noexcept {
|
||||
|
@ -446,20 +451,6 @@ BPlusTree<K, V, A, C, T>::~BPlusTree() noexcept {
|
|||
template <typename K, typename V, typename A, typename C, typename T>
|
||||
Status BPlusTree<K, V, A, C, T>::DoInsert(const key_type &key, std::unique_ptr<value_type> &&value) {
|
||||
IndexRc rc;
|
||||
if (root_ == nullptr) {
|
||||
UniqueLock lck(&rw_lock_);
|
||||
// Check again after we get the lock. Other thread may have created the root node already.
|
||||
if (root_ == nullptr) {
|
||||
LeafNode *leaf = nullptr;
|
||||
rc = AllocateLeaf(&leaf);
|
||||
if (rc != IndexRc::kOk) {
|
||||
return IndexRc2Status(rc);
|
||||
}
|
||||
leaf_nodes_.Append(leaf);
|
||||
root_ = leaf;
|
||||
}
|
||||
// lock will be unlocked when it goes out of scope.
|
||||
}
|
||||
bool retry = false;
|
||||
do {
|
||||
// Track all the paths to the target and lock each internal node in S.
|
||||
|
@ -468,7 +459,7 @@ Status BPlusTree<K, V, A, C, T>::DoInsert(const key_type &key, std::unique_ptr<v
|
|||
retry = false;
|
||||
BaseNode *new_child = nullptr;
|
||||
key_type new_key = key_type();
|
||||
rc = InsertKeyValue(&InsCB, root_, key, std::move(value), &new_key, &new_child);
|
||||
rc = InsertKeyValue(acquire_lock_ ? &InsCB : nullptr, root_, key, std::move(value), &new_key, &new_child);
|
||||
if (rc == IndexRc::kRetry) {
|
||||
retry = true;
|
||||
} else if (rc != IndexRc::kOk) {
|
||||
|
@ -511,9 +502,12 @@ std::unique_ptr<V> BPlusTree<K, V, A, C, T>::DoUpdate(const key_type &key, std::
|
|||
if (root_ != nullptr) {
|
||||
LeafNode *leaf = nullptr;
|
||||
slot_type slot;
|
||||
RWLock *myLock = &this->rw_lock_;
|
||||
RWLock *myLock = nullptr;
|
||||
if (acquire_lock_) {
|
||||
myLock = &this->rw_lock_;
|
||||
// Lock the tree in S, pass the lock to Locate which will unlock it for us underneath.
|
||||
myLock->LockShared();
|
||||
}
|
||||
IndexRc rc = Locate(myLock, true, root_, key, &leaf, &slot);
|
||||
if (rc == IndexRc::kOk) {
|
||||
// All locks from the tree to the parent of leaf are all gone. We still have a X lock
|
||||
|
@ -521,7 +515,9 @@ std::unique_ptr<V> BPlusTree<K, V, A, C, T>::DoUpdate(const key_type &key, std::
|
|||
// Swap out the old value and replace it with new value.
|
||||
std::unique_ptr<value_type> old = std::move(leaf->data_[leaf->slot_dir_[slot]]);
|
||||
leaf->data_[leaf->slot_dir_[slot]] = std::move(new_value);
|
||||
if (acquire_lock_) {
|
||||
leaf->rw_lock_.Unlock();
|
||||
}
|
||||
return old;
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Key not found. rc = " << static_cast<int>(rc) << ".";
|
||||
|
|
|
@ -109,7 +109,7 @@ BPlusTree<K, V, A, C, T>::Iterator::Iterator(const BPlusTree<K, V, A, C, T>::Ite
|
|||
}
|
||||
|
||||
template <typename K, typename V, typename A, typename C, typename T>
|
||||
BPlusTree<K, V, A, C, T>::Iterator::Iterator(BPlusTree<K, V, A, C, T>::Iterator &&lhs) {
|
||||
BPlusTree<K, V, A, C, T>::Iterator::Iterator(BPlusTree<K, V, A, C, T>::Iterator &&lhs) noexcept {
|
||||
this->cur_ = lhs.cur_;
|
||||
this->slot_ = lhs.slot_;
|
||||
this->locked_ = lhs.locked_;
|
||||
|
@ -241,7 +241,7 @@ BPlusTree<K, V, A, C, T>::ConstIterator::ConstIterator(const BPlusTree<K, V, A,
|
|||
}
|
||||
|
||||
template <typename K, typename V, typename A, typename C, typename T>
|
||||
BPlusTree<K, V, A, C, T>::ConstIterator::ConstIterator(BPlusTree<K, V, A, C, T>::ConstIterator &&lhs) {
|
||||
BPlusTree<K, V, A, C, T>::ConstIterator::ConstIterator(BPlusTree<K, V, A, C, T>::ConstIterator &&lhs) noexcept {
|
||||
this->cur_ = lhs.cur_;
|
||||
this->slot_ = lhs.slot_;
|
||||
this->locked_ = lhs.locked_;
|
||||
|
@ -290,9 +290,12 @@ std::pair<typename BPlusTree<K, V, A, C, T>::ConstIterator, bool> BPlusTree<K, V
|
|||
if (root_ != nullptr) {
|
||||
LeafNode *leaf = nullptr;
|
||||
slot_type slot;
|
||||
RWLock *myLock = &this->rw_lock_;
|
||||
RWLock *myLock = nullptr;
|
||||
if (acquire_lock_) {
|
||||
myLock = &this->rw_lock_;
|
||||
// Lock the tree in S, pass the lock to Locate which will unlock it for us underneath.
|
||||
myLock->LockShared();
|
||||
}
|
||||
IndexRc rc = Locate(myLock, false, root_, key, &leaf, &slot);
|
||||
bool find = (rc == IndexRc::kOk);
|
||||
return std::make_pair(ConstIterator(leaf, slot, find), find);
|
||||
|
@ -306,9 +309,12 @@ std::pair<typename BPlusTree<K, V, A, C, T>::Iterator, bool> BPlusTree<K, V, A,
|
|||
if (root_ != nullptr) {
|
||||
LeafNode *leaf = nullptr;
|
||||
slot_type slot;
|
||||
RWLock *myLock = &this->rw_lock_;
|
||||
RWLock *myLock = nullptr;
|
||||
if (acquire_lock_) {
|
||||
myLock = &this->rw_lock_;
|
||||
// Lock the tree in S, pass the lock to Locate which will unlock it for us underneath.
|
||||
myLock->LockShared();
|
||||
}
|
||||
IndexRc rc = Locate(myLock, false, root_, key, &leaf, &slot);
|
||||
bool find = (rc == IndexRc::kOk);
|
||||
return std::make_pair(Iterator(leaf, slot, find), find);
|
||||
|
|
|
@ -69,7 +69,7 @@ Status BuddySpace::Alloc(const uint64_t sz, BSpaceDescriptor *desc, addr_t *p) n
|
|||
*p = addr;
|
||||
return Status::OK();
|
||||
} else {
|
||||
return Status(StatusCode::kNoSpace, "BuddySpace full. Not an error. Please ignore.");
|
||||
return Status(StatusCode::kBuddySpaceFull, "BuddySpace full. Not an error. Please ignore.");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -20,8 +20,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
CachePool::CachePool(const value_allocator &alloc, const std::string &root)
|
||||
: alloc_(alloc), root_(root), subfolder_(Services::GetUniqueID()), sm_(nullptr), tree_(nullptr) {}
|
||||
CachePool::CachePool(const value_allocator &alloc, bool ourOwnArena, const std::string &root)
|
||||
: alloc_(alloc),
|
||||
root_(root),
|
||||
subfolder_(Services::GetUniqueID()),
|
||||
sm_(nullptr),
|
||||
tree_(nullptr),
|
||||
custom_arena_(ourOwnArena) {}
|
||||
|
||||
Status CachePool::DoServiceStart() {
|
||||
tree_ = std::make_shared<data_index>();
|
||||
|
@ -45,11 +50,14 @@ Status CachePool::DoServiceStop() {
|
|||
}
|
||||
}
|
||||
sm_.reset();
|
||||
// If it is our own arena, skip freeing individual pieces.
|
||||
if (!custom_arena_) {
|
||||
for (auto &bl : *tree_) {
|
||||
if (bl.ptr != nullptr) {
|
||||
alloc_.deallocate(bl.ptr, bl.sz);
|
||||
}
|
||||
}
|
||||
}
|
||||
tree_.reset();
|
||||
if (!root_.toString().empty()) {
|
||||
Path spill = GetSpillPath();
|
||||
|
@ -68,7 +76,7 @@ Status CachePool::DoServiceStop() {
|
|||
return rc2;
|
||||
}
|
||||
CachePool::~CachePool() noexcept { (void)ServiceStop(); }
|
||||
Status CachePool::Insert(const std::vector<ReadableSlice> &buf, CachePool::key_type *key) {
|
||||
Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlice> &buf, bool writeToDiskDirectly) {
|
||||
DataLocator bl;
|
||||
Status rc;
|
||||
size_t sz = 0;
|
||||
|
@ -78,6 +86,7 @@ Status CachePool::Insert(const std::vector<ReadableSlice> &buf, CachePool::key_t
|
|||
}
|
||||
bl.sz = sz;
|
||||
try {
|
||||
if (!writeToDiskDirectly) {
|
||||
bl.ptr = alloc_.allocate(sz);
|
||||
// We will do a piecewise copy.
|
||||
WritableSlice dest(bl.ptr, bl.sz);
|
||||
|
@ -95,6 +104,14 @@ Status CachePool::Insert(const std::vector<ReadableSlice> &buf, CachePool::key_t
|
|||
bl.ptr = nullptr;
|
||||
return rc;
|
||||
}
|
||||
} else if (sm_ != nullptr) {
|
||||
MS_LOG(DEBUG) << "Spill to disk directly ... " << bl.sz << " bytes.";
|
||||
RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf));
|
||||
} else {
|
||||
// If asked to spill to disk instead but there is no storage set up, simply return no memory
|
||||
// instead.
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
|
||||
}
|
||||
} catch (std::bad_alloc &e) {
|
||||
if (sm_ != nullptr) {
|
||||
RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf));
|
||||
|
@ -102,7 +119,13 @@ Status CachePool::Insert(const std::vector<ReadableSlice> &buf, CachePool::key_t
|
|||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
|
||||
}
|
||||
}
|
||||
rc = tree_->insert(bl, key);
|
||||
// Insert into the B+ tree. We may still get out of memory error. So need to catch it.
|
||||
try {
|
||||
rc = tree_->DoInsert(key, bl);
|
||||
} catch (const std::bad_alloc &e) {
|
||||
rc = Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
|
||||
}
|
||||
// Duplicate key is treated as error and we will also free the memory.
|
||||
if (rc.IsError() && bl.ptr != nullptr) {
|
||||
alloc_.deallocate(bl.ptr, sz);
|
||||
}
|
||||
|
@ -138,16 +161,27 @@ Path CachePool::GetSpillPath() const {
|
|||
auto spill = Path(root_) / subfolder_;
|
||||
return spill;
|
||||
}
|
||||
CachePool::CacheStat CachePool::GetStat() const {
|
||||
CacheStat cs{0};
|
||||
CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const {
|
||||
CacheStat cs{-1, -1, 0, 0, 0};
|
||||
int64_t total_sz = 0;
|
||||
for (auto &it : *tree_) {
|
||||
total_sz += it.sz;
|
||||
if (it.ptr != nullptr) {
|
||||
if (tree_->begin() != tree_->end()) {
|
||||
cs.min_key = tree_->begin().key();
|
||||
cs.max_key = cs.min_key; // will adjust later.
|
||||
for (auto it = tree_->begin(); it != tree_->end(); ++it) {
|
||||
total_sz += it.value().sz;
|
||||
if (it.value().ptr != nullptr) {
|
||||
++cs.num_mem_cached;
|
||||
} else {
|
||||
++cs.num_disk_cached;
|
||||
}
|
||||
auto cur_key = it.key();
|
||||
if (GetMissingKeys) {
|
||||
for (auto i = cs.max_key + 1; i < cur_key; ++i) {
|
||||
cs.gap.push_back((i));
|
||||
}
|
||||
}
|
||||
cs.max_key = cur_key;
|
||||
}
|
||||
}
|
||||
if (total_sz > 0) {
|
||||
// integer arithmetic. NO need to cast to float or double.
|
||||
|
|
|
@ -25,13 +25,13 @@
|
|||
#include "minddata/dataset/util/slice.h"
|
||||
#include "minddata/dataset/util/storage_manager.h"
|
||||
#include "minddata/dataset/util/auto_index.h"
|
||||
#include "minddata/dataset/util/btree.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// \brief A CachePool provides service for backup/restore a buffer. A buffer can be represented in a form of vector of
|
||||
/// ReadableSlice where all memory blocks will be copied to one contiguous block which can be in memory or spilled to
|
||||
/// disk (if a disk directory is provided). Every buffer insert will return a generated key which can be used to
|
||||
/// restore the buffer.
|
||||
/// disk (if a disk directory is provided). User must provide a key to insert the buffer.
|
||||
/// \see ReadableSlice
|
||||
class CachePool : public Service {
|
||||
public:
|
||||
|
@ -73,22 +73,25 @@ class CachePool : public Service {
|
|||
StorageManager::key_type storage_key;
|
||||
};
|
||||
|
||||
using data_index = AutoIndexObj<DataLocator>;
|
||||
using data_index = BPlusTree<int64_t, DataLocator>;
|
||||
using key_type = data_index::key_type;
|
||||
using bl_alloc_type = typename value_allocator::template rebind<DataLocator>::other;
|
||||
|
||||
/// \brief Simple statistics returned from CachePool like how many elements are cached in memory and
|
||||
/// how many elements are spilled to disk.
|
||||
struct CacheStat {
|
||||
key_type min_key;
|
||||
key_type max_key;
|
||||
int64_t num_mem_cached;
|
||||
int64_t num_disk_cached;
|
||||
int64_t average_cache_sz;
|
||||
std::vector<key_type> gap;
|
||||
};
|
||||
|
||||
/// \brief Constructor
|
||||
/// \param alloc Allocator to allocate memory from
|
||||
/// \param root Optional disk folder to spill
|
||||
explicit CachePool(const value_allocator &alloc, const std::string &root = "");
|
||||
explicit CachePool(const value_allocator &alloc, bool customArena, const std::string &root = "");
|
||||
|
||||
CachePool(const CachePool &) = delete;
|
||||
CachePool(CachePool &&) = delete;
|
||||
|
@ -103,10 +106,11 @@ class CachePool : public Service {
|
|||
|
||||
/// \brief Insert a sequence of ReadableSlice objects into the pool.
|
||||
/// All memory blocks will be consolidated into one contiguous block and be cached in either memory or on disk.
|
||||
/// \param[in] key User supplied key
|
||||
/// \param[in] buf A sequence of ReadableSlice objects.
|
||||
/// \param[out] key Generated key
|
||||
/// \param[in] writeToDiskDirectly If true, no spill to disk if spill is enabled, or return no memory
|
||||
/// \return Error code
|
||||
Status Insert(const std::vector<ReadableSlice> &buf, key_type *key);
|
||||
Status Insert(key_type key, const std::vector<ReadableSlice> &buf, bool writeToDiskDirectly);
|
||||
/// \brief Restore a cached buffer (from memory or disk)
|
||||
/// \param[in] key A previous key returned from Insert
|
||||
/// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice
|
||||
|
@ -122,18 +126,23 @@ class CachePool : public Service {
|
|||
|
||||
/// \brief Get statistics.
|
||||
/// \return CacheStat object
|
||||
CacheStat GetStat() const;
|
||||
CacheStat GetStat(bool GetMissingKeys = false) const;
|
||||
|
||||
const value_allocator &get_allocator() const;
|
||||
|
||||
std::string MyName() const { return subfolder_; }
|
||||
|
||||
/// \brief Toggle locking
|
||||
/// \note Once locking is off. It is user's responsibility to ensure concurrency
|
||||
void SetLocking(bool on_off) { tree_->SetLocking(on_off); }
|
||||
|
||||
private:
|
||||
value_allocator alloc_;
|
||||
Path root_;
|
||||
const std::string subfolder_;
|
||||
std::shared_ptr<StorageManager> sm_;
|
||||
std::shared_ptr<data_index> tree_;
|
||||
bool custom_arena_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -133,12 +133,13 @@ void CircularPool::Deallocate(void *p) {
|
|||
// Lock in the chain in shared mode and find out which
|
||||
// segment it comes from
|
||||
SharedLock lock(&rw_lock_);
|
||||
auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [p](std::shared_ptr<Arena> &b) -> bool {
|
||||
auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [this, p](std::shared_ptr<Arena> &b) -> bool {
|
||||
char *q = reinterpret_cast<char *>(p);
|
||||
char *base = const_cast<char *>(reinterpret_cast<const char *>(b->get_base_addr()));
|
||||
return (q > base && q < base + b->get_max_size());
|
||||
auto *base = reinterpret_cast<const char *>(b->get_base_addr());
|
||||
return (q > base && q < base + arena_size_ * 1048576L);
|
||||
});
|
||||
lock.Unlock();
|
||||
MS_ASSERT(it != mem_segments_.end());
|
||||
it->get()->Deallocate(p);
|
||||
}
|
||||
|
||||
|
@ -150,10 +151,10 @@ Status CircularPool::Reallocate(void **pp, size_t old_sz, size_t new_sz) {
|
|||
}
|
||||
void *p = *pp;
|
||||
SharedLock lock(&rw_lock_);
|
||||
auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [p](std::shared_ptr<Arena> &b) -> bool {
|
||||
auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [this, p](std::shared_ptr<Arena> &b) -> bool {
|
||||
char *q = reinterpret_cast<char *>(p);
|
||||
char *base = const_cast<char *>(reinterpret_cast<const char *>(b->get_base_addr()));
|
||||
return (q > base && q < base + b->get_max_size());
|
||||
auto *base = reinterpret_cast<const char *>(b->get_base_addr());
|
||||
return (q > base && q < base + arena_size_ * 1048576L);
|
||||
});
|
||||
lock.Unlock();
|
||||
MS_ASSERT(it != mem_segments_.end());
|
||||
|
|
|
@ -16,11 +16,14 @@
|
|||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_
|
||||
|
||||
#include <atomic>
|
||||
#include <deque>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include "minddata/dataset/util/allocator.h"
|
||||
#include "minddata/dataset/util/system_pool.h"
|
||||
#include "minddata/dataset/util/semaphore.h"
|
||||
#include "minddata/dataset/util/services.h"
|
||||
namespace mindspore {
|
||||
|
@ -37,7 +40,7 @@ class QueueMap {
|
|||
using key_type = K;
|
||||
using value_type = T;
|
||||
|
||||
QueueMap() = default;
|
||||
QueueMap() : num_rows_(0) {}
|
||||
virtual ~QueueMap() = default;
|
||||
|
||||
/// Add an element <key, T> to the map and wake up any consumer that is waiting
|
||||
|
@ -48,6 +51,7 @@ class QueueMap {
|
|||
RequestQueue *rq = nullptr;
|
||||
RETURN_IF_NOT_OK(GetRq(key, &rq));
|
||||
RETURN_IF_NOT_OK(rq->WakeUpAny(std::move(payload)));
|
||||
++num_rows_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -56,9 +60,35 @@ class QueueMap {
|
|||
RequestQueue *rq = nullptr;
|
||||
RETURN_IF_NOT_OK(GetRq(key, &rq));
|
||||
RETURN_IF_NOT_OK(rq->Wait(out));
|
||||
--num_rows_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// Get the number of elements in the container
|
||||
/// \return The number of elements in the container
|
||||
int64_t size() const { return num_rows_; }
|
||||
|
||||
/// \return if the container is empty
|
||||
bool empty() const { return num_rows_ == 0; }
|
||||
|
||||
/// Print out some useful information about the container
|
||||
friend std::ostream &operator<<(std::ostream &out, const QueueMap &qm) {
|
||||
std::unique_lock<std::mutex> lck(qm.mux_);
|
||||
out << "Number of elements: " << qm.num_rows_ << "\n";
|
||||
out << "Dumping internal info:\n";
|
||||
int64_t k = 0;
|
||||
for (auto &it : qm.all_) {
|
||||
auto key = it.first;
|
||||
const RequestQueue *rq = it.second.GetPointer();
|
||||
out << "(k:" << key << "," << *rq << ") ";
|
||||
++k;
|
||||
if (k % 6 == 0) {
|
||||
out << "\n";
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
protected:
|
||||
/// This is a handshake structure between producer and consumer
|
||||
class RequestQueue {
|
||||
|
@ -86,8 +116,13 @@ class QueueMap {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
friend std::ostream &operator<<(std::ostream &out, const RequestQueue &rq) {
|
||||
out << "sz:" << rq.row_.size() << ",uc:" << rq.use_count_.Peek();
|
||||
return out;
|
||||
}
|
||||
|
||||
private:
|
||||
std::mutex dq_mux_;
|
||||
mutable std::mutex dq_mux_;
|
||||
Semaphore use_count_;
|
||||
std::deque<T> row_;
|
||||
};
|
||||
|
@ -104,7 +139,7 @@ class QueueMap {
|
|||
*out = it->second.GetMutablePointer();
|
||||
} else {
|
||||
// We will create a new one.
|
||||
auto alloc = Services::GetAllocator<RequestQueue>();
|
||||
auto alloc = SystemPool::GetAllocator<RequestQueue>();
|
||||
auto r = all_.emplace(key, MemGuard<RequestQueue, Allocator<RequestQueue>>(alloc));
|
||||
if (r.second) {
|
||||
auto &mem = r.first->second;
|
||||
|
@ -118,8 +153,9 @@ class QueueMap {
|
|||
}
|
||||
|
||||
private:
|
||||
std::mutex mux_;
|
||||
mutable std::mutex mux_;
|
||||
std::map<K, MemGuard<RequestQueue, Allocator<RequestQueue>>> all_;
|
||||
std::atomic<int64_t> num_rows_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,10 +29,7 @@ void Semaphore::V() {
|
|||
++value_;
|
||||
wait_cond_.NotifyOne();
|
||||
}
|
||||
int Semaphore::Peek() {
|
||||
std::unique_lock<std::mutex> lck(mutex_);
|
||||
return value_;
|
||||
}
|
||||
int Semaphore::Peek() const { return value_; }
|
||||
Status Semaphore::Register(TaskGroup *vg) { return wait_cond_.Register(vg->GetIntrpService()); }
|
||||
Status Semaphore::Deregister() { return (wait_cond_.Deregister()); }
|
||||
void Semaphore::ResetIntrpState() { wait_cond_.ResetIntrpState(); }
|
||||
|
|
|
@ -38,7 +38,7 @@ class Semaphore {
|
|||
void V();
|
||||
/// \brief Peek the internal value
|
||||
/// \return The internal value
|
||||
int Peek();
|
||||
int Peek() const;
|
||||
Status Register(TaskGroup *vg);
|
||||
Status Deregister();
|
||||
void ResetIntrpState();
|
||||
|
|
|
@ -51,6 +51,12 @@ std::string CodeAsString(const StatusCode c) {
|
|||
case StatusCode::kSyntaxError:
|
||||
s = "Syntax error";
|
||||
break;
|
||||
case StatusCode::kBuddySpaceFull:
|
||||
s = "BuddySpace full";
|
||||
break;
|
||||
case StatusCode::kNetWorkError:
|
||||
s = "Network error";
|
||||
break;
|
||||
case StatusCode::kUnexpectedError:
|
||||
default:
|
||||
s = "Unexpected error";
|
||||
|
|
|
@ -82,6 +82,8 @@ enum class StatusCode : char {
|
|||
kBoundingBoxInvalidShape = 12,
|
||||
kSyntaxError = 13,
|
||||
kTimeOut = 14,
|
||||
kBuddySpaceFull = 14,
|
||||
kNetWorkError = 15,
|
||||
// Make this error code the last one. Add new error code above it.
|
||||
kUnexpectedError = 127
|
||||
};
|
||||
|
@ -137,6 +139,8 @@ class Status {
|
|||
|
||||
bool IsNoSpace() const { return (get_code() == StatusCode::kNoSpace); }
|
||||
|
||||
bool IsNetWorkError() const { return (get_code() == StatusCode::kNetWorkError); }
|
||||
|
||||
private:
|
||||
StatusCode code_;
|
||||
std::string err_msg_;
|
||||
|
|
|
@ -99,8 +99,12 @@ Status StorageContainer::Write(const ReadableSlice &dest, off64_t offset) const
|
|||
#endif
|
||||
if (r_sz != sz) {
|
||||
errno_t err = (r_sz == 0) ? EOF : errno;
|
||||
if (errno == ENOSPC) {
|
||||
return Status(StatusCode::kNoSpace, __LINE__, __FILE__);
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED(strerror(err));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -71,10 +71,11 @@ Status StorageManager::Write(key_type *key, const std::vector<ReadableSlice> &bu
|
|||
key_type out_key;
|
||||
value_type out_value;
|
||||
bool create_new_container = false;
|
||||
size_t last_num_container = -1;
|
||||
do {
|
||||
SharedLock lock_s(&rw_lock_);
|
||||
size_t num_containers = containers_.size();
|
||||
if (create_new_container) {
|
||||
if (create_new_container && (num_containers == last_num_container)) {
|
||||
// Upgrade to exclusvie lock.
|
||||
lock_s.Upgrade();
|
||||
create_new_container = false;
|
||||
|
@ -95,8 +96,11 @@ Status StorageManager::Write(key_type *key, const std::vector<ReadableSlice> &bu
|
|||
cont = containers_.at(num_containers - 1);
|
||||
off64_t offset;
|
||||
Status rc = cont->Insert(buf, &offset);
|
||||
if (rc.IsNoSpace()) {
|
||||
if (rc.get_code() == StatusCode::kBuddySpaceFull) {
|
||||
create_new_container = true;
|
||||
// Remember how many containers we saw. In the next iteration we will do a comparision to see
|
||||
// if someone has already created it.
|
||||
last_num_container = num_containers;
|
||||
} else if (rc.IsOk()) {
|
||||
out_value = std::make_pair(num_containers - 1, std::make_pair(offset, sz));
|
||||
RETURN_IF_NOT_OK(index_.insert(out_value, &out_key));
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""Cache client
|
||||
"""
|
||||
|
||||
import os
|
||||
import copy
|
||||
|
||||
from ..core.validator_helpers import type_check, check_uint32, check_uint64
|
||||
|
@ -25,11 +26,11 @@ class DatasetCache:
|
|||
A client to interface with tensor caching service
|
||||
"""
|
||||
|
||||
def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, prefetch_size=20):
|
||||
def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, num_connections=None,
|
||||
prefetch_size=None):
|
||||
check_uint32(session_id, "session_id")
|
||||
check_uint64(size, "size")
|
||||
type_check(spilling, (bool,), "spilling")
|
||||
check_uint32(prefetch_size, "prefetch size")
|
||||
|
||||
self.session_id = session_id
|
||||
self.size = size
|
||||
|
@ -37,8 +38,13 @@ class DatasetCache:
|
|||
self.hostname = hostname
|
||||
self.port = port
|
||||
self.prefetch_size = prefetch_size
|
||||
self.num_connections = num_connections
|
||||
if os.getenv('MS_ENABLE_CACHE') != 'TRUE':
|
||||
# temporary disable cache feature in the current release
|
||||
self.cache_client = None
|
||||
else:
|
||||
from mindspore._c_dataengine import CacheClient
|
||||
self.cache_client = CacheClient(session_id, size, spilling, hostname, port, num_connections, prefetch_size)
|
||||
|
||||
def GetStat(self):
|
||||
return self.cache_client.GetStat()
|
||||
|
@ -55,5 +61,6 @@ class DatasetCache:
|
|||
new_cache.hostname = copy.deepcopy(self.hostname, memodict)
|
||||
new_cache.port = copy.deepcopy(self.port, memodict)
|
||||
new_cache.prefetch_size = copy.deepcopy(self.prefetch_size, memodict)
|
||||
new_cache.num_connections = copy.deepcopy(self.num_connections, memodict)
|
||||
new_cache.cache_client = self.cache_client
|
||||
return new_cache
|
||||
|
|
|
@ -1234,5 +1234,8 @@ def check_paddeddataset(method):
|
|||
def check_cache_option(cache):
|
||||
"""Sanity check for cache parameter"""
|
||||
if cache is not None:
|
||||
if os.getenv('MS_ENABLE_CACHE') != 'TRUE':
|
||||
# temporary disable cache feature in the current release
|
||||
raise ValueError("Caching is disabled in the current release")
|
||||
from . import cache_client
|
||||
type_check(cache, (cache_client.DatasetCache,), "cache")
|
||||
|
|
19
setup.py
19
setup.py
|
@ -156,6 +156,24 @@ def update_permissions(path):
|
|||
if filename == "ms_serving":
|
||||
os.chmod(file_fullpath, stat.S_IREAD | stat.S_IEXEC)
|
||||
|
||||
def bin_files():
|
||||
"""
|
||||
Gets the binary files to be installed.
|
||||
"""
|
||||
data_files = []
|
||||
binary_files = []
|
||||
|
||||
cache_server_bin = os.path.join('mindspore', 'bin', 'cache_server')
|
||||
if not os.path.exists(cache_server_bin):
|
||||
return data_files
|
||||
binary_files.append(cache_server_bin)
|
||||
cache_admin_bin = os.path.join('mindspore', 'bin', 'cache_admin')
|
||||
if not os.path.exists(cache_admin_bin):
|
||||
return data_files
|
||||
binary_files.append(cache_admin_bin)
|
||||
data_files.append(('bin', binary_files))
|
||||
return data_files
|
||||
|
||||
|
||||
class EggInfo(egg_info):
|
||||
"""Egg info."""
|
||||
|
@ -192,6 +210,7 @@ setup(
|
|||
'framework that could be used for mobile, edge and cloud scenarios.',
|
||||
long_description="\n\n".join([readme, release]),
|
||||
long_description_content_type="text/markdown",
|
||||
data_files=bin_files(),
|
||||
packages=find_packages(),
|
||||
package_data=package_data,
|
||||
include_package_data=True,
|
||||
|
|
|
@ -24,9 +24,9 @@
|
|||
#include "utils/log_adapter.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::LogStream;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
|
||||
// For testing purposes, we will make the branching factor very low.
|
||||
struct mytraits {
|
||||
|
@ -35,7 +35,6 @@ struct mytraits {
|
|||
static const slot_type kInnerSlots = 3;
|
||||
};
|
||||
|
||||
|
||||
class MindDataTestBPlusTree : public UT::Common {
|
||||
public:
|
||||
MindDataTestBPlusTree() = default;
|
||||
|
@ -44,7 +43,7 @@ class MindDataTestBPlusTree : public UT::Common {
|
|||
// Test serial insert.
|
||||
TEST_F(MindDataTestBPlusTree, Test1) {
|
||||
Allocator<std::string> alloc(std::make_shared<SystemPool>());
|
||||
BPlusTree<uint64_t, std::string, Allocator<std::string>, std::less<uint64_t>, mytraits> btree(alloc);
|
||||
BPlusTree<uint64_t, std::string, Allocator<std::string>, std::less<>, mytraits> btree(alloc);
|
||||
Status rc;
|
||||
for (int i = 0; i < 100; i++) {
|
||||
uint64_t key = 2 * i;
|
||||
|
@ -109,7 +108,7 @@ TEST_F(MindDataTestBPlusTree, Test1) {
|
|||
// Test concurrent insert.
|
||||
TEST_F(MindDataTestBPlusTree, Test2) {
|
||||
Allocator<std::string> alloc(std::make_shared<SystemPool>());
|
||||
BPlusTree<uint64_t, std::string, Allocator<std::string>, std::less<uint64_t>, mytraits> btree(alloc);
|
||||
BPlusTree<uint64_t, std::string, Allocator<std::string>, std::less<>, mytraits> btree(alloc);
|
||||
TaskGroup vg;
|
||||
auto f = [&](int k) -> Status {
|
||||
TaskManager::FindMe()->Post();
|
||||
|
@ -125,7 +124,8 @@ TEST_F(MindDataTestBPlusTree, Test2) {
|
|||
auto g = [&](int k) -> Status {
|
||||
TaskManager::FindMe()->Post();
|
||||
for (int i = 0; i < 1000; i++) {
|
||||
uint64_t key = rand() % 10000;;
|
||||
uint64_t key = rand() % 10000;
|
||||
;
|
||||
auto it = btree.Search(key);
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -226,3 +226,22 @@ TEST_F(MindDataTestBPlusTree, Test4) {
|
|||
EXPECT_EQ(cnt, 1000);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestBPlusTree, TestPerfNoLocking) {
|
||||
AutoIndexObj<int64_t> btree;
|
||||
// No locking test
|
||||
btree.SetLocking(false);
|
||||
// Insert a million entries using the default traits.
|
||||
for (auto i = 0; i < 1000000; ++i) {
|
||||
ASSERT_TRUE(btree.insert(i));
|
||||
}
|
||||
std::cout << "Tree height : " << btree.GetHeight() << std::endl;
|
||||
std::cout << "Tree Order : " << btree.GetOrder() << std::endl;
|
||||
std::cout << "Number of leaves : " << btree.GetNumLeaves() << std::endl;
|
||||
std::cout << "Number of inner nodes : " << btree.GetNumInnerNodes() << std::endl;
|
||||
|
||||
auto r = btree.Search(3);
|
||||
EXPECT_TRUE(r.second);
|
||||
r = btree.Search(999999);
|
||||
EXPECT_TRUE(r.second);
|
||||
}
|
||||
|
|
|
@ -35,6 +35,23 @@ using mindspore::dataset::TaskGroup;
|
|||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
|
||||
// Helper function to get the session id from SESSION_ID env variable
|
||||
Status GetSessionFromEnv(session_id_type *session_id) {
|
||||
RETURN_UNEXPECTED_IF_NULL(session_id);
|
||||
if (const char *session_env = std::getenv("SESSION_ID")) {
|
||||
std::string session_id_str(session_env);
|
||||
try {
|
||||
*session_id = std::stoul(session_id_str);
|
||||
} catch (const std::exception &e) {
|
||||
std::string err_msg = "Invalid numeric value for session id in env var: " + session_id_str;
|
||||
return Status(StatusCode::kSyntaxError, err_msg);
|
||||
}
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Test case requires a session id to be provided via SESSION_ID environment variable.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
class MindDataTestCacheOp : public UT::DatasetOpTesting {
|
||||
public:
|
||||
void SetUp() override {
|
||||
|
@ -46,8 +63,12 @@ class MindDataTestCacheOp : public UT::DatasetOpTesting {
|
|||
TEST_F(MindDataTestCacheOp, DISABLED_TestCacheServer) {
|
||||
Status rc;
|
||||
CacheClient::Builder builder;
|
||||
session_id_type env_session;
|
||||
rc = GetSessionFromEnv(&env_session);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// use arbitrary session of 1, size of 0, spilling// is true
|
||||
builder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true);
|
||||
builder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true);
|
||||
std::shared_ptr<CacheClient> myClient;
|
||||
rc = builder.Build(&myClient);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
@ -118,9 +139,6 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheServer) {
|
|||
cmp = (map_out == map);
|
||||
ASSERT_TRUE(cmp);
|
||||
|
||||
// Test Purge and Destroy
|
||||
rc = myClient->PurgeCache();
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = myClient->DestroyCache();
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
}
|
||||
|
@ -130,10 +148,15 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestConcurrencyRequest) {
|
|||
(void)TaskManager::GetMasterThreadRc();
|
||||
TaskGroup vg;
|
||||
Status rc;
|
||||
|
||||
session_id_type env_session;
|
||||
rc = GetSessionFromEnv(&env_session);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// use arbitrary session of 1, size 1, spilling is true
|
||||
CacheClient::Builder builder;
|
||||
// use arbitrary session of 1, size of 0, spilling// is true
|
||||
builder.SetSessionId(1).SetCacheMemSz(1).SetSpill(true);
|
||||
builder.SetSessionId(env_session).SetCacheMemSz(1).SetSpill(true);
|
||||
std::shared_ptr<CacheClient> myClient;
|
||||
rc = builder.Build(&myClient);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
@ -199,8 +222,15 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestConcurrencyRequest) {
|
|||
// RandomDataOp
|
||||
//
|
||||
TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) {
|
||||
// Clear the rc of the master thread if any
|
||||
(void)TaskManager::GetMasterThreadRc();
|
||||
Status rc;
|
||||
int32_t rank = 0; // not used
|
||||
|
||||
session_id_type env_session;
|
||||
rc = GetSessionFromEnv(&env_session);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
MS_LOG(INFO) << "UT test TestRandomDataCache1";
|
||||
// Start with an empty execution tree
|
||||
auto myTree = std::make_shared<ExecutionTree>();
|
||||
|
@ -236,8 +266,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) {
|
|||
// CacheOp
|
||||
// size of 0, spilling is true
|
||||
CacheClient::Builder builder;
|
||||
// use arbitrary session of 1, size of 0, spilling// is true
|
||||
builder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true);
|
||||
builder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true);
|
||||
std::shared_ptr<CacheClient> myClient;
|
||||
rc = builder.Build(&myClient);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
@ -273,7 +302,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) {
|
|||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
MS_LOG(INFO) << "Launching tree and begin iteration";
|
||||
rc = myTree->Prepare();
|
||||
rc = myTree->Prepare(1);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// quick check to see what tree looks like
|
||||
|
@ -314,9 +343,16 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) {
|
|||
//// RandomDataOp
|
||||
////
|
||||
TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) {
|
||||
// Clear the rc of the master thread if any
|
||||
(void)TaskManager::GetMasterThreadRc();
|
||||
Status rc;
|
||||
int32_t rank = 0; // not used
|
||||
MS_LOG(INFO) << "UT test TestRandomDataCacheSpill";
|
||||
|
||||
session_id_type env_session;
|
||||
rc = GetSessionFromEnv(&env_session);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// Start with an empty execution tree
|
||||
auto myTree = std::make_shared<ExecutionTree>();
|
||||
|
||||
|
@ -353,8 +389,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) {
|
|||
int64_t start_index = 0;
|
||||
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
|
||||
CacheClient::Builder builder;
|
||||
// use arbitrary session of 1, size of 0, spilling// is true
|
||||
builder.SetSessionId(1).SetCacheMemSz(4).SetSpill(true);
|
||||
builder.SetSessionId(env_session).SetCacheMemSz(4).SetSpill(true);
|
||||
std::shared_ptr<CacheClient> myClient;
|
||||
rc = builder.Build(&myClient);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
@ -386,7 +421,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) {
|
|||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
MS_LOG(INFO) << "Launching tree and begin iteration";
|
||||
rc = myTree->Prepare();
|
||||
rc = myTree->Prepare(1);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
std::cout << *myClient << std::endl;
|
||||
|
@ -413,14 +448,20 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) {
|
|||
}
|
||||
|
||||
TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
|
||||
// Clear the rc of the master thread if any
|
||||
(void)TaskManager::GetMasterThreadRc();
|
||||
Status rc;
|
||||
int64_t num_samples = 0;
|
||||
int64_t start_index = 0;
|
||||
|
||||
session_id_type env_session;
|
||||
rc = GetSessionFromEnv(&env_session);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
|
||||
|
||||
CacheClient::Builder ccbuilder;
|
||||
// use arbitrary session of 1, size of 0, spilling// is true
|
||||
ccbuilder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true);
|
||||
ccbuilder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true);
|
||||
std::shared_ptr<CacheClient> myClient;
|
||||
rc = ccbuilder.Build(&myClient);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
@ -468,7 +509,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
|
|||
rc = myCacheOp->AddChild(so);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
rc = myTree->Prepare();
|
||||
rc = myTree->Prepare(1);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = myTree->Launch();
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
@ -507,10 +548,16 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
|
|||
//// RandomDataOp
|
||||
////
|
||||
TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) {
|
||||
// Clear the rc of the master thread if any
|
||||
(void)TaskManager::GetMasterThreadRc();
|
||||
Status rc;
|
||||
int32_t rank = 0; // not used
|
||||
MS_LOG(INFO) << "UT test TestCacheInheritSampler";
|
||||
|
||||
session_id_type env_session;
|
||||
rc = GetSessionFromEnv(&env_session);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
int64_t num_samples = 0;
|
||||
int64_t start_index = 0;
|
||||
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
|
||||
|
@ -550,7 +597,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) {
|
|||
// CacheOp
|
||||
CacheClient::Builder ccbuilder;
|
||||
// use arbitrary session of 1, size of 0, spilling// is true
|
||||
ccbuilder.SetSessionId(1).SetCacheMemSz(4).SetSpill(true);
|
||||
ccbuilder.SetSessionId(env_session).SetCacheMemSz(4).SetSpill(true);
|
||||
std::shared_ptr<CacheClient> myClient;
|
||||
rc = ccbuilder.Build(&myClient);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
@ -577,7 +624,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) {
|
|||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
MS_LOG(INFO) << "Launching tree and begin iteration";
|
||||
rc = myTree->Prepare();
|
||||
rc = myTree->Prepare(1);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
std::cout << *myClient << std::endl;
|
||||
|
|
|
@ -62,9 +62,8 @@ TEST_F(MindDataTestMemoryPool, TestAllocator) {
|
|||
class A {
|
||||
public:
|
||||
explicit A(int x) : a(x) {}
|
||||
int val_a() const {
|
||||
return a;
|
||||
}
|
||||
int val_a() const { return a; }
|
||||
|
||||
private:
|
||||
int a;
|
||||
};
|
||||
|
@ -74,3 +73,16 @@ TEST_F(MindDataTestMemoryPool, TestAllocator) {
|
|||
ASSERT_EQ(v, 3);
|
||||
MS_LOG(DEBUG) << *(std::dynamic_pointer_cast<CircularPool>(mp_)) << std::endl;
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestMemoryPool, TestMemGuard) {
|
||||
MemGuard<uint8_t> mem;
|
||||
// Try some large value.
|
||||
int64_t sz = 5LL * 1024LL * 1024LL * 1024LL;
|
||||
Status rc = mem.allocate(sz);
|
||||
ASSERT_TRUE(rc.IsOk() || rc.IsOutofMemory());
|
||||
if (rc.IsOk()) {
|
||||
// Try write a character half way.
|
||||
auto *p = mem.GetMutablePointer();
|
||||
p[sz / 2] = 'a';
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
# This script is the driver of the individual test scenarios
|
||||
CURRPATH=$(cd $(dirname $0); pwd)
|
||||
|
||||
echo "----------------------------------------------"
|
||||
echo "Invalid syntax and cache_admin failure testing"
|
||||
echo "----------------------------------------------"
|
||||
echo
|
||||
${CURRPATH}/cachetest_args.sh
|
||||
num_failures=$?
|
||||
echo
|
||||
echo "Invalid syntax and cache_admin failure testing complete. Number of failures: $num_failures"
|
||||
echo
|
||||
|
||||
echo "----------------------------------------------"
|
||||
echo "Test pipelines with cache (python)"
|
||||
echo "----------------------------------------------"
|
||||
echo
|
||||
${CURRPATH}/cachetest_py.sh
|
||||
num_failures=$?
|
||||
echo
|
||||
echo "Test pipelines with cache complete. Number of failures: $num_failures"
|
||||
echo
|
||||
|
||||
echo "----------------------------------------------"
|
||||
echo "Cache cpp tests"
|
||||
echo "----------------------------------------------"
|
||||
echo
|
||||
${CURRPATH}/cachetest_cpp.sh
|
||||
num_failures=$?
|
||||
echo
|
||||
echo "Cache cpp tests complete. Number of failures: $num_failures"
|
||||
echo
|
|
@ -0,0 +1,207 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
# source the globals and functions for use with cache testing
|
||||
SKIP_ADMIN_COUNTER=false
|
||||
. cachetest_lib.sh
|
||||
echo
|
||||
|
||||
################################################################################
|
||||
# Cache testing: cache_admin argument testing #
|
||||
# Summary: Various tests that expect to get failure messages returned #
|
||||
################################################################################
|
||||
|
||||
# Double-command test
|
||||
cmd="${CACHE_ADMIN} --start --stop"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
# missing command test
|
||||
cmd="${CACHE_ADMIN} --port 50082"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
# bad arg test
|
||||
cmd="${CACHE_ADMIN} -p abc --start"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
# missing arg test
|
||||
cmd="${CACHE_ADMIN} -p --start"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
# invalid command
|
||||
cmd="${CACHE_ADMIN} -p 50082 --start --not_exist_cmd"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
# spill directory does not exist
|
||||
cmd="${CACHE_ADMIN} --start --spilldir /path_that_does_not_exist"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
# start cache server twice
|
||||
StartServer
|
||||
HandleRcExit $? 1 1
|
||||
# start the cache server again, however, this time we expect an error
|
||||
cmd="${CACHE_ADMIN} --start"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 1
|
||||
StopServer
|
||||
HandleRcExit $? 1 1
|
||||
|
||||
# start cache server twice with different ports
|
||||
# this one starts with the default port 50052
|
||||
StartServer
|
||||
HandleRcExit $? 1 1
|
||||
# this one starts with port 50053
|
||||
cmd="${CACHE_ADMIN} --start -p 50053"
|
||||
CacheAdminCmd "${cmd}" 0
|
||||
HandleRcExit $? 1 1
|
||||
# stop the cache server with default port
|
||||
StopServer
|
||||
HandleRcExit $? 1 1
|
||||
# stop the cache server with port 50053
|
||||
cmd="${CACHE_ADMIN} --stop -p 50053"
|
||||
CacheAdminCmd "${cmd}" 0
|
||||
HandleRcExit $? 1 1
|
||||
|
||||
# stop the cache server without bringing it up
|
||||
cmd="${CACHE_ADMIN} --stop"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 1
|
||||
|
||||
# start the cache server with illegal hostname
|
||||
cmd="${CACHE_ADMIN} --start -h 0.0.0.0"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 1
|
||||
cmd="${CACHE_ADMIN} --start -h illegal"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 1
|
||||
cmd="${CACHE_ADMIN} --start -h"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 1
|
||||
cmd="${CACHE_ADMIN} --start -h --hostname"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 1
|
||||
cmd="${CACHE_ADMIN} --start -h --hostname 127.0.0.1"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 1
|
||||
|
||||
# start the cache server with illegal port
|
||||
cmd="${CACHE_ADMIN} --start -p 0"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 1
|
||||
cmd="${CACHE_ADMIN} --start -p -1"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 1
|
||||
cmd="${CACHE_ADMIN} --start -p 65536"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 1
|
||||
cmd="${CACHE_ADMIN} --start -p illegal"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 1
|
||||
cmd="${CACHE_ADMIN} --start -p"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 1
|
||||
|
||||
# find a port that is occupied using netstat
|
||||
if [ -x "$(command -v netstat)" ]; then
|
||||
port=$(netstat -ntp | grep -v '::' | awk '{print $4}' | grep -E '^[[:digit:]]+' | awk -F: '{print $2}' | sort -n | tail -n 1)
|
||||
if [ ${port} -gt 1025 ]; then
|
||||
# start cache server with occupied port
|
||||
cmd="${CACHE_ADMIN} --start -p ${port}"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# generate session before starting the cache server
|
||||
cmd="${CACHE_ADMIN} -g"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
# illegal generate session command
|
||||
StartServer
|
||||
HandleRcExit $? 1 1
|
||||
cmd="${CACHE_ADMIN} -g 1"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
# illegal destroy session command
|
||||
cmd="${CACHE_ADMIN} -d -2"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
cmd="${CACHE_ADMIN} -d illegal"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
cmd="${CACHE_ADMIN} -d"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
# destroy a non-existing session
|
||||
cmd="${CACHE_ADMIN} -d 99999"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
# stop cache server at this point
|
||||
StopServer
|
||||
HandleRcExit $? 1 1
|
||||
|
||||
# illegal number of workers
|
||||
cmd="${CACHE_ADMIN} --start -w 0"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
cmd="${CACHE_ADMIN} --start -w -1"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
cmd="${CACHE_ADMIN} --start -w illegal"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
cmd="${CACHE_ADMIN} --start -w 101"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
cmd="${CACHE_ADMIN} --start -w 9999999"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
cmd="${CACHE_ADMIN} --start -w"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
# illegal spill path
|
||||
cmd="${CACHE_ADMIN} --start -s"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
# spill path without writing perm
|
||||
if [ "$EUID" -ne 0 ]; then
|
||||
cmd="${CACHE_ADMIN} --start -s /"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
fi
|
||||
|
||||
# illegal log level
|
||||
cmd="${CACHE_ADMIN} --start -l 4"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
cmd="${CACHE_ADMIN} --start -l -1"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
cmd="${CACHE_ADMIN} --start -l"
|
||||
CacheAdminCmd "${cmd}" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
exit ${failed_tests}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue