!5930 Cache server phase 2 single node

Merge pull request !5930 from Jamie/CacheOp_dev
This commit is contained in:
mindspore-ci-bot 2020-09-25 03:54:55 +08:00 committed by Gitee
commit e88e114a50
106 changed files with 5764 additions and 969 deletions

View File

@ -36,6 +36,7 @@ include(CPack)
set(INSTALL_LIB_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Installation directory for libraries")
set(INSTALL_PY_DIR ".")
set(INSTALL_BASE_DIR ".")
set(INSTALL_BIN_DIR "bin")
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
set(INSTALL_LIB_DIR ".")
@ -78,7 +79,14 @@ if (ENABLE_MINDDATA)
DESTINATION ${INSTALL_BASE_DIR}
COMPONENT mindspore
)
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
install(
TARGETS cache_admin cache_server
OPTIONAL
DESTINATION ${INSTALL_BIN_DIR}
COMPONENT mindspore
)
endif()
file(GLOB_RECURSE OPENCV_LIB_LIST
${opencv_LIBPATH}/libopencv_core*
${opencv_LIBPATH}/libopencv_imgcodecs*

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include <optional>
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/engine/cache/cache_client.h"
@ -22,17 +23,19 @@ namespace dataset {
PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) {
(void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient")
.def(
py::init([](session_id_type id, uint64_t mem_sz, bool spill, std::optional<std::string> hostname,
std::optional<int32_t> port, int32_t prefetch_sz) {
std::shared_ptr<CacheClient> cc;
CacheClient::Builder builder;
builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPrefetchSize(prefetch_sz);
if (hostname) builder.SetHostname(hostname.value());
if (port) builder.SetPort(port.value());
THROW_IF_ERROR(builder.Build(&cc));
return cc;
}))
.def(py::init([](session_id_type id, uint64_t mem_sz, bool spill,
std::optional<std::string> hostname, std::optional<int32_t> port,
std::optional<int32_t> num_connections, std::optional<int32_t> prefetch_sz) {
std::shared_ptr<CacheClient> cc;
CacheClient::Builder builder;
builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill);
if (hostname) builder.SetHostname(hostname.value());
if (port) builder.SetPort(port.value());
if (num_connections) builder.SetNumConnections(num_connections.value());
if (prefetch_sz) builder.SetPrefetchSize(prefetch_sz.value());
THROW_IF_ERROR(builder.Build(&cc));
return cc;
}))
.def("GetStat", [](CacheClient &cc) {
CacheServiceStat stat{};
THROW_IF_ERROR(cc.GetStat(&stat));

View File

@ -18,6 +18,7 @@
#include <fstream>
#include <iostream>
#include <string>
#include <utility>
#include "mindspore/core/utils/log_adapter.h"
#include "minddata/dataset/util/system_pool.h"
@ -33,7 +34,9 @@ ConfigManager::ConfigManager()
monitor_sampling_interval_(kCfgMonitorSamplingInterval),
callback_timout_(kCfgCallbackTimeout),
cache_host_(kCfgDefaultCacheHost),
cache_port_(kCfgDefaultCachePort) {
cache_port_(kCfgDefaultCachePort),
num_connections_(kDftNumConnections),
prefetch_size_(kDftPrefetchSize) {
auto env_cache_host = std::getenv("MS_CACHE_HOST");
auto env_cache_port = std::getenv("MS_CACHE_PORT");
if (env_cache_host != nullptr) {
@ -71,6 +74,8 @@ Status ConfigManager::FromJson(const nlohmann::json &j) {
set_monitor_sampling_interval(j.value("monitorSamplingInterval", monitor_sampling_interval_));
set_cache_host(j.value("cacheHost", cache_host_));
set_cache_port(j.value("cachePort", cache_port_));
set_num_connections(j.value("numConnections", num_connections_));
set_prefetch_size(j.value("prefetchSize", prefetch_size_));
return Status::OK();
}
@ -120,8 +125,12 @@ void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_s
void ConfigManager::set_callback_timeout(uint32_t timeout) { callback_timout_ = timeout; }
void ConfigManager::set_cache_host(std::string cache_host) { cache_host_ = cache_host; }
void ConfigManager::set_cache_host(std::string cache_host) { cache_host_ = std::move(cache_host); }
void ConfigManager::set_cache_port(int32_t cache_port) { cache_port_ = cache_port; }
void ConfigManager::set_num_connections(int32_t num_connections) { num_connections_ = num_connections; }
void ConfigManager::set_prefetch_size(int32_t prefetch_size) { prefetch_size_ = prefetch_size; }
} // namespace dataset
} // namespace mindspore

View File

@ -97,6 +97,14 @@ class ConfigManager {
// @return The port of cache server
int32_t cache_port() const { return cache_port_; }
/// getter function
/// \return Number of tcp/ip connection
int32_t num_connections() const { return num_connections_; }
/// getter function
/// \return Prefetch size
int32_t prefetch_size() const { return prefetch_size_; }
// setter function
// @param rows_per_buffer - The setting to apply to the config
void set_rows_per_buffer(int32_t rows_per_buffer);
@ -121,6 +129,14 @@ class ConfigManager {
// @param cache_port - The port of cache server
void set_cache_port(int32_t cache_port);
/// setter function
/// \param num_connections
void set_num_connections(int32_t num_connections);
/// setter function
/// \param prefetch_size
void set_prefetch_size(int32_t prefetch_size);
uint32_t seed() const;
// setter function
@ -153,6 +169,8 @@ class ConfigManager {
uint32_t callback_timout_;
std::string cache_host_;
int32_t cache_port_;
int32_t num_connections_;
int32_t prefetch_size_;
// Private helper function that takes a nlohmann json format and populates the settings
// @param j - The json nlohmann json info

View File

@ -71,6 +71,8 @@ constexpr uint32_t kCfgMonitorSamplingInterval = 10;
constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds
constexpr int32_t kCfgDefaultCachePort = 50052;
constexpr char kCfgDefaultCacheHost[] = "127.0.0.1";
constexpr int32_t kDftPrefetchSize = 20;
constexpr int32_t kDftNumConnections = 12;
// Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h)
constexpr uint8_t kCVInvalidType = 255;

View File

@ -79,6 +79,14 @@ class TensorRow {
const vector_type &getRow() const { return row_; }
int64_t SizeInBytes() const {
size_t sz = 0;
for (auto &it : row_) {
sz += it->SizeInBytes();
}
return sz;
}
// Wrapper functions to support vector operations
void emplace_back(value_type t) { row_.emplace_back(t); }

View File

@ -12,7 +12,9 @@ add_library(engine-cache-client OBJECT
if (ENABLE_CACHE)
ms_grpc_generate(CACHE_GRPC_SRCS CACHE_GRPC_HDRS cache_grpc.proto)
target_sources(engine-cache-client PUBLIC ${CACHE_GRPC_SRCS} cache_grpc_client.cc)
target_sources(engine-cache-client PUBLIC ${CACHE_GRPC_SRCS}
cache_grpc_client.cc
cache_ipc.cc)
add_library(engine-cache-server OBJECT
${CACHE_GRPC_SRCS}

View File

@ -37,12 +37,17 @@ int main(int argc, char **argv) {
warningMsg += "WARNING:\n";
warningMsg += "cache_admin and the cache server that it controls are currently only used for experimental research";
warningMsg += " purposes at this time.\n";
warningMsg += "This command is currently disabled. Quitting.\n";
auto env_enable_cache = std::getenv("MS_ENABLE_CACHE");
if (env_enable_cache == nullptr || strcmp(env_enable_cache, "TRUE") != 0) {
// temporary disable cache feature in the current release
warningMsg += "This command is currently disabled. Quitting.\n";
std::cerr << warningMsg << std::endl;
return 0;
}
warningMsg += "It is not intended for general availability yet as it may not be stable. Use it at your own risk.\n";
// A warning message until the code is mature enough.
std::cerr << warningMsg << std::endl;
// temporary disable cache feature in the current release
return 0;
if (argc == 1) {
args.Help();

View File

@ -19,9 +19,11 @@
#include <sys/wait.h>
#include <unistd.h>
#include <cerrno>
#include <iomanip>
#include <iostream>
#include <string>
#include <cstdlib>
#include <vector>
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/util/path.h"
@ -39,6 +41,7 @@ CacheAdminArgHandler::CacheAdminArgHandler()
num_workers_(kDefaultNumWorkers),
shm_mem_sz_(kDefaultSharedMemorySizeInGB),
log_level_(kDefaultLogLevel),
memory_cap_ratio_(kMemoryCapRatio),
hostname_(kCfgDefaultCacheHost),
spill_dir_(kDefaultSpillDir),
command_id_(CommandId::kCmdUnknown) {
@ -62,6 +65,9 @@ CacheAdminArgHandler::CacheAdminArgHandler()
arg_map_["--shared_memory_size"] = ArgValue::kArgSharedMemorySize;
arg_map_["-l"] = ArgValue::kArgLogLevel;
arg_map_["--minloglevel"] = ArgValue::kArgLogLevel;
arg_map_["-r"] = ArgValue::kArgMemoryCapRatio;
arg_map_["--memory_cap_ratio"] = ArgValue::kArgMemoryCapRatio;
arg_map_["--list_sessions"] = ArgValue::kArgListSessions;
// Initialize argument tracker with false values
for (int16_t i = 0; i < static_cast<int16_t>(ArgValue::kArgNumArgs); ++i) {
ArgValue currAV = static_cast<ArgValue>(i);
@ -69,6 +75,8 @@ CacheAdminArgHandler::CacheAdminArgHandler()
}
}
CacheAdminArgHandler::~CacheAdminArgHandler() = default;
Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream,
CommandId command_id) {
// Detect if the user tried to provide this argument more than once
@ -102,7 +110,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std
return Status(StatusCode::kSyntaxError, err_msg);
}
// Now, attempt to convert the value into it's string format for output
// Now, attempt to convert the value into it's numeric format for output
try {
*out_arg = std::stoul(value_as_string);
} catch (const std::exception &e) {
@ -140,7 +148,13 @@ Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg,
// If there is no argument to get, such as the --start command, then out_arg will be a nullptr.
if (out_arg != nullptr) {
// Fetch the argument from the arg stream into a string
*arg_stream >> *out_arg;
if (arg_stream->rdbuf()->in_avail() != 0) {
*arg_stream >> *out_arg;
} else {
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
return Status(StatusCode::kSyntaxError, err_msg);
}
if (out_arg->empty()) {
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
return Status(StatusCode::kSyntaxError, err_msg);
@ -150,12 +164,62 @@ Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg,
return Status::OK();
}
Status CacheAdminArgHandler::AssignArg(std::string option, float *out_arg, std::stringstream *arg_stream,
CommandId command_id) {
// Detect if the user tried to provide this argument more than once
ArgValue selected_arg = arg_map_[option];
if (used_args_[selected_arg]) {
std::string err_msg = "The " + option + " argument was given more than once.";
return Status(StatusCode::kSyntaxError, err_msg);
}
// Flag that this arg is used now
used_args_[selected_arg] = true;
// Some options are just arguments, for example "--hostname "127.0.0.1" is not a command, it's just an argument.
// Other options are actual commands, for example "--start".
// If this option is also a command, make sure there has not been multiple commands given before assigning it.
if (command_id != CommandId::kCmdUnknown) {
if (command_id_ != CommandId::kCmdUnknown) {
std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option;
return Status(StatusCode::kSyntaxError, err_msg);
} else {
command_id_ = command_id;
}
}
std::string value_as_string;
// Fetch the argument from the arg stream into a string
*arg_stream >> value_as_string;
if (value_as_string.empty()) {
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
return Status(StatusCode::kSyntaxError, err_msg);
}
// Now, attempt to convert the value into it's string format for output
try {
*out_arg = std::stof(value_as_string, nullptr);
} catch (const std::exception &e) {
std::string err_msg = "Invalid numeric value: " + value_as_string;
return Status(StatusCode::kSyntaxError, err_msg);
}
return Status::OK();
}
Status CacheAdminArgHandler::ParseArgStream(std::stringstream *arg_stream) {
std::string tok;
while (*arg_stream >> tok) {
switch (arg_map_[tok]) {
case ArgValue::kArgHost: {
RETURN_IF_NOT_OK(AssignArg(tok, &hostname_, arg_stream));
// Temporary sanity check. We only support localhost for now
if (hostname_ != std::string(kCfgDefaultCacheHost)) {
std::string err_msg =
"Invalid host interface: " + hostname_ + ". Current limitation, only 127.0.0.1 can be used.";
return Status(StatusCode::kSyntaxError, err_msg);
}
break;
}
case ArgValue::kArgPort: {
@ -203,6 +267,14 @@ Status CacheAdminArgHandler::ParseArgStream(std::stringstream *arg_stream) {
RETURN_IF_NOT_OK(AssignArg(tok, &log_level_, arg_stream));
break;
}
case ArgValue::kArgMemoryCapRatio: {
RETURN_IF_NOT_OK(AssignArg(tok, &memory_cap_ratio_, arg_stream));
break;
}
case ArgValue::kArgListSessions: {
RETURN_IF_NOT_OK(AssignArg(tok, static_cast<std::string *>(nullptr), arg_stream, CommandId::kCmdListSessions));
break;
}
default: {
// Save space delimited trailing arguments
trailing_args_ += (" " + tok);
@ -232,9 +304,12 @@ Status CacheAdminArgHandler::Validate() {
}
// Additional checks here
if (num_workers_ < 1) return Status(StatusCode::kSyntaxError, "Number of workers must be positive value.");
if (num_workers_ < 1 || num_workers_ > 100)
return Status(StatusCode::kSyntaxError, "Number of workers must be in range of 1 and 100.");
if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kSyntaxError, "Log level must be in range (0..3).");
// port range check?
if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1)
return Status(StatusCode::kSyntaxError, "Memory cap ratio should be positive and no greater than 1");
if (port_ < 1025 || port_ > 65535) return Status(StatusCode::kSyntaxError, "Port must be in range (1025..65535).");
return Status::OK();
}
@ -245,12 +320,9 @@ Status CacheAdminArgHandler::RunCommand() {
Help();
break;
}
case CommandId::kCmdStart: {
RETURN_IF_NOT_OK(StartServer());
break;
}
case CommandId::kCmdStart:
case CommandId::kCmdStop: {
RETURN_IF_NOT_OK(StopServer());
RETURN_IF_NOT_OK(StartStopServer(command_id_));
break;
}
case CommandId::kCmdGenerateSession: {
@ -259,7 +331,7 @@ Status CacheAdminArgHandler::RunCommand() {
auto rq = std::make_shared<GenerateSessionIdRequest>();
RETURN_IF_NOT_OK(comm.HandleRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
std::cout << rq->GetSessionId() << std::endl;
std::cout << "Session: " << rq->GetSessionId() << std::endl;
break;
}
case CommandId::kCmdDestroySession: {
@ -273,6 +345,39 @@ Status CacheAdminArgHandler::RunCommand() {
std::cout << "Drop session successful" << std::endl;
break;
}
case CommandId::kCmdListSessions: {
CacheClientGreeter comm(hostname_, port_, 1);
RETURN_IF_NOT_OK(comm.ServiceStart());
auto rq = std::make_shared<ListSessionsRequest>();
RETURN_IF_NOT_OK(comm.HandleRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
std::vector<SessionCacheInfo> session_info = rq->GetSessionCacheInfo();
if (!session_info.empty()) {
std::cout << std::setw(12) << "Session" << std::setw(12) << "Cache Id" << std::setw(12) << "Mem cached"
<< std::setw(12) << "Disk cached" << std::setw(16) << "Avg cache size" << std::endl;
for (auto curr_session : session_info) {
std::string cache_id;
std::string stat_mem_cached;
std::string stat_disk_cached;
std::string stat_avg_cached;
int32_t crc = (curr_session.connection_id & 0x00000000FFFFFFFF);
cache_id = (curr_session.connection_id == 0) ? "n/a" : std::to_string(crc);
stat_mem_cached =
(curr_session.stats.num_mem_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_mem_cached);
stat_disk_cached =
(curr_session.stats.num_disk_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_disk_cached);
stat_avg_cached =
(curr_session.stats.avg_cache_sz == 0) ? "n/a" : std::to_string(curr_session.stats.avg_cache_sz);
std::cout << std::setw(12) << curr_session.session_id << std::setw(12) << cache_id << std::setw(12)
<< stat_mem_cached << std::setw(12) << stat_disk_cached << std::setw(16) << stat_avg_cached
<< std::endl;
}
} else {
std::cout << "No active sessions." << std::endl;
}
break;
}
default: {
RETURN_STATUS_UNEXPECTED("Invalid cache admin command id.");
break;
@ -282,7 +387,7 @@ Status CacheAdminArgHandler::RunCommand() {
return Status::OK();
}
Status CacheAdminArgHandler::StartServer() {
Status CacheAdminArgHandler::StartStopServer(CommandId command_id) {
// There currently does not exist any "install path" or method to identify which path the installed binaries will
// exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the
// cache_admin binary (this process).
@ -324,7 +429,10 @@ Status CacheAdminArgHandler::StartServer() {
close(fd[1]);
dup2(fd[0], 0);
close(fd[0]);
wait(nullptr);
int status;
if (waitpid(pid, &status, 0) == -1) {
RETURN_STATUS_UNEXPECTED("waitpid fails. errno = " + std::to_string(errno));
}
std::string msg;
const int32_t buf_sz = 1024;
msg.resize(buf_sz);
@ -335,6 +443,13 @@ Status CacheAdminArgHandler::StartServer() {
}
msg.resize(n);
std::cout << msg << std::endl;
if (WIFEXITED(status)) {
auto exit_status = WEXITSTATUS(status);
if (exit_status) {
std::string errMsg = "Child exit status " + std::to_string(exit_status);
return Status(StatusCode::kUnexpectedError, errMsg);
}
}
return Status::OK();
} else {
// Child here ...
@ -350,19 +465,29 @@ Status CacheAdminArgHandler::StartServer() {
std::string shared_memory_string = std::to_string(shm_mem_sz_);
std::string minloglevel_string = std::to_string(log_level_);
std::string daemonize_string = "true";
std::string memory_cap_ratio_string = std::to_string(memory_cap_ratio_);
char *argv[8];
argv[0] = cache_server_binary.data(); // First arg is usually the binary name
argv[1] = spill_dir_.data();
argv[2] = workers_string.data();
argv[3] = port_string.data();
argv[4] = shared_memory_string.data();
argv[5] = minloglevel_string.data();
argv[6] = daemonize_string.data();
argv[7] = nullptr;
char *argv[9];
if (command_id == CommandId::kCmdStart) {
argv[0] = cache_server_binary.data();
argv[1] = spill_dir_.data();
argv[2] = workers_string.data();
argv[3] = port_string.data();
argv[4] = shared_memory_string.data();
argv[5] = minloglevel_string.data();
argv[6] = daemonize_string.data();
argv[7] = memory_cap_ratio_string.data();
argv[8] = nullptr;
} else {
// We are doing a --stop. Change the name to '-' and we also need the port number.
// The rest we don't need.
argv[0] = std::string("-").data();
argv[1] = port_string.data();
argv[2] = nullptr;
}
// Now exec the binary
execv(argv[0], argv);
execv(cache_server_binary.data(), argv);
// If the exec was successful, this line will never be reached due to process image being replaced.
// ..unless exec failed.
std::string err_msg = "Failed to exec cache server: " + cache_server_binary;
@ -371,16 +496,6 @@ Status CacheAdminArgHandler::StartServer() {
}
}
Status CacheAdminArgHandler::StopServer() {
CacheClientGreeter comm(hostname_, port_, 1);
RETURN_IF_NOT_OK(comm.ServiceStart());
auto rq = std::make_shared<ShutdownRequest>();
RETURN_IF_NOT_OK(comm.HandleRequest(rq));
// We will ignore the rc because if the shutdown is successful, the server will not reply back.
(void)rq->Wait();
return Status::OK();
}
void CacheAdminArgHandler::Help() {
std::cerr << "Syntax:\n";
std::cerr << " cache_admin [--start | --stop]\n";
@ -390,8 +505,12 @@ void CacheAdminArgHandler::Help() {
std::cerr << " [ [-d | --destroy_session] <session id> ]\n";
std::cerr << " [ [-w | --workers] <number of workers> ]\n";
std::cerr << " [ [-s | --spilldir] <spilling directory> ]\n";
std::cerr << " [ [-m | --shared_memory_size] <shared memory size> ]\n";
std::cerr << " [ [-l | --minloglevel] <log level> ]\n";
std::cerr << " [ --list_sessions ]\n";
// Do not expose these option to the user via help or documentation, but the options do exist to aid with
// development and tuning.
// std::cerr << " [ [-m | --shared_memory_size] <shared memory size> ]\n";
// std::cerr << " [ [-r | --memory_cap_ratio] <float percent value>]\n";
std::cerr << " [--help]" << std::endl;
}
} // namespace dataset

View File

@ -32,6 +32,7 @@ class CacheAdminArgHandler {
static constexpr int32_t kDefaultNumWorkers = 32;
static constexpr int32_t kDefaultSharedMemorySizeInGB = 4;
static constexpr int32_t kDefaultLogLevel = 1;
static constexpr float kMemoryCapRatio = 0.8;
static const char kServerBinary[];
static const char kDefaultSpillDir[];
@ -42,12 +43,13 @@ class CacheAdminArgHandler {
kCmdStop = 2,
kCmdGenerateSession = 3,
kCmdDestroySession = 4,
kCmdListSessions = 5,
kCmdUnknown = 32767
};
CacheAdminArgHandler();
~CacheAdminArgHandler() = default;
virtual ~CacheAdminArgHandler();
Status ParseArgStream(std::stringstream *arg_stream);
@ -70,12 +72,12 @@ class CacheAdminArgHandler {
kArgNumWorkers = 9,
kArgSharedMemorySize = 10,
kArgLogLevel = 11,
kArgNumArgs = 12 // Must be the last position to provide a count
kArgMemoryCapRatio = 12,
kArgListSessions = 13,
kArgNumArgs = 14 // Must be the last position to provide a count
};
Status StartServer();
Status StopServer();
Status StartStopServer(CommandId);
Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream,
CommandId command_id = CommandId::kCmdUnknown);
@ -83,6 +85,9 @@ class CacheAdminArgHandler {
Status AssignArg(std::string option, std::string *out_arg, std::stringstream *arg_stream,
CommandId command_id = CommandId::kCmdUnknown);
Status AssignArg(std::string option, float *out_arg, std::stringstream *arg_stream,
CommandId command_id = CommandId::kCmdUnknown);
Status Validate();
CommandId command_id_;
@ -90,6 +95,7 @@ class CacheAdminArgHandler {
int32_t num_workers_;
int32_t shm_mem_sz_;
int32_t log_level_;
float memory_cap_ratio_;
session_id_type session_id_;
std::string hostname_;
std::string spill_dir_;

View File

@ -17,27 +17,19 @@
#include "minddata/dataset/util/path.h"
namespace mindspore {
namespace dataset {
CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB)
: ptr_(nullptr), val_in_GB_(val_in_GB), port_(port), shmid_(-1) {}
CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB) : val_in_GB_(val_in_GB), port_(port) {
// We create the shared memory and we will destroy it. All other client just detach only.
shm_.RemoveResourcesOnExit();
}
CachedSharedMemoryArena::~CachedSharedMemoryArena() {
#if CACHE_LOCAL_CLIENT
if (this->ptr_ != nullptr && this->ptr_ != reinterpret_cast<void *>(-1)) {
shmdt(this->ptr_);
}
this->ptr_ = nullptr;
if (shmid_ != -1) {
shmctl(shmid_, IPC_RMID, nullptr);
// Also remove the path we use to generate ftok.
Path p(PortToUnixSocketPath(port_));
(void)p.Remove();
}
#endif
// Also remove the path we use to generate ftok.
Path p(PortToUnixSocketPath(port_));
(void)p.Remove();
}
Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port,
size_t val_in_GB) {
RETURN_UNEXPECTED_IF_NULL(out);
#if CACHE_LOCAL_CLIENT
auto ba = new (std::nothrow) CachedSharedMemoryArena(port, val_in_GB);
if (ba == nullptr) {
return Status(StatusCode::kOutOfMemory);
@ -46,26 +38,13 @@ Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryAr
// the destructor of *out to deal.
(*out).reset(ba);
// Generate the ftok using a combination of port.
int err;
auto shm_key = PortToFtok(port, &err);
if (shm_key == (key_t)-1) {
std::string errMsg = "Ftok failed with errno " + std::to_string(err);
RETURN_STATUS_UNEXPECTED(errMsg);
}
auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP;
SharedMemory::shm_key_t shm_key;
RETURN_IF_NOT_OK(PortToFtok(port, &shm_key));
ba->shm_.SetPublicKey(shm_key);
// Value is in GB. Convert into bytes.
int64_t sz = val_in_GB * 1073741824L;
ba->shmid_ = shmget(shm_key, sz, IPC_CREAT | IPC_EXCL | access_mode);
if (ba->shmid_) {
ba->ptr_ = shmat(ba->shmid_, nullptr, 0);
if (ba->ptr_ == reinterpret_cast<void *>(-1)) {
RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno));
}
ba->impl_ = std::make_unique<ArenaImpl>(ba->ptr_, sz);
} else {
RETURN_STATUS_UNEXPECTED("Shared memory creation failed. Errno " + std::to_string(errno));
}
#endif
RETURN_IF_NOT_OK(ba->shm_.Create(sz));
ba->impl_ = std::make_unique<ArenaImpl>(ba->shm_.SharedMemoryBaseAddr(), sz);
return Status::OK();
}
} // namespace dataset

View File

@ -21,6 +21,7 @@
#include <string>
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/engine/cache/cache_ipc.h"
namespace mindspore {
namespace dataset {
/// This is a derived class of Arena but resides in shared memory
@ -73,10 +74,9 @@ class CachedSharedMemoryArena : public MemoryPool {
private:
mutable std::mutex mux_;
void *ptr_;
int32_t val_in_GB_;
int32_t port_;
int shmid_;
SharedMemory shm_;
std::unique_ptr<ArenaImpl> impl_;
/// Private constructor. Not to be called directly.
CachedSharedMemoryArena(int32_t port, size_t val_in_GB);

View File

@ -24,26 +24,26 @@
namespace mindspore {
namespace dataset {
CacheClient::Builder::Builder()
: session_id_(0), cache_mem_sz_(0), spill_(false), hostname_(""), port_(0), num_workers_(0), prefetch_size_(0) {
: session_id_(0), cache_mem_sz_(0), spill_(false), hostname_(""), port_(0), num_connections_(0), prefetch_size_(0) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
hostname_ = cfg->cache_host();
port_ = cfg->cache_port();
num_workers_ = cfg->num_parallel_workers();
prefetch_size_ = 20; // rows_per_buf is too small (1 by default).
num_connections_ = cfg->num_connections(); // number of async tcp/ip connections
prefetch_size_ = cfg->prefetch_size(); // prefetch size
}
Status CacheClient::Builder::Build(std::shared_ptr<CacheClient> *out) {
RETURN_UNEXPECTED_IF_NULL(out);
RETURN_IF_NOT_OK(SanityCheck());
*out =
std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_, prefetch_size_);
*out = std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_connections_,
prefetch_size_);
return Status::OK();
}
Status CacheClient::Builder::SanityCheck() {
CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited");
CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(num_connections_ > 0, "rpc connections must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty");
CHECK_FAIL_RETURN_UNEXPECTED(port_ > 0, "port must be positive");
@ -55,26 +55,32 @@ Status CacheClient::Builder::SanityCheck() {
// Constructor
CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname,
int32_t port, int32_t num_workers, int32_t prefetch_size)
int32_t port, int32_t num_connections, int32_t prefetch_size)
: server_connection_id_(0),
cache_mem_sz_(cache_mem_sz),
spill_(spill),
local_bypass_(false),
hostname_(std::move(hostname)),
port_(port),
num_workers_(num_workers),
prefetch_size_(prefetch_size) {
num_connections_(num_connections),
prefetch_size_(prefetch_size),
fetch_all_keys_(true) {
cinfo_.set_session_id(session_id);
comm_ = std::make_shared<CacheClientGreeter>(hostname_, port_, num_workers_);
comm_ = std::make_shared<CacheClientGreeter>(hostname_, port_, num_connections_);
}
CacheClient::~CacheClient() {
cache_miss_keys_wp_.Set();
(void)comm_->ServiceStop();
}
// print method for display cache details
void CacheClient::Print(std::ostream &out) const {
out << " Session id: " << session_id() << "\n Cache crc: " << cinfo_.crc()
<< "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << getCacheMemSz()
<< "\n Spilling: " << std::boolalpha << isSpill() << "\n Hostname: " << getHostname()
<< "\n Port: " << getPort() << "\n Number of rpc workers: " << getNumWorkers()
<< "\n Prefetch size: " << getPrefetchSize() << "\n Local client support: " << std::boolalpha
<< "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << GetCacheMemSz()
<< "\n Spilling: " << std::boolalpha << isSpill() << "\n Hostname: " << GetHostname()
<< "\n Port: " << GetPort() << "\n Number of rpc workers: " << GetNumConnections()
<< "\n Prefetch size: " << GetPrefetchSize() << "\n Local client support: " << std::boolalpha
<< SupportLocalClient();
}
@ -199,14 +205,6 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
return Status::OK();
}
Status CacheClient::PurgeCache() {
UniqueLock lck(&mux_);
auto rq = std::make_shared<PurgeCacheRequest>(server_connection_id_);
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
return Status::OK();
}
Status CacheClient::DestroyCache() {
UniqueLock lck(&mux_);
auto rq = std::make_shared<DestroyCacheRequest>(server_connection_id_);
@ -253,5 +251,71 @@ Status CacheClient::BuildPhaseDone() const {
}
Status CacheClient::PushRequest(std::shared_ptr<BaseRequest> rq) const { return comm_->HandleRequest(std::move(rq)); }
void CacheClient::ServerRunningOutOfResources() {
bool expected = true;
if (fetch_all_keys_.compare_exchange_strong(expected, false)) {
Status rc;
// Server runs out of memory or disk space to cache any more rows.
// First of all, we will turn off the locking.
auto toggle_write_mode_rq = std::make_shared<ToggleWriteModeRequest>(server_connection_id_, false);
rc = PushRequest(toggle_write_mode_rq);
if (rc.IsError()) {
return;
}
// Wait until we can toggle the state of the server to non-locking
rc = toggle_write_mode_rq->Wait();
if (rc.IsError()) {
return;
}
// Now we get a list of all the keys not cached at the server so
// we can filter out at the prefetch level.
auto cache_miss_rq = std::make_shared<GetCacheMissKeysRequest>(server_connection_id_);
rc = PushRequest(cache_miss_rq);
if (rc.IsError()) {
return;
}
rc = cache_miss_rq->Wait();
if (rc.IsError()) {
return;
}
// We will get back a vector of row id between [min,max] that are absent in the server.
auto &row_id_buf = cache_miss_rq->reply_.result();
auto p = flatbuffers::GetRoot<TensorRowIds>(row_id_buf.data());
std::vector<row_id_type> row_ids;
auto sz = p->row_id()->size();
row_ids.reserve(sz);
for (auto i = 0; i < sz; ++i) {
row_ids.push_back(p->row_id()->Get(i));
}
cache_miss_keys_ = std::make_unique<CacheMissKeys>(row_ids);
// We are all set.
cache_miss_keys_wp_.Set();
}
}
CacheClient::CacheMissKeys::CacheMissKeys(const std::vector<row_id_type> &v) {
auto it = v.begin();
min_ = *it;
++it;
max_ = *it;
++it;
while (it != v.end()) {
gap_.insert(*it);
++it;
}
MS_LOG(WARNING) << "# of cache miss keys between min(" << min_ << ") and max(" << max_ << ") is " << gap_.size();
}
bool CacheClient::CacheMissKeys::KeyIsCacheMiss(row_id_type key) {
if (key > max_ || key < min_) {
return true;
} else if (key == min_ || key == max_) {
return false;
} else {
auto it = gap_.find(key);
return it != gap_.end();
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -16,8 +16,13 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_
#include <atomic>
#include <iostream>
#include <limits>
#include <memory>
#include <map>
#include <mutex>
#include <set>
#include <string>
#include <unordered_map>
#include <utility>
@ -31,6 +36,8 @@
#endif
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/util/lock.h"
#include "minddata/dataset/util/cond_var.h"
#include "minddata/dataset/util/queue_map.h"
namespace mindspore {
namespace dataset {
@ -89,10 +96,10 @@ class CacheClient {
}
/// Setter function to set number of async rpc workers
/// \param num_workers
/// \param num_connections
/// \return Builder object itself
Builder &SetNumWorkers(int32_t num_workers) {
num_workers_ = num_workers;
Builder &SetNumConnections(int32_t num_connections) {
num_connections_ = num_connections;
return *this;
}
@ -105,13 +112,13 @@ class CacheClient {
}
/// Getter functions
session_id_type getSessionId() const { return session_id_; }
uint64_t getCacheMemSz() const { return cache_mem_sz_; }
session_id_type GetSessionId() const { return session_id_; }
uint64_t GetCacheMemSz() const { return cache_mem_sz_; }
bool isSpill() const { return spill_; }
const std::string &getHostname() const { return hostname_; }
int32_t getPort() const { return port_; }
int32_t getNumWorkers() const { return num_workers_; }
int32_t getPrefetchSize() const { return prefetch_size_; }
int32_t GetPort() const { return port_; }
int32_t GetNumConnections() const { return num_connections_; }
int32_t GetPrefetchSize() const { return prefetch_size_; }
Status SanityCheck();
@ -123,7 +130,7 @@ class CacheClient {
bool spill_;
std::string hostname_;
int32_t port_;
int32_t num_workers_;
int32_t num_connections_;
int32_t prefetch_size_;
};
@ -132,10 +139,10 @@ class CacheClient {
/// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited
/// \param spill Spill to disk if out of memory
CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, int32_t port,
int32_t num_workers, int32_t prefetch_size);
int32_t num_connections, int32_t prefetch_size);
/// \brief Destructor
~CacheClient() { (void)comm_->ServiceStop(); }
~CacheClient();
/// \brief Send a TensorRow to the cache server
/// \param[in] row
@ -161,10 +168,6 @@ class CacheClient {
/// \return Status object
Status CreateCache(uint32_t tree_crc, bool generate_id);
/// \brief Purge a cache. Cache can be reused after reset.
/// \return Status object
Status PurgeCache();
/// \brief Destroy a cache. Like Purge but the cache is deleted and can't be reused.
/// \return Status object
Status DestroyCache();
@ -218,12 +221,31 @@ class CacheClient {
/// Getter functions
session_id_type session_id() const { return cinfo_.session_id(); }
uint64_t getCacheMemSz() const { return cache_mem_sz_; }
uint64_t GetCacheMemSz() const { return cache_mem_sz_; }
bool isSpill() const { return spill_; }
const std::string &getHostname() const { return hostname_; }
int32_t getPort() const { return port_; }
int32_t getNumWorkers() const { return num_workers_; }
int32_t getPrefetchSize() const { return prefetch_size_; }
const std::string &GetHostname() const { return hostname_; }
int32_t GetPort() const { return port_; }
int32_t GetNumConnections() const { return num_connections_; }
int32_t GetPrefetchSize() const { return prefetch_size_; }
/// MergeOp will notify us when the server can't cache any more rows.
/// We will stop any attempt to fetch any rows that are most likely
/// not present at the server.
void ServerRunningOutOfResources();
/// \brief Check if a row is 100% cache miss at the server by checking the local information
/// \param key row id to be test
/// \return true if not at the server
bool KeyIsCacheMiss(row_id_type key) {
if (cache_miss_keys_) {
// Make sure it is fully built even though the pointer is not null
Status rc = cache_miss_keys_wp_.Wait();
if (rc.IsOk()) {
return cache_miss_keys_->KeyIsCacheMiss(key);
}
}
return false;
}
private:
mutable RWLock mux_;
@ -240,9 +262,27 @@ class CacheClient {
bool local_bypass_;
std::string hostname_;
int32_t port_;
int32_t num_workers_;
int32_t num_connections_;
int32_t prefetch_size_;
mutable std::shared_ptr<CacheClientGreeter> comm_;
std::atomic<bool> fetch_all_keys_;
WaitPost cache_miss_keys_wp_;
/// A structure shared by all the prefetchers to know what keys are missing at the server.
class CacheMissKeys {
public:
explicit CacheMissKeys(const std::vector<row_id_type> &v);
~CacheMissKeys() = default;
/// This checks if a key is missing.
/// \param key
/// \return true if definitely a key miss
bool KeyIsCacheMiss(row_id_type key);
private:
row_id_type min_;
row_id_type max_;
std::set<row_id_type> gap_;
};
std::unique_ptr<CacheMissKeys> cache_miss_keys_;
};
} // namespace dataset
} // namespace mindspore

View File

@ -25,13 +25,6 @@
#define CACHE_LOCAL_CLIENT 1
#endif
#ifdef CACHE_LOCAL_CLIENT
#include <sys/types.h>
#include <sys/ipc.h>
#include <sys/shm.h>
#else
typedef int key_t;
#endif
#ifdef ENABLE_CACHE
#include <grpcpp/grpcpp.h>
#endif
@ -54,6 +47,8 @@ constexpr static uint32_t kLocalClientSupport = 1;
/// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is
/// inline in the protobuf. This also implies kLocalClientSupport is also true.
constexpr static uint32_t kDataIsInSharedMemory = 2;
/// \brief Size of each message used in message queue.
constexpr static int32_t kSharedMessageSize = 2048;
/// \brief Convert a Status object into a protobuf
/// \param rc[in] Status object
@ -62,29 +57,10 @@ inline void Status2CacheReply(const Status &rc, CacheReply *reply) {
reply->set_rc(static_cast<int32_t>(rc.get_code()));
reply->set_msg(rc.ToString());
}
/// \brief Generate the unix socket file we use on both client/server side given a tcp/ip port number
/// \param port
/// \return unix socket url
inline std::string PortToUnixSocketPath(int port) { return "/tmp/cache_server_p" + std::to_string(port); }
/// \brief Generate a shared memory key using the tcp/ip port.
/// \note It must be called after the cache server generates the unix socket or ftok will fail.
/// \note Caller must check the return value. -1 means ftok failed.
/// \param[in] port
/// \param[out] err. If not null and ftok fails, this will contain the value of errno
/// \return key
inline key_t PortToFtok(int port, int *err) {
key_t shmkey = -1;
#ifdef CACHE_LOCAL_CLIENT
const std::string unix_path = PortToUnixSocketPath(port);
shmkey = ftok(unix_path.data(), 'a');
if (err != nullptr && shmkey == (key_t)-1) {
*err = errno;
}
#endif
return shmkey;
}
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_

View File

@ -17,34 +17,10 @@
#include <chrono>
namespace mindspore {
namespace dataset {
Status CacheClientRequestTag::MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq,
std::unique_ptr<CacheClientRequestTag> &&tag) {
// If there is anything extra we need to do before we send.
RETURN_IF_NOT_OK(tag->base_rq_->Prepare());
// One minute timeout
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60);
tag->ctx_.set_deadline(deadline);
tag->rpc_ = stub->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, cq);
tag->rpc_->StartCall();
// Last step is we release the ownership and transfer it to the completion queue.
// The memory will be released by WorkerEntry or by the destructor when we drain the queue
auto ccReqTag = tag.release();
ccReqTag->rpc_->Finish(&ccReqTag->base_rq_->reply_, &ccReqTag->rc_,
ccReqTag); // inject this object into the completion queue
return Status::OK();
}
CacheClientGreeter::~CacheClientGreeter() { (void)ServiceStop(); }
CacheClientGreeter::~CacheClientGreeter() {
(void)ServiceStop();
// Detach from shared memory if any
if (shmat_addr_ != nullptr) {
shmdt(shmat_addr_);
shmat_addr_ = nullptr;
}
}
CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers)
: num_workers_(num_workers), shm_key_(-1), shm_id_(-1), shmat_addr_(nullptr) {
CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_connections)
: num_connections_(num_connections), request_cnt_(0) {
grpc::ChannelArguments args;
// We need to bump up the message size to unlimited. The default receiving
// message limit is 4MB which is not big enough.
@ -68,21 +44,11 @@ CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port
Status CacheClientGreeter::AttachToSharedMemory(int32_t port, bool *local_bypass) {
*local_bypass = false;
#if CACHE_LOCAL_CLIENT
int err;
shm_key_ = PortToFtok(port, &err);
if (shm_key_ == (key_t)-1) {
std::string errMsg = "Ftok failed with errno " + std::to_string(err);
RETURN_STATUS_UNEXPECTED(errMsg);
}
SharedMemory::shm_key_t shm_key;
RETURN_IF_NOT_OK(PortToFtok(port, &shm_key));
// Attach to the shared memory
shm_id_ = shmget(shm_key_, 0, 0);
if (shm_id_ == -1) {
RETURN_STATUS_UNEXPECTED("Shmget failed. Errno " + std::to_string(errno));
}
shmat_addr_ = shmat(shm_id_, nullptr, 0);
if (shmat_addr_ == reinterpret_cast<void *>(-1)) {
RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno));
}
mem_.SetPublicKey(shm_key);
RETURN_IF_NOT_OK(mem_.Attach());
*local_bypass = true;
#endif
return Status::OK();
@ -90,7 +56,7 @@ Status CacheClientGreeter::AttachToSharedMemory(int32_t port, bool *local_bypass
Status CacheClientGreeter::DoServiceStart() {
RETURN_IF_NOT_OK(vg_.ServiceStart());
RETURN_IF_NOT_OK(DispatchWorkers(num_workers_));
RETURN_IF_NOT_OK(DispatchWorkers(num_connections_));
return Status::OK();
}
@ -100,19 +66,40 @@ Status CacheClientGreeter::DoServiceStop() {
// Shutdown the TaskGroup.
vg_.interrupt_all();
vg_.join_all(Task::WaitFlag::kNonBlocking);
// Drain the queue
bool success;
void *tag;
while (cq_.Next(&tag, &success)) {
auto r = reinterpret_cast<CacheClientRequestTag *>(tag);
delete r;
// Drain the queue. We know how many requests we send out
while (!req_.empty()) {
bool success;
void *tag;
while (cq_.Next(&tag, &success)) {
auto r = reinterpret_cast<CacheClientRequestTag *>(tag);
req_.erase(r->seqNo_);
}
}
return Status::OK();
}
Status CacheClientGreeter::HandleRequest(std::shared_ptr<BaseRequest> rq) {
auto tag = std::make_unique<CacheClientRequestTag>(std::move(rq));
return tag->MakeCall(stub_.get(), &cq_, std::move(tag));
// If there is anything extra we need to do before we send.
RETURN_IF_NOT_OK(rq->Prepare());
auto seqNo = request_cnt_.fetch_add(1);
auto tag = std::make_unique<CacheClientRequestTag>(std::move(rq), seqNo);
// One minute timeout
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60);
tag->ctx_.set_deadline(deadline);
tag->rpc_ = stub_->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, &cq_);
tag->rpc_->StartCall();
auto ccReqTag = tag.get();
// Insert it into the map.
{
std::unique_lock<std::mutex> lck(mux_);
auto r = req_.emplace(seqNo, std::move(tag));
if (!r.second) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__);
}
}
// Last step is to tag the request.
ccReqTag->rpc_->Finish(&ccReqTag->base_rq_->reply_, &ccReqTag->rc_, ccReqTag);
return Status::OK();
}
Status CacheClientGreeter::WorkerEntry() {
@ -129,15 +116,26 @@ Status CacheClientGreeter::WorkerEntry() {
auto &rc = rq->rc_;
if (!rc.ok()) {
auto error_code = rq->rc_.error_code();
std::string errMsg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code);
Status remote_rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
std::string err_msg;
if (error_code == grpc::StatusCode::UNAVAILABLE) {
err_msg =
"Cache server is unreachable. Make sure the server is running. GRPC Code" + std::to_string(error_code);
} else {
err_msg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code);
}
Status remote_rc = Status(StatusCode::kNetWorkError, __LINE__, __FILE__, err_msg);
Status2CacheReply(remote_rc, &rq->base_rq_->reply_);
}
// Notify the waiting thread.
rq->Notify();
}
// We can now free the memory
delete rq;
{
// We can now free the memory
std::unique_lock<std::mutex> lck(mux_);
auto seqNo = rq->seqNo_;
auto n = req_.erase(seqNo);
CHECK_FAIL_RETURN_UNEXPECTED(n == 1, "Sequence " + std::to_string(seqNo) + " not found");
}
} else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) {
// If we are interrupted, exit. Otherwise wait again.
RETURN_IF_INTERRUPTED();

View File

@ -16,10 +16,14 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_
#include <atomic>
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/engine/cache/cache_ipc.h"
#include "minddata/dataset/util/service.h"
#include "minddata/dataset/util/task_manager.h"
namespace mindspore {
@ -34,16 +38,10 @@ namespace dataset {
class CacheClientRequestTag {
public:
friend class CacheClientGreeter;
explicit CacheClientRequestTag(std::shared_ptr<BaseRequest> rq) : base_rq_(std::move(rq)) {}
explicit CacheClientRequestTag(std::shared_ptr<BaseRequest> rq, int64_t seqNo)
: base_rq_(std::move(rq)), seqNo_(seqNo) {}
~CacheClientRequestTag() = default;
/// \brief Make a RPC call
/// \param stub from CacheClientGreeter
/// \param cq from CacheClientGreeter
/// \return Status object
static Status MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq,
std::unique_ptr<CacheClientRequestTag> &&tag);
/// \brief Notify the client that a result has come back from the server
void Notify() { base_rq_->wp_.Set(); }
@ -52,6 +50,7 @@ class CacheClientRequestTag {
grpc::Status rc_;
grpc::ClientContext ctx_;
std::unique_ptr<grpc::ClientAsyncResponseReader<CacheReply>> rpc_;
int64_t seqNo_;
};
/// \brief A GRPC layer to convert BaseRequest into protobuf and send to the cache server using gRPC
@ -60,7 +59,7 @@ class CacheClientGreeter : public Service {
friend class CacheClient;
public:
explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers);
explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_connections);
~CacheClientGreeter();
/// Override base Service class
@ -85,17 +84,18 @@ class CacheClientGreeter : public Service {
/// \brief This returns where we attach to the shared memory.
/// \return Base address of the shared memory.
const void *SharedMemoryBaseAddr() const { return shmat_addr_; }
const void *SharedMemoryBaseAddr() const { return mem_.SharedMemoryBaseAddr(); }
private:
std::shared_ptr<grpc::Channel> channel_;
std::unique_ptr<CacheServerGreeter::Stub> stub_;
grpc::CompletionQueue cq_;
TaskGroup vg_;
int32_t num_workers_;
key_t shm_key_;
int32_t shm_id_;
void *shmat_addr_;
int32_t num_connections_;
std::atomic<int64_t> request_cnt_;
mutable std::mutex mux_;
std::map<int64_t, std::unique_ptr<CacheClientRequestTag>> req_;
SharedMemory mem_;
};
} // namespace dataset
} // namespace mindspore

View File

@ -47,53 +47,10 @@ void CacheServerGreeterImpl::Shutdown() {
CacheServerGreeterImpl::~CacheServerGreeterImpl() { Shutdown(); }
Status CacheServerGreeterImpl::IpcResourceCleanup() {
#if CACHE_LOCAL_CLIENT
int err;
auto shm_key = PortToFtok(port_, &err);
// We are expecting the unix path doesn't exist.
if (shm_key == (key_t)-1) {
return Status::OK();
}
// Attach to the shared memory
auto shm_id = shmget(shm_key, 0, 0);
if (shm_id == -1) {
return Status::OK();
}
struct shmid_ds ds {};
auto inx = shmctl(shm_id, IPC_STAT, &ds);
if (inx == -1) {
std::string errMsg = "Unable to query shared memory with id " + std::to_string(shm_id);
errMsg += "\nPlesae remove it manually using ipcrm -m command";
RETURN_STATUS_UNEXPECTED(errMsg);
}
if (ds.shm_nattch == 0) {
// Stale shared memory from last time.
// Remove both the memory and the socket path
inx = shmctl(shm_id, IPC_RMID, nullptr);
if (inx == -1) {
std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id);
errMsg += ". Errno :" + std::to_string(errno);
errMsg += "\nPlesae remove it manually using ipcrm -m command";
RETURN_STATUS_UNEXPECTED(errMsg);
}
Path p(unix_socket_);
(void)p.Remove();
} else {
// Server is already up.
MS_LOG(ERROR) << "Cache server is already up and running";
// We return a duplicate error. The main() will intercept
// and output a proper message
return Status(StatusCode::kDuplicateKey);
}
#endif
return Status::OK();
}
Status CacheServerGreeterImpl::Run() {
// To listen on all interfaces, use 0.0.0.0
// Use 127.0.0.1 if just locally on the same machine.
std::string host("0.0.0.0"); // listen on all interfaces.
// Future, allow the user to choose listening interface. For now, default to localhost
std::string host("127.0.0.1");
std::string server_address = host + ":" + std::to_string(port_);
grpc::ServerBuilder builder;
// Default message size for gRPC is 4MB. Increase it to 2g-1
@ -101,9 +58,6 @@ Status CacheServerGreeterImpl::Run() {
int port_tcpip = 0;
#if CACHE_LOCAL_CLIENT
int port_local = 0;
// Check if we need to do clean up on the shared memory if the server
// came down unexpectedly like SEGV
RETURN_IF_NOT_OK(IpcResourceCleanup());
// We also optimize on local clients on the same machine using unix socket
builder.AddListeningPort("unix://" + unix_socket_, grpc::InsecureServerCredentials(), &port_local);
#endif

View File

@ -41,7 +41,7 @@ class CacheServerRequest : public BaseRequest {
st_(STATE::CREATE),
responder_(&ctx_) {}
~CacheServerRequest() = default;
~CacheServerRequest() override = default;
/// \brief Functor. Used mainly by CacheServerGreeterImpl class to tag each incoming request and this
/// functor will translate each protobuf into some form understood by by CacheService class.
@ -87,8 +87,6 @@ class CacheServerGreeterImpl final {
void Shutdown();
Status IpcResourceCleanup();
private:
int32_t port_;
size_t shm_pool_sz_in_gb_;

View File

@ -0,0 +1,163 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_ipc.h"
#include <sys/stat.h>
namespace mindspore {
namespace dataset {
Status PortToFtok(int port, SharedMemory::shm_key_t *out) {
RETURN_UNEXPECTED_IF_NULL(out);
key_t shmkey = -1;
const std::string unix_path = PortToUnixSocketPath(port);
shmkey = ftok(unix_path.data(), 'a');
if (shmkey == (key_t)-1) {
std::string errMsg = "Unable to create a ftok token. Errno = " + std::to_string(errno);
return Status(errno == ENOENT ? StatusCode::kFileNotExist : StatusCode::kUnexpectedError, errMsg);
}
*out = shmkey;
return Status::OK();
}
SharedMessage::~SharedMessage() {
// Only remove the queue if we are asked to.
if (remove_ipc_on_exit_ && msg_qid_ != -1) {
// Remove the message que and never mind about the return code.
(void)msgctl(msg_qid_, IPC_RMID, nullptr);
msg_qid_ = -1;
}
}
Status SharedMessage::Create() {
CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ == -1, "Message queue already created");
auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP;
msg_qid_ = msgget(IPC_PRIVATE, IPC_CREAT | IPC_EXCL | access_mode);
if (msg_qid_ == -1) {
std::string errMsg = "Unable to create a message queue. Errno = " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(errMsg);
}
return Status::OK();
}
Status SharedMessage::SendStatus(const Status &rc) {
CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id");
StatusMsgBuf msg{
1,
};
msg.body.status.err_code = static_cast<int32_t>(rc.get_code());
auto err = memcpy_s(msg.body.status.err_msg, kSharedMessageSize, rc.ToString().data(), rc.ToString().size());
CHECK_FAIL_RETURN_UNEXPECTED(err == EOK, "memcpy_s failed. err = " + std::to_string(err));
msg.body.status.err_msg[rc.ToString().size()] = '\0';
err = msgsnd(msg_qid_, reinterpret_cast<void *>(&msg), sizeof(msg.body.status), IPC_NOWAIT);
if (err == -1) {
std::string errMsg = "Failed to call msgsnd. Errno = " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(errMsg);
}
return Status::OK();
}
Status SharedMessage::ReceiveStatus(Status *rc) {
RETURN_UNEXPECTED_IF_NULL(rc);
CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id");
struct StatusMsgBuf msg {};
auto err = msgrcv(msg_qid_, reinterpret_cast<void *>(&msg), sizeof(msg.body.status), 0, MSG_NOERROR);
if (err == -1) {
std::string errMsg = "Failed to call msgrcv. Errno = " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(errMsg);
}
Status rc_recv(static_cast<StatusCode>(msg.body.status.err_code), msg.body.status.err_msg);
*rc = std::move(rc_recv);
return Status::OK();
}
SharedMemory::~SharedMemory() {
if (shmat_addr_) {
(void)Detach();
}
if (remove_ipc_on_exit_ && shm_id_ != -1) {
// Remove the shared memory and never mind about the return code.
Status rc = Destroy();
if (rc.IsError()) {
MS_LOG(ERROR) << rc.ToString();
}
}
shm_id_ = -1;
shmat_addr_ = nullptr;
}
Status SharedMemory::Create(int64_t sz) {
auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP;
shm_id_ = shmget(shm_key_, sz, IPC_CREAT | IPC_EXCL | access_mode);
if (shm_id_ == -1) {
RETURN_STATUS_UNEXPECTED("Shared memory creation failed. Errno " + std::to_string(errno));
} else {
shmat_addr_ = shmat(shm_id_, nullptr, 0);
if (shmat_addr_ == reinterpret_cast<void *>(-1)) {
RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno));
}
}
return Status::OK();
}
Status SharedMemory::Attach() {
shm_id_ = shmget(shm_key_, 0, 0);
if (shm_id_ == -1) {
RETURN_STATUS_UNEXPECTED("Shmget failed. Errno " + std::to_string(errno));
}
shmat_addr_ = shmat(shm_id_, nullptr, 0);
if (shmat_addr_ == reinterpret_cast<void *>(-1)) {
RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno));
}
return Status::OK();
}
Status SharedMemory::Detach() {
if (shmat_addr_) {
auto err = shmdt(shmat_addr_);
if (err == -1) {
RETURN_STATUS_UNEXPECTED("Shared memory detach failed. Errno " + std::to_string(errno));
}
}
shmat_addr_ = nullptr;
return Status::OK();
}
Status SharedMemory::Destroy() {
// Remove the shared memory and never mind about the return code.
auto err = shmctl(shm_id_, IPC_RMID, nullptr);
if (err == -1) {
std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id_);
errMsg += ". Errno :" + std::to_string(errno);
errMsg += "\nPlesae remove it manually using ipcrm -m command";
RETURN_STATUS_UNEXPECTED(errMsg);
}
return Status::OK();
}
Status SharedMemory::GetNumAttached(int32_t *num) {
RETURN_UNEXPECTED_IF_NULL(num);
struct shmid_ds ds {};
auto err = shmctl(shm_id_, IPC_STAT, &ds);
if (err == -1) {
std::string errMsg = "Unable to query shared memory with id " + std::to_string(shm_id_);
errMsg += "\nPlease remove it manually using ipcrm -m command";
RETURN_STATUS_UNEXPECTED(errMsg);
}
*num = ds.shm_nattch;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,207 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_IPC_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_IPC_H_
#include <sys/types.h>
#include <sys/ipc.h>
#include <sys/shm.h>
#include <sys/msg.h>
#include <string>
#include <utility>
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
/// A message queue structure between the parent and the child process
struct StatusMsgBuf {
int64_t mtype;
union {
char mtext[1];
struct {
int32_t err_code;
char err_msg[kSharedMessageSize];
} status;
} body;
};
class BaseIPC {
public:
BaseIPC() : remove_ipc_on_exit_(false) {}
virtual ~BaseIPC() {}
/// Indicate if we should remove the ipc resource on exit. Usually this is done by parent process.
void RemoveResourcesOnExit() { remove_ipc_on_exit_ = true; }
/// Copy constructors
BaseIPC(const BaseIPC &rhs) : remove_ipc_on_exit_(false) {}
BaseIPC &operator=(const BaseIPC &rhs) {
if (&rhs != this) {
remove_ipc_on_exit_ = false;
}
return *this;
}
/// Move constructors
BaseIPC(BaseIPC &&rhs) noexcept : remove_ipc_on_exit_(rhs.remove_ipc_on_exit_) { rhs.remove_ipc_on_exit_ = false; }
BaseIPC &operator=(BaseIPC &&rhs) noexcept {
if (&rhs != this) {
remove_ipc_on_exit_ = rhs.remove_ipc_on_exit_;
rhs.remove_ipc_on_exit_ = false;
}
return *this;
}
protected:
bool remove_ipc_on_exit_;
};
/// \brief This wraps a shared message for the communication between processes. It is used primarily
/// for starting and stopping a server.
class SharedMessage : public BaseIPC {
public:
using queue_id_t = int;
SharedMessage() : msg_qid_(-1) {}
explicit SharedMessage(queue_id_t qid) : msg_qid_(qid) {}
~SharedMessage() override;
/// Copy constructors
SharedMessage(const SharedMessage &rhs) : BaseIPC(rhs), msg_qid_(rhs.msg_qid_) {}
SharedMessage &operator=(const SharedMessage &rhs) {
if (&rhs != this) {
msg_qid_ = rhs.msg_qid_;
BaseIPC::operator=(rhs);
}
return *this;
}
/// Move constructors
SharedMessage(SharedMessage &&rhs) noexcept : BaseIPC(std::move(rhs)) {
msg_qid_ = rhs.msg_qid_;
rhs.msg_qid_ = -1;
}
SharedMessage &operator=(SharedMessage &&rhs) noexcept {
if (&rhs != this) {
msg_qid_ = rhs.msg_qid_;
rhs.msg_qid_ = -1;
BaseIPC::operator=(std::move(rhs));
}
return *this;
}
/// Return the private id
queue_id_t GetMsgQueueId() const { return msg_qid_; }
/// \brief Create a private message queue
Status Create();
/// Send a Status object
Status SendStatus(const Status &rc);
/// Retrieve a Status object
Status ReceiveStatus(Status *rc);
private:
queue_id_t msg_qid_;
};
/// \brief This wraps a shared memory for the communication between processes. It is used primarily
/// for transporting large tensor rows.
class SharedMemory : public BaseIPC {
public:
using shm_key_t = int;
using shm_id_t = int;
SharedMemory() : shm_id_(-1), shm_key_(-1), shmat_addr_(nullptr) {}
explicit SharedMemory(shm_key_t public_key) : shm_id_(-1), shm_key_(public_key), shmat_addr_(nullptr) {}
~SharedMemory() override;
/// Copy constructors
SharedMemory(const SharedMemory &rhs)
: BaseIPC(rhs), shm_id_(rhs.shm_id_), shm_key_(rhs.shm_key_), shmat_addr_(rhs.shmat_addr_) {}
SharedMemory &operator=(const SharedMemory &rhs) {
if (&rhs != this) {
shm_id_ = rhs.shm_id_;
shm_key_ = rhs.shm_key_;
shmat_addr_ = rhs.shmat_addr_;
BaseIPC::operator=(rhs);
}
return *this;
}
/// Move constructors
SharedMemory(SharedMemory &&rhs) noexcept : BaseIPC(std::move(rhs)) {
shm_id_ = rhs.shm_id_;
shm_key_ = rhs.shm_key_;
shmat_addr_ = rhs.shmat_addr_;
rhs.shm_id_ = -1;
rhs.shm_key_ = -1;
rhs.shmat_addr_ = nullptr;
}
SharedMemory &operator=(SharedMemory &&rhs) noexcept {
if (&rhs != this) {
shm_id_ = rhs.shm_id_;
shm_key_ = rhs.shm_key_;
shmat_addr_ = rhs.shmat_addr_;
rhs.shm_id_ = -1;
rhs.shm_key_ = -1;
rhs.shmat_addr_ = nullptr;
BaseIPC::operator=(std::move(rhs));
}
return *this;
}
/// \brief Set the public key
void SetPublicKey(key_t public_key) { shm_key_ = public_key; }
/// \brief This returns where we attach to the shared memory.
/// \return Base address of the shared memory.
const void *SharedMemoryBaseAddr() const { return shmat_addr_; }
void *SharedMemoryBaseAddr() { return shmat_addr_; }
/// \brief Attach to shared memory
/// \return Status object
Status Attach();
/// Detach from shared memory
/// \return Status object
Status Detach();
/// Create shared memory
/// \return Status object
Status Create(int64_t sz);
/// Destroy shared memory
/// \return Status object
Status Destroy();
/// \brief Return the shared memory id
shm_id_t GetSharedMemoryId() const { return shm_id_; }
/// \brief Get number of processes attached to the shared memory
/// \return Status object
Status GetNumAttached(int32_t *num);
private:
shm_id_t shm_id_;
shm_key_t shm_key_;
void *shmat_addr_;
};
/// \brief Generate a shared memory key using the tcp/ip port.
/// \note It must be called after the cache server generates the unix socket or ftok will fail.
/// \note Caller must check the return value. -1 means ftok failed.
/// \param[in] port
/// \param[out] err. If not null and ftok fails, this will contain the value of errno
/// \return key
Status PortToFtok(int port, SharedMemory::shm_key_t *);
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_IPC_H_

View File

@ -21,24 +21,136 @@
#include <glog/logging.h>
#endif
#include <cstdlib>
#include <thread>
#include <chrono>
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/engine/cache/cache_ipc.h"
namespace ds = mindspore::dataset;
int main(int argc, char **argv) {
/// Send a synchronous command to the local server using tcp/ip.
/// We aren't using any client code because this binary is not necessarily linked with the client library.
/// So just using grpc call directly.
/// \param port tcp/ip port to use
/// \param type Type of command.
/// \param out grpc result
/// \return Status object
ds::Status SendSyncCommand(int32_t port, ds::BaseRequest::RequestType type, ds::CacheRequest *rq, ds::CacheReply *reply,
grpc::Status *out) {
if (rq == nullptr) {
return ds::Status(ds::StatusCode::kUnexpectedError, "pointer rq is null");
}
if (reply == nullptr) {
return ds::Status(ds::StatusCode::kUnexpectedError, "pointer reply is null");
}
if (out == nullptr) {
return ds::Status(ds::StatusCode::kUnexpectedError, "pointer out is null");
}
const std::string hostname = "127.0.0.1";
auto unix_socket = ds::PortToUnixSocketPath(port);
#if CACHE_LOCAL_CLIENT
const std::string target = "unix://" + unix_socket;
#else
const std::string target = hostname + ":" + std::to_string(port);
#endif
try {
rq->set_type(static_cast<int16_t>(type));
grpc::ChannelArguments args;
grpc::ClientContext ctx;
grpc::CompletionQueue cq;
// Standard async rpc call
std::shared_ptr<grpc::Channel> channel =
grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args);
std::unique_ptr<ds::CacheServerGreeter::Stub> stub = ds::CacheServerGreeter::NewStub(channel);
std::unique_ptr<grpc::ClientAsyncResponseReader<ds::CacheReply>> rpc =
stub->PrepareAsyncCacheServerRequest(&ctx, *rq, &cq);
rpc->StartCall();
// We need to pass a tag. But since this is the only request in the completion queue and so we
// just pass a dummy
int64_t dummy;
void *tag;
bool success;
rpc->Finish(reply, out, &dummy);
// Now we wait on the completion queue synchronously.
auto r = cq.Next(&tag, &success);
if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) {
if (!success || tag != &dummy) {
std::string errMsg = "Unexpected programming error ";
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
}
if (out->ok()) {
return ds::Status(static_cast<ds::StatusCode>(reply->rc()), reply->msg());
} else {
auto error_code = out->error_code();
std::string errMsg = out->error_message() + ". GRPC Code " + std::to_string(error_code);
return ds::Status(ds::StatusCode::kNetWorkError, errMsg);
}
} else {
std::string errMsg = "Unexpected queue rc = " + std::to_string(r);
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
}
} catch (const std::exception &e) {
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, e.what());
}
}
/// Stop the server
/// \param argv
/// \return Status object
ds::Status StopServer(int argc, char **argv) {
ds::Status rc;
ds::CacheServer::Builder builder;
std::string errMsg;
if (argc != 2) {
return ds::Status(ds::StatusCode::kSyntaxError);
}
int32_t port = strtol(argv[1], nullptr, 10);
// We will go through the builder to do some snaity check. We only need the port number
// to shut down the server. Null the root directory as we don't trigger the sanity code to write out anything
// to the spill directory.
builder.SetPort(port).SetRootDirectory("");
// Part of the sanity check is check the shared memory. If the server is up and running, we expect
// the return code is kDuplicate.
rc = builder.SanityCheck();
if (rc.IsOk()) {
errMsg = "Server is not up or has been shutdown already.";
return ds::Status(ds::StatusCode::kUnexpectedError, errMsg);
} else if (rc.get_code() != ds::StatusCode::kDuplicateKey) {
// Not OK, and no duplicate, just return the rc whatever it is.
return rc;
} else {
// Now we get some work to do. We will send a tcp/ip request to the given port.
// This binary is not linked with client side of code, so we will just call grpc directly.
ds::CacheRequest rq;
ds::CacheReply reply;
grpc::Status grpc_rc;
rc = SendSyncCommand(port, ds::BaseRequest::RequestType::kStopService, &rq, &reply, &grpc_rc);
// The request is like a self destruct message, the server will not send anything back and
// shutdown all incoming request. So we should expect some unexpected network error if
// all goes well and we expect to GRPC code 14.
auto err_code = grpc_rc.error_code();
if (rc.get_code() != ds::StatusCode::kNetWorkError || err_code != grpc::StatusCode::UNAVAILABLE) {
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__);
}
}
return ds::Status::OK();
}
// This executable is not to be called directly, and should be invoked by cache_admin executable.
if (argc != 7) {
rc = ds::Status(ds::StatusCode::kSyntaxError);
std::cerr << rc.ToString() << std::endl;
return static_cast<int>(rc.get_code());
/// Start the server
/// \param argv
/// \return Status object
ds::Status StartServer(int argc, char **argv) {
ds::Status rc;
ds::CacheServer::Builder builder;
if (argc != 8) {
return ds::Status(ds::StatusCode::kSyntaxError);
}
int32_t port = strtol(argv[3], nullptr, 10);
builder.SetRootDirectory(argv[1])
.SetNumWorkers(strtol(argv[2], nullptr, 10))
.SetPort(strtol(argv[3], nullptr, 10))
.SetSharedMemorySizeInGB(strtol(argv[4], nullptr, 10));
.SetPort(port)
.SetSharedMemorySizeInGB(strtol(argv[4], nullptr, 10))
.SetMemoryCapRatio(strtof(argv[7], nullptr));
#ifdef USE_GLOG
FLAGS_minloglevel = strtol(argv[5], nullptr, 10);
@ -52,36 +164,42 @@ int main(int argc, char **argv) {
// is called. This is a standard procedure for daemonize a process on unix.
if (chdir("/") == -1) {
std::string errMsg = "Unable to change directory to /. Errno = " + std::to_string(errno);
std::cerr << errMsg << std::endl;
return -1;
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
}
// Simple check of the parameters before we move on.
rc = builder.SanityCheck();
if (rc.IsError()) {
std::cerr << rc.ToString() << std::endl;
return static_cast<int>(rc.get_code());
}
#ifdef USE_GLOG
FLAGS_log_dir = "/tmp";
google::InitGoogleLogging(argv[0]);
#endif
// A message queue for communication between parent and child (if we fork).
ds::SharedMessage msg;
if (daemonize) {
// fork the child process to become the daemon
#ifdef USE_GLOG
FLAGS_log_dir = "/tmp";
google::InitGoogleLogging(argv[0]);
#endif
rc = msg.Create();
if (rc.IsError()) {
return rc;
}
pid_t pid = fork();
// failed to fork
if (pid < 0) {
std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno);
std::cerr << err_msg << std::endl;
return errno;
std::string errMsg = "Failed to fork process for cache server. Errno = " + std::to_string(errno);
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else if (pid > 0) {
// Parent
// Parent and will be responsible for remove the queue on exit.
msg.RemoveResourcesOnExit();
// Sleep one second and we attach to the msg que
std::this_thread::sleep_for(std::chrono::seconds(1));
ds::Status child_rc;
rc = msg.ReceiveStatus(&child_rc);
if (rc.IsError()) {
return rc;
}
if (child_rc.IsError()) {
return child_rc;
}
std::cerr << "cache server daemon process has been created as process id: " << pid
<< "\nCheck log file for any start up error" << std::endl;
signal(SIGCHLD, SIG_IGN); // ignore sig child signal.
return 0;
return ds::Status::OK();
} else {
// Child process will continue from here if daemonize and parent has already exited.
// If we are running in the foreground, none of the code in block below will be run.
@ -89,8 +207,8 @@ int main(int argc, char **argv) {
umask(0);
sid = setsid();
if (sid < 0) {
MS_LOG(ERROR) << "Failed to setsid(). Errno = " << std::to_string(errno);
return errno;
std::string errMsg = "Failed to setsid(). Errno = " + std::to_string(errno);
return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
}
close(0);
close(1);
@ -100,22 +218,36 @@ int main(int argc, char **argv) {
// Dump the summary
MS_LOG(INFO) << builder << std::endl;
// Create the instance with some sanity checks built in
rc = builder.Build();
if (rc.IsOk()) {
// If all goes well, kick off the threads. Loop forever and never return unless error.
ds::CacheServer &cs = ds::CacheServer::GetInstance();
// Kick off the threads. Loop forever and never return unless error.
rc = cs.Run();
if (rc.get_code() == ds::StatusCode::kDuplicateKey) {
std::string errMsg = "Server is already started";
MS_LOG(ERROR) << errMsg;
std::cerr << errMsg << std::endl;
return 0;
}
rc = cs.Run(msg.GetMsgQueueId());
} else if (daemonize) {
// If we didn't pass the sanity check to at least create the instance, use
// the message queue to return the error message if this is the child daemon.
return msg.SendStatus(rc);
}
return rc;
}
int main(int argc, char **argv) {
ds::Status rc;
ds::CacheServer::Builder builder;
// This executable is not to be called directly, and should be invoked by cache_admin executable.
if (strcmp(argv[0], "-") == 0) {
rc = StopServer(argc, argv);
} else {
rc = StartServer(argc, argv);
}
// Check result
if (rc.IsError()) {
MS_LOG(ERROR) << rc.ToString();
std::cerr << rc.ToString() << std::endl;
return static_cast<int>(rc.get_code());
auto errCode = rc.get_code();
auto errMsg = rc.ToString();
std::cerr << errMsg << std::endl;
return static_cast<int>(errCode);
}
return 0;
}

View File

@ -250,5 +250,27 @@ Status GetStatRequest::PostReply() {
stat_.cache_service_state = msg->state();
return Status::OK();
}
Status ListSessionsRequest::PostReply() {
auto *msg = flatbuffers::GetRoot<ListSessionsMsg>(reply_.result().data());
auto session_vector = msg->sessions();
for (auto i = 0; i < session_vector->size(); ++i) {
SessionCacheInfo current_info;
CacheServiceStat stats;
auto current_session_info = session_vector->Get(i);
current_info.session_id = current_session_info->session_id();
current_info.connection_id = current_session_info->connection_id();
stats.num_mem_cached = current_session_info->stats()->num_mem_cached();
stats.num_disk_cached = current_session_info->stats()->num_disk_cached();
stats.avg_cache_sz = current_session_info->stats()->avg_cache_sz();
stats.min_row_id = current_session_info->stats()->min_row_id();
stats.max_row_id = current_session_info->stats()->max_row_id();
stats.cache_service_state = current_session_info->stats()->state();
current_info.stats = stats; // fixed length struct. = operator is safe
session_info_list_.push_back(current_info);
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -46,6 +46,13 @@ struct CacheServiceStat {
int8_t cache_service_state;
};
/// \brief Info structure ListSessionsRequest
struct SessionCacheInfo {
session_id_type session_id;
connection_id_type connection_id;
CacheServiceStat stats;
};
/// \brief CacheClient communicates with CacheServer using Requests.
class BaseRequest {
public:
@ -54,7 +61,7 @@ class BaseRequest {
kCacheRow = 0,
kBatchFetchRows = 1,
kCreateCache = 2,
kPurgeCache = 3,
kGetCacheMissKeys = 3,
kDestroyCache = 4,
kGetStat = 5,
kCacheSchema = 6,
@ -65,6 +72,9 @@ class BaseRequest {
kAllocateSharedBlock = 11,
kFreeSharedBlock = 12,
kStopService = 13,
kHeartBeat = 14,
kToggleWriteMode = 15,
kListSessions = 16,
// Add new request before it.
kRequestUnknown = 32767
};
@ -73,6 +83,7 @@ class BaseRequest {
friend class CacheServerRequest;
friend class CacheClientGreeter;
friend class CacheClientRequestTag;
friend class CacheClient;
/// \brief Base class of a cache server request
/// \param type Type of the request
@ -119,7 +130,7 @@ class FreeSharedBlockRequest : public BaseRequest {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(std::to_string(addr));
}
~FreeSharedBlockRequest() = default;
~FreeSharedBlockRequest() override = default;
};
/// \brief Request to cache a single TensorRow
@ -136,7 +147,7 @@ class CacheRowRequest : public BaseRequest {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(cookie);
}
~CacheRowRequest() = default;
~CacheRowRequest() override = default;
/// \brief Serialize a TensorRow for streaming to the cache server
/// \param row TensorRow
@ -183,7 +194,7 @@ class BatchFetchRequest : public BaseRequest {
friend class CacheServer;
friend class CacheService;
BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id, bool local_bypass);
~BatchFetchRequest() = default;
~BatchFetchRequest() override = default;
Status RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr);
private:
@ -203,7 +214,7 @@ class CreateCacheRequest : public BaseRequest {
/// \param flag Attributes of the cache.
explicit CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz,
CreateCacheFlag flag = CreateCacheFlag::kNone);
~CreateCacheRequest() = default;
~CreateCacheRequest() override = default;
void ParseResult(connection_id_type *id, std::string *out) {
auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data());
*id = p->connection_id();
@ -218,14 +229,15 @@ class CreateCacheRequest : public BaseRequest {
CreateCacheFlag flag_;
};
/// \brief Request to purge a cache.
class PurgeCacheRequest : public BaseRequest {
/// \brief Request to get all the keys not present at the server.
/// \note Only applicable to mappable case
class GetCacheMissKeysRequest : public BaseRequest {
public:
friend class CacheServer;
explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kPurgeCache) {
explicit GetCacheMissKeysRequest(connection_id_type connection_id) : BaseRequest(RequestType::kGetCacheMissKeys) {
rq_.set_connection_id(connection_id);
}
~PurgeCacheRequest() = default;
~GetCacheMissKeysRequest() override = default;
};
/// \brief Request to destroy a cache
@ -235,7 +247,7 @@ class DestroyCacheRequest : public BaseRequest {
explicit DestroyCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kDestroyCache) {
rq_.set_connection_id(connection_id);
}
~DestroyCacheRequest() = default;
~DestroyCacheRequest() override = default;
};
/// \brief Obtain the statistics of the current connection
@ -247,7 +259,7 @@ class GetStatRequest : public BaseRequest {
rq_.set_connection_id(connection_id);
}
~GetStatRequest() = default;
~GetStatRequest() override = default;
/// \brief Override base function to process the result.
Status PostReply() override;
@ -269,7 +281,7 @@ class CacheSchemaRequest : public BaseRequest {
explicit CacheSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kCacheSchema) {
rq_.set_connection_id(connection_id);
}
~CacheSchemaRequest() = default;
~CacheSchemaRequest() override = default;
Status SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map);
};
@ -281,7 +293,7 @@ class FetchSchemaRequest : public BaseRequest {
explicit FetchSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kFetchSchema) {
rq_.set_connection_id(connection_id);
}
~FetchSchemaRequest() = default;
~FetchSchemaRequest() override = default;
Status PostReply() override;
@ -300,7 +312,7 @@ class BuildPhaseDoneRequest : public BaseRequest {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(cookie_);
}
~BuildPhaseDoneRequest() = default;
~BuildPhaseDoneRequest() override = default;
private:
std::string cookie_;
@ -313,7 +325,7 @@ class DropSessionRequest : public BaseRequest {
explicit DropSessionRequest(const CacheClientInfo &cinfo) : BaseRequest(RequestType::kDropSession) {
rq_.mutable_connection_info()->operator=(cinfo);
}
~DropSessionRequest() = default;
~DropSessionRequest() override = default;
};
class GenerateSessionIdRequest : public BaseRequest {
@ -325,11 +337,36 @@ class GenerateSessionIdRequest : public BaseRequest {
rq_.set_connection_id(0);
}
~GenerateSessionIdRequest() = default;
~GenerateSessionIdRequest() override = default;
session_id_type GetSessionId() { return atoi(reply_.result().data()); }
};
class ListSessionsRequest : public BaseRequest {
public:
friend class CacheServer;
ListSessionsRequest() : BaseRequest(RequestType::kListSessions) {
// This request is not specific to any cache or session
rq_.set_connection_id(0);
}
~ListSessionsRequest() override = default;
/// \brief Override base function to process the result.
Status PostReply() override;
void GetSessionCacheInfo(std::vector<SessionCacheInfo> *info) {
if (info != nullptr) {
(*info) = session_info_list_;
}
}
std::vector<SessionCacheInfo> GetSessionCacheInfo() { return session_info_list_; }
private:
std::vector<SessionCacheInfo> session_info_list_;
};
class AllocateSharedBlockRequest : public BaseRequest {
public:
friend class CacheServer;
@ -338,7 +375,7 @@ class AllocateSharedBlockRequest : public BaseRequest {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(std::to_string(requestedSz));
}
~AllocateSharedBlockRequest() = default;
~AllocateSharedBlockRequest() override = default;
/// \brief On return from the server, we get the (relative) address where
/// the free block is located.
@ -349,11 +386,15 @@ class AllocateSharedBlockRequest : public BaseRequest {
}
};
class ShutdownRequest : public BaseRequest {
class ToggleWriteModeRequest : public BaseRequest {
public:
friend class CacheServer;
ShutdownRequest() : BaseRequest(RequestType::kStopService) {}
~ShutdownRequest() = default;
explicit ToggleWriteModeRequest(connection_id_type connection_id, bool on_off)
: BaseRequest(RequestType::kToggleWriteMode) {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(on_off ? "on" : "off");
}
~ToggleWriteModeRequest() override = default;
};
} // namespace dataset
} // namespace mindspore

View File

@ -18,6 +18,7 @@
#include <functional>
#include <limits>
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/engine/cache/cache_ipc.h"
#include "minddata/dataset/engine/cache/cache_service.h"
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/util/bit.h"
@ -107,6 +108,8 @@ Status CacheServer::DoServiceStop() {
// First stop all the threads.
RETURN_IF_NOT_OK(vg_.ServiceStop());
// Clean up all the caches if any.
// Dump a message how much memory we have consumed in total.
MS_LOG(INFO) << "Memory usage for the current server: " << GetMemoryUsage() << " bytes.";
UniqueLock lck(&rwLock_);
auto it = all_caches_.begin();
while (it != all_caches_.end()) {
@ -121,7 +124,6 @@ Status CacheServer::DoServiceStop() {
}
CacheService *CacheServer::GetService(connection_id_type id) const {
SharedLock lck(&rwLock_);
auto it = all_caches_.find(id);
if (it != all_caches_.end()) {
return it->second.get();
@ -134,6 +136,16 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
std::string cookie;
auto session_id = rq->connection_info().session_id();
auto crc = rq->connection_info().crc();
// Before allowing the creation, make sure the session had already been created by the user
// Our intention is to add this cache to the active sessions list so leave the list locked during
// this entire function.
UniqueLock lock(&sessions_lock_);
auto session_it = active_sessions_.find(session_id);
if (session_it == active_sessions_.end()) {
RETURN_STATUS_UNEXPECTED("A cache creation has been requested but the session was not found!");
}
// We concat both numbers to form the internal connection id.
auto connection_id = GetConnectionID(session_id, crc);
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing info to create cache");
@ -172,10 +184,15 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory);
}
// Add the cache into the active session tracking.
// We have already validated that the session exists and that this is a new cache created.
session_it->second.insert(connection_id);
} else {
duplicate = true;
MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service";
}
off_cookie = fbb.CreateString(cookie);
CreateCacheReplyMsgBuilder bld(fbb);
bld.add_connection_id(connection_id);
@ -183,19 +200,18 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
auto off = bld.Finish();
fbb.Finish(off);
reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
// Track the history of all the sessions that we have created so far.
history_sessions_.insert(session_id);
// We can return OK but we will return a duplicate key so user can act accordingly to either ignore it
// treat it as OK.
return duplicate ? Status(StatusCode::kDuplicateKey) : Status::OK();
}
Status CacheServer::DestroyCache(CacheService *cs, CacheRequest *rq) {
Status CacheServer::DestroyCache(CacheRequest *rq) {
// We need a strong lock to protect the map.
UniqueLock lck(&rwLock_);
auto id = rq->connection_id();
CacheService *cs = GetService(id);
// it is already destroyed. Ignore it.
if (cs != nullptr) {
auto id = rq->connection_id();
MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id);
// std::map will invoke the destructor of CacheService. So we don't need to do anything here.
auto n = all_caches_.erase(id);
@ -204,11 +220,34 @@ Status CacheServer::DestroyCache(CacheService *cs, CacheRequest *rq) {
MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service";
}
}
// Now that this cache is removed, we need to also remove it's connection id from active session tracking
auto session_id = GetSessionID(id);
UniqueLock sess_lck(&sessions_lock_);
auto it = active_sessions_.find(session_id);
if (it == active_sessions_.end()) {
// The session was not found in the active sessions
RETURN_STATUS_UNEXPECTED("A destroy cache request has been completed but it had a stale session id!");
}
auto connection_it = it->second.find(id);
if (connection_it == it->second.end()) {
RETURN_STATUS_UNEXPECTED("A destroy cache request could not find the connection in the activate sessions!");
}
// remove that connection id from the set
it->second.erase(connection_it);
MS_LOG(INFO) << "Destroyed cache " << id << " and removed from active session " << session_id;
return Status::OK();
}
inline Status CacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) {
Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) {
auto connection_id = rq->connection_id();
// Hold the shared lock to prevent the cache from being dropped.
SharedLock lck(&rwLock_);
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
@ -236,8 +275,11 @@ inline Status CacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) {
return Status::OK();
}
Status CacheServer::FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) {
Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
auto connection_id = rq->connection_id();
// Hold the shared lock to prevent the cache from being dropped.
SharedLock lck(&rwLock_);
CacheService *cs = GetService(connection_id);
auto shared_pool = comm_layer_->GetSharedMemoryPool();
auto *base = shared_pool->SharedMemoryBaseAddr();
// Ensure we got 3 pieces of data coming in
@ -270,8 +312,11 @@ Status CacheServer::FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply
return rc;
}
Status CacheServer::BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheReply *reply) {
Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
auto connection_id = rq->connection_id();
// Hold the shared lock to prevent the cache from being dropped.
SharedLock lck(&rwLock_);
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
@ -325,8 +370,11 @@ Status CacheServer::BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheRepl
return Status::OK();
}
inline Status GetStat(CacheService *cs, CacheRequest *rq, CacheReply *reply) {
Status CacheServer::GetStat(CacheRequest *rq, CacheReply *reply) {
auto connection_id = rq->connection_id();
// Hold the shared lock to prevent the cache from being dropped.
SharedLock lck(&rwLock_);
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
@ -338,8 +386,8 @@ inline Status GetStat(CacheService *cs, CacheRequest *rq, CacheReply *reply) {
bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached);
bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached);
bld.add_avg_cache_sz(svc_stat.stat_.average_cache_sz);
bld.add_max_row_id(svc_stat.max_);
bld.add_min_row_id(svc_stat.min_);
bld.add_max_row_id(svc_stat.stat_.max_key);
bld.add_min_row_id(svc_stat.stat_.min_key);
bld.add_state(svc_stat.state_);
auto offset = bld.Finish();
fbb.Finish(offset);
@ -348,8 +396,11 @@ inline Status GetStat(CacheService *cs, CacheRequest *rq, CacheReply *reply) {
return Status::OK();
}
inline Status CacheSchema(CacheService *cs, CacheRequest *rq) {
Status CacheServer::CacheSchema(CacheRequest *rq) {
auto connection_id = rq->connection_id();
// Hold the shared lock to prevent the cache from being dropped.
SharedLock lck(&rwLock_);
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
@ -361,8 +412,11 @@ inline Status CacheSchema(CacheService *cs, CacheRequest *rq) {
return Status::OK();
}
inline Status FetchSchema(CacheService *cs, CacheRequest *rq, CacheReply *reply) {
Status CacheServer::FetchSchema(CacheRequest *rq, CacheReply *reply) {
auto connection_id = rq->connection_id();
// Hold the shared lock to prevent the cache from being dropped.
SharedLock lck(&rwLock_);
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
@ -377,8 +431,11 @@ inline Status FetchSchema(CacheService *cs, CacheRequest *rq, CacheReply *reply)
return Status::OK();
}
inline Status BuildPhaseDone(CacheService *cs, CacheRequest *rq) {
Status CacheServer::BuildPhaseDone(CacheRequest *rq) {
auto connection_id = rq->connection_id();
// Hold the shared lock to prevent the cache from being dropped.
SharedLock lck(&rwLock_);
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
@ -396,15 +453,24 @@ inline Status BuildPhaseDone(CacheService *cs, CacheRequest *rq) {
return Status::OK();
}
Status CacheServer::PurgeCache(CacheService *cs) {
Status CacheServer::GetCacheMissKeys(CacheRequest *rq, CacheReply *reply) {
auto connection_id = rq->connection_id();
// Hold the shared lock to prevent the cache from being dropped.
SharedLock lck(&rwLock_);
// If shutdown in progress, ignore the command.
if (global_shutdown_) {
return Status::OK();
}
// it is already purged. Ignore it.
if (cs != nullptr) {
RETURN_IF_NOT_OK(cs->Purge());
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
std::vector<row_id_type> gap;
RETURN_IF_NOT_OK(cs->FindKeysMiss(&gap));
flatbuffers::FlatBufferBuilder fbb;
auto off_t = fbb.CreateVector(gap);
TensorRowIdsBuilder bld(fbb);
bld.add_row_id(off_t);
auto off = bld.Finish();
fbb.Finish(off);
reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
}
return Status::OK();
}
@ -414,6 +480,72 @@ inline Status GenerateClientSessionID(session_id_type session_id, CacheReply *re
return Status::OK();
}
Status CacheServer::ToggleWriteMode(CacheRequest *rq) {
auto connection_id = rq->connection_id();
// Hold the shared lock to prevent the cache from being dropped.
SharedLock lck(&rwLock_);
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
// First piece of data is the on/off flag
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing action flag");
const auto &action = rq->buf_data(0);
bool on_off = false;
if (strcmp(action.data(), "on") == 0) {
on_off = true;
} else if (strcmp(action.data(), "off") == 0) {
on_off = false;
} else {
RETURN_STATUS_UNEXPECTED("Unknown request: " + action);
}
RETURN_IF_NOT_OK(cs->ToggleWriteMode(on_off));
}
return Status::OK();
}
Status CacheServer::ListSessions(CacheReply *reply) {
SharedLock lck(&sessions_lock_);
flatbuffers::FlatBufferBuilder fbb;
std::vector<flatbuffers::Offset<ListSessionMsg>> session_msgs_vector;
for (auto it = active_sessions_.begin(); it != active_sessions_.end(); it++) {
session_id_type current_session_id = it->first;
// Loop over each cache inside this session
if (!it->second.empty()) {
for (auto current_conn_id : it->second) {
CacheService *cs = GetService(current_conn_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(current_conn_id) + " not found during list sessions.";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
CacheService::ServiceStat svc_stat;
RETURN_IF_NOT_OK(cs->GetStat(&svc_stat));
auto current_stats = CreateServiceStatMsg(fbb, svc_stat.stat_.num_mem_cached, svc_stat.stat_.num_disk_cached,
svc_stat.stat_.average_cache_sz, svc_stat.stat_.min_key,
svc_stat.stat_.max_key, svc_stat.state_);
auto current_session_info = CreateListSessionMsg(fbb, current_session_id, current_conn_id, current_stats);
session_msgs_vector.push_back(current_session_info);
}
}
} else {
// If there is no cache created yet, assign a connection id of 0 along with empty stats
auto current_stats = CreateServiceStatMsg(fbb, 0, 0, 0, 0, 0, 0);
auto current_session_info = CreateListSessionMsg(fbb, current_session_id, 0, current_stats);
session_msgs_vector.push_back(current_session_info);
}
}
auto session_msgs = fbb.CreateVector(session_msgs_vector);
ListSessionsMsgBuilder s_builder(fbb);
s_builder.add_sessions(session_msgs);
auto offset = s_builder.Finish();
fbb.Finish(offset);
reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
return Status::OK();
}
/// \brief This is the main loop the cache server thread(s) are running.
/// Each thread will pop a request and send the result back to the client using grpc
/// \return
@ -426,12 +558,6 @@ Status CacheServer::ServerRequest(int32_t worker_id) {
RETURN_IF_NOT_OK(my_que->PopFront(&cache_req));
auto &rq = cache_req->rq_;
auto &reply = cache_req->reply_;
CacheService *cs = nullptr;
// Request comes in roughly two sets. One set is at the cache level with a connection id.
// The other set is working at a high level and without a connection id
if (!rq.has_connection_info()) {
cs = GetService(rq.connection_id());
}
// Except for creating a new session, we expect cs is not null.
switch (cache_req->type_) {
case BaseRequest::RequestType::kCacheRow: {
@ -439,42 +565,42 @@ Status CacheServer::ServerRequest(int32_t worker_id) {
// call the appropriate method.
auto flag = rq.flag();
if (BitTest(flag, kDataIsInSharedMemory)) {
cache_req->rc_ = FastCacheRow(cs, &rq, &reply);
cache_req->rc_ = FastCacheRow(&rq, &reply);
} else {
cache_req->rc_ = CacheRow(cs, &rq, &reply);
cache_req->rc_ = CacheRow(&rq, &reply);
}
break;
}
case BaseRequest::RequestType::kBatchFetchRows: {
cache_req->rc_ = BatchFetchRows(cs, &rq, &reply);
cache_req->rc_ = BatchFetchRows(&rq, &reply);
break;
}
case BaseRequest::RequestType::kCreateCache: {
cache_req->rc_ = CreateService(&rq, &reply);
break;
}
case BaseRequest::RequestType::kPurgeCache: {
cache_req->rc_ = PurgeCache(cs);
case BaseRequest::RequestType::kGetCacheMissKeys: {
cache_req->rc_ = GetCacheMissKeys(&rq, &reply);
break;
}
case BaseRequest::RequestType::kDestroyCache: {
cache_req->rc_ = DestroyCache(cs, &rq);
cache_req->rc_ = DestroyCache(&rq);
break;
}
case BaseRequest::RequestType::kGetStat: {
cache_req->rc_ = GetStat(cs, &rq, &reply);
cache_req->rc_ = GetStat(&rq, &reply);
break;
}
case BaseRequest::RequestType::kCacheSchema: {
cache_req->rc_ = CacheSchema(cs, &rq);
cache_req->rc_ = CacheSchema(&rq);
break;
}
case BaseRequest::RequestType::kFetchSchema: {
cache_req->rc_ = FetchSchema(cs, &rq, &reply);
cache_req->rc_ = FetchSchema(&rq, &reply);
break;
}
case BaseRequest::RequestType::kBuildPhaseDone: {
cache_req->rc_ = BuildPhaseDone(cs, &rq);
cache_req->rc_ = BuildPhaseDone(&rq);
break;
}
case BaseRequest::RequestType::kDropSession: {
@ -498,6 +624,18 @@ Status CacheServer::ServerRequest(int32_t worker_id) {
cache_req->rc_ = GlobalShutdown();
break;
}
case BaseRequest::RequestType::kHeartBeat: {
cache_req->rc_ = Status::OK();
break;
}
case BaseRequest::RequestType::kToggleWriteMode: {
cache_req->rc_ = ToggleWriteMode(&rq);
break;
}
case BaseRequest::RequestType::kListSessions: {
cache_req->rc_ = ListSessions(&reply);
break;
}
default:
std::string errMsg("Unknown request type : ");
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
@ -526,18 +664,32 @@ session_id_type CacheServer::GetSessionID(connection_id_type connection_id) cons
}
CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port,
int32_t shared_meory_sz_in_gb)
int32_t shared_meory_sz_in_gb, float memory_cap_ratio)
: top_(spill_path),
num_workers_(num_workers),
port_(port),
shared_memory_sz_in_gb_(shared_meory_sz_in_gb),
global_shutdown_(false) {}
global_shutdown_(false),
memory_cap_ratio_(memory_cap_ratio),
cur_mem_usage_(0) {
memory_cap_ = CacheServer::GetTotalSystemMemory() * memory_cap_ratio_;
}
Status CacheServer::Run() {
RETURN_IF_NOT_OK(ServiceStart());
Status CacheServer::Run(int msg_qid) {
Status rc = ServiceStart();
// If there is a message que, return the status now before we call join_all which will never return
if (msg_qid != -1) {
SharedMessage msg(msg_qid);
RETURN_IF_NOT_OK(msg.SendStatus(rc));
}
if (rc.IsError()) {
return rc;
}
// This is called by the main function and we shouldn't exit. Otherwise the main thread
// will just shutdown. So we will call some function that never return unless error.
// One good case will be simply to wait for all threads to return.
// note that after we have sent the initial status using the msg_qid, parent process will exit and
// remove it. So we can't use it again.
RETURN_IF_NOT_OK(vg_.join_all(Task::WaitFlag::kBlocking));
return Status::OK();
}
@ -567,32 +719,51 @@ Status CacheServer::ReturnRequestTag(CacheServerRequest *p) {
Status CacheServer::DestroySession(CacheRequest *rq) {
CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing session id");
auto drop_session_id = rq->connection_info().session_id();
UniqueLock lck(&rwLock_);
for (auto &cs : all_caches_) {
auto connection_id = cs.first;
auto session_id = GetSessionID(connection_id);
// We can just call DestroyCache() but we are holding a lock already. Doing so will cause deadlock.
// So we will just manually do it.
if (session_id == drop_session_id) {
// std::map will invoke the destructor of CacheService. So we don't need to do anything here.
auto n = all_caches_.erase(connection_id);
MS_LOG(INFO) << "Destroy " << n << " copies of cache with id " << connection_id;
UniqueLock lck(&sessions_lock_);
// First validate that this session exists
auto it = active_sessions_.find(drop_session_id);
if (it == active_sessions_.end()) {
RETURN_STATUS_UNEXPECTED("A destroy session command has been requested but the session was not found!");
}
// Iterate over the set of connection id's for this session that we're dropping and erase each one.
{
UniqueLock rwlck(&rwLock_);
for (auto drop_connection_id : it->second) {
auto cache_drop_it = all_caches_.find(drop_connection_id);
if (cache_drop_it == all_caches_.end()) {
RETURN_STATUS_UNEXPECTED("active session tracking had stale or incorrect cache entry.");
}
all_caches_.erase(cache_drop_it);
MS_LOG(INFO) << "Session destroy: Destroy cache with id " << drop_connection_id;
// **Do not bother to remove the cache connection id from the active session because we will soon remove the
// entire session.
}
}
// Finally remove the session itself
active_sessions_.erase(it);
MS_LOG(INFO) << "Session destroyed with id " << drop_session_id;
return Status::OK();
}
session_id_type CacheServer::GenerateSessionID() const {
SharedLock lock(&rwLock_);
session_id_type CacheServer::GenerateSessionID() {
UniqueLock lock(&sessions_lock_);
auto mt = GetRandomDevice();
std::uniform_int_distribution<session_id_type> distribution(0, std::numeric_limits<session_id_type>::max());
session_id_type session_id;
bool duplicate = false;
do {
session_id = distribution(mt);
auto it = history_sessions_.find(session_id);
duplicate = (it != history_sessions_.end());
auto it = active_sessions_.find(session_id);
duplicate = (it != active_sessions_.end());
} while (duplicate);
// Add this session to our tracking of active sessions with initialized empty set of connections.
active_sessions_[session_id] = std::set<connection_id_type>();
return session_id;
}
@ -637,19 +808,59 @@ Status CacheServer::GlobalShutdown() {
vg_.interrupt_all();
// The next thing to do drop all the caches.
UniqueLock lck(&rwLock_);
for (auto &it : all_caches_) {
auto id = it.first;
for (auto it = all_caches_.begin(); it != all_caches_.end();) {
auto id = it->first;
MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id);
// Wait for all outstanding work to be finished.
auto &cs = it.second;
auto &cs = it->second;
UniqueLock cs_lock(&cs->rw_lock_);
// std::map will invoke the destructor of CacheService. So we don't need to do anything here.
(void)all_caches_.erase(id);
it = all_caches_.erase(it);
}
}
return Status::OK();
}
int64_t CacheServer::GetTotalSystemMemory() {
auto pages = sysconf(_SC_PHYS_PAGES);
auto page_size = sysconf(_SC_PAGE_SIZE);
auto total = static_cast<int64_t>(pages) * static_cast<int64_t>(page_size);
MS_LOG(INFO) << "Total physical RAM in bytes: " << total;
return total;
}
Status CacheServer::Builder::IpcResourceCleanup() {
Status rc;
SharedMemory::shm_key_t shm_key;
auto unix_socket = PortToUnixSocketPath(port_);
rc = PortToFtok(port_, &shm_key);
// We are expecting the unix path doesn't exist.
if (rc.IsError()) {
return Status::OK();
}
// Attach to the shared memory which we expect don't exist
SharedMemory mem(shm_key);
rc = mem.Attach();
if (rc.IsError()) {
return Status::OK();
}
int32_t num_attached;
RETURN_IF_NOT_OK(mem.GetNumAttached(&num_attached));
if (num_attached == 0) {
// Stale shared memory from last time.
// Remove both the memory and the socket path
RETURN_IF_NOT_OK(mem.Destroy());
Path p(unix_socket);
(void)p.Remove();
} else {
// Server is already up.
std::string errMsg = "Cache server is already up and running";
// We return a duplicate error. The main() will intercept
// and output a proper message
return Status(StatusCode::kDuplicateKey, errMsg);
}
return Status::OK();
}
Status CacheServer::Builder::SanityCheck() {
if (shared_memory_sz_in_gb_ <= 0) {
RETURN_STATUS_UNEXPECTED("Shared memory size (in GB unit) must be positive");
@ -673,6 +884,12 @@ Status CacheServer::Builder::SanityCheck() {
RETURN_STATUS_UNEXPECTED("Spilling directory is not writable\n" + rc.ToString());
}
}
if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1) {
RETURN_STATUS_UNEXPECTED("Memory cap ratio should be positive and no greater than 1");
}
// Check if the shared memory.
RETURN_IF_NOT_OK(IpcResourceCleanup());
return Status::OK();
}
} // namespace dataset

View File

@ -17,6 +17,8 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_
#include <string.h>
#include <unistd.h>
#include <algorithm>
#include <atomic>
#include <memory>
@ -47,15 +49,16 @@ class CacheServer : public Service {
using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>;
class Builder {
public:
Builder() : top_("/tmp"), num_workers_(32), port_(50052), shared_memory_sz_in_gb_(4) {}
Builder() : top_("/tmp"), num_workers_(32), port_(50052), shared_memory_sz_in_gb_(4), memory_cap_ratio_(0.8) {}
~Builder() = default;
/// \brief Getter functions
const std::string &getTop() const { return top_; }
int32_t getNumWorkers() const { return num_workers_; }
int32_t getPort() const { return port_; }
int32_t getSharedMemorySzInGb() const { return shared_memory_sz_in_gb_; }
const std::string &GetTop() const { return top_; }
int32_t GetNumWorkers() const { return num_workers_; }
int32_t GetPort() const { return port_; }
int32_t GetSharedMemorySzInGb() const { return shared_memory_sz_in_gb_; }
float GetMemoryCapRatio() const { return memory_cap_ratio_; }
Builder &SetRootDirectory(std::string root) {
top_ = std::move(root);
@ -73,15 +76,20 @@ class CacheServer : public Service {
shared_memory_sz_in_gb_ = sz;
return *this;
}
Builder &SetMemoryCapRatio(float ratio) {
memory_cap_ratio_ = ratio;
return *this;
}
Status SanityCheck();
void Print(std::ostream &out) const {
out << "Summary of the cache server configuration\n"
<< "Spill directory: " << getTop() << "\n"
<< "Number of parallel workers: " << getNumWorkers() << "\n"
<< "Tcp/ip port: " << getPort() << "\n"
<< "Shared memory size (in GB): " << getSharedMemorySzInGb();
<< "Spill directory: " << GetTop() << "\n"
<< "Number of parallel workers: " << GetNumWorkers() << "\n"
<< "Tcp/ip port: " << GetPort() << "\n"
<< "Shared memory size (in GB): " << GetSharedMemorySzInGb() << "\n"
<< "Memory cap ratio: " << GetMemoryCapRatio();
}
friend std::ostream &operator<<(std::ostream &out, const Builder &bld) {
@ -93,7 +101,8 @@ class CacheServer : public Service {
RETURN_IF_NOT_OK(SanityCheck());
// We need to bring up the Task Manager by bringing up the Services singleton.
RETURN_IF_NOT_OK(Services::CreateInstance());
RETURN_IF_NOT_OK(CacheServer::CreateInstance(top_, num_workers_, port_, shared_memory_sz_in_gb_));
RETURN_IF_NOT_OK(
CacheServer::CreateInstance(top_, num_workers_, port_, shared_memory_sz_in_gb_, memory_cap_ratio_));
return Status::OK();
}
@ -102,20 +111,27 @@ class CacheServer : public Service {
int32_t num_workers_;
int32_t port_;
int32_t shared_memory_sz_in_gb_;
float memory_cap_ratio_;
/// \brief Sanity checks on the shared memory.
/// \return Status object
Status IpcResourceCleanup();
};
CacheServer(const CacheServer &) = delete;
CacheServer &operator=(const CacheServer &) = delete;
CacheServer(CacheServer &&) = delete;
CacheServer &operator=(CacheServer &) = delete;
Status DoServiceStart() override;
Status DoServiceStop() override;
~CacheServer() { (void)ServiceStop(); }
~CacheServer() override { (void)ServiceStop(); }
static Status CreateInstance(const std::string &spill_path, int32_t num_workers, int32_t port,
int32_t shared_memory_sz) {
int32_t shared_memory_sz, float memory_cap_ratio) {
std::call_once(init_instance_flag_, [&]() -> Status {
auto &svcManager = Services::GetInstance();
RETURN_IF_NOT_OK(svcManager.AddHook(&instance_, spill_path, num_workers, port, shared_memory_sz));
auto &SvcManager = Services::GetInstance();
RETURN_IF_NOT_OK(
SvcManager.AddHook(&instance_, spill_path, num_workers, port, shared_memory_sz, memory_cap_ratio));
return Status::OK();
});
return Status::OK();
@ -133,7 +149,7 @@ class CacheServer : public Service {
}
/// \\brief Kick off server threads. Never return unless error out.
Status Run();
Status Run(SharedMessage::queue_id_t msg_qid);
/// \brief Get a free tag
/// \param q[in] pointer to a pointer to a CacheServerRequest
@ -145,13 +161,35 @@ class CacheServer : public Service {
/// \return Status object
static Status ReturnRequestTag(CacheServerRequest *p);
/// \brief This returns the size (in bytes) of the physical RAM on the machine.
/// \return the size (in bytes) of the physical RAM on the machine.
static int64_t GetTotalSystemMemory();
/// \brief Internally this is how much we will try to use without exceeding the limit
/// \return Internal cap maximum
int64_t GetAvailableSystemMemory() { return memory_cap_; }
/// \brief Find out the current memory usage
int64_t GetMemoryUsage() { return cur_mem_usage_; }
/// \brief This updates our current memory usage.
enum MemUsageOp : int8_t { kAllocate = 1, kFree = 2 };
void UpdateMemoryUsage(int64_t sz, MemUsageOp op) {
if (op == MemUsageOp::kAllocate) {
cur_mem_usage_ += sz;
} else {
cur_mem_usage_ -= sz;
}
}
private:
static std::once_flag init_instance_flag_;
static CacheServer *instance_;
mutable RWLock rwLock_;
mutable RWLock sessions_lock_;
std::string top_;
cache_index all_caches_;
std::set<session_id_type> history_sessions_;
std::map<session_id_type, std::set<connection_id_type>> active_sessions_;
std::shared_ptr<QueueList<CacheServerRequest *>> cache_q_;
std::shared_ptr<QueueList<CacheServerRequest *>> free_list_;
std::vector<std::unique_ptr<MemGuard<CacheServerRequest, Allocator<CacheServerRequest>>>> tag_;
@ -162,11 +200,15 @@ class CacheServer : public Service {
int32_t port_;
int32_t shared_memory_sz_in_gb_;
std::atomic<bool> global_shutdown_;
float memory_cap_ratio_;
int64_t memory_cap_;
std::atomic<int64_t> cur_mem_usage_;
/// \brief Constructor
/// \param spill_path Top directory for spilling buffers to.
/// \param num_workers Number of threads for handling requests.
explicit CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, int32_t share_memory_sz_in_gb);
explicit CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, int32_t share_memory_sz_in_gb,
float memory_cap_ratio);
/// \brief Locate a cache service from connection id.
/// \return Pointer to cache service. Null if not found
@ -179,11 +221,9 @@ class CacheServer : public Service {
Status CreateService(CacheRequest *rq, CacheReply *reply);
/// \brief Destroy a cache service
/// \param cs
/// \param rq
/// \return
Status DestroyCache(CacheService *cs, CacheRequest *rq);
Status PurgeCache(CacheService *cs);
/// \return Status object
Status DestroyCache(CacheRequest *rq);
/// \brief Entry point for all internal server threads.
Status ServerRequest(int32_t worker_id);
@ -207,7 +247,7 @@ class CacheServer : public Service {
/// \brief Generate a session ID for the client
/// \return Session ID
session_id_type GenerateSessionID() const;
session_id_type GenerateSessionID();
/// \brief Handle kAllocateSharedBlock request
/// \param rq CacheRequest
@ -220,20 +260,55 @@ class CacheServer : public Service {
/// \return Status object
Status FreeSharedMemory(CacheRequest *rq);
/// \brief Handle kFastCacheRow request
/// \brief Handle CacheRow request
/// \note There are two different implementation depends if shared memory is used for transportation.
/// \return Status object
Status FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply);
Status FastCacheRow(CacheRequest *rq, CacheReply *reply);
Status CacheRow(CacheRequest *rq, CacheReply *reply);
/// \brief Internal function to do row batch fetch
/// \param cs CacheService
/// \param rq Request
/// \param reply Reply
/// \return
Status BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheReply *reply);
/// \return Status object
Status BatchFetchRows(CacheRequest *rq, CacheReply *reply);
/// \brief Internal function to get statistics
/// \param rq
/// \param reply
/// \return Status object
Status GetStat(CacheRequest *rq, CacheReply *reply);
/// \brief Cache a schema request
/// \param rq
/// \return Status object
Status CacheSchema(CacheRequest *rq);
/// \brief Fetch a schema request
/// \param rq
/// \param reply
/// \return Status object
Status FetchSchema(CacheRequest *rq, CacheReply *reply);
/// \brief Mark Build phase done (for non-mappable case)
/// \param rq
/// \return Status object
Status BuildPhaseDone(CacheRequest *rq);
/// \brief A proper shutdown of the server
/// \return Status object
Status GlobalShutdown();
/// \brief Find keys that will be cache miss
/// \return Status object
Status GetCacheMissKeys(CacheRequest *rq, CacheReply *reply);
/// \brief Toggle write mode for a service
Status ToggleWriteMode(CacheRequest *rq);
/// \brief List the sessions and their caches
/// \param reply
/// \return Status object
Status ListSessions(CacheReply *reply);
};
} // namespace dataset
} // namespace mindspore

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_service.h"
#include "minddata/dataset/engine/cache/cache_server.h"
#include "minddata/dataset/util/slice.h"
namespace mindspore {
@ -22,42 +23,62 @@ CacheService::CacheService(uint64_t mem_sz, const std::string &root, bool genera
: root_(root),
cache_mem_sz_(mem_sz),
cp_(nullptr),
map_(nullptr),
next_id_(0),
generate_id_(generate_id),
schema_key_(-1),
st_(generate_id ? State::kBuildPhase : State::kNone) {}
st_(generate_id ? State::kBuildPhase : State::kNone),
cur_mem_usage_(0),
cur_disk_usage_(0) {}
CacheService::~CacheService() { (void)ServiceStop(); }
bool CacheService::UseArena() {
// If fixed size, use Arena instead of the pool from global context.
return (cache_mem_sz_ > 0);
}
Status CacheService::DoServiceStart() {
std::shared_ptr<MemoryPool> mp_;
CacheServer &cs = CacheServer::GetInstance();
if (UseArena()) {
auto avail_mem = cs.GetAvailableSystemMemory() / 1048576L;
if (cache_mem_sz_ > avail_mem) {
// Output a warning that we use more than recommended. If we fail to allocate, we will fail anyway.
MS_LOG(WARNING) << "Requesting cache size " << cache_mem_sz_ << " MB while available system memory " << avail_mem
<< " MB";
}
// Create a fixed size arena based on the parameter.
std::shared_ptr<Arena> arena;
RETURN_IF_NOT_OK(Arena::CreateArena(&arena, cache_mem_sz_));
mp_ = std::move(arena);
// update the global usage only.
cs.UpdateMemoryUsage(cache_mem_sz_ * 1048576L, CacheServer::MemUsageOp::kAllocate);
} else {
// Unlimited size. Simply use a system pool. Another choice is CircularPool.
mp_ = std::make_shared<SystemPool>();
}
// Put together a CachePool for backing up the Tensor
cp_ = std::make_shared<CachePool>(CachePool::value_allocator(mp_), root_);
cp_ = std::make_shared<CachePool>(CachePool::value_allocator(mp_), UseArena(), root_);
RETURN_IF_NOT_OK(cp_->ServiceStart());
// Set up the B+ tree as well. But use the system pool instead.
map_ = std::make_shared<row_map>();
// Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name.
cookie_ = cp_->MyName();
return Status::OK();
}
Status CacheService::DoServiceStop() {
if (cp_ != nullptr) {
RETURN_IF_NOT_OK(cp_->ServiceStop());
}
CacheServer &cs = CacheServer::GetInstance();
if (UseArena()) {
cs.UpdateMemoryUsage(cache_mem_sz_ * 1048576L, CacheServer::MemUsageOp::kFree);
} else {
MS_LOG(INFO) << "Memory/disk usage for the current service: " << GetMemoryUsage() << " bytes and " << GetDiskUsage()
<< " bytes.";
cs.UpdateMemoryUsage(GetMemoryUsage(), CacheServer::MemUsageOp::kFree);
}
return Status::OK();
}
Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated) {
SharedLock rw(&rw_lock_);
RETURN_UNEXPECTED_IF_NULL(row_id_generated);
@ -66,6 +87,11 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
// allow other to cache more rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
if (st_ == State::kNoLocking) {
// We ignore write this request once we turn off locking on the B+ tree. So we will just
// return out of memory from now on.
return Status(StatusCode::kOutOfMemory);
}
try {
// The first buffer is a flatbuffer which describes the rest of the buffers follow
auto fb = buf.front();
@ -86,6 +112,7 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
*row_id_generated = msg->row_id();
}
auto size_of_this = msg->size_of_this();
size_t total_sz = size_of_this;
auto column_hdr = msg->column();
// Number of tensor buffer should match the number of columns plus one.
if (buf.size() != column_hdr->size() + 1) {
@ -99,16 +126,28 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
all_data.emplace_back(fb, size_of_this);
for (auto i = 0; i < column_hdr->size(); ++i) {
all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i));
total_sz += msg->data_sz()->Get(i);
}
// Now we cache the flat buffer.
CachePool::key_type key;
RETURN_IF_NOT_OK(cp_->Insert(all_data, &key));
Status rc = map_->DoInsert(*row_id_generated, key);
// Now we cache the buffer. If we are using Arena which has a fixed cap, then just do it.
// Otherwise, we check how much (globally) how much we use and may simply spill to disk
// directly.
CacheServer &cs = CacheServer::GetInstance();
bool write_to_disk_directly = UseArena() ? false : (total_sz + cs.GetMemoryUsage()) > cs.GetAvailableSystemMemory();
Status rc = cp_->Insert(*row_id_generated, all_data, write_to_disk_directly);
if (rc == Status(StatusCode::kDuplicateKey)) {
MS_LOG(DEBUG) << "Ignoring duplicate key.";
} else {
RETURN_IF_NOT_OK(rc);
}
// All good, then update the memory usage local and global (if not using arena)
if (write_to_disk_directly) {
cur_disk_usage_ += total_sz;
} else {
cur_mem_usage_ += total_sz;
if (!UseArena()) {
cs.UpdateMemoryUsage(total_sz, CacheServer::MemUsageOp::kAllocate);
}
}
return Status::OK();
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
@ -123,6 +162,11 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_
// allow other to cache more rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
if (st_ == State::kNoLocking) {
// We ignore write this request once we turn off locking on the B+ tree. So we will just
// return out of memory from now on.
return Status(StatusCode::kOutOfMemory);
}
try {
// If we don't need to generate id, we need to find it from the buffer.
if (generate_id_) {
@ -139,20 +183,33 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_
}
*row_id_generated = msg->row_id();
}
// Now we cache the flat buffer.
CachePool::key_type key;
RETURN_IF_NOT_OK(cp_->Insert({src}, &key));
Status rc = map_->DoInsert(*row_id_generated, key);
// Now we cache the buffer. If we are using Arena which has a fixed cap, then just do it.
// Otherwise, we check how much (globally) how much we use and may simply spill to disk
// directly.
auto total_sz = src.GetSize();
CacheServer &cs = CacheServer::GetInstance();
bool write_to_disk_directly = UseArena() ? false : (total_sz + cs.GetMemoryUsage()) > cs.GetAvailableSystemMemory();
Status rc = cp_->Insert(*row_id_generated, {src}, write_to_disk_directly);
if (rc == Status(StatusCode::kDuplicateKey)) {
MS_LOG(DEBUG) << "Ignoring duplicate key.";
} else {
RETURN_IF_NOT_OK(rc);
}
// All good, then update the memory usage local and global (if not using arena)
if (write_to_disk_directly) {
cur_disk_usage_ += total_sz;
} else {
cur_mem_usage_ += total_sz;
if (!UseArena()) {
cs.UpdateMemoryUsage(total_sz, CacheServer::MemUsageOp::kAllocate);
}
}
return Status::OK();
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
}
std::ostream &operator<<(std::ostream &out, const CacheService &cs) {
// Then show any custom derived-internal stuff
out << "\nCache memory size: " << cs.cache_mem_sz_;
@ -164,34 +221,29 @@ std::ostream &operator<<(std::ostream &out, const CacheService &cs) {
}
return out;
}
Path CacheService::GetSpillPath() const { return cp_->GetSpillPath(); }
Status CacheService::Purge() {
// First we must lock exclusively. No one else can cache/restore anything.
UniqueLock rw(&rw_lock_);
RETURN_IF_NOT_OK(cp_->ServiceStop());
auto new_map = std::make_shared<row_map>();
map_.reset();
map_ = std::move(new_map);
next_id_ = 0;
RETURN_IF_NOT_OK(cp_->ServiceStart());
Status CacheService::FindKeysMiss(std::vector<row_id_type> *out) {
RETURN_UNEXPECTED_IF_NULL(out);
std::unique_lock<std::mutex> lock(get_key_miss_mux_);
if (key_miss_results_ == nullptr) {
// Just do it once.
key_miss_results_ = std::make_shared<std::vector<row_id_type>>();
auto stat = cp_->GetStat(true);
key_miss_results_->push_back(stat.min_key);
key_miss_results_->push_back(stat.max_key);
key_miss_results_->insert(key_miss_results_->end(), stat.gap.begin(), stat.gap.end());
}
out->insert(out->end(), key_miss_results_->begin(), key_miss_results_->end());
return Status::OK();
}
Status CacheService::GetStat(CacheService::ServiceStat *out) {
SharedLock rw(&rw_lock_);
RETURN_UNEXPECTED_IF_NULL(out);
if (st_ == State::kNone || st_ == State::kFetchPhase) {
out->stat_ = cp_->GetStat();
out->state_ = static_cast<ServiceStat::state_type>(st_);
auto it = map_->begin();
if (it != map_->end()) {
out->min_ = it.key();
auto end_it = map_->end();
--end_it;
out->max_ = end_it.key();
}
} else {
out->state_ = static_cast<ServiceStat::state_type>(st_);
}
out->stat_ = cp_->GetStat();
out->state_ = static_cast<ServiceStat::state_type>(st_);
return Status::OK();
}
@ -204,19 +256,12 @@ Status CacheService::PreBatchFetch(const std::vector<row_id_type> &v, std::vecto
*mem_sz = (num_elements + 1) * sizeof(int64_t);
(*out).reserve(num_elements);
for (auto row_id : v) {
auto r = map_->Search(row_id);
if (r.second) {
auto &it = r.first;
CachePool::key_type key = it.value();
auto sz = cp_->GetSize(key);
if (sz == 0) {
std::string errMsg = "Key not found: ";
errMsg += std::to_string(key);
RETURN_STATUS_UNEXPECTED(errMsg);
}
(*out).emplace_back(key, sz);
auto sz = cp_->GetSize(row_id);
if (sz > 0) {
(*out).emplace_back(row_id, sz);
(*mem_sz) += sz;
} else {
// key not found
(*out).emplace_back(-1, 0);
}
}
@ -252,27 +297,19 @@ Status CacheService::BatchFetch(const std::vector<row_id_type> &v, const std::ve
}
return Status::OK();
}
Status CacheService::CacheSchema(const void *buf, int64_t len) {
SharedLock rw(&rw_lock_);
if (st_ == State::kFetchPhase) {
// For this kind of cache service, once we are done with the build phase into fetch phase, we can't
// allow other to cache more rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
// This is a special request and we need to remember where we store it.
UniqueLock rw(&rw_lock_);
// In case we are calling the same function from multiple threads, only
// the first one is considered. Rest is ignored.
CachePool::key_type cur_key = schema_key_;
CachePool::key_type key;
if (cur_key < 0) {
RETURN_IF_NOT_OK(cp_->Insert({ReadableSlice(buf, len)}, &key));
auto result = std::atomic_compare_exchange_strong(&schema_key_, &cur_key, key);
MS_LOG(DEBUG) << "Caching Schema. Result = " << result;
if (schema_.empty()) {
schema_.assign(static_cast<const char *>(buf), len);
} else {
MS_LOG(DEBUG) << "Caching Schema already done";
}
return Status::OK();
}
Status CacheService::FetchSchema(std::string *out) const {
SharedLock rw(&rw_lock_);
if (st_ == State::kBuildPhase) {
@ -283,32 +320,44 @@ Status CacheService::FetchSchema(std::string *out) const {
// We are going to use std::string to allocate and hold the result which will be eventually
// 'moved' to the protobuf message (which underneath is also a std::string) for the purpose
// to minimize memory copy.
std::string mem;
if (schema_key_ >= 0) {
auto len = cp_->GetSize(schema_key_);
try {
mem.resize(len);
CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= len, "Programming error");
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory);
}
auto slice = WritableSlice(mem.data(), len);
RETURN_IF_NOT_OK(cp_->Read(schema_key_, &slice));
std::string mem(schema_);
if (!mem.empty()) {
*out = std::move(mem);
} else {
return Status(StatusCode::kFileNotExist, __LINE__, __FILE__, "No schema has been cached");
}
return Status::OK();
}
Status CacheService::BuildPhaseDone() {
if (HasBuildPhase()) {
// Exclusive lock to switch phase
UniqueLock rw(&rw_lock_);
st_ = State::kFetchPhase;
cp_->SetLocking(false);
return Status::OK();
} else {
RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase");
}
}
Status CacheService::ToggleWriteMode(bool on_off) {
UniqueLock rw(&rw_lock_);
if (HasBuildPhase()) {
RETURN_STATUS_UNEXPECTED("Not applicable to non-mappable dataset");
} else {
// If we stop accepting write request, we turn off locking for the
// underlying B+ tree. All future write request we will return kOutOfMemory.
if (st_ == State::kNone && !on_off) {
st_ = State::kNoLocking;
cp_->SetLocking(on_off);
MS_LOG(WARNING) << "Locking mode is switched off.";
} else if (st_ == State::kNoLocking && on_off) {
st_ = State::kNone;
cp_->SetLocking(on_off);
}
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -20,6 +20,7 @@
#include <algorithm>
#include <atomic>
#include <memory>
#include <mutex>
#include <string>
#include <type_traits>
#include <utility>
@ -44,9 +45,8 @@ using key_size_pair = std::pair<CachePool::key_type, size_t>;
class CacheService : public Service {
public:
friend class CacheServer;
using row_map = BPlusTree<row_id_type, CachePool::key_type>;
enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase };
enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase, kNoLocking };
/// \brief Constructor
/// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited
@ -97,11 +97,9 @@ class CacheService : public Service {
class ServiceStat {
public:
using state_type = std::underlying_type<State>::type;
ServiceStat() : min_(0), max_(0), state_(0) {}
ServiceStat() : state_(0) {}
~ServiceStat() = default;
CachePool::CacheStat stat_{};
row_id_type min_;
row_id_type max_;
state_type state_;
};
/// \brief Statistics for the current service
@ -117,9 +115,9 @@ class CacheService : public Service {
/// \param out A contiguous memory that contains the serialized form of schema.
/// \return Status object
Status FetchSchema(std::string *out) const;
/// \brief Purge the content of a cache
/// \brief Return a set of keys that are definitely cache miss
/// \return Status object
Status Purge();
Status FindKeysMiss(std::vector<row_id_type> *out);
/// \brief Overload the << operator to print a cache service
/// \param out std::ostream
/// \param cs A cache service
@ -136,19 +134,33 @@ class CacheService : public Service {
/// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call.
/// \return Status object
Status BuildPhaseDone();
/// \brief Find out the current memory usage
int64_t GetMemoryUsage() { return cur_mem_usage_; }
/// \brief Find out the current disk usage
int64_t GetDiskUsage() { return cur_disk_usage_; }
/// \brief For kToggleWriteMode request
Status ToggleWriteMode(bool on_off);
private:
mutable RWLock rw_lock_;
std::string root_;
uint64_t cache_mem_sz_;
std::shared_ptr<CachePool> cp_;
std::shared_ptr<row_map> map_;
std::atomic<row_id_type> next_id_;
bool generate_id_;
std::atomic<CachePool::key_type> schema_key_;
std::string cookie_;
State st_;
std::string schema_;
// If we use an Arena, cur_disk_usage is always 0 as we don't know how CachePool manages it.
// Otherwise we track how much is in memory and how much is on disk (if root_ is not empty).
// We use them to control when we should stop caching in memory in the case when there is no
// Arena.
std::atomic<int64_t> cur_mem_usage_;
std::atomic<int64_t> cur_disk_usage_;
// We also cache the result from calling FindKeysMiss because it is expensive. Besides user make
// this request after we hit memory full or disk full. So the result is unlikely to change.
std::mutex get_key_miss_mux_;
std::shared_ptr<std::vector<row_id_type>> key_miss_results_;
/// \brief Private function to generate a row id
/// \return Row id assigned.
row_id_type GetNextRowId() { return next_id_.fetch_add(1); }

View File

@ -92,3 +92,13 @@ table CreateCacheReplyMsg {
connection_id:int64;
cookie:string;
}
table ListSessionMsg {
session_id:uint32;
connection_id:uint64;
stats:ServiceStatMsg;
}
table ListSessionsMsg {
sessions:[ListSessionMsg];
}

View File

@ -53,22 +53,35 @@ CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t row
num_cache_miss_(0),
cache_client_(std::move(cache_client)),
rows_per_buffer_(rows_per_buf),
// We can cause deadlock if this internal Connector size is too small.
keys_miss_(num_workers_, 1, connector_capacity_),
prefetch_size_(cache_client_->getPrefetchSize()) {
prefetch_size_(rows_per_buffer_),
num_prefetchers_(num_workers_) {
// Adjust the prefetch size based on the number of workers.
auto prefetch_sz_per_thread = cache_client_->GetPrefetchSize() / num_prefetchers_;
if (prefetch_size_ < prefetch_sz_per_thread) {
prefetch_size_ = prefetch_sz_per_thread;
MS_LOG(DEBUG) << "Per worker prefetch size : " << prefetch_size_;
}
io_block_queues_.Init(num_workers, op_connector_size);
prefetch_queues_.Init(num_workers, op_connector_size);
sampler_queue_ = std::make_unique<Queue<std::shared_ptr<Tensor>>>(op_connector_size);
prefetch_queues_.Init(num_prefetchers_, op_connector_size);
// We can cause deadlock if this internal Connector size is too small.
keys_miss_ = std::make_unique<Connector<std::vector<row_id_type>>>(num_prefetchers_, 1, connector_capacity_);
}
// Common function to fetch samples from the sampler and send them using the io_block_queues to
// the parallel workers
Status CacheBase::FetchSamplesToWorkers() {
int64_t buf_cnt = 0;
int64_t wait_cnt = 0;
int64_t prefetch_cnt = 0;
// Kick off several threads which will prefetch prefetch_size_ rows in advance. The rows_per_buffers_
// is too small (1 by default) and won't help performance.
RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Dispatcher", std::bind(&CacheBase::Dispatcher, this)));
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1)));
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_prefetchers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1)));
auto send_to_que = [](QueueList<std::unique_ptr<IOBlock>> &qList, int32_t worker_id,
std::vector<row_id_type> &keys) -> Status {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(qList[worker_id]->Add(std::move(blk)));
return Status::OK();
};
// Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them
// to the WorkerEntry.
do {
@ -82,33 +95,54 @@ Status CacheBase::FetchSamplesToWorkers() {
++wait_cnt;
std::vector<row_id_type> keys;
keys.reserve(rows_per_buffer_);
std::vector<row_id_type> prefetch_keys;
prefetch_keys.reserve(prefetch_size_);
std::unique_ptr<DataBuffer> sampler_buffer;
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
while (!sampler_buffer->eoe()) {
TensorRow sample_row;
RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row));
std::shared_ptr<Tensor> sample_ids = sample_row[0];
// Send the sampler tensor to other thread for prefetching. We are using shared pointer so it
// won't go out scope until it is really not in use.
RETURN_IF_NOT_OK(sampler_queue_->Add(sample_ids));
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) {
keys.push_back(*itr);
++row_cnt_;
if (row_cnt_ % rows_per_buffer_ == 0) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
keys.clear();
prefetch_keys.push_back(*itr);
// Batch enough rows for performance reason.
if (row_cnt_ % prefetch_size_ == 0) {
RETURN_IF_NOT_OK(send_to_que(prefetch_queues_, prefetch_cnt++ % num_prefetchers_, prefetch_keys));
// Now we tell the WorkerEntry to wait for them to come back. If prefetch_size_ is a multiple
// of rows_per_buffer_, the keys vector will always be empty. But it can be partially filled.
// The only requirement we set up is rows_per_buffer_ is less than or equal to prefetch_size_.
for (auto row_id : prefetch_keys) {
keys.push_back(row_id);
if (keys.size() == rows_per_buffer_) {
RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys));
keys.clear();
}
}
prefetch_keys.clear();
}
}
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
}
// Deal with any partial keys left.
if (!prefetch_keys.empty()) {
RETURN_IF_NOT_OK(send_to_que(prefetch_queues_, prefetch_cnt++ % num_prefetchers_, prefetch_keys));
for (auto row_id : prefetch_keys) {
keys.push_back(row_id);
if (keys.size() == rows_per_buffer_) {
RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys));
keys.clear();
}
}
}
if (!keys.empty()) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys));
}
// send the eoe
RETURN_IF_NOT_OK(
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
RETURN_IF_NOT_OK(prefetch_queues_[(prefetch_cnt++) % num_prefetchers_]->Add(
std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
// If repeat but the not last repeat, wait for reset.
if (!IsLastIteration()) {
MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << wait_cnt << " Buffer sent " << buf_cnt;
@ -123,8 +157,6 @@ Status CacheBase::FetchSamplesToWorkers() {
RETURN_IF_NOT_OK(
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof)));
// Shutdown threads
std::shared_ptr<Tensor> empty;
RETURN_IF_NOT_OK(sampler_queue_->Add(std::move(empty)));
for (int32_t i = 0; i < num_workers_; i++) {
RETURN_IF_NOT_OK(
io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
@ -145,13 +177,6 @@ Status CacheBase::FetchFromCache(int32_t worker_id) {
if (blk->eof()) {
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)));
} else if (blk->eoe()) {
if (AllowCacheMiss()) {
// This code path is for CacheLookupOp acting as a sampler. If we get a eoe from
// a sampler, send a eoe to physical leaf op as well.
std::vector<row_id_type> eoe;
eoe.push_back(eoe_row_id);
RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, eoe));
}
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)));
} else {
std::vector<int64_t> keys;
@ -162,22 +187,21 @@ Status CacheBase::FetchFromCache(int32_t worker_id) {
}
std::unique_ptr<DataBuffer> db = std::make_unique<DataBuffer>(buffer_id, DataBuffer::kDeBFlagNone);
std::unique_ptr<TensorQTable> que = std::make_unique<TensorQTable>();
std::vector<row_id_type> cache_miss;
cache_miss.reserve(keys.size());
for (auto row_id : keys) {
TensorRow row;
// Block until the row shows up in the pool.
RETURN_IF_NOT_OK(prefetch_.PopFront(row_id, &row));
RETURN_IF_NOT_OK(GetPrefetchRow(row_id, &row));
if (row.empty()) {
cache_miss.push_back(row_id);
if (AllowCacheMiss()) {
++num_cache_miss_;
} else {
std::string errMsg = "Row id " + std::to_string(row_id) + " not found.";
RETURN_STATUS_UNEXPECTED(errMsg);
}
}
que->push_back(std::move(row));
}
db->set_tensor_table(std::move(que));
if (AllowCacheMiss()) {
// Because of the way connector works, we push unconditionally even cache_miss can be empty.
RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, cache_miss));
}
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db)));
buffer_id += num_workers_;
}
@ -189,7 +213,6 @@ Status CacheBase::RegisterResources() {
RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(sampler_queue_->Register(tree_->AllTasks()));
return Status::OK();
}
@ -208,73 +231,97 @@ Status CacheBase::UpdateColumnMapFromCache() {
return rc;
}
Status CacheBase::Dispatcher() {
TaskManager::FindMe()->Post();
int64_t buf_cnt = 0;
int64_t num_row = 0;
std::vector<row_id_type> keys;
keys.reserve(prefetch_size_);
do {
keys.clear();
std::shared_ptr<Tensor> sample_ids;
RETURN_IF_NOT_OK(sampler_queue_->PopFront(&sample_ids));
if (sample_ids == nullptr) {
// A null shared pointer signal times to quit.
// Also signal all prefetchers to quit.
for (int32_t i = 0; i < num_workers_; i++) {
RETURN_IF_NOT_OK(
prefetch_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
}
break;
}
// Now we distribute the sampler ids to each prefetcher according to the prefetch size.
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) {
keys.push_back(*itr);
++num_row;
if (num_row % prefetch_size_ == 0) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(prefetch_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
keys.clear();
}
}
// Send the remaining sample id
if (!keys.empty()) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(prefetch_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
}
} while (true);
Status CacheBase::GetPrefetchRow(row_id_type row_id, TensorRow *out) {
RETURN_UNEXPECTED_IF_NULL(out);
CHECK_FAIL_RETURN_UNEXPECTED(row_id >= 0, "Expect positive row id");
RETURN_IF_NOT_OK(prefetch_.PopFront(row_id, out));
return Status::OK();
}
Status CacheBase::PrefetchRows(const std::vector<row_id_type> &keys, std::vector<row_id_type> *cache_miss) {
RETURN_UNEXPECTED_IF_NULL(cache_miss);
std::vector<row_id_type> prefetch_keys;
prefetch_keys.reserve(keys.size());
// Filter out all those keys that unlikely we will find at the server
for (auto row_id : keys) {
if (cache_client_->KeyIsCacheMiss(row_id)) {
// Just put an empty row in the cache.
TensorRow row;
row.setId(row_id);
RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
cache_miss->push_back(row_id);
} else {
prefetch_keys.push_back(row_id);
}
}
// Early exit if nothing to fetch
if (prefetch_keys.empty()) {
return Status::OK();
}
// Get the rows from the server
TensorTable ttbl;
Status rc = cache_client_->GetRows(prefetch_keys, &ttbl);
if (rc.IsOk()) {
auto row_it = ttbl.begin();
for (auto row_id : prefetch_keys) {
auto &row = *row_it;
if (row.empty()) {
cache_miss->push_back(row_id);
}
// Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row
RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
++row_it;
}
} else {
// In case any thread is waiting for the rows to come back and blocked on a semaphore,
// we will put an empty row in the local cache.
for (auto row_id : prefetch_keys) {
TensorRow row;
row.setId(row_id);
RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
cache_miss->push_back(row_id);
}
}
return rc;
}
Status CacheBase::Prefetcher(int32_t worker_id) {
TaskManager::FindMe()->Post();
std::vector<row_id_type> prefetch_keys;
prefetch_keys.reserve(prefetch_size_);
std::vector<row_id_type> cache_miss;
cache_miss.reserve(prefetch_size_);
do {
prefetch_keys.clear();
cache_miss.clear();
std::unique_ptr<IOBlock> blk;
RETURN_IF_NOT_OK(prefetch_queues_[worker_id]->PopFront(&blk));
RETURN_IF_NOT_OK(blk->GetKeys(&prefetch_keys));
if (prefetch_keys.empty()) {
// Empty keys mean time to quit.
break;
}
TensorTable ttbl;
RETURN_IF_NOT_OK(cache_client_->GetRows(prefetch_keys, &ttbl));
auto row_it = ttbl.begin();
for (auto row_id : prefetch_keys) {
auto &row = *row_it;
if (row.empty()) {
if (AllowCacheMiss()) {
++num_cache_miss_;
} else {
std::string errMsg = "Row id " + std::to_string(row_id) + " not found.";
RETURN_STATUS_UNEXPECTED(errMsg);
CHECK_FAIL_RETURN_UNEXPECTED(!blk->eof(), "Expect eoe or a regular io block");
if (!blk->eoe()) {
RETURN_IF_NOT_OK(blk->GetKeys(&prefetch_keys));
Status rc;
const int32_t max_retries = 5;
int32_t retry_count = 0;
do {
rc = PrefetchRows(prefetch_keys, &cache_miss);
if (rc.IsNetWorkError() && retry_count < max_retries) {
// If we get some network error, we will attempt some retries
retry_count++;
} else if (rc.IsError()) {
return rc;
}
} while (rc.IsNetWorkError());
} else {
if (AllowCacheMiss()) {
// This code path is for CacheLookupOp acting as a sampler. If we get a eoe from
// a sampler, send a eoe to physical leaf op as well.
cache_miss.push_back(eoe_row_id);
}
// Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row
RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
++row_it;
}
if (AllowCacheMiss()) {
// Because of the way connector works, we push unconditionally even cache_miss can be empty.
RETURN_IF_NOT_OK(keys_miss_->Push(worker_id, cache_miss));
}
} while (true);
return Status::OK();

View File

@ -22,6 +22,7 @@
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/connector.h"
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/cache/cache_service.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
@ -90,8 +91,7 @@ class CacheBase : public ParallelOp {
std::shared_ptr<CacheClient> cache_client_;
WaitPost epoch_sync_;
int32_t rows_per_buffer_;
Connector<std::vector<row_id_type>> keys_miss_;
QueueMap<row_id_type, TensorRow> prefetch_;
std::unique_ptr<Connector<std::vector<row_id_type>>> keys_miss_;
/// \brief Common function to register resources for interrupt
/// \note Derived should override this function for extra resources to be registered
@ -111,13 +111,16 @@ class CacheBase : public ParallelOp {
constexpr static int32_t connector_capacity_ = 1024;
int32_t prefetch_size_;
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
int32_t num_prefetchers_;
QueueList<std::unique_ptr<IOBlock>> prefetch_queues_;
std::unique_ptr<Queue<std::shared_ptr<Tensor>>> sampler_queue_;
QueueMap<row_id_type, TensorRow> prefetch_;
Status Dispatcher();
/// \brief Prefetcher. It prefetch the rows from cache server
/// \return Status object.
Status Prefetcher(int32_t worker_id);
/// \brief Functions used by prefetcher and WorkerEntry
Status PrefetchRows(const std::vector<row_id_type> &keys, std::vector<row_id_type> *cache_miss);
Status GetPrefetchRow(row_id_type row_id, TensorRow *out);
};
} // namespace dataset
} // namespace mindspore

View File

@ -87,10 +87,10 @@ Status CacheLookupOp::InitSampler() { return Sampler::InitSampler(); }
void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); }
Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
std::vector<row_id_type> cache_miss;
RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss));
RETURN_IF_NOT_OK(keys_miss_->Pop(0, &cache_miss));
// Ignore the case we have no cache miss, we can't return empty samples.
while (cache_miss.empty()) {
RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss));
RETURN_IF_NOT_OK(keys_miss_->Pop(0, &cache_miss));
}
// Special code for eoe
if (cache_miss.at(0) == eoe_row_id) {

View File

@ -25,6 +25,7 @@
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/util/system_pool.h"
#include "minddata/dataset/util/task_manager.h"
namespace mindspore {
@ -48,7 +49,8 @@ CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler)
: ParallelOp(numWorkers, opConnectorSize, sampler),
num_cleaners_(numCleaners),
cache_client_(std::move(cache_client)) {}
cache_client_(std::move(cache_client)),
cache_missing_rows_(true) {}
Status CacheMergeOp::operator()() {
// A queue of row id to let cleaner send cache miss rows to the cache server
@ -129,17 +131,19 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
std::string errMsg = "Expect positive row id: " + std::to_string(row_id);
RETURN_STATUS_UNEXPECTED(errMsg);
}
// Technically number of this row shows up in the cache miss stream is equal to the number
// of P() call. However the cleaner wants it too. So we need an extra copy.
TensorRowCacheRequest *rq;
RETURN_IF_NOT_OK(GetRq(row_id, &rq));
if (rq->GetState() == TensorRowCacheRequest::State::kEmpty) {
// We will send the request async. But any error we most
// likely ignore and continue.
Status rc;
rc = rq->AsyncSendCacheRequest(cache_client_, row);
if (rc.IsOk()) {
RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id));
if (cache_missing_rows_) {
// Technically number of this row shows up in the cache miss stream is equal to the number
// of P() call. However the cleaner wants it too. So we need an extra copy.
TensorRowCacheRequest *rq;
RETURN_IF_NOT_OK(GetRq(row_id, &rq));
if (rq->GetState() == TensorRowCacheRequest::State::kEmpty) {
// We will send the request async. But any error we most
// likely ignore and continue.
Status rc;
rc = rq->AsyncSendCacheRequest(cache_client_, row);
if (rc.IsOk()) {
RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id));
}
}
}
RETURN_IF_NOT_OK(cache_miss_.Add(row_id, std::move(row)));
@ -168,13 +172,18 @@ Status CacheMergeOp::Cleaner() {
Status rc = rq->CheckCacheResult();
if (rc.IsError()) {
// If interrupt, time to quit.
if (rc.get_code() == StatusCode::kInterrupted) {
if (rc.IsInterrupted()) {
return Status::OK();
} else if (rc.IsOutofMemory() || rc.IsNoSpace()) {
// The server is hitting some limit and we will turn off caching from now on.
cache_missing_rows_ = false;
cache_client_->ServerRunningOutOfResources();
} else {
MS_LOG(INFO) << "Cache row not successful: " << rc.ToString();
// Bad rc should not bring down the pipeline. We will simply continue and
// change the state back to empty. We don't need a CAS from CLEAN back to EMPTY.
rq->SetState(TensorRowCacheRequest::State::kEmpty);
}
MS_LOG(INFO) << "Cache row not successful: " << rc.ToString();
// Bad rc should not bring down the pipeline. We will simply continue and
// change the state back to empty. We don't need a CAS from CLEAN back to EMPTY.
rq->SetState(TensorRowCacheRequest::State::kEmpty);
}
}
return Status::OK();
@ -253,7 +262,7 @@ Status CacheMergeOp::Accept(NodePass *p, bool *modified) {
Status CacheMergeOp::EoeReceived(int32_t worker_id) {
// If we are in a repeat path, send the eoe up.
// Otherwise ignore it.
if (op_total_repeats_ > 1) {
if (op_total_repeats_ != 1) {
return DatasetOp::EoeReceived(worker_id);
}
return Status::OK();
@ -281,7 +290,7 @@ Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowCacheReque
*out = it->second.GetMutablePointer();
} else {
// We will create a new one.
auto alloc = Services::GetAllocator<TensorRowCacheRequest>();
auto alloc = SystemPool::GetAllocator<TensorRowCacheRequest>();
auto r = io_request_.emplace(row_id, MemGuard<TensorRowCacheRequest, Allocator<TensorRowCacheRequest>>(alloc));
if (r.second) {
auto &mem = r.first->second;

View File

@ -202,6 +202,7 @@ class CacheMergeOp : public ParallelOp {
std::unique_ptr<Queue<row_id_type>> io_que_;
std::shared_ptr<CacheClient> cache_client_;
int32_t num_cleaners_;
std::atomic<bool> cache_missing_rows_;
/// \brief Locate the cache request from the io_request_ map
/// \param row_id

View File

@ -16,6 +16,7 @@
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include <memory>
#include <utility>
#include <vector>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/constants.h"
@ -64,7 +65,7 @@ Status CacheOp::Builder::Build(std::shared_ptr<CacheOp> *ptr) {
// Constructor of CacheOp
CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
: CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler),
: CacheBase(num_workers, op_connector_size, rows_per_buf, std::move(cache_client), std::move(sampler)),
num_guys_in_(0),
phase_(Phase::kBuildPhase) {}
@ -174,7 +175,7 @@ Status CacheOp::WorkerEntry(int32_t worker_id) {
Status CacheOp::RegisterResources() {
RETURN_IF_NOT_OK(CacheBase::RegisterResources());
RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(keys_miss_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(keys_miss_->Register(tree_->AllTasks()));
return Status::OK();
}

View File

@ -20,6 +20,7 @@
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/datasetops/concat_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
@ -188,5 +189,11 @@ Status ConcatOp::ComputeColMap() {
}
return Status::OK();
}
// Visitor pre-accept method for NodePass
Status ConcatOp::PreAccept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->PreRunOnNode(shared_from_base<ConcatOp>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -105,6 +105,12 @@ class ConcatOp : public PipelineOp {
// @return - Status
Status ComputeColMap() override;
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
private:
Status Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf);

View File

@ -243,6 +243,7 @@ void DatasetOp::Print(std::ostream &out, bool show_all) const {
out << "\nConnector queue size : " << oc_queue_size_ << "\nTotal repeats : " << op_total_repeats_
<< "\nNumber repeats per epoch : " << op_num_repeats_per_epoch_;
if (sampler_) {
out << "\nSampler:\n";
sampler_->Print(out, show_all);
}
}

View File

@ -268,5 +268,11 @@ Status FilterOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<FilterOp>(), modified);
}
// Visitor pre-accept method for NodePass
Status FilterOp::PreAccept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->PreRunOnNode(shared_from_base<FilterOp>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -121,6 +121,12 @@ class FilterOp : public ParallelOp {
// @param show_all A bool to control if you want to show all info or just a summary.
void Print(std::ostream &out, bool show_all) const override;
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.

View File

@ -458,6 +458,12 @@ Status MapOp::Accept(NodePass *p, bool *modified) {
return p->RunOnNode(shared_from_base<MapOp>(), modified);
}
// Visitor pre-accept method for NodePass
Status MapOp::PreAccept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->PreRunOnNode(shared_from_base<MapOp>(), modified);
}
Status MapOp::WaitForWorkers() {
// reset num_paused workers to 0
num_workers_paused_ = 0;

View File

@ -177,10 +177,16 @@ class MapOp : public ParallelOp {
// @return the number of threads consuming data from previous op's output Connector.
int32_t num_consumers() const override;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
/// \brief Base-class override for NodePass visitor acceptor.
/// \param[in] p Pointer to the NodePass to be accepted.
/// \param[out] modified Whether this node visit modified the pipeline.
/// \return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
// Op name getter

View File

@ -52,9 +52,9 @@ Status ParallelOp::CreateWorkerConnector(int32_t worker_connector_size) {
void ParallelOp::Print(std::ostream &out, bool show_all) const {
// Summary 1-liner print
if (!show_all) {
out << " [workers: " << num_workers_ << "]";
// Call super class printer
DatasetOp::Print(out, show_all);
out << " [workers: " << num_workers_ << "]";
} else {
// Detailed print
DatasetOp::Print(out, show_all);

View File

@ -27,14 +27,14 @@ PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampl
void PipelineOp::Print(std::ostream &out, bool show_all) const {
// Summary 1-liner print
if (!show_all) {
// Call super class printer
DatasetOp::Print(out, show_all);
out << " [workers: ";
if (this->inlined()) {
out << "0 (inlined)]";
} else {
out << "1]"; // Pipeline ops only have 1 worker
}
// Call super class printer
DatasetOp::Print(out, show_all);
} else {
// Detailed print
DatasetOp::Print(out, show_all);

View File

@ -235,6 +235,12 @@ Status ZipOp::EoeReceived(int32_t) {
return Status::OK();
}
// Visitor pre-accept method for NodePass
Status ZipOp::PreAccept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->PreRunOnNode(shared_from_base<ZipOp>(), modified);
}
// Visitor accept method for NodePass
Status ZipOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor

View File

@ -104,10 +104,16 @@ class ZipOp : public PipelineOp {
// @return Status - The error code return
Status operator()() override;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
/// \brief Base-class override for NodePass visitor acceptor.
/// \param[in] p Pointer to the NodePass to be accepted.
/// \param[out] modified Whether this node visit modified the pipeline.
/// \return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
// Op name getter

View File

@ -26,6 +26,7 @@
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
#include "minddata/dataset/engine/opt/post/repeat_pass.h"
#endif
#include "minddata/dataset/engine/opt/pre/cache_error_pass.h"
#include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h"
#include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
#include "minddata/dataset/engine/perf/profiling.h"
@ -235,6 +236,7 @@ Status ExecutionTree::PrepareTreePreAction() {
std::vector<std::unique_ptr<Pass>> pre_actions;
// Construct pre actions
MS_LOG(INFO) << "Running pre pass loops.";
pre_actions.push_back(std::make_unique<CacheErrorPass>());
pre_actions.push_back(std::make_unique<EpochInjectionPass>());
pre_actions.push_back(std::make_unique<RemovalPass>());
#ifndef ENABLE_ANDROID

View File

@ -3,6 +3,7 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE
add_library(engine-opt OBJECT
pass.cc
post/repeat_pass.cc
pre/cache_error_pass.cc
pre/cache_transform_pass.cc
pre/epoch_injection_pass.cc
pre/removal_pass.cc

View File

@ -23,6 +23,7 @@
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#endif
#include "minddata/dataset/engine/datasetops/concat_op.h"
#include "minddata/dataset/engine/datasetops/dataset_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
@ -143,40 +144,6 @@ Status NodePass::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
#ifndef ENABLE_ANDROID
Status NodePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
#endif
#ifdef ENABLE_PYTHON
Status NodePass::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
#endif
Status NodePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
@ -207,13 +174,6 @@ Status NodePass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) {
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
#ifndef ENABLE_ANDROID
Status NodePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
#endif
Status NodePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
@ -239,18 +199,6 @@ Status NodePass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
#ifndef ENABLE_ANDROID
Status NodePass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
#endif
Status NodePass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
@ -261,18 +209,6 @@ Status NodePass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
#ifndef ENABLE_ANDROID
Status NodePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
#endif
Status NodePass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
@ -283,12 +219,88 @@ Status NodePass::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
#ifndef ENABLE_ANDROID
Status NodePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
#endif
#ifdef ENABLE_PYTHON
Status NodePass::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
#endif
} // namespace dataset
} // namespace mindspore

View File

@ -37,18 +37,6 @@ class SkipOp;
class ShuffleOp;
#ifndef ENABLE_ANDROID
class MindRecordOp;
class TFReaderOp;
#endif
#ifdef ENABLE_PYTHON
class FilterOp;
class GeneratorOp;
#endif
class AlbumOp;
class RandomDataOp;
@ -63,10 +51,6 @@ class DeviceQueueOp;
class ImageFolderOp;
#ifndef ENABLE_ANDROID
class CacheOp;
#endif
class MnistOp;
class ManifestOp;
@ -79,20 +63,32 @@ class CocoOp;
class CelebAOp;
#ifndef ENABLE_ANDROID
class CacheMergeOp;
class CacheLookupOp;
#endif
class EpochCtrlOp;
class BuildVocabOp;
class ConcatOp;
#ifndef ENABLE_ANDROID
class MindRecordOp;
class TFReaderOp;
class CacheOp;
class CacheMergeOp;
class CacheLookupOp;
class BuildSentencePieceVocabOp;
#endif
#ifdef ENABLE_PYTHON
class FilterOp;
class GeneratorOp;
#endif
// The base class Pass is the basic unit of tree transformation.
// The actual implementation of the passes will be derived from here.
class Pass : public std::enable_shared_from_this<Pass> {
@ -168,22 +164,6 @@ class NodePass : public Pass {
virtual Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified);
#ifndef ENABLE_ANDROID
virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified);
#endif
#ifdef ENABLE_PYTHON
virtual Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified);
#endif
virtual Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified);
@ -194,10 +174,6 @@ class NodePass : public Pass {
virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified);
#ifndef ENABLE_ANDROID
virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
#endif
virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified);
@ -210,32 +186,50 @@ class NodePass : public Pass {
virtual Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified);
#ifndef ENABLE_ANDROID
virtual Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified);
#endif
virtual Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified);
#ifndef ENABLE_ANDROID
virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
#endif
virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified);
#ifndef ENABLE_ANDROID
virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
#endif
virtual Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified);
#ifndef ENABLE_ANDROID
virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified);
#endif
#ifdef ENABLE_PYTHON
virtual Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified);
#endif
private:
// Helper function to perform DFS visit
Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified);

View File

@ -225,13 +225,17 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) {
// Turns off the tracking for operations under merge op
Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
// If there was not any repeat in the merge cache miss leg, then the cache_lookup
// would not have been consumed yet. In that case, we need to set its total repeats for it.
if (cache_lookup_) {
cache_lookup_->set_total_repeats(num_repeats_);
cache_lookup_->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
}
// Setting the flag is needed since we didn't call the base class DatasetOp version
if (is_repeated_) {
// If there was not any repeat in the merge cache miss leg, then the cache_lookup
// would not have been consumed yet. In that case, we need to assign it to the upper repeat eoe stack
if (cache_lookup_) {
cache_lookup_->set_total_repeats(num_repeats_);
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
AddToEOEOpStack(std::move(cache_lookup_));
}
}

View File

@ -0,0 +1,79 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/zip_op.h"
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
#include "minddata/dataset/engine/opt/pre/cache_error_pass.h"
namespace mindspore {
namespace dataset {
// Constructor
CacheErrorPass::CacheErrorPass() : is_cached_(false) {}
// Identifies the subtree below this node as being cached
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
// Turn on the flag that we're under a merge op
is_cached_ = true;
return Status::OK();
}
// Returns an error if ZipOp exists under a cache
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified) {
if (is_cached_) {
RETURN_STATUS_UNEXPECTED("ZipOp is currently not supported as a descendant operator under a cache.");
}
return Status::OK();
}
// Returns an error if MapOp with non-deterministic TensorOps exists under a cache
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
if (is_cached_) {
auto tfuncs = node->TFuncs();
for (size_t i = 0; i < tfuncs.size(); i++) {
if (!tfuncs[i]->Deterministic()) {
RETURN_STATUS_UNEXPECTED(
"MapOp with non-deterministic TensorOps is currently not supported as a descendant of cache.");
}
}
}
return Status::OK();
}
// Returns an error if ConcatOp exists under a cache
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) {
if (is_cached_) {
RETURN_STATUS_UNEXPECTED("ConcatOp is currently not supported as a descendant operator under a cache.");
}
return Status::OK();
}
#ifdef ENABLE_PYTHON
// Returns an error if FilterOp exists under a cache
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) {
if (is_cached_) {
RETURN_STATUS_UNEXPECTED("FilterOp is currently not supported as a descendant operator under a cache.");
}
return Status::OK();
}
#endif
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,76 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_ERROR_PASS_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_ERROR_PASS_
#include <memory>
#include <stack>
#include <utility>
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
/// \class CacheErrorPass cache_error_pass.h
/// \brief This is a NodePass who's job is to catch invalid tree configurations related to cache and generate failures.
class CacheErrorPass : public NodePass {
public:
/// \brief Constructor
CacheErrorPass();
/// \brief Destructor
~CacheErrorPass() = default;
/// \brief Identifies the subtree below this node as being cached
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Returns an error if ZipOp exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified) override;
/// \brief Returns an error if MapOp with non-deterministic TensorOps exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;
/// \brief Returns an error if ConcatOp exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override;
#ifdef ENABLE_PYTHON
/// \brief Returns an error if FilterOp exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override;
#endif
private:
bool is_cached_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_POST_CACHE_ERROR_PASS_

View File

@ -155,50 +155,77 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ImageFolderOp> n
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for AlbumOp under cache.");
}
return Status::OK();
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for MnistOp under cache.");
}
return Status::OK();
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for CifarOp under cache.");
}
return Status::OK();
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for CocoOp under cache.");
}
return Status::OK();
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for CelebAOp under cache.");
}
return Status::OK();
}
#ifndef ENABLE_ANDROID
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for MindRecordOp under cache.");
}
return Status::OK();
}
#endif
#ifdef ENABLE_PYTHON
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for GeneratorOp under cache.");
}
return Status::OK();
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for ManifestOp under cache.");
}
return Status::OK();
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for VOCOp under cache.");
}
return Status::OK();
}
#endif

View File

@ -40,13 +40,6 @@ Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildSe
injection_point_ = nullptr;
return Status::OK();
}
// Temporary code to prevent the injection of epoch control when cache op is present
// Remove this code in cache op phase 2
Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
injection_point_ = nullptr;
return Status::OK();
}
#endif
Status EpochInjectionPass::InjectionFinder::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) {

View File

@ -54,13 +54,6 @@ class EpochInjectionPass : public TreePass {
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) override;
/// \brief Temporary code to prevent the injection of epoch control when cache op is present.
/// Remove this code in cache op phase 2
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
#endif
/// \brief Register the DeviceQueueOp for further action.

View File

@ -62,6 +62,7 @@ Status RandomApplyOp::Compute(const TensorRow &input, TensorRow *output) {
RandomApplyOp::RandomApplyOp(double prob, const std::vector<std::shared_ptr<TensorOp>> &ops)
: prob_(prob), gen_(GetSeed()), rand_double_(0, 1) {
compose_ = std::make_unique<ComposeOp>(ops);
is_deterministic_ = false;
}
} // namespace dataset

View File

@ -92,6 +92,7 @@ RandomChoiceOp::RandomChoiceOp(const std::vector<std::shared_ptr<TensorOp>> &ops
} else if (ops_.size() == 1) {
MS_LOG(WARNING) << "op_list has only 1 op, this op would be picked every time.";
}
is_deterministic_ = false;
}
} // namespace dataset
} // namespace mindspore

View File

@ -44,6 +44,7 @@ RandomAffineOp::RandomAffineOp(std::vector<float_t> degrees, std::vector<float_t
interpolation_ = interpolation;
fill_value_ = fill_value;
rnd_.seed(GetSeed());
is_deterministic_ = false;
}
Status RandomAffineOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {

View File

@ -36,6 +36,7 @@ RandomColorAdjustOp::RandomColorAdjustOp(float s_bright_factor, float e_bright_f
hue_factor_start_(s_hue_factor),
hue_factor_end_(e_hue_factor) {
rnd_.seed(GetSeed());
is_deterministic_ = false;
}
Status RandomColorAdjustOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {

View File

@ -19,7 +19,9 @@
namespace mindspore {
namespace dataset {
RandomColorOp::RandomColorOp(float t_lb, float t_ub) : rnd_(GetSeed()), dist_(t_lb, t_ub), t_lb_(t_lb), t_ub_(t_ub) {}
RandomColorOp::RandomColorOp(float t_lb, float t_ub) : rnd_(GetSeed()), dist_(t_lb, t_ub), t_lb_(t_lb), t_ub_(t_ub) {
is_deterministic_ = false;
}
Status RandomColorOp::Compute(const std::shared_ptr<Tensor> &in, std::shared_ptr<Tensor> *out) {
IO_CHECK(in, out);

View File

@ -41,6 +41,7 @@ RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t targ
aspect_ub_(aspect_ub),
max_iter_(max_iter) {
rnd_.seed(GetSeed());
is_deterministic_ = false;
}
Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {

View File

@ -46,6 +46,7 @@ RandomCropOp::RandomCropOp(int32_t crop_height, int32_t crop_width, int32_t pad_
fill_g_(fill_g),
fill_b_(fill_b) {
rnd_.seed(GetSeed());
is_deterministic_ = false;
}
Status RandomCropOp::ImagePadding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *pad_image,

View File

@ -33,6 +33,7 @@ class RandomHorizontalFlipOp : public TensorOp {
static const float kDefProbability;
explicit RandomHorizontalFlipOp(float probability = kDefProbability) : distribution_(probability) {
is_deterministic_ = false;
rnd_.seed(GetSeed());
}

View File

@ -35,6 +35,7 @@ class RandomHorizontalFlipWithBBoxOp : public TensorOp {
explicit RandomHorizontalFlipWithBBoxOp(float probability = kDefProbability) : distribution_(probability) {
rnd_.seed(GetSeed());
is_deterministic_ = false;
}
~RandomHorizontalFlipWithBBoxOp() override = default;

View File

@ -29,6 +29,7 @@ const std::vector<uint8_t> RandomPosterizeOp::kBitRange = {4, 8};
RandomPosterizeOp::RandomPosterizeOp(const std::vector<uint8_t> &bit_range)
: PosterizeOp(bit_range[0]), bit_range_(bit_range) {
rnd_.seed(GetSeed());
is_deterministic_ = false;
}
Status RandomPosterizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {

View File

@ -35,6 +35,7 @@ class RandomResizeOp : public ResizeOp {
explicit RandomResizeOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeOp(size_1, size_2) {
random_generator_.seed(GetSeed());
is_deterministic_ = false;
}
~RandomResizeOp() = default;

View File

@ -36,6 +36,7 @@ class RandomResizeWithBBoxOp : public ResizeWithBBoxOp {
static const int32_t kDefTargetWidth;
explicit RandomResizeWithBBoxOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeWithBBoxOp(size_1, size_2) {
random_generator_.seed(GetSeed());
is_deterministic_ = false;
}
~RandomResizeWithBBoxOp() = default;

View File

@ -46,6 +46,7 @@ RandomRotationOp::RandomRotationOp(float start_degree, float end_degree, float c
fill_g_(fill_g),
fill_b_(fill_b) {
rnd_.seed(GetSeed());
is_deterministic_ = false;
}
// main function call for random rotation : Generate the random degrees

View File

@ -90,6 +90,7 @@ RandomSelectSubpolicyOp::RandomSelectSubpolicyOp(const std::vector<Subpolicy> &p
if (policy_.empty()) {
MS_LOG(ERROR) << "policy in RandomSelectSubpolicyOp is empty.";
}
is_deterministic_ = false;
}
} // namespace dataset

View File

@ -31,6 +31,7 @@ const float RandomSharpnessOp::kDefEndDegree = 1.9;
RandomSharpnessOp::RandomSharpnessOp(float start_degree, float end_degree)
: start_degree_(start_degree), end_degree_(end_degree) {
rnd_.seed(GetSeed());
is_deterministic_ = false;
}
/// main function call for random sharpness : Generate the random degrees

View File

@ -32,7 +32,10 @@ namespace dataset {
class RandomSolarizeOp : public SolarizeOp {
public:
// Pick a random threshold value to solarize the image with
explicit RandomSolarizeOp(std::vector<uint8_t> threshold = {0, 255}) : threshold_(threshold) { rnd_.seed(GetSeed()); }
explicit RandomSolarizeOp(std::vector<uint8_t> threshold = {0, 255}) : threshold_(threshold) {
rnd_.seed(GetSeed());
is_deterministic_ = false;
}
~RandomSolarizeOp() = default;

View File

@ -34,6 +34,7 @@ class RandomVerticalFlipOp : public TensorOp {
explicit RandomVerticalFlipOp(float probability = kDefProbability) : distribution_(probability) {
rnd_.seed(GetSeed());
is_deterministic_ = false;
}
~RandomVerticalFlipOp() override = default;

View File

@ -34,6 +34,7 @@ class RandomVerticalFlipWithBBoxOp : public TensorOp {
// @param probability: Probablity of Image flipping, 0.5 by default
explicit RandomVerticalFlipWithBBoxOp(float probability = kDefProbability) : distribution_(probability) {
rnd_.seed(GetSeed());
is_deterministic_ = false;
}
~RandomVerticalFlipWithBBoxOp() override = default;

View File

@ -168,6 +168,10 @@ class TensorOp {
// @return true/false
bool OneToOne() { return NumInput() == 1 && NumOutput() == 1; }
// Returns true oif the TensorOp produces deterministic result.
// @return true/false
bool Deterministic() { return is_deterministic_; }
// Function to determine the number of inputs the TensorOp can take. 0: means undefined.
// @return uint32_t
virtual uint32_t NumInput() { return 1; }
@ -191,6 +195,9 @@ class TensorOp {
virtual Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs);
virtual std::string Name() const = 0;
protected:
bool is_deterministic_{true};
};
} // namespace dataset
} // namespace mindspore

View File

@ -88,21 +88,21 @@ class Allocator {
std::shared_ptr<MemoryPool> pool_;
};
/// \brief It is a wrapper of unique_ptr with a custom Allocator class defined above
template <typename T, typename... Args>
Status MakeUnique(std::unique_ptr<T[], std::function<void(T *)>> *out, Allocator<T> alloc, size_t n, Args &&... args) {
template <typename T, typename C = std::allocator<T>, typename... Args>
Status MakeUnique(std::unique_ptr<T[], std::function<void(T *)>> *out, C alloc, size_t n, Args &&... args) {
RETURN_UNEXPECTED_IF_NULL(out);
CHECK_FAIL_RETURN_UNEXPECTED(n > 0, "size must be positive");
try {
T *data = alloc.allocate(n);
if (!std::is_arithmetic<T>::value) {
for (auto i = 0; i < n; i++) {
std::allocator_traits<Allocator<T>>::construct(alloc, &(data[i]), std::forward<Args>(args)...);
std::allocator_traits<C>::construct(alloc, &(data[i]), std::forward<Args>(args)...);
}
}
auto deleter = [](T *p, Allocator<T> f_alloc, size_t f_n) {
auto deleter = [](T *p, C f_alloc, size_t f_n) {
if (!std::is_arithmetic<T>::value && std::is_destructible<T>::value) {
for (auto i = 0; i < f_n; ++i) {
std::allocator_traits<Allocator<T>>::destroy(f_alloc, &p[i]);
std::allocator_traits<C>::destroy(f_alloc, &p[i]);
}
}
f_alloc.deallocate(p, f_n);
@ -129,7 +129,7 @@ class MemGuard {
MemGuard(const MemGuard &) = delete;
MemGuard &operator=(const MemGuard &) = delete;
// On the other hand, We can support move constructor
MemGuard(MemGuard &&lhs) noexcept : alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)), n_(lhs.n_) {}
MemGuard(MemGuard &&lhs) noexcept : n_(lhs.n_), alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)) {}
MemGuard &operator=(MemGuard &&lhs) noexcept {
if (this != &lhs) {
this->deallocate();

View File

@ -37,7 +37,8 @@ struct MemHdr {
ArenaImpl::ArenaImpl(void *ptr, size_t sz) : size_in_bytes_(sz), ptr_(ptr) {
// Divide the memory into blocks. Ignore the last partial block.
uint64_t num_blks = size_in_bytes_ / ARENA_BLK_SZ;
MS_LOG(DEBUG) << "Size of memory pool is " << num_blks << ", number of blocks of size is " << ARENA_BLK_SZ << ".";
MS_LOG(DEBUG) << "Arena memory pool is created. Number of blocks : " << num_blks << ". Block size : " << ARENA_BLK_SZ
<< ".";
tr_.Insert(0, num_blks);
}
@ -233,9 +234,9 @@ std::ostream &operator<<(std::ostream &os, const ArenaImpl &s) {
Status Arena::Init() {
try {
auto sz = size_in_MB_ * 1048576L;
mem_ = std::make_unique<uint8_t[]>(sz);
impl_ = std::make_unique<ArenaImpl>(mem_.get(), sz);
int64_t sz = size_in_MB_ * 1048576L;
RETURN_IF_NOT_OK(mem_.allocate(sz));
impl_ = std::make_unique<ArenaImpl>(mem_.GetMutablePointer(), sz);
} catch (std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory);
}

View File

@ -19,6 +19,7 @@
#include <memory>
#include <mutex>
#include <utility>
#include "minddata/dataset/util/allocator.h"
#include "minddata/dataset/util/memory_pool.h"
#include "minddata/dataset/util/treap.h"
@ -140,7 +141,7 @@ class Arena : public MemoryPool {
protected:
mutable std::mutex mux_;
std::unique_ptr<ArenaImpl> impl_;
std::unique_ptr<uint8_t[]> mem_;
MemGuard<uint8_t> mem_;
size_t size_in_MB_;
explicit Arena(size_t val_in_MB = 4096);

View File

@ -131,6 +131,23 @@ class BPlusTree {
tree_stats() : size_(0), leaves_(0), inner_nodes_(0), level_(0) {}
};
/// \brief Statistics functions
/// \return Return the height of the tree
auto GetHeight() const { return empty() ? 0 : stats_.level_ + 1; }
/// \return Order of the B+ tree
auto GetOrder() const { return traits::kLeafSlots; }
/// \return Number of leaves nodes
auto GetNumLeaves() const { return stats_.leaves_; }
/// \return Number of inner nodes
auto GetNumInnerNodes() const { return stats_.inner_nodes_; }
/// \brief Toggle locking
/// \note Once locking is off. It is user's responsibility to ensure concurrency
void SetLocking(bool on_off) {
UniqueLock lck(&rw_lock_);
acquire_lock_ = on_off;
}
private:
// Abstract class of a node (leaf or inner)
class BaseNode {
@ -288,6 +305,17 @@ class BPlusTree {
key_compare key_less_;
// Stat
tree_stats stats_;
// lock mode
bool acquire_lock_;
void Init() {
typename LeafNode::alloc_type alloc(alloc_);
auto *p = alloc.allocate(1);
root_ = new (p) LeafNode(alloc_);
all_.Prepend(p);
leaf_nodes_.Append(p);
stats_.leaves_++;
}
bool LessThan(const key_type &a, const key_type &b) const { return key_less_(a, b); }
@ -350,11 +378,11 @@ class BPlusTree {
~Iterator();
explicit Iterator(const Iterator &);
Iterator(const Iterator &);
Iterator &operator=(const Iterator &lhs);
explicit Iterator(Iterator &&);
Iterator(Iterator &&) noexcept;
Iterator &operator=(Iterator &&lhs);
@ -399,11 +427,11 @@ class BPlusTree {
ConstIterator(const LeafNode *leaf, slot_type slot, bool locked = false)
: cur_(leaf), slot_(slot), locked_(locked) {}
explicit ConstIterator(const ConstIterator &);
ConstIterator(const ConstIterator &);
ConstIterator &operator=(const ConstIterator &lhs);
explicit ConstIterator(ConstIterator &&);
ConstIterator(ConstIterator &&) noexcept;
ConstIterator &operator=(ConstIterator &&lhs);

View File

@ -413,11 +413,16 @@ typename BPlusTree<K, V, A, C, T>::IndexRc BPlusTree<K, V, A, C, T>::Locate(RWLo
}
template <typename K, typename V, typename A, typename C, typename T>
BPlusTree<K, V, A, C, T>::BPlusTree() : leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr) {}
BPlusTree<K, V, A, C, T>::BPlusTree()
: leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr), acquire_lock_(true) {
Init();
}
template <typename K, typename V, typename A, typename C, typename T>
BPlusTree<K, V, A, C, T>::BPlusTree(const Allocator<V> &alloc)
: alloc_(alloc), leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr) {}
: alloc_(alloc), leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr), acquire_lock_(true) {
Init();
}
template <typename K, typename V, typename A, typename C, typename T>
BPlusTree<K, V, A, C, T>::~BPlusTree() noexcept {
@ -446,20 +451,6 @@ BPlusTree<K, V, A, C, T>::~BPlusTree() noexcept {
template <typename K, typename V, typename A, typename C, typename T>
Status BPlusTree<K, V, A, C, T>::DoInsert(const key_type &key, std::unique_ptr<value_type> &&value) {
IndexRc rc;
if (root_ == nullptr) {
UniqueLock lck(&rw_lock_);
// Check again after we get the lock. Other thread may have created the root node already.
if (root_ == nullptr) {
LeafNode *leaf = nullptr;
rc = AllocateLeaf(&leaf);
if (rc != IndexRc::kOk) {
return IndexRc2Status(rc);
}
leaf_nodes_.Append(leaf);
root_ = leaf;
}
// lock will be unlocked when it goes out of scope.
}
bool retry = false;
do {
// Track all the paths to the target and lock each internal node in S.
@ -468,7 +459,7 @@ Status BPlusTree<K, V, A, C, T>::DoInsert(const key_type &key, std::unique_ptr<v
retry = false;
BaseNode *new_child = nullptr;
key_type new_key = key_type();
rc = InsertKeyValue(&InsCB, root_, key, std::move(value), &new_key, &new_child);
rc = InsertKeyValue(acquire_lock_ ? &InsCB : nullptr, root_, key, std::move(value), &new_key, &new_child);
if (rc == IndexRc::kRetry) {
retry = true;
} else if (rc != IndexRc::kOk) {
@ -511,9 +502,12 @@ std::unique_ptr<V> BPlusTree<K, V, A, C, T>::DoUpdate(const key_type &key, std::
if (root_ != nullptr) {
LeafNode *leaf = nullptr;
slot_type slot;
RWLock *myLock = &this->rw_lock_;
// Lock the tree in S, pass the lock to Locate which will unlock it for us underneath.
myLock->LockShared();
RWLock *myLock = nullptr;
if (acquire_lock_) {
myLock = &this->rw_lock_;
// Lock the tree in S, pass the lock to Locate which will unlock it for us underneath.
myLock->LockShared();
}
IndexRc rc = Locate(myLock, true, root_, key, &leaf, &slot);
if (rc == IndexRc::kOk) {
// All locks from the tree to the parent of leaf are all gone. We still have a X lock
@ -521,7 +515,9 @@ std::unique_ptr<V> BPlusTree<K, V, A, C, T>::DoUpdate(const key_type &key, std::
// Swap out the old value and replace it with new value.
std::unique_ptr<value_type> old = std::move(leaf->data_[leaf->slot_dir_[slot]]);
leaf->data_[leaf->slot_dir_[slot]] = std::move(new_value);
leaf->rw_lock_.Unlock();
if (acquire_lock_) {
leaf->rw_lock_.Unlock();
}
return old;
} else {
MS_LOG(DEBUG) << "Key not found. rc = " << static_cast<int>(rc) << ".";

View File

@ -109,7 +109,7 @@ BPlusTree<K, V, A, C, T>::Iterator::Iterator(const BPlusTree<K, V, A, C, T>::Ite
}
template <typename K, typename V, typename A, typename C, typename T>
BPlusTree<K, V, A, C, T>::Iterator::Iterator(BPlusTree<K, V, A, C, T>::Iterator &&lhs) {
BPlusTree<K, V, A, C, T>::Iterator::Iterator(BPlusTree<K, V, A, C, T>::Iterator &&lhs) noexcept {
this->cur_ = lhs.cur_;
this->slot_ = lhs.slot_;
this->locked_ = lhs.locked_;
@ -241,7 +241,7 @@ BPlusTree<K, V, A, C, T>::ConstIterator::ConstIterator(const BPlusTree<K, V, A,
}
template <typename K, typename V, typename A, typename C, typename T>
BPlusTree<K, V, A, C, T>::ConstIterator::ConstIterator(BPlusTree<K, V, A, C, T>::ConstIterator &&lhs) {
BPlusTree<K, V, A, C, T>::ConstIterator::ConstIterator(BPlusTree<K, V, A, C, T>::ConstIterator &&lhs) noexcept {
this->cur_ = lhs.cur_;
this->slot_ = lhs.slot_;
this->locked_ = lhs.locked_;
@ -290,9 +290,12 @@ std::pair<typename BPlusTree<K, V, A, C, T>::ConstIterator, bool> BPlusTree<K, V
if (root_ != nullptr) {
LeafNode *leaf = nullptr;
slot_type slot;
RWLock *myLock = &this->rw_lock_;
// Lock the tree in S, pass the lock to Locate which will unlock it for us underneath.
myLock->LockShared();
RWLock *myLock = nullptr;
if (acquire_lock_) {
myLock = &this->rw_lock_;
// Lock the tree in S, pass the lock to Locate which will unlock it for us underneath.
myLock->LockShared();
}
IndexRc rc = Locate(myLock, false, root_, key, &leaf, &slot);
bool find = (rc == IndexRc::kOk);
return std::make_pair(ConstIterator(leaf, slot, find), find);
@ -306,9 +309,12 @@ std::pair<typename BPlusTree<K, V, A, C, T>::Iterator, bool> BPlusTree<K, V, A,
if (root_ != nullptr) {
LeafNode *leaf = nullptr;
slot_type slot;
RWLock *myLock = &this->rw_lock_;
// Lock the tree in S, pass the lock to Locate which will unlock it for us underneath.
myLock->LockShared();
RWLock *myLock = nullptr;
if (acquire_lock_) {
myLock = &this->rw_lock_;
// Lock the tree in S, pass the lock to Locate which will unlock it for us underneath.
myLock->LockShared();
}
IndexRc rc = Locate(myLock, false, root_, key, &leaf, &slot);
bool find = (rc == IndexRc::kOk);
return std::make_pair(Iterator(leaf, slot, find), find);

View File

@ -69,7 +69,7 @@ Status BuddySpace::Alloc(const uint64_t sz, BSpaceDescriptor *desc, addr_t *p) n
*p = addr;
return Status::OK();
} else {
return Status(StatusCode::kNoSpace, "BuddySpace full. Not an error. Please ignore.");
return Status(StatusCode::kBuddySpaceFull, "BuddySpace full. Not an error. Please ignore.");
}
}

View File

@ -20,8 +20,13 @@
namespace mindspore {
namespace dataset {
CachePool::CachePool(const value_allocator &alloc, const std::string &root)
: alloc_(alloc), root_(root), subfolder_(Services::GetUniqueID()), sm_(nullptr), tree_(nullptr) {}
CachePool::CachePool(const value_allocator &alloc, bool ourOwnArena, const std::string &root)
: alloc_(alloc),
root_(root),
subfolder_(Services::GetUniqueID()),
sm_(nullptr),
tree_(nullptr),
custom_arena_(ourOwnArena) {}
Status CachePool::DoServiceStart() {
tree_ = std::make_shared<data_index>();
@ -45,9 +50,12 @@ Status CachePool::DoServiceStop() {
}
}
sm_.reset();
for (auto &bl : *tree_) {
if (bl.ptr != nullptr) {
alloc_.deallocate(bl.ptr, bl.sz);
// If it is our own arena, skip freeing individual pieces.
if (!custom_arena_) {
for (auto &bl : *tree_) {
if (bl.ptr != nullptr) {
alloc_.deallocate(bl.ptr, bl.sz);
}
}
}
tree_.reset();
@ -68,7 +76,7 @@ Status CachePool::DoServiceStop() {
return rc2;
}
CachePool::~CachePool() noexcept { (void)ServiceStop(); }
Status CachePool::Insert(const std::vector<ReadableSlice> &buf, CachePool::key_type *key) {
Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlice> &buf, bool writeToDiskDirectly) {
DataLocator bl;
Status rc;
size_t sz = 0;
@ -78,22 +86,31 @@ Status CachePool::Insert(const std::vector<ReadableSlice> &buf, CachePool::key_t
}
bl.sz = sz;
try {
bl.ptr = alloc_.allocate(sz);
// We will do a piecewise copy.
WritableSlice dest(bl.ptr, bl.sz);
size_t pos = 0;
for (auto &v : buf) {
WritableSlice out(dest, pos);
rc = WritableSlice::Copy(&out, v);
if (rc.IsError()) {
break;
if (!writeToDiskDirectly) {
bl.ptr = alloc_.allocate(sz);
// We will do a piecewise copy.
WritableSlice dest(bl.ptr, bl.sz);
size_t pos = 0;
for (auto &v : buf) {
WritableSlice out(dest, pos);
rc = WritableSlice::Copy(&out, v);
if (rc.IsError()) {
break;
}
pos += v.GetSize();
}
pos += v.GetSize();
}
if (rc.IsError()) {
alloc_.deallocate(bl.ptr, sz);
bl.ptr = nullptr;
return rc;
if (rc.IsError()) {
alloc_.deallocate(bl.ptr, sz);
bl.ptr = nullptr;
return rc;
}
} else if (sm_ != nullptr) {
MS_LOG(DEBUG) << "Spill to disk directly ... " << bl.sz << " bytes.";
RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf));
} else {
// If asked to spill to disk instead but there is no storage set up, simply return no memory
// instead.
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
} catch (std::bad_alloc &e) {
if (sm_ != nullptr) {
@ -102,7 +119,13 @@ Status CachePool::Insert(const std::vector<ReadableSlice> &buf, CachePool::key_t
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}
rc = tree_->insert(bl, key);
// Insert into the B+ tree. We may still get out of memory error. So need to catch it.
try {
rc = tree_->DoInsert(key, bl);
} catch (const std::bad_alloc &e) {
rc = Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
// Duplicate key is treated as error and we will also free the memory.
if (rc.IsError() && bl.ptr != nullptr) {
alloc_.deallocate(bl.ptr, sz);
}
@ -138,15 +161,26 @@ Path CachePool::GetSpillPath() const {
auto spill = Path(root_) / subfolder_;
return spill;
}
CachePool::CacheStat CachePool::GetStat() const {
CacheStat cs{0};
CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const {
CacheStat cs{-1, -1, 0, 0, 0};
int64_t total_sz = 0;
for (auto &it : *tree_) {
total_sz += it.sz;
if (it.ptr != nullptr) {
++cs.num_mem_cached;
} else {
++cs.num_disk_cached;
if (tree_->begin() != tree_->end()) {
cs.min_key = tree_->begin().key();
cs.max_key = cs.min_key; // will adjust later.
for (auto it = tree_->begin(); it != tree_->end(); ++it) {
total_sz += it.value().sz;
if (it.value().ptr != nullptr) {
++cs.num_mem_cached;
} else {
++cs.num_disk_cached;
}
auto cur_key = it.key();
if (GetMissingKeys) {
for (auto i = cs.max_key + 1; i < cur_key; ++i) {
cs.gap.push_back((i));
}
}
cs.max_key = cur_key;
}
}
if (total_sz > 0) {

View File

@ -25,13 +25,13 @@
#include "minddata/dataset/util/slice.h"
#include "minddata/dataset/util/storage_manager.h"
#include "minddata/dataset/util/auto_index.h"
#include "minddata/dataset/util/btree.h"
namespace mindspore {
namespace dataset {
/// \brief A CachePool provides service for backup/restore a buffer. A buffer can be represented in a form of vector of
/// ReadableSlice where all memory blocks will be copied to one contiguous block which can be in memory or spilled to
/// disk (if a disk directory is provided). Every buffer insert will return a generated key which can be used to
/// restore the buffer.
/// disk (if a disk directory is provided). User must provide a key to insert the buffer.
/// \see ReadableSlice
class CachePool : public Service {
public:
@ -73,22 +73,25 @@ class CachePool : public Service {
StorageManager::key_type storage_key;
};
using data_index = AutoIndexObj<DataLocator>;
using data_index = BPlusTree<int64_t, DataLocator>;
using key_type = data_index::key_type;
using bl_alloc_type = typename value_allocator::template rebind<DataLocator>::other;
/// \brief Simple statistics returned from CachePool like how many elements are cached in memory and
/// how many elements are spilled to disk.
struct CacheStat {
key_type min_key;
key_type max_key;
int64_t num_mem_cached;
int64_t num_disk_cached;
int64_t average_cache_sz;
std::vector<key_type> gap;
};
/// \brief Constructor
/// \param alloc Allocator to allocate memory from
/// \param root Optional disk folder to spill
explicit CachePool(const value_allocator &alloc, const std::string &root = "");
explicit CachePool(const value_allocator &alloc, bool customArena, const std::string &root = "");
CachePool(const CachePool &) = delete;
CachePool(CachePool &&) = delete;
@ -103,10 +106,11 @@ class CachePool : public Service {
/// \brief Insert a sequence of ReadableSlice objects into the pool.
/// All memory blocks will be consolidated into one contiguous block and be cached in either memory or on disk.
/// \param[in] key User supplied key
/// \param[in] buf A sequence of ReadableSlice objects.
/// \param[out] key Generated key
/// \param[in] writeToDiskDirectly If true, no spill to disk if spill is enabled, or return no memory
/// \return Error code
Status Insert(const std::vector<ReadableSlice> &buf, key_type *key);
Status Insert(key_type key, const std::vector<ReadableSlice> &buf, bool writeToDiskDirectly);
/// \brief Restore a cached buffer (from memory or disk)
/// \param[in] key A previous key returned from Insert
/// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice
@ -122,18 +126,23 @@ class CachePool : public Service {
/// \brief Get statistics.
/// \return CacheStat object
CacheStat GetStat() const;
CacheStat GetStat(bool GetMissingKeys = false) const;
const value_allocator &get_allocator() const;
std::string MyName() const { return subfolder_; }
/// \brief Toggle locking
/// \note Once locking is off. It is user's responsibility to ensure concurrency
void SetLocking(bool on_off) { tree_->SetLocking(on_off); }
private:
value_allocator alloc_;
Path root_;
const std::string subfolder_;
std::shared_ptr<StorageManager> sm_;
std::shared_ptr<data_index> tree_;
bool custom_arena_;
};
} // namespace dataset
} // namespace mindspore

View File

@ -133,12 +133,13 @@ void CircularPool::Deallocate(void *p) {
// Lock in the chain in shared mode and find out which
// segment it comes from
SharedLock lock(&rw_lock_);
auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [p](std::shared_ptr<Arena> &b) -> bool {
auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [this, p](std::shared_ptr<Arena> &b) -> bool {
char *q = reinterpret_cast<char *>(p);
char *base = const_cast<char *>(reinterpret_cast<const char *>(b->get_base_addr()));
return (q > base && q < base + b->get_max_size());
auto *base = reinterpret_cast<const char *>(b->get_base_addr());
return (q > base && q < base + arena_size_ * 1048576L);
});
lock.Unlock();
MS_ASSERT(it != mem_segments_.end());
it->get()->Deallocate(p);
}
@ -150,10 +151,10 @@ Status CircularPool::Reallocate(void **pp, size_t old_sz, size_t new_sz) {
}
void *p = *pp;
SharedLock lock(&rw_lock_);
auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [p](std::shared_ptr<Arena> &b) -> bool {
auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [this, p](std::shared_ptr<Arena> &b) -> bool {
char *q = reinterpret_cast<char *>(p);
char *base = const_cast<char *>(reinterpret_cast<const char *>(b->get_base_addr()));
return (q > base && q < base + b->get_max_size());
auto *base = reinterpret_cast<const char *>(b->get_base_addr());
return (q > base && q < base + arena_size_ * 1048576L);
});
lock.Unlock();
MS_ASSERT(it != mem_segments_.end());

View File

@ -16,11 +16,14 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_
#include <atomic>
#include <deque>
#include <iostream>
#include <map>
#include <memory>
#include <mutex>
#include "minddata/dataset/util/allocator.h"
#include "minddata/dataset/util/system_pool.h"
#include "minddata/dataset/util/semaphore.h"
#include "minddata/dataset/util/services.h"
namespace mindspore {
@ -37,7 +40,7 @@ class QueueMap {
using key_type = K;
using value_type = T;
QueueMap() = default;
QueueMap() : num_rows_(0) {}
virtual ~QueueMap() = default;
/// Add an element <key, T> to the map and wake up any consumer that is waiting
@ -48,6 +51,7 @@ class QueueMap {
RequestQueue *rq = nullptr;
RETURN_IF_NOT_OK(GetRq(key, &rq));
RETURN_IF_NOT_OK(rq->WakeUpAny(std::move(payload)));
++num_rows_;
return Status::OK();
}
@ -56,9 +60,35 @@ class QueueMap {
RequestQueue *rq = nullptr;
RETURN_IF_NOT_OK(GetRq(key, &rq));
RETURN_IF_NOT_OK(rq->Wait(out));
--num_rows_;
return Status::OK();
}
/// Get the number of elements in the container
/// \return The number of elements in the container
int64_t size() const { return num_rows_; }
/// \return if the container is empty
bool empty() const { return num_rows_ == 0; }
/// Print out some useful information about the container
friend std::ostream &operator<<(std::ostream &out, const QueueMap &qm) {
std::unique_lock<std::mutex> lck(qm.mux_);
out << "Number of elements: " << qm.num_rows_ << "\n";
out << "Dumping internal info:\n";
int64_t k = 0;
for (auto &it : qm.all_) {
auto key = it.first;
const RequestQueue *rq = it.second.GetPointer();
out << "(k:" << key << "," << *rq << ") ";
++k;
if (k % 6 == 0) {
out << "\n";
}
}
return out;
}
protected:
/// This is a handshake structure between producer and consumer
class RequestQueue {
@ -86,8 +116,13 @@ class QueueMap {
return Status::OK();
}
friend std::ostream &operator<<(std::ostream &out, const RequestQueue &rq) {
out << "sz:" << rq.row_.size() << ",uc:" << rq.use_count_.Peek();
return out;
}
private:
std::mutex dq_mux_;
mutable std::mutex dq_mux_;
Semaphore use_count_;
std::deque<T> row_;
};
@ -104,7 +139,7 @@ class QueueMap {
*out = it->second.GetMutablePointer();
} else {
// We will create a new one.
auto alloc = Services::GetAllocator<RequestQueue>();
auto alloc = SystemPool::GetAllocator<RequestQueue>();
auto r = all_.emplace(key, MemGuard<RequestQueue, Allocator<RequestQueue>>(alloc));
if (r.second) {
auto &mem = r.first->second;
@ -118,8 +153,9 @@ class QueueMap {
}
private:
std::mutex mux_;
mutable std::mutex mux_;
std::map<K, MemGuard<RequestQueue, Allocator<RequestQueue>>> all_;
std::atomic<int64_t> num_rows_;
};
} // namespace dataset
} // namespace mindspore

View File

@ -29,10 +29,7 @@ void Semaphore::V() {
++value_;
wait_cond_.NotifyOne();
}
int Semaphore::Peek() {
std::unique_lock<std::mutex> lck(mutex_);
return value_;
}
int Semaphore::Peek() const { return value_; }
Status Semaphore::Register(TaskGroup *vg) { return wait_cond_.Register(vg->GetIntrpService()); }
Status Semaphore::Deregister() { return (wait_cond_.Deregister()); }
void Semaphore::ResetIntrpState() { wait_cond_.ResetIntrpState(); }

View File

@ -38,7 +38,7 @@ class Semaphore {
void V();
/// \brief Peek the internal value
/// \return The internal value
int Peek();
int Peek() const;
Status Register(TaskGroup *vg);
Status Deregister();
void ResetIntrpState();

View File

@ -51,6 +51,12 @@ std::string CodeAsString(const StatusCode c) {
case StatusCode::kSyntaxError:
s = "Syntax error";
break;
case StatusCode::kBuddySpaceFull:
s = "BuddySpace full";
break;
case StatusCode::kNetWorkError:
s = "Network error";
break;
case StatusCode::kUnexpectedError:
default:
s = "Unexpected error";

View File

@ -82,6 +82,8 @@ enum class StatusCode : char {
kBoundingBoxInvalidShape = 12,
kSyntaxError = 13,
kTimeOut = 14,
kBuddySpaceFull = 14,
kNetWorkError = 15,
// Make this error code the last one. Add new error code above it.
kUnexpectedError = 127
};
@ -137,6 +139,8 @@ class Status {
bool IsNoSpace() const { return (get_code() == StatusCode::kNoSpace); }
bool IsNetWorkError() const { return (get_code() == StatusCode::kNetWorkError); }
private:
StatusCode code_;
std::string err_msg_;

View File

@ -99,7 +99,11 @@ Status StorageContainer::Write(const ReadableSlice &dest, off64_t offset) const
#endif
if (r_sz != sz) {
errno_t err = (r_sz == 0) ? EOF : errno;
RETURN_STATUS_UNEXPECTED(strerror(err));
if (errno == ENOSPC) {
return Status(StatusCode::kNoSpace, __LINE__, __FILE__);
} else {
RETURN_STATUS_UNEXPECTED(strerror(err));
}
}
return Status::OK();
}

View File

@ -71,10 +71,11 @@ Status StorageManager::Write(key_type *key, const std::vector<ReadableSlice> &bu
key_type out_key;
value_type out_value;
bool create_new_container = false;
size_t last_num_container = -1;
do {
SharedLock lock_s(&rw_lock_);
size_t num_containers = containers_.size();
if (create_new_container) {
if (create_new_container && (num_containers == last_num_container)) {
// Upgrade to exclusvie lock.
lock_s.Upgrade();
create_new_container = false;
@ -95,8 +96,11 @@ Status StorageManager::Write(key_type *key, const std::vector<ReadableSlice> &bu
cont = containers_.at(num_containers - 1);
off64_t offset;
Status rc = cont->Insert(buf, &offset);
if (rc.IsNoSpace()) {
if (rc.get_code() == StatusCode::kBuddySpaceFull) {
create_new_container = true;
// Remember how many containers we saw. In the next iteration we will do a comparision to see
// if someone has already created it.
last_num_container = num_containers;
} else if (rc.IsOk()) {
out_value = std::make_pair(num_containers - 1, std::make_pair(offset, sz));
RETURN_IF_NOT_OK(index_.insert(out_value, &out_key));

View File

@ -15,6 +15,7 @@
"""Cache client
"""
import os
import copy
from ..core.validator_helpers import type_check, check_uint32, check_uint64
@ -25,11 +26,11 @@ class DatasetCache:
A client to interface with tensor caching service
"""
def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, prefetch_size=20):
def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, num_connections=None,
prefetch_size=None):
check_uint32(session_id, "session_id")
check_uint64(size, "size")
type_check(spilling, (bool,), "spilling")
check_uint32(prefetch_size, "prefetch size")
self.session_id = session_id
self.size = size
@ -37,8 +38,13 @@ class DatasetCache:
self.hostname = hostname
self.port = port
self.prefetch_size = prefetch_size
# temporary disable cache feature in the current release
self.cache_client = None
self.num_connections = num_connections
if os.getenv('MS_ENABLE_CACHE') != 'TRUE':
# temporary disable cache feature in the current release
self.cache_client = None
else:
from mindspore._c_dataengine import CacheClient
self.cache_client = CacheClient(session_id, size, spilling, hostname, port, num_connections, prefetch_size)
def GetStat(self):
return self.cache_client.GetStat()
@ -55,5 +61,6 @@ class DatasetCache:
new_cache.hostname = copy.deepcopy(self.hostname, memodict)
new_cache.port = copy.deepcopy(self.port, memodict)
new_cache.prefetch_size = copy.deepcopy(self.prefetch_size, memodict)
new_cache.num_connections = copy.deepcopy(self.num_connections, memodict)
new_cache.cache_client = self.cache_client
return new_cache

View File

@ -1234,5 +1234,8 @@ def check_paddeddataset(method):
def check_cache_option(cache):
"""Sanity check for cache parameter"""
if cache is not None:
# temporary disable cache feature in the current release
raise ValueError("Caching is disabled in the current release")
if os.getenv('MS_ENABLE_CACHE') != 'TRUE':
# temporary disable cache feature in the current release
raise ValueError("Caching is disabled in the current release")
from . import cache_client
type_check(cache, (cache_client.DatasetCache,), "cache")

View File

@ -156,6 +156,24 @@ def update_permissions(path):
if filename == "ms_serving":
os.chmod(file_fullpath, stat.S_IREAD | stat.S_IEXEC)
def bin_files():
"""
Gets the binary files to be installed.
"""
data_files = []
binary_files = []
cache_server_bin = os.path.join('mindspore', 'bin', 'cache_server')
if not os.path.exists(cache_server_bin):
return data_files
binary_files.append(cache_server_bin)
cache_admin_bin = os.path.join('mindspore', 'bin', 'cache_admin')
if not os.path.exists(cache_admin_bin):
return data_files
binary_files.append(cache_admin_bin)
data_files.append(('bin', binary_files))
return data_files
class EggInfo(egg_info):
"""Egg info."""
@ -192,6 +210,7 @@ setup(
'framework that could be used for mobile, edge and cloud scenarios.',
long_description="\n\n".join([readme, release]),
long_description_content_type="text/markdown",
data_files=bin_files(),
packages=find_packages(),
package_data=package_data,
include_package_data=True,

View File

@ -24,27 +24,26 @@
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::MsLogLevel::INFO;
// For testing purposes, we will make the branching factor very low.
struct mytraits {
using slot_type = uint16_t;
static const slot_type kLeafSlots = 6;
static const slot_type kInnerSlots = 3;
using slot_type = uint16_t;
static const slot_type kLeafSlots = 6;
static const slot_type kInnerSlots = 3;
};
class MindDataTestBPlusTree : public UT::Common {
public:
MindDataTestBPlusTree() = default;
MindDataTestBPlusTree() = default;
};
// Test serial insert.
TEST_F(MindDataTestBPlusTree, Test1) {
Allocator<std::string> alloc(std::make_shared<SystemPool>());
BPlusTree<uint64_t, std::string, Allocator<std::string>, std::less<uint64_t>, mytraits> btree(alloc);
BPlusTree<uint64_t, std::string, Allocator<std::string>, std::less<>, mytraits> btree(alloc);
Status rc;
for (int i = 0; i < 100; i++) {
uint64_t key = 2 * i;
@ -109,23 +108,24 @@ TEST_F(MindDataTestBPlusTree, Test1) {
// Test concurrent insert.
TEST_F(MindDataTestBPlusTree, Test2) {
Allocator<std::string> alloc(std::make_shared<SystemPool>());
BPlusTree<uint64_t, std::string, Allocator<std::string>, std::less<uint64_t>, mytraits> btree(alloc);
BPlusTree<uint64_t, std::string, Allocator<std::string>, std::less<>, mytraits> btree(alloc);
TaskGroup vg;
auto f = [&](int k) -> Status {
TaskManager::FindMe()->Post();
for (int i = 0; i < 100; i++) {
uint64_t key = k * 100 + i;
std::ostringstream oss;
oss << "Hello World. I am " << key;
Status rc = btree.DoInsert(key, oss.str());
EXPECT_TRUE(rc.IsOk());
}
return Status::OK();
for (int i = 0; i < 100; i++) {
uint64_t key = k * 100 + i;
std::ostringstream oss;
oss << "Hello World. I am " << key;
Status rc = btree.DoInsert(key, oss.str());
EXPECT_TRUE(rc.IsOk());
}
return Status::OK();
};
auto g = [&](int k) -> Status {
TaskManager::FindMe()->Post();
for (int i = 0; i < 1000; i++) {
uint64_t key = rand() % 10000;;
uint64_t key = rand() % 10000;
;
auto it = btree.Search(key);
}
return Status::OK();
@ -226,3 +226,22 @@ TEST_F(MindDataTestBPlusTree, Test4) {
EXPECT_EQ(cnt, 1000);
}
}
TEST_F(MindDataTestBPlusTree, TestPerfNoLocking) {
AutoIndexObj<int64_t> btree;
// No locking test
btree.SetLocking(false);
// Insert a million entries using the default traits.
for (auto i = 0; i < 1000000; ++i) {
ASSERT_TRUE(btree.insert(i));
}
std::cout << "Tree height : " << btree.GetHeight() << std::endl;
std::cout << "Tree Order : " << btree.GetOrder() << std::endl;
std::cout << "Number of leaves : " << btree.GetNumLeaves() << std::endl;
std::cout << "Number of inner nodes : " << btree.GetNumInnerNodes() << std::endl;
auto r = btree.Search(3);
EXPECT_TRUE(r.second);
r = btree.Search(999999);
EXPECT_TRUE(r.second);
}

View File

@ -35,6 +35,23 @@ using mindspore::dataset::TaskGroup;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::MsLogLevel::INFO;
// Helper function to get the session id from SESSION_ID env variable
Status GetSessionFromEnv(session_id_type *session_id) {
RETURN_UNEXPECTED_IF_NULL(session_id);
if (const char *session_env = std::getenv("SESSION_ID")) {
std::string session_id_str(session_env);
try {
*session_id = std::stoul(session_id_str);
} catch (const std::exception &e) {
std::string err_msg = "Invalid numeric value for session id in env var: " + session_id_str;
return Status(StatusCode::kSyntaxError, err_msg);
}
} else {
RETURN_STATUS_UNEXPECTED("Test case requires a session id to be provided via SESSION_ID environment variable.");
}
return Status::OK();
}
class MindDataTestCacheOp : public UT::DatasetOpTesting {
public:
void SetUp() override {
@ -46,8 +63,12 @@ class MindDataTestCacheOp : public UT::DatasetOpTesting {
TEST_F(MindDataTestCacheOp, DISABLED_TestCacheServer) {
Status rc;
CacheClient::Builder builder;
session_id_type env_session;
rc = GetSessionFromEnv(&env_session);
ASSERT_TRUE(rc.IsOk());
// use arbitrary session of 1, size of 0, spilling// is true
builder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true);
builder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true);
std::shared_ptr<CacheClient> myClient;
rc = builder.Build(&myClient);
ASSERT_TRUE(rc.IsOk());
@ -118,9 +139,6 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheServer) {
cmp = (map_out == map);
ASSERT_TRUE(cmp);
// Test Purge and Destroy
rc = myClient->PurgeCache();
ASSERT_TRUE(rc.IsOk());
rc = myClient->DestroyCache();
ASSERT_TRUE(rc.IsOk());
}
@ -130,10 +148,15 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestConcurrencyRequest) {
(void)TaskManager::GetMasterThreadRc();
TaskGroup vg;
Status rc;
session_id_type env_session;
rc = GetSessionFromEnv(&env_session);
ASSERT_TRUE(rc.IsOk());
// use arbitrary session of 1, size 1, spilling is true
CacheClient::Builder builder;
// use arbitrary session of 1, size of 0, spilling// is true
builder.SetSessionId(1).SetCacheMemSz(1).SetSpill(true);
builder.SetSessionId(env_session).SetCacheMemSz(1).SetSpill(true);
std::shared_ptr<CacheClient> myClient;
rc = builder.Build(&myClient);
ASSERT_TRUE(rc.IsOk());
@ -199,8 +222,15 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestConcurrencyRequest) {
// RandomDataOp
//
TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) {
// Clear the rc of the master thread if any
(void)TaskManager::GetMasterThreadRc();
Status rc;
int32_t rank = 0; // not used
session_id_type env_session;
rc = GetSessionFromEnv(&env_session);
ASSERT_TRUE(rc.IsOk());
MS_LOG(INFO) << "UT test TestRandomDataCache1";
// Start with an empty execution tree
auto myTree = std::make_shared<ExecutionTree>();
@ -236,8 +266,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) {
// CacheOp
// size of 0, spilling is true
CacheClient::Builder builder;
// use arbitrary session of 1, size of 0, spilling// is true
builder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true);
builder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true);
std::shared_ptr<CacheClient> myClient;
rc = builder.Build(&myClient);
ASSERT_TRUE(rc.IsOk());
@ -273,7 +302,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) {
ASSERT_TRUE(rc.IsOk());
MS_LOG(INFO) << "Launching tree and begin iteration";
rc = myTree->Prepare();
rc = myTree->Prepare(1);
ASSERT_TRUE(rc.IsOk());
// quick check to see what tree looks like
@ -314,9 +343,16 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) {
//// RandomDataOp
////
TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) {
// Clear the rc of the master thread if any
(void)TaskManager::GetMasterThreadRc();
Status rc;
int32_t rank = 0; // not used
MS_LOG(INFO) << "UT test TestRandomDataCacheSpill";
session_id_type env_session;
rc = GetSessionFromEnv(&env_session);
ASSERT_TRUE(rc.IsOk());
// Start with an empty execution tree
auto myTree = std::make_shared<ExecutionTree>();
@ -353,8 +389,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) {
int64_t start_index = 0;
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
CacheClient::Builder builder;
// use arbitrary session of 1, size of 0, spilling// is true
builder.SetSessionId(1).SetCacheMemSz(4).SetSpill(true);
builder.SetSessionId(env_session).SetCacheMemSz(4).SetSpill(true);
std::shared_ptr<CacheClient> myClient;
rc = builder.Build(&myClient);
ASSERT_TRUE(rc.IsOk());
@ -386,7 +421,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) {
ASSERT_TRUE(rc.IsOk());
MS_LOG(INFO) << "Launching tree and begin iteration";
rc = myTree->Prepare();
rc = myTree->Prepare(1);
ASSERT_TRUE(rc.IsOk());
std::cout << *myClient << std::endl;
@ -413,14 +448,20 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) {
}
TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
// Clear the rc of the master thread if any
(void)TaskManager::GetMasterThreadRc();
Status rc;
int64_t num_samples = 0;
int64_t start_index = 0;
session_id_type env_session;
rc = GetSessionFromEnv(&env_session);
ASSERT_TRUE(rc.IsOk());
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
CacheClient::Builder ccbuilder;
// use arbitrary session of 1, size of 0, spilling// is true
ccbuilder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true);
ccbuilder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true);
std::shared_ptr<CacheClient> myClient;
rc = ccbuilder.Build(&myClient);
ASSERT_TRUE(rc.IsOk());
@ -468,7 +509,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
rc = myCacheOp->AddChild(so);
ASSERT_TRUE(rc.IsOk());
rc = myTree->Prepare();
rc = myTree->Prepare(1);
ASSERT_TRUE(rc.IsOk());
rc = myTree->Launch();
ASSERT_TRUE(rc.IsOk());
@ -507,10 +548,16 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
//// RandomDataOp
////
TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) {
// Clear the rc of the master thread if any
(void)TaskManager::GetMasterThreadRc();
Status rc;
int32_t rank = 0; // not used
MS_LOG(INFO) << "UT test TestCacheInheritSampler";
session_id_type env_session;
rc = GetSessionFromEnv(&env_session);
ASSERT_TRUE(rc.IsOk());
int64_t num_samples = 0;
int64_t start_index = 0;
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
@ -550,7 +597,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) {
// CacheOp
CacheClient::Builder ccbuilder;
// use arbitrary session of 1, size of 0, spilling// is true
ccbuilder.SetSessionId(1).SetCacheMemSz(4).SetSpill(true);
ccbuilder.SetSessionId(env_session).SetCacheMemSz(4).SetSpill(true);
std::shared_ptr<CacheClient> myClient;
rc = ccbuilder.Build(&myClient);
ASSERT_TRUE(rc.IsOk());
@ -577,7 +624,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) {
ASSERT_TRUE(rc.IsOk());
MS_LOG(INFO) << "Launching tree and begin iteration";
rc = myTree->Prepare();
rc = myTree->Prepare(1);
ASSERT_TRUE(rc.IsOk());
std::cout << *myClient << std::endl;

View File

@ -25,13 +25,13 @@ using namespace mindspore::dataset;
class MindDataTestMemoryPool : public UT::Common {
public:
std::shared_ptr<MemoryPool> mp_;
MindDataTestMemoryPool() {}
std::shared_ptr<MemoryPool> mp_;
MindDataTestMemoryPool() {}
void SetUp() {
Status rc = CircularPool::CreateCircularPool(&mp_, 1, 1, true);
ASSERT_TRUE(rc.IsOk());
}
void SetUp() {
Status rc = CircularPool::CreateCircularPool(&mp_, 1, 1, true);
ASSERT_TRUE(rc.IsOk());
}
};
TEST_F(MindDataTestMemoryPool, DumpPoolInfo) {
@ -40,7 +40,7 @@ TEST_F(MindDataTestMemoryPool, DumpPoolInfo) {
TEST_F(MindDataTestMemoryPool, TestOperator1) {
Status rc;
int *p = new(&rc, mp_) int;
int *p = new (&rc, mp_) int;
ASSERT_TRUE(rc.IsOk());
*p = 2048;
::operator delete(p, mp_);
@ -61,12 +61,11 @@ TEST_F(MindDataTestMemoryPool, TestOperator3) {
TEST_F(MindDataTestMemoryPool, TestAllocator) {
class A {
public:
explicit A (int x) : a(x) {}
int val_a() const {
return a;
}
explicit A(int x) : a(x) {}
int val_a() const { return a; }
private:
int a;
int a;
};
Allocator<A> alloc(mp_);
std::shared_ptr<A> obj_a = std::allocate_shared<A>(alloc, 3);
@ -74,3 +73,16 @@ TEST_F(MindDataTestMemoryPool, TestAllocator) {
ASSERT_EQ(v, 3);
MS_LOG(DEBUG) << *(std::dynamic_pointer_cast<CircularPool>(mp_)) << std::endl;
}
TEST_F(MindDataTestMemoryPool, TestMemGuard) {
MemGuard<uint8_t> mem;
// Try some large value.
int64_t sz = 5LL * 1024LL * 1024LL * 1024LL;
Status rc = mem.allocate(sz);
ASSERT_TRUE(rc.IsOk() || rc.IsOutofMemory());
if (rc.IsOk()) {
// Try write a character half way.
auto *p = mem.GetMutablePointer();
p[sz / 2] = 'a';
}
}

View File

@ -0,0 +1,48 @@
#!/bin/bash
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# This script is the driver of the individual test scenarios
CURRPATH=$(cd $(dirname $0); pwd)
echo "----------------------------------------------"
echo "Invalid syntax and cache_admin failure testing"
echo "----------------------------------------------"
echo
${CURRPATH}/cachetest_args.sh
num_failures=$?
echo
echo "Invalid syntax and cache_admin failure testing complete. Number of failures: $num_failures"
echo
echo "----------------------------------------------"
echo "Test pipelines with cache (python)"
echo "----------------------------------------------"
echo
${CURRPATH}/cachetest_py.sh
num_failures=$?
echo
echo "Test pipelines with cache complete. Number of failures: $num_failures"
echo
echo "----------------------------------------------"
echo "Cache cpp tests"
echo "----------------------------------------------"
echo
${CURRPATH}/cachetest_cpp.sh
num_failures=$?
echo
echo "Cache cpp tests complete. Number of failures: $num_failures"
echo

View File

@ -0,0 +1,207 @@
#!/bin/bash
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# source the globals and functions for use with cache testing
SKIP_ADMIN_COUNTER=false
. cachetest_lib.sh
echo
################################################################################
# Cache testing: cache_admin argument testing #
# Summary: Various tests that expect to get failure messages returned #
################################################################################
# Double-command test
cmd="${CACHE_ADMIN} --start --stop"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
# missing command test
cmd="${CACHE_ADMIN} --port 50082"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
# bad arg test
cmd="${CACHE_ADMIN} -p abc --start"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
# missing arg test
cmd="${CACHE_ADMIN} -p --start"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
# invalid command
cmd="${CACHE_ADMIN} -p 50082 --start --not_exist_cmd"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
# spill directory does not exist
cmd="${CACHE_ADMIN} --start --spilldir /path_that_does_not_exist"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
# start cache server twice
StartServer
HandleRcExit $? 1 1
# start the cache server again, however, this time we expect an error
cmd="${CACHE_ADMIN} --start"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 1
StopServer
HandleRcExit $? 1 1
# start cache server twice with different ports
# this one starts with the default port 50052
StartServer
HandleRcExit $? 1 1
# this one starts with port 50053
cmd="${CACHE_ADMIN} --start -p 50053"
CacheAdminCmd "${cmd}" 0
HandleRcExit $? 1 1
# stop the cache server with default port
StopServer
HandleRcExit $? 1 1
# stop the cache server with port 50053
cmd="${CACHE_ADMIN} --stop -p 50053"
CacheAdminCmd "${cmd}" 0
HandleRcExit $? 1 1
# stop the cache server without bringing it up
cmd="${CACHE_ADMIN} --stop"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 1
# start the cache server with illegal hostname
cmd="${CACHE_ADMIN} --start -h 0.0.0.0"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 1
cmd="${CACHE_ADMIN} --start -h illegal"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 1
cmd="${CACHE_ADMIN} --start -h"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 1
cmd="${CACHE_ADMIN} --start -h --hostname"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 1
cmd="${CACHE_ADMIN} --start -h --hostname 127.0.0.1"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 1
# start the cache server with illegal port
cmd="${CACHE_ADMIN} --start -p 0"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 1
cmd="${CACHE_ADMIN} --start -p -1"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 1
cmd="${CACHE_ADMIN} --start -p 65536"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 1
cmd="${CACHE_ADMIN} --start -p illegal"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 1
cmd="${CACHE_ADMIN} --start -p"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 1
# find a port that is occupied using netstat
if [ -x "$(command -v netstat)" ]; then
port=$(netstat -ntp | grep -v '::' | awk '{print $4}' | grep -E '^[[:digit:]]+' | awk -F: '{print $2}' | sort -n | tail -n 1)
if [ ${port} -gt 1025 ]; then
# start cache server with occupied port
cmd="${CACHE_ADMIN} --start -p ${port}"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 1
fi
fi
# generate session before starting the cache server
cmd="${CACHE_ADMIN} -g"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
# illegal generate session command
StartServer
HandleRcExit $? 1 1
cmd="${CACHE_ADMIN} -g 1"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
# illegal destroy session command
cmd="${CACHE_ADMIN} -d -2"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
cmd="${CACHE_ADMIN} -d illegal"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
cmd="${CACHE_ADMIN} -d"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
# destroy a non-existing session
cmd="${CACHE_ADMIN} -d 99999"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
# stop cache server at this point
StopServer
HandleRcExit $? 1 1
# illegal number of workers
cmd="${CACHE_ADMIN} --start -w 0"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
cmd="${CACHE_ADMIN} --start -w -1"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
cmd="${CACHE_ADMIN} --start -w illegal"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
cmd="${CACHE_ADMIN} --start -w 101"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
cmd="${CACHE_ADMIN} --start -w 9999999"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
cmd="${CACHE_ADMIN} --start -w"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
# illegal spill path
cmd="${CACHE_ADMIN} --start -s"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
# spill path without writing perm
if [ "$EUID" -ne 0 ]; then
cmd="${CACHE_ADMIN} --start -s /"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
fi
# illegal log level
cmd="${CACHE_ADMIN} --start -l 4"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
cmd="${CACHE_ADMIN} --start -l -1"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
cmd="${CACHE_ADMIN} --start -l"
CacheAdminCmd "${cmd}" 1
HandleRcExit $? 0 0
exit ${failed_tests}

Some files were not shown because too many files have changed in this diff Show More