!17541 Fix MD Cache code check warnings [master]

From: @lixiachen
Reviewed-by: @pandoublefeng,@robingrosman
Signed-off-by: @robingrosman
This commit is contained in:
mindspore-ci-bot 2021-06-02 04:53:54 +08:00 committed by Gitee
commit ae2e7c1288
42 changed files with 240 additions and 228 deletions

View File

@ -28,6 +28,7 @@
#include "mindspore/lite/src/common/log_adapter.h"
#endif
#include "minddata/dataset/util/system_pool.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace dataset {
@ -53,14 +54,14 @@ ConfigManager::ConfigManager()
enable_shared_mem_(true) {
num_cpu_threads_ = num_cpu_threads_ > 0 ? num_cpu_threads_ : std::numeric_limits<uint16_t>::max();
num_parallel_workers_ = num_parallel_workers_ < num_cpu_threads_ ? num_parallel_workers_ : num_cpu_threads_;
auto env_cache_host = std::getenv("MS_CACHE_HOST");
auto env_cache_port = std::getenv("MS_CACHE_PORT");
if (env_cache_host != nullptr) {
std::string env_cache_host = common::GetEnv("MS_CACHE_HOST");
std::string env_cache_port = common::GetEnv("MS_CACHE_PORT");
if (!env_cache_host.empty()) {
cache_host_ = env_cache_host;
}
if (env_cache_port != nullptr) {
if (!env_cache_port.empty()) {
char *end = nullptr;
cache_port_ = strtol(env_cache_port, &end, 10);
cache_port_ = static_cast<int32_t>(strtol(env_cache_port.c_str(), &end, kDecimal));
if (*end != '\0') {
MS_LOG(WARNING) << "Cache port from env variable MS_CACHE_PORT is invalid\n";
cache_port_ = 0; // cause the port range validation to generate an error during the validation checks

View File

@ -33,37 +33,34 @@
namespace mindspore {
namespace dataset {
const int32_t CacheAdminArgHandler::kDefaultNumWorkers = std::thread::hardware_concurrency() > 2
? std::thread::hardware_concurrency() / 2
: 1;
const char CacheAdminArgHandler::kServerBinary[] = "cache_server";
CacheAdminArgHandler::CacheAdminArgHandler()
: port_(kCfgDefaultCachePort),
num_workers_(kDefaultNumWorkers),
shm_mem_sz_(kDefaultSharedMemorySizeInGB),
shm_mem_sz_(kDefaultSharedMemorySize),
log_level_(kDefaultLogLevel),
memory_cap_ratio_(kMemoryCapRatio),
memory_cap_ratio_(kDefaultMemoryCapRatio),
hostname_(kCfgDefaultCacheHost),
spill_dir_(""),
command_id_(CommandId::kCmdUnknown) {
const char *env_cache_host = std::getenv("MS_CACHE_HOST");
const char *env_cache_port = std::getenv("MS_CACHE_PORT");
if (env_cache_host != nullptr) {
std::string env_cache_host = common::GetEnv("MS_CACHE_HOST");
std::string env_cache_port = common::GetEnv("MS_CACHE_PORT");
if (!env_cache_host.empty()) {
hostname_ = env_cache_host;
}
if (env_cache_port != nullptr) {
if (!env_cache_port.empty()) {
char *end = nullptr;
port_ = strtol(env_cache_port, &end, 10);
port_ = static_cast<int32_t>(strtol(env_cache_port.c_str(), &end, kDecimal));
if (*end != '\0') {
std::cerr << "Cache port from env variable MS_CACHE_PORT is invalid\n";
port_ = 0; // cause the port range validation to generate an error during the validation checks
}
}
const char *env_log_level = std::getenv("GLOG_v");
if (env_log_level != nullptr) {
std::string env_log_level = common::GetEnv("GLOG_v");
if (!env_log_level.empty()) {
char *end = nullptr;
log_level_ = strtol(env_log_level, &end, 10);
log_level_ = static_cast<int32_t>(strtol(env_log_level.c_str(), &end, kDecimal));
if (*end != '\0') {
std::cerr << "Log level from env variable GLOG_v is invalid\n";
log_level_ = -1; // cause the log level range validation to generate an error during the validation checks
@ -377,15 +374,17 @@ Status CacheAdminArgHandler::Validate() {
}
// Additional checks here
auto max_num_workers = std::max<int32_t>(std::thread::hardware_concurrency(), 100);
auto max_num_workers = std::max<int32_t>(std::thread::hardware_concurrency(), kMaxNumWorkers);
if (used_args_[ArgValue::kArgNumWorkers] && (num_workers_ < 1 || num_workers_ > max_num_workers))
// Check the value of num_workers only if it's provided by users.
return Status(StatusCode::kMDSyntaxError,
"Number of workers must be in range of 1 and " + std::to_string(max_num_workers) + ".");
if (log_level_ < 0 || log_level_ > 4) return Status(StatusCode::kMDSyntaxError, "Log level must be in range (0..4).");
if (log_level_ < MsLogLevel::DEBUG || log_level_ > MsLogLevel::EXCEPTION)
return Status(StatusCode::kMDSyntaxError, "Log level must be in range (0..4).");
if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1)
return Status(StatusCode::kMDSyntaxError, "Memory cap ratio should be positive and no greater than 1");
if (port_ < 1025 || port_ > 65535) return Status(StatusCode::kMDSyntaxError, "Port must be in range (1025..65535).");
if (port_ < kMinLegalPort || port_ > kMaxLegalPort)
return Status(StatusCode::kMDSyntaxError, "Port must be in range (1025..65535).");
return Status::OK();
}
@ -542,7 +541,7 @@ Status CacheAdminArgHandler::StopServer(CommandId command_id) {
// The server will send a message back and remove the queue and we will then wake up. But on the safe
// side, we will also set up an alarm and kill this process if we hang on
// the message queue.
alarm(60);
(void)alarm(kAlarmDeadline);
Status dummy_rc;
(void)msg.ReceiveStatus(&dummy_rc);
std::cout << "Cache server on port " << std::to_string(port_) << " has been stopped successfully." << std::endl;
@ -579,8 +578,7 @@ Status CacheAdminArgHandler::StartServer(CommandId command_id) {
}
// fork the child process to become the daemon
pid_t pid;
pid = fork();
pid_t pid = fork();
// failed to fork
if (pid < 0) {
std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno);
@ -588,7 +586,7 @@ Status CacheAdminArgHandler::StartServer(CommandId command_id) {
} else if (pid > 0) {
// As a parent, we close the write end. We only listen.
close(fd[1]);
dup2(fd[0], 0);
(void)dup2(fd[0], STDIN_FILENO);
close(fd[0]);
int status;
if (waitpid(pid, &status, 0) == -1) {
@ -616,11 +614,11 @@ Status CacheAdminArgHandler::StartServer(CommandId command_id) {
} else {
// Child here ...
// Close all stdin, redirect stdout and stderr to the write end of the pipe.
close(fd[0]);
dup2(fd[1], 1);
dup2(fd[1], 2);
close(0);
close(fd[1]);
(void)close(fd[0]);
(void)dup2(fd[1], STDOUT_FILENO);
(void)dup2(fd[1], STDERR_FILENO);
(void)close(STDIN_FILENO);
(void)close(fd[1]);
// exec the cache server binary in this process
// If the user did not provide the value of num_workers, we pass -1 to cache server to allow it assign the default.
// So that the server knows if the number is provided by users or by default.

View File

@ -31,10 +31,8 @@ namespace dataset {
class CacheAdminArgHandler {
public:
static const int32_t kDefaultNumWorkers;
static constexpr int32_t kDefaultSharedMemorySizeInGB = 4;
static constexpr int32_t kDefaultLogLevel = 1;
static constexpr float kMemoryCapRatio = 0.8;
static constexpr int32_t kAlarmDeadline = 60;
static constexpr int32_t kMaxNumWorkers = 100;
static const char kServerBinary[];
// These are the actual command types to execute

View File

@ -41,7 +41,8 @@ Status CachedSharedMemory::Init() {
// We will create a number of sub pool out of shared memory to reduce latch contention
int32_t num_of_pools = num_numa_nodes_;
if (num_numa_nodes_ == 1) {
num_of_pools = shared_memory_sz_in_gb_ * 2;
constexpr int32_t kNumPoolMultiplier = 2;
num_of_pools = shared_memory_sz_in_gb_ * kNumPoolMultiplier;
}
sub_pool_sz_ = shm_mem_sz / num_of_pools;
// If each subpool is too small, readjust the number of pools

View File

@ -46,8 +46,8 @@ Status CacheClient::Builder::SanityCheck() {
CHECK_FAIL_RETURN_SYNTAX_ERROR(num_connections_ > 0, "number of tcp/ip connections must be positive.");
CHECK_FAIL_RETURN_SYNTAX_ERROR(prefetch_size_ > 0, "prefetch size must be positive.");
CHECK_FAIL_RETURN_SYNTAX_ERROR(!hostname_.empty(), "hostname must not be empty.");
CHECK_FAIL_RETURN_SYNTAX_ERROR(port_ > 1024, "Port must be in range (1025..65535).");
CHECK_FAIL_RETURN_SYNTAX_ERROR(port_ <= 65535, "Port must be in range (1025..65535).");
CHECK_FAIL_RETURN_SYNTAX_ERROR(port_ >= kMinLegalPort, "Port must be in range (1025..65535).");
CHECK_FAIL_RETURN_SYNTAX_ERROR(port_ <= kMaxLegalPort, "Port must be in range (1025..65535).");
CHECK_FAIL_RETURN_SYNTAX_ERROR(hostname_ == "127.0.0.1",
"now cache client has to be on the same host with cache server.");
return Status::OK();
@ -103,6 +103,9 @@ void CacheClient::Print(std::ostream &out) const {
<< SupportLocalClient();
}
std::string CacheClient::GetHostname() const { return comm_->GetHostname(); }
int32_t CacheClient::GetPort() const { return comm_->GetPort(); }
Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const {
auto rq = std::make_shared<CacheRowRequest>(this);
RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(this, row));
@ -426,36 +429,35 @@ Status CacheClient::AsyncBufferStream::SyncFlush(AsyncFlushFlag flag) {
asyncWriter->rq.reset(
new BatchCacheRowsRequest(cc_, offset_addr_ + cur_ * kAsyncBufferSize, asyncWriter->num_ele_));
flush_rc_ = cc_->PushRequest(asyncWriter->rq);
if (flush_rc_.IsOk()) {
// If we are asked to wait, say this is the final flush, just wait for its completion.
bool blocking = (flag & AsyncFlushFlag::kFlushBlocking) == AsyncFlushFlag::kFlushBlocking;
if (blocking) {
// Make sure we are done with all the buffers
for (auto i = 0; i < kNumAsyncBuffer; ++i) {
if (buf_arr_[i].rq) {
Status rc = buf_arr_[i].rq->Wait();
if (rc.IsError()) {
flush_rc_ = rc;
}
buf_arr_[i].rq.reset();
}
RETURN_IF_NOT_OK(flush_rc_);
// If we are asked to wait, say this is the final flush, just wait for its completion.
bool blocking = (flag & AsyncFlushFlag::kFlushBlocking) == AsyncFlushFlag::kFlushBlocking;
if (blocking) {
// Make sure we are done with all the buffers
for (auto i = 0; i < kNumAsyncBuffer; ++i) {
if (buf_arr_[i].rq) {
Status rc = buf_arr_[i].rq->Wait();
if (rc.IsError()) flush_rc_ = rc;
buf_arr_[i].rq.reset();
}
}
// Prepare for the next buffer.
cur_ = (cur_ + 1) % kNumAsyncBuffer;
asyncWriter = &buf_arr_[cur_];
// Update the cur_ while we have the lock.
// Before we do anything, make sure the cache server has done with this buffer, or we will corrupt its content
// Also we can also pick up any error from previous flush.
if (asyncWriter->rq) {
// Save the result into a common area, so worker can see it and quit.
flush_rc_ = asyncWriter->rq->Wait();
asyncWriter->rq.reset();
}
asyncWriter->bytes_avail_ = kAsyncBufferSize;
asyncWriter->num_ele_ = 0;
}
// Prepare for the next buffer.
cur_ = (cur_ + 1) % kNumAsyncBuffer;
asyncWriter = &buf_arr_[cur_];
// Update the cur_ while we have the lock.
// Before we do anything, make sure the cache server has done with this buffer, or we will corrupt its content
// Also we can also pick up any error from previous flush.
if (asyncWriter->rq) {
// Save the result into a common area, so worker can see it and quit.
flush_rc_ = asyncWriter->rq->Wait();
asyncWriter->rq.reset();
}
asyncWriter->bytes_avail_ = kAsyncBufferSize;
asyncWriter->num_ele_ = 0;
}
return flush_rc_;
}

View File

@ -232,6 +232,8 @@ class CacheClient {
int32_t GetNumConnections() const { return num_connections_; }
int32_t GetPrefetchSize() const { return prefetch_size_; }
int32_t GetClientId() const { return client_id_; }
std::string GetHostname() const;
int32_t GetPort() const;
/// MergeOp will notify us when the server can't cache any more rows.
/// We will stop any attempt to fetch any rows that are most likely

View File

@ -24,6 +24,7 @@
#include <grpcpp/grpcpp.h>
#endif
#include <string>
#include <thread>
#ifdef ENABLE_CACHE
#include "proto/cache_grpc.grpc.pb.h"
#endif
@ -41,6 +42,12 @@ constexpr static int32_t kLocalByPassThreshold = 64 * 1024;
constexpr static int32_t kDefaultSharedMemorySize = 4;
/// \brief Memory Cap ratio used by the server
constexpr static float kDefaultMemoryCapRatio = 0.8;
/// \brief Default log level of the server
constexpr static int32_t kDefaultLogLevel = 1;
/// \brief Set num workers to half of num_cpus as the default
static const int32_t kDefaultNumWorkers = std::thread::hardware_concurrency() > 2
? std::thread::hardware_concurrency() / 2
: 1;
/// \brief A flag used by the BatchFetch request (client side) if it can support local bypass
constexpr static uint32_t kLocalClientSupport = 1;
/// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is

View File

@ -26,7 +26,7 @@ CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port
// message limit is 4MB which is not big enough.
args.SetMaxReceiveMessageSize(-1);
MS_LOG(INFO) << "Hostname: " << hostname_ << ", port: " << std::to_string(port_);
#if CACHE_LOCAL_CLIENT
#ifdef CACHE_LOCAL_CLIENT
// Try connect locally to the unix_socket first as the first preference
// Need to resolve hostname to ip address rather than to do a string compare
if (hostname == "127.0.0.1") {
@ -36,7 +36,7 @@ CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port
#endif
std::string target = hostname + ":" + std::to_string(port);
channel_ = grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args);
#if CACHE_LOCAL_CLIENT
#ifdef CACHE_LOCAL_CLIENT
}
#endif
stub_ = CacheServerGreeter::NewStub(channel_);
@ -44,7 +44,7 @@ CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port
Status CacheClientGreeter::AttachToSharedMemory(bool *local_bypass) {
*local_bypass = false;
#if CACHE_LOCAL_CLIENT
#ifdef CACHE_LOCAL_CLIENT
SharedMemory::shm_key_t shm_key;
RETURN_IF_NOT_OK(PortToFtok(port_, &shm_key));
// Attach to the shared memory
@ -85,7 +85,7 @@ Status CacheClientGreeter::HandleRequest(std::shared_ptr<BaseRequest> rq) {
auto seqNo = request_cnt_.fetch_add(1);
auto tag = std::make_unique<CacheClientRequestTag>(std::move(rq), seqNo);
// One minute timeout
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60);
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(kRequestTimeoutDeadlineInSec);
tag->ctx_.set_deadline(deadline);
tag->rpc_ = stub_->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, &cq_);
tag->rpc_->StartCall();
@ -108,7 +108,7 @@ Status CacheClientGreeter::WorkerEntry() {
do {
bool success;
void *tag;
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1);
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(kWaitForNewEventDeadlineInSec);
// Set a timeout for one second. Check for interrupt if we need to do early exit.
auto r = cq_.AsyncNext(&tag, &success, deadline);
if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) {

View File

@ -59,6 +59,8 @@ class CacheClientGreeter : public Service {
friend class CacheClient;
public:
constexpr static int32_t kRequestTimeoutDeadlineInSec = 60;
constexpr static int32_t kWaitForNewEventDeadlineInSec = 1;
explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_connections);
~CacheClientGreeter();
@ -86,6 +88,9 @@ class CacheClientGreeter : public Service {
/// \return Base address of the shared memory.
const void *SharedMemoryBaseAddr() const { return mem_.SharedMemoryBaseAddr(); }
std::string GetHostname() const { return hostname_; }
int32_t GetPort() const { return port_; }
private:
std::shared_ptr<grpc::Channel> channel_;
std::unique_ptr<CacheServerGreeter::Stub> stub_;

View File

@ -56,7 +56,7 @@ Status CacheServerGreeterImpl::Run() {
// Default message size for gRPC is 4MB. Increase it to 2g-1
builder.SetMaxReceiveMessageSize(std::numeric_limits<int32_t>::max());
int port_tcpip = 0;
#if CACHE_LOCAL_CLIENT
#ifdef CACHE_LOCAL_CLIENT
int port_local = 0;
// We also optimize on local clients on the same machine using unix socket
builder.AddListeningPort("unix://" + unix_socket_, grpc::InsecureServerCredentials(), &port_local);
@ -72,7 +72,7 @@ Status CacheServerGreeterImpl::Run() {
if (port_tcpip != port_) {
errMsg += "Unable to bind to tcpip port " + std::to_string(port_) + ".";
}
#if CACHE_LOCAL_CLIENT
#ifdef CACHE_LOCAL_CLIENT
if (port_local == 0) {
errMsg += " Unable to create unix socket " + unix_socket_ + ".";
}
@ -176,7 +176,7 @@ void CacheServerRequest::Print(std::ostream &out) const {
Status CacheServerGreeterImpl::MonitorUnixSocket() {
TaskManager::FindMe()->Post();
#if CACHE_LOCAL_CLIENT
#ifdef CACHE_LOCAL_CLIENT
Path p(unix_socket_);
do {
RETURN_IF_INTERRUPTED();
@ -197,7 +197,7 @@ Status CacheServerGreeterImpl::MonitorUnixSocket() {
MS_LOG(WARNING) << "Unix socket is removed.";
TaskManager::WakeUpWatchDog();
}
std::this_thread::sleep_for(std::chrono::seconds(5));
std::this_thread::sleep_for(std::chrono::seconds(kMonitorIntervalInSec));
} while (true);
#endif
return Status::OK();

View File

@ -70,6 +70,7 @@ class CacheServerGreeterImpl final {
friend class CacheServer;
public:
constexpr static int32_t kMonitorIntervalInSec = 5;
explicit CacheServerGreeterImpl(int32_t port);
virtual ~CacheServerGreeterImpl();
/// \brief Brings up gRPC server

View File

@ -101,7 +101,7 @@ Status CacheServerHW::GetNumaNodeInfo() {
auto p = it->next();
const std::string entry = p.Basename();
const char *name = entry.data();
if (strncmp(name, kNodeName, 4) == 0 && isdigit_string(name + strlen(kNodeName))) {
if (strncmp(name, kNodeName, strlen(kNodeName)) == 0 && isdigit_string(name + strlen(kNodeName))) {
numa_nodes_.insert(p);
}
}
@ -116,7 +116,7 @@ Status CacheServerHW::GetNumaNodeInfo() {
auto r = std::regex("[0-9]*-[0-9]*");
for (Path p : numa_nodes_) {
auto node_dir = p.Basename();
numa_id_t numa_node = strtol(node_dir.data() + strlen(kNodeName), nullptr, 10);
numa_id_t numa_node = static_cast<numa_id_t>(strtol(node_dir.data() + strlen(kNodeName), nullptr, kDecimal));
Path f = p / kCpuList;
std::ifstream fs(f.toString());
CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "Fail to open file: " + f.toString());
@ -134,8 +134,8 @@ Status CacheServerHW::GetNumaNodeInfo() {
CHECK_FAIL_RETURN_UNEXPECTED(pos != std::string::npos, "Failed to parse numa node file");
std::string min = match.substr(0, pos);
std::string max = match.substr(pos + 1);
cpu_id_t cpu_min = strtol(min.data(), nullptr, 10);
cpu_id_t cpu_max = strtol(max.data(), nullptr, 10);
cpu_id_t cpu_min = static_cast<cpu_id_t>(strtol(min.data(), nullptr, kDecimal));
cpu_id_t cpu_max = static_cast<cpu_id_t>(strtol(max.data(), nullptr, kDecimal));
MS_LOG(DEBUG) << "Numa node " << numa_node << " CPU(s) : " << cpu_min << "-" << cpu_max;
for (int i = cpu_min; i <= cpu_max; ++i) {
CPU_SET(i, &cpuset);

View File

@ -22,6 +22,7 @@
#include <chrono>
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/engine/cache/cache_ipc.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "mindspore/core/utils/log_adapter.h"
namespace ms = mindspore;
namespace ds = mindspore::dataset;
@ -32,19 +33,30 @@ namespace ds = mindspore::dataset;
ms::Status StartServer(int argc, char **argv) {
ms::Status rc;
ds::CacheServer::Builder builder;
if (argc != 8) {
const int32_t kTotalArgs = 8;
enum {
kProcessNameIdx = 0,
kRootDirArgIdx = 1,
kNumWorkersArgIdx = 2,
kPortArgIdx = 3,
kSharedMemorySizeArgIdx = 4,
kLogLevelArgIdx = 5,
kDemonizeArgIdx = 6,
kMemoryCapRatioArgIdx = 7
};
if (argc != kTotalArgs) {
return ms::Status(ms::StatusCode::kMDSyntaxError);
}
int32_t port = strtol(argv[3], nullptr, 10);
builder.SetRootDirectory(argv[1])
.SetNumWorkers(strtol(argv[2], nullptr, 10))
int32_t port = static_cast<int32_t>(strtol(argv[kPortArgIdx], nullptr, ds::kDecimal));
builder.SetRootDirectory(argv[kRootDirArgIdx])
.SetNumWorkers(static_cast<int32_t>(strtol(argv[kNumWorkersArgIdx], nullptr, ds::kDecimal)))
.SetPort(port)
.SetSharedMemorySizeInGB(strtol(argv[4], nullptr, 10))
.SetLogLevel(strtol(argv[5], nullptr, 10))
.SetMemoryCapRatio(strtof(argv[7], nullptr));
.SetSharedMemorySizeInGB(static_cast<int32_t>(strtol(argv[kSharedMemorySizeArgIdx], nullptr, ds::kDecimal)))
.SetLogLevel(static_cast<int8_t>((strtol(argv[kLogLevelArgIdx], nullptr, ds::kDecimal))))
.SetMemoryCapRatio(strtof(argv[kMemoryCapRatioArgIdx], nullptr));
auto daemonize_string = argv[6];
auto daemonize_string = argv[kDemonizeArgIdx];
bool daemonize = strcmp(daemonize_string, "true") == 0 || strcmp(daemonize_string, "TRUE") == 0 ||
strcmp(daemonize_string, "t") == 0 || strcmp(daemonize_string, "T") == 0;
@ -68,8 +80,9 @@ ms::Status StartServer(int argc, char **argv) {
if (rc.IsError()) {
return rc;
}
ms::g_ms_submodule_log_levels[SUBMODULE_ID] = strtol(argv[5], nullptr, 10);
google::InitGoogleLogging(argv[0]);
ms::g_ms_submodule_log_levels[SUBMODULE_ID] =
static_cast<int>(strtol(argv[kLogLevelArgIdx], nullptr, ds::kDecimal));
google::InitGoogleLogging(argv[kProcessNameIdx]);
#undef google
#endif
rc = msg.Create();
@ -94,9 +107,8 @@ ms::Status StartServer(int argc, char **argv) {
}
if (child_rc.IsError()) {
return child_rc;
} else {
warning_string = child_rc.ToString();
}
warning_string = child_rc.ToString();
std::cout << "Cache server startup completed successfully!\n";
std::cout << "The cache server daemon has been created as process id " << pid << " and listening on port " << port
<< ".\n";
@ -116,9 +128,9 @@ ms::Status StartServer(int argc, char **argv) {
std::string errMsg = "Failed to setsid(). Errno = " + std::to_string(errno);
return ms::Status(ms::StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
}
close(0);
close(1);
close(2);
(void)close(STDIN_FILENO);
(void)close(STDOUT_FILENO);
(void)close(STDERR_FILENO);
}
}

View File

@ -69,8 +69,7 @@ Status CacheRowRequest::SerializeCacheRowRequest(const CacheClient *cc, const Te
WritableSlice all(p, sz_);
auto offset = fbb->GetSize();
ReadableSlice header(fbb->GetBufferPointer(), fbb->GetSize());
Status copy_rc;
copy_rc = WritableSlice::Copy(&all, header);
Status copy_rc = WritableSlice::Copy(&all, header);
if (copy_rc.IsOk()) {
for (const auto &ts : row) {
WritableSlice row_data(all, offset, ts->SizeInBytes());
@ -108,7 +107,7 @@ Status CacheRowRequest::SerializeCacheRowRequest(const CacheClient *cc, const Te
Status CacheRowRequest::PostReply() {
if (!reply_.result().empty()) {
row_id_from_server_ = strtoll(reply_.result().data(), nullptr, 10);
row_id_from_server_ = strtoll(reply_.result().data(), nullptr, kDecimal);
}
return Status::OK();
}
@ -116,11 +115,13 @@ Status CacheRowRequest::PostReply() {
Status CacheRowRequest::Prepare() {
if (BitTest(rq_.flag(), kDataIsInSharedMemory)) {
// First one is cookie, followed by address and then size.
CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() == 3, "Incomplete rpc data");
constexpr int32_t kExpectedBufDataSize = 3;
CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() == kExpectedBufDataSize, "Incomplete rpc data");
} else {
// First one is cookie. 2nd one is the google flat buffers followed by a number of buffers.
// But we are not going to decode them to verify.
CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() >= 3, "Incomplete rpc data");
constexpr int32_t kMinBufDataSize = 3;
CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() >= kMinBufDataSize, "Incomplete rpc data");
}
return Status::OK();
}
@ -161,7 +162,7 @@ Status BatchFetchRequest::RestoreRows(TensorTable *out, const void *baseAddr, in
auto flag = reply_.flag();
bool dataOnSharedMemory = support_local_bypass_ ? (BitTest(flag, kDataIsInSharedMemory)) : false;
if (dataOnSharedMemory) {
auto addr = strtoll(reply_.result().data(), nullptr, 10);
auto addr = strtoll(reply_.result().data(), nullptr, kDecimal);
ptr = reinterpret_cast<const char *>(reinterpret_cast<int64_t>(baseAddr) + addr);
RETURN_UNEXPECTED_IF_NULL(out);
*out_addr = addr;

View File

@ -446,7 +446,7 @@ class AllocateSharedBlockRequest : public BaseRequest {
/// the free block is located.
/// \return
int64_t GetAddr() {
auto addr = strtoll(reply_.result().data(), nullptr, 10);
auto addr = strtoll(reply_.result().data(), nullptr, kDecimal);
return addr;
}
};

View File

@ -74,7 +74,7 @@ Status CacheServer::DoServiceStart() {
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
#if CACHE_LOCAL_CLIENT
#ifdef CACHE_LOCAL_CLIENT
RETURN_IF_NOT_OK(CachedSharedMemory::CreateArena(&shm_, port_, shared_memory_sz_in_gb_));
// Bring up a thread to monitor the unix socket in case it is removed. But it must be done
// after we have created the unix socket.
@ -173,7 +173,7 @@ Status CacheServer::GlobalMemoryCheck(uint64_t cache_mem_sz) {
} else if (req_mem == 0) {
// This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than
// 85% of our limit, fail this request.
if (static_cast<float>(max_avail) / static_cast<float>(avail_mem) <= 0.15) {
if (static_cast<float>(max_avail) / static_cast<float>(avail_mem) <= kMemoryBottomLineForNewService) {
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
}
}
@ -331,14 +331,17 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
CacheService *cs = GetService(connection_id);
auto *base = SharedMemoryBaseAddr();
// Ensure we got 3 pieces of data coming in
CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() >= 3, "Incomplete data");
constexpr int32_t kMinBufDataSize = 3;
CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() >= kMinBufDataSize, "Incomplete data");
// First one is cookie, followed by data address and then size.
enum { kCookieIdx = 0, kAddrIdx = 1, kSizeIdx = 2 };
// First piece of data is the cookie and is required
auto &cookie = rq->buf_data(0);
auto &cookie = rq->buf_data(kCookieIdx);
// Second piece of data is the address where we can find the serialized data
auto addr = strtoll(rq->buf_data(1).data(), nullptr, 10);
auto addr = strtoll(rq->buf_data(kAddrIdx).data(), nullptr, kDecimal);
auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr);
// Third piece of data is the size of the serialized data that we need to transfer
auto sz = strtoll(rq->buf_data(2).data(), nullptr, 10);
auto sz = strtoll(rq->buf_data(kSizeIdx).data(), nullptr, kDecimal);
// Successful or not, we need to free the memory on exit.
Status rc;
if (cs == nullptr) {
@ -380,7 +383,9 @@ Status CacheServer::InternalCacheRow(CacheRequest *rq, CacheReply *reply) {
// This is an internal request and is not tied to rpc. But need to post because there
// is a thread waiting on the completion of this request.
try {
int64_t addr = strtol(rq->buf_data(3).data(), nullptr, 10);
constexpr int32_t kBatchWaitIdx = 3;
// Fourth piece of the data is the address of the BatchWait ptr
int64_t addr = strtol(rq->buf_data(kBatchWaitIdx).data(), nullptr, kDecimal);
auto *bw = reinterpret_cast<BatchWait *>(addr);
// Check if the object is still around.
auto bwObj = bw->GetBatchWait();
@ -405,11 +410,13 @@ Status CacheServer::InternalFetchRow(CacheRequest *rq) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
}
rc = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq->buf_data(0).data()));
// First piece is a flatbuffer containing row fetch information, second piece is the address of the BatchWait ptr
enum { kFetchRowMsgIdx = 0, kBatchWaitIdx = 1 };
rc = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq->buf_data(kFetchRowMsgIdx).data()));
// This is an internal request and is not tied to rpc. But need to post because there
// is a thread waiting on the completion of this request.
try {
int64_t addr = strtol(rq->buf_data(1).data(), nullptr, 10);
int64_t addr = strtol(rq->buf_data(kBatchWaitIdx).data(), nullptr, kDecimal);
auto *bw = reinterpret_cast<BatchWait *>(addr);
// Check if the object is still around.
auto bwObj = bw->GetBatchWait();
@ -747,17 +754,20 @@ Status CacheServer::ConnectReset(CacheRequest *rq) {
}
Status CacheServer::BatchCacheRows(CacheRequest *rq) {
CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data().size() == 3, "Expect three pieces of data");
// First one is cookie, followed by address and then size.
enum { kCookieIdx = 0, kAddrIdx = 1, kSizeIdx = 2 };
constexpr int32_t kExpectedBufDataSize = 3;
CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data().size() == kExpectedBufDataSize, "Expect three pieces of data");
try {
auto &cookie = rq->buf_data(0);
auto &cookie = rq->buf_data(kCookieIdx);
auto connection_id = rq->connection_id();
auto client_id = rq->client_id();
int64_t offset_addr;
int32_t num_elem;
auto *base = SharedMemoryBaseAddr();
offset_addr = strtoll(rq->buf_data(1).data(), nullptr, 10);
offset_addr = strtoll(rq->buf_data(kAddrIdx).data(), nullptr, kDecimal);
auto p = reinterpret_cast<char *>(reinterpret_cast<int64_t>(base) + offset_addr);
num_elem = strtol(rq->buf_data(2).data(), nullptr, 10);
num_elem = static_cast<int32_t>(strtol(rq->buf_data(kSizeIdx).data(), nullptr, kDecimal));
auto batch_wait = std::make_shared<BatchWait>(num_elem);
// Get a set of free request and push into the queues.
for (auto i = 0; i < num_elem; ++i) {
@ -1105,7 +1115,7 @@ Status CacheServer::AllocateSharedMemory(CacheRequest *rq, CacheReply *reply) {
auto client_id = rq->client_id();
CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set");
try {
auto requestedSz = strtoll(rq->buf_data(0).data(), nullptr, 10);
auto requestedSz = strtoll(rq->buf_data(0).data(), nullptr, kDecimal);
void *p = nullptr;
RETURN_IF_NOT_OK(AllocateSharedMemory(client_id, requestedSz, &p));
auto *base = SharedMemoryBaseAddr();
@ -1124,7 +1134,7 @@ Status CacheServer::FreeSharedMemory(CacheRequest *rq) {
CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set");
auto *base = SharedMemoryBaseAddr();
try {
auto addr = strtoll(rq->buf_data(0).data(), nullptr, 10);
auto addr = strtoll(rq->buf_data(0).data(), nullptr, kDecimal);
auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr);
DeallocateSharedMemory(client_id, p);
} catch (const std::exception &e) {
@ -1291,11 +1301,11 @@ int32_t CacheServer::Builder::AdjustNumWorkers(int32_t num_workers) {
CacheServer::Builder::Builder()
: top_(""),
num_workers_(std::thread::hardware_concurrency() / 2),
port_(50052),
num_workers_(kDefaultNumWorkers),
port_(kCfgDefaultCachePort),
shared_memory_sz_in_gb_(kDefaultSharedMemorySize),
memory_cap_ratio_(kDefaultMemoryCapRatio),
log_level_(1) {
log_level_(kDefaultLogLevel) {
if (num_workers_ == 0) {
num_workers_ = 1;
}

View File

@ -57,6 +57,8 @@ class CacheServer : public Service {
public:
friend class Services;
using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>;
// Only allow new service to be created if left memory is more than 15% of our hard memory cap
constexpr static float kMemoryBottomLineForNewService = 0.15;
class Builder {
public:

View File

@ -90,8 +90,9 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
if (generate_id_) {
*row_id_generated = GetNextRowId();
// Some debug information on how many rows we have generated so far.
if ((*row_id_generated) % 1000 == 0) {
MS_LOG(DEBUG) << "Number of rows cached: " << (*row_id_generated) + 1;
constexpr int32_t kDisplayInterval = 1000;
if ((*row_id_generated) % kDisplayInterval == 0) {
MS_LOG(DEBUG) << "Number of rows cached: " << ((*row_id_generated) + 1);
}
} else {
if (msg->row_id() < 0) {
@ -159,8 +160,9 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_
if (generate_id_) {
*row_id_generated = GetNextRowId();
// Some debug information on how many rows we have generated so far.
if ((*row_id_generated) % 1000 == 0) {
MS_LOG(DEBUG) << "Number of rows cached: " << (*row_id_generated) + 1;
constexpr int32_t kDisplayInterval = 1000;
if ((*row_id_generated) % kDisplayInterval == 0) {
MS_LOG(DEBUG) << "Number of rows cached: " << ((*row_id_generated) + 1);
}
} else {
auto msg = GetTensorRowHeaderMsg(src.GetPointer());

View File

@ -15,7 +15,6 @@
*/
#include "minddata/dataset/engine/cache/perf/cache_perf_run.h"
#include <string.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/wait.h>
@ -24,6 +23,7 @@
#include <unistd.h>
#include <algorithm>
#include <chrono>
#include <cstring>
#include <iomanip>
#include <sstream>
#include "minddata/dataset/util/random.h"
@ -344,11 +344,11 @@ void CachePerfRun::PrintEpochSummary() const {
<< std::setw(14) << "buffer count" << std::setw(18) << "Elapsed time (s)" << std::endl;
for (auto &it : epoch_results_) {
auto epoch_worker_summary = it.second;
std::cout << std::setw(12) << epoch_worker_summary.pipeline() + 1 << std::setw(10) << epoch_worker_summary.worker()
<< std::setw(10) << epoch_worker_summary.min() << std::setw(10) << epoch_worker_summary.max()
<< std::setw(10) << epoch_worker_summary.avg() << std::setw(13) << epoch_worker_summary.med()
<< std::setw(14) << epoch_worker_summary.cnt() << std::setw(18) << epoch_worker_summary.elapse()
<< std::endl;
std::cout << std::setw(12) << (epoch_worker_summary.pipeline() + 1) << std::setw(10)
<< epoch_worker_summary.worker() << std::setw(10) << epoch_worker_summary.min() << std::setw(10)
<< epoch_worker_summary.max() << std::setw(10) << epoch_worker_summary.avg() << std::setw(13)
<< epoch_worker_summary.med() << std::setw(14) << epoch_worker_summary.cnt() << std::setw(18)
<< epoch_worker_summary.elapse() << std::endl;
}
}
@ -463,7 +463,7 @@ Status CachePerfRun::StartPipelines() {
// Call _exit instead of exit because we will hang TaskManager destructor for a forked child process.
_exit(-1);
} else if (pid > 0) {
std::cout << "Pipeline number " << i + 1 << " has been created with process id: " << pid << std::endl;
std::cout << "Pipeline number " << (i + 1) << " has been created with process id: " << pid << std::endl;
pid_lists_.push_back(pid);
} else {
std::string errMsg = "Failed to fork process for cache pipeline: " + std::to_string(errno);

View File

@ -15,7 +15,7 @@
*/
#include "minddata/dataset/engine/cache/perf/cache_pipeline_run.h"
#include <string.h>
#include <cstring>
#include "mindspore/core/utils/log_adapter.h"
namespace ms = mindspore;

View File

@ -15,10 +15,10 @@
*/
#include "minddata/dataset/engine/cache/perf/cache_pipeline_run.h"
#include <string.h>
#include <sys/types.h>
#include <algorithm>
#include <chrono>
#include <cstring>
#include <iomanip>
#include <sstream>
#include "minddata/dataset/core/tensor.h"
@ -182,7 +182,7 @@ Status CachePipelineRun::Run() {
}
// Log a warning level message so we can see it in the log file when it starts.
MS_LOG(WARNING) << "Pipeline number " << my_pipeline_ + 1 << " successfully creating cache service." << std::endl;
MS_LOG(WARNING) << "Pipeline number " << (my_pipeline_ + 1) << " successfully creating cache service." << std::endl;
// Spawn a thread to listen to the parent process
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Queue listener", std::bind(&CachePipelineRun::ListenToParent, this)));
@ -213,7 +213,7 @@ Status CachePipelineRun::RunFirstEpoch() {
if (my_pipeline_ + 1 == num_pipelines_) {
end_row_ = num_rows_ - 1;
}
std::cout << "Pipeline number " << my_pipeline_ + 1 << " row id range: [" << start_row_ << "," << end_row_ << "]"
std::cout << "Pipeline number " << (my_pipeline_ + 1) << " row id range: [" << start_row_ << "," << end_row_ << "]"
<< std::endl;
// Spawn the worker threads.
@ -305,7 +305,7 @@ Status CachePipelineRun::WriterWorkerEntry(int32_t worker_id) {
auto end_tick = std::chrono::steady_clock::now();
if (rc.IsError()) {
if (rc == StatusCode::kMDOutOfMemory || rc == StatusCode::kMDNoSpace) {
MS_LOG(WARNING) << "Pipeline number " << my_pipeline_ + 1 << " worker id " << worker_id << ": "
MS_LOG(WARNING) << "Pipeline number " << (my_pipeline_ + 1) << " worker id " << worker_id << ": "
<< rc.ToString();
resource_err = true;
cc_->ServerRunningOutOfResources();

View File

@ -55,7 +55,7 @@ Status StorageManager::AddOneContainer(int replaced_container_pos) {
}
Status StorageManager::DoServiceStart() {
containers_.reserve(1000);
containers_.reserve(kMaxNumContainers);
writable_containers_pool_.reserve(pool_size_);
if (root_.IsDirectory()) {
// create multiple containers and store their index in a pool

View File

@ -47,6 +47,7 @@ class StorageManager : public Service {
using value_type = std::pair<int, std::pair<off_t, size_t>>;
using storage_index = AutoIndexObj<value_type, std::allocator<value_type>, StorageBPlusTreeTraits>;
using key_type = storage_index::key_type;
constexpr static int32_t kMaxNumContainers = 1000;
explicit StorageManager(const Path &);

View File

@ -35,6 +35,8 @@ class CacheClientGreeter : public Service {
void *SharedMemoryBaseAddr() { return nullptr; }
Status HandleRequest(std::shared_ptr<BaseRequest> rq) { RETURN_STATUS_UNEXPECTED("Not supported"); }
Status AttachToSharedMemory(bool *local_bypass) { RETURN_STATUS_UNEXPECTED("Not supported"); }
std::string GetHostname() const { return "Not supported"; }
int32_t GetPort() const { return 0; }
protected:
private:

View File

@ -115,8 +115,8 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
MS_LOG(DEBUG) << "Ignore eoe";
// However we need to flush any left over from the async write buffer. But any error
// we are getting will just to stop caching but the pipeline will continue
Status rc;
if ((rc = cache_client_->FlushAsyncWriteBuffer()).IsError()) {
Status rc = cache_client_->FlushAsyncWriteBuffer();
if (rc.IsError()) {
cache_missing_rows_ = false;
if (rc == StatusCode::kMDOutOfMemory || rc == kMDNoSpace) {
cache_client_->ServerRunningOutOfResources();
@ -138,8 +138,7 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
if (rq->GetState() == TensorRowCacheRequest::State::kEmpty) {
// We will send the request async. But any error we most
// likely ignore and continue.
Status rc;
rc = rq->AsyncSendCacheRequest(cache_client_, new_row);
Status rc = rq->AsyncSendCacheRequest(cache_client_, new_row);
if (rc.IsOk()) {
RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id));
} else if (rc == StatusCode::kMDOutOfMemory || rc == kMDNoSpace) {
@ -192,7 +191,7 @@ Status CacheMergeOp::Cleaner() {
Status CacheMergeOp::PrepareOperator() { // Run any common code from super class first before adding our own
// specific logic
CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 2, "Incorrect number of children");
CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == kNumChildren, "Incorrect number of children");
RETURN_IF_NOT_OK(DatasetOp::PrepareOperator());
// Get the computed check sum from all ops in the cache miss class
uint32_t cache_crc = DatasetOp::GenerateCRC(child_[kCacheMissChildIdx]);
@ -287,8 +286,7 @@ Status CacheMergeOp::TensorRowCacheRequest::AsyncSendCacheRequest(const std::sha
auto expected = State::kEmpty;
if (st_.compare_exchange_strong(expected, State::kDirty)) {
// We will do a deep copy but write directly into CacheRequest protobuf or shared memory
Status rc;
rc = cc->AsyncWriteRow(row);
Status rc = cc->AsyncWriteRow(row);
if (rc.StatusCode() == StatusCode::kMDNotImplementedYet) {
cleaner_copy_ = std::make_shared<CacheRowRequest>(cc.get());
rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row);

View File

@ -68,6 +68,7 @@ class CacheMergeOp : public ParallelOp {
std::shared_ptr<CacheRowRequest> cleaner_copy_;
};
constexpr static int kNumChildren = 2; // CacheMergeOp has 2 children
constexpr static int kCacheHitChildIdx = 0; // Cache hit stream
constexpr static int kCacheMissChildIdx = 1; // Cache miss stream

View File

@ -161,7 +161,7 @@ Status CacheOp::WaitForCachingAllRows() {
case CacheServiceState::kBuildPhase:
// Do nothing. Continue to wait.
BuildPhaseDone = false;
std::this_thread::sleep_for(std::chrono::milliseconds(100));
std::this_thread::sleep_for(std::chrono::milliseconds(kPhaseCheckIntervalInMilliSec));
break;
case CacheServiceState::kFetchPhase:
BuildPhaseDone = true;

View File

@ -34,6 +34,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
// assigns row id). No read access in the first phase. Once the cache is fully built,
// we switch to second phase and fetch requests from the sampler.
enum class Phase : uint8_t { kBuildPhase = 0, kFetchPhase = 1 };
constexpr static int32_t kPhaseCheckIntervalInMilliSec = 100;
/// \brief The nested builder class inside of the CacheOp is used to help manage all of
/// the arguments for constructing it. Use the builder by setting each argument

View File

@ -82,8 +82,7 @@ Status RenameOp::ComputeColMap() {
std::string name = pair.first;
int32_t id = pair.second;
// find name
std::vector<std::string>::iterator it;
it = std::find(in_columns_.begin(), in_columns_.end(), name);
std::vector<std::string>::iterator it = std::find(in_columns_.begin(), in_columns_.end(), name);
// for c input checks here we have to count the number of times we find the stuff in in_columns_
// because we iterate over the mInputList n times
if (it != in_columns_.end()) {

View File

@ -71,5 +71,18 @@ Status DatasetCacheImpl::CreateCacheMergeOp(int32_t num_workers, std::shared_ptr
return Status::OK();
}
Status DatasetCacheImpl::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["session_id"] = session_id_;
args["cache_memory_size"] = cache_mem_sz_;
args["spill"] = spill_;
if (hostname_) args["hostname"] = hostname_.value();
if (port_) args["port"] = port_.value();
if (num_connections_) args["num_connections"] = num_connections_.value();
if (prefetch_sz_) args["prefetch_size"] = prefetch_sz_.value();
*out_json = args;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -32,6 +32,7 @@ namespace dataset {
/// DatasetCache is the IR of CacheClient
class DatasetCacheImpl : public DatasetCache {
public:
friend class PreBuiltDatasetCache;
///
/// \brief Constructor
/// \param id A user assigned session id for the current pipeline.
@ -66,6 +67,8 @@ class DatasetCacheImpl : public DatasetCache {
Status ValidateParams() override { return Status::OK(); }
Status to_json(nlohmann::json *out_json) override;
~DatasetCacheImpl() = default;
private:

View File

@ -28,52 +28,5 @@ Status PreBuiltDatasetCache::Build() {
// we actually want to keep a reference of the runtime object so it can be shared by different pipelines
return Status::OK();
}
Status PreBuiltDatasetCache::CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *const ds) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
std::shared_ptr<CacheOp> cache_op = nullptr;
RETURN_IF_NOT_OK(CacheOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&cache_op));
*ds = cache_op;
return Status::OK();
}
Status PreBuiltDatasetCache::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["session_id"] = cache_client_->session_id();
args["cache_memory_size"] = cache_client_->GetCacheMemSz();
args["spill"] = cache_client_->isSpill();
args["num_connections"] = cache_client_->GetNumConnections();
args["prefetch_size"] = cache_client_->GetPrefetchSize();
*out_json = args;
return Status::OK();
}
Status PreBuiltDatasetCache::CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
std::shared_ptr<SamplerObj> sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
std::shared_ptr<CacheLookupOp> lookup_op = nullptr;
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt));
RETURN_IF_NOT_OK(CacheLookupOp::Builder()
.SetNumWorkers(num_workers)
.SetClient(cache_client_)
.SetSampler(sampler_rt)
.Build(&lookup_op));
*ds = lookup_op;
return Status::OK();
}
Status PreBuiltDatasetCache::CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
std::shared_ptr<CacheMergeOp> merge_op = nullptr;
RETURN_IF_NOT_OK(CacheMergeOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&merge_op));
*ds = merge_op;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -21,37 +21,27 @@
#include <utility>
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
namespace mindspore {
namespace dataset {
/// DatasetCache is the IR of CacheClient
class PreBuiltDatasetCache : public DatasetCache {
class PreBuiltDatasetCache : public DatasetCacheImpl {
public:
/// \brief Constructor
/// \param cc a pre-built cache client
explicit PreBuiltDatasetCache(std::shared_ptr<CacheClient> cc) : cache_client_(std::move(cc)) {}
explicit PreBuiltDatasetCache(std::shared_ptr<CacheClient> cc)
: DatasetCacheImpl(cc->session_id(), cc->GetCacheMemSz(), cc->isSpill(), StringToChar(cc->GetHostname()),
cc->GetPort(), cc->GetNumConnections(), cc->GetPrefetchSize()) {
cache_client_ = std::move(cc);
}
~PreBuiltDatasetCache() = default;
/// Method to initialize the DatasetCache by creating an instance of a CacheClient
/// \return Status Error code
Status Build() override;
Status CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *const ds) override;
Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
std::shared_ptr<SamplerObj> sampler) override;
Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override;
Status ValidateParams() override { return Status::OK(); }
Status to_json(nlohmann::json *out_json) override;
private:
std::shared_ptr<CacheClient> cache_client_;
};
} // namespace dataset
} // namespace mindspore

View File

@ -202,12 +202,11 @@ Status CacheTransformPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool
}
// Helper function to execute mappable cache transformation.
// Input:
// Input tree:
// Sampler
// |
// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache)
//
// Transformed:
// Transformed tree:
// Sampler --> CacheLookupNode ------------------------->
// | |
// | CacheMergeNode
@ -232,10 +231,9 @@ Status CacheTransformPass::InjectMappableCacheNode(std::shared_ptr<MappableSourc
}
// Helper function to execute non-mappable cache transformation.
// Input:
// Input tree:
// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache)
//
// Transformed:
// Transformed tree:
// Sampler
// |
// LeafNode --> OtherNodes --> CachedNode --> CacheNode

View File

@ -133,12 +133,12 @@ class CacheTransformPass : public IRTreePass {
private:
/// \brief Helper function to execute mappable cache transformation.
///
/// Input:
/// Input tree:
/// Sampler
/// |
/// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache)
///
/// Transformed:
/// Transformed tree:
/// Sampler --> CacheLookupNode ------------------------->
/// | |
/// | CacheMergeNode
@ -153,10 +153,10 @@ class CacheTransformPass : public IRTreePass {
/// \brief Helper function to execute non-mappable cache transformation.
///
/// Input:
/// Input tree:
/// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache)
///
/// Transformed:
/// Transformed tree:
/// Sampler
/// |
/// LeafNode --> OtherNodes --> CachedNode --> CacheNode

View File

@ -110,6 +110,9 @@ constexpr int32_t kDftPrefetchSize = 20;
constexpr int32_t kDftNumConnections = 12;
constexpr int32_t kDftAutoNumWorkers = false;
constexpr char kDftMetaColumnPrefix[] = "_meta-";
constexpr int32_t kDecimal = 10; // used in strtol() to convert a string value according to decimal numeral system
constexpr int32_t kMinLegalPort = 1025;
constexpr int32_t kMaxLegalPort = 65535;
// Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h)
constexpr uint8_t kCVInvalidType = 255;

View File

@ -31,5 +31,4 @@ def main():
cache_server = os.path.join(cache_admin_dir, "cache_server")
os.chmod(cache_admin, stat.S_IRWXU)
os.chmod(cache_server, stat.S_IRWXU)
cmd = cache_admin + " " + " ".join(sys.argv[1:])
sys.exit(subprocess.call(cmd, shell=True))
sys.exit(subprocess.call([cache_admin] + sys.argv[1:], shell=False))

View File

@ -26,8 +26,8 @@ class DatasetCache:
"""
A client to interface with tensor caching service.
For details, please check `Chinese tutorial <https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/enable_cache.html>`_,
`Chinese programming guide <https://www.mindspore.cn/doc/programming_guide/zh-CN/master/cache.html?highlight=datasetcache>`_.
For details, please check `Tutorial <https://www.mindspore.cn/tutorial/training/en/master/advanced_use/
enable_cache.html>`_, `Programming guide <https://www.mindspore.cn/doc/programming_guide/en/master/cache.html>`_.
Args:
session_id (int): A user assigned session id for the current pipeline.
@ -76,7 +76,8 @@ class DatasetCache:
self.num_connections = num_connections
self.cache_client = CacheClient(session_id, size, spilling, hostname, port, num_connections, prefetch_size)
def GetStat(self):
def get_stat(self):
"""Get the statistics from a cache."""
return self.cache_client.GetStat()
def __deepcopy__(self, memodict):

View File

@ -22,6 +22,10 @@ from . import py_transforms_util as util
from .c_transforms import TensorOperation
def not_random(function):
"""
Specify the function as "not random", i.e., it produces deterministic result.
A Python function can only be cached after it is specified as "not random".
"""
function.random = False
return function

View File

@ -47,6 +47,10 @@ DE_PY_BORDER_TYPE = {Border.CONSTANT: 'constant',
def not_random(function):
"""
Specify the function as "not random", i.e., it produces deterministic result.
A Python function can only be cached after it is specified as "not random".
"""
function.random = False
return function

View File

@ -2221,7 +2221,7 @@ def test_cache_map_interrupt_and_rerun():
assert num_iter == 10000
epoch_count += 1
cache_stat = some_cache.GetStat()
cache_stat = some_cache.get_stat()
assert cache_stat.num_mem_cached == 10000
logger.info("test_cache_map_interrupt_and_rerun Ended.\n")

View File

@ -153,7 +153,7 @@ def test_cache_nomap_basic3():
assert num_iter == 12
# Contact the server to get the statistics
stat = some_cache.GetStat()
stat = some_cache.get_stat()
cache_sz = stat.avg_cache_sz
num_mem_cached = stat.num_mem_cached
num_disk_cached = stat.num_disk_cached
@ -366,7 +366,7 @@ def test_cache_nomap_basic8():
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_basic9():
"""
Testing the GetStat interface for getting some info from server, but this should fail if the cache is not created
Testing the get_stat interface for getting some info from server, but this should fail if the cache is not created
in a pipeline.
"""
@ -381,7 +381,7 @@ def test_cache_nomap_basic9():
# Contact the server to get the statistics, this should fail because we have not used this cache in any pipeline
# so there will not be any cache to get stats on.
with pytest.raises(RuntimeError) as e:
stat = some_cache.GetStat()
stat = some_cache.get_stat()
cache_sz = stat.avg_cache_sz
logger.info("Average row cache size: {}".format(cache_sz))
assert "Unexpected error" in str(e.value)
@ -1239,7 +1239,7 @@ def test_cache_nomap_interrupt_and_rerun():
assert num_iter == 10000
epoch_count += 1
cache_stat = some_cache.GetStat()
cache_stat = some_cache.get_stat()
assert cache_stat.num_mem_cached == 10000
logger.info("test_cache_nomap_interrupt_and_rerun Ended.\n")
@ -2349,7 +2349,7 @@ def test_cache_nomap_all_rows_cached():
logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == num_total_rows
cache_stat = some_cache.GetStat()
cache_stat = some_cache.get_stat()
assert cache_stat.num_mem_cached == num_total_rows
logger.info("test_cache_nomap_all_rows_cached Ended.\n")