From 8c1b56114916352b6cc176472066cf638aab401e Mon Sep 17 00:00:00 2001 From: Lixia Chen Date: Fri, 28 May 2021 14:13:23 -0400 Subject: [PATCH] Fix cache code check warnings --- .../minddata/dataset/core/config_manager.cc | 11 ++-- .../dataset/engine/cache/cache_admin_arg.cc | 48 ++++++++------- .../dataset/engine/cache/cache_admin_arg.h | 6 +- .../dataset/engine/cache/cache_arena.cc | 3 +- .../dataset/engine/cache/cache_client.cc | 58 ++++++++++--------- .../dataset/engine/cache/cache_client.h | 2 + .../dataset/engine/cache/cache_common.h | 7 +++ .../dataset/engine/cache/cache_grpc_client.cc | 10 ++-- .../dataset/engine/cache/cache_grpc_client.h | 5 ++ .../dataset/engine/cache/cache_grpc_server.cc | 8 +-- .../dataset/engine/cache/cache_grpc_server.h | 1 + .../minddata/dataset/engine/cache/cache_hw.cc | 8 +-- .../dataset/engine/cache/cache_main.cc | 42 +++++++++----- .../dataset/engine/cache/cache_request.cc | 13 +++-- .../dataset/engine/cache/cache_request.h | 2 +- .../dataset/engine/cache/cache_server.cc | 46 +++++++++------ .../dataset/engine/cache/cache_server.h | 2 + .../dataset/engine/cache/cache_service.cc | 10 ++-- .../engine/cache/perf/cache_perf_run.cc | 14 ++--- .../engine/cache/perf/cache_pipeline.cc | 2 +- .../engine/cache/perf/cache_pipeline_run.cc | 8 +-- .../dataset/engine/cache/storage_manager.cc | 2 +- .../dataset/engine/cache/storage_manager.h | 1 + .../engine/cache/stub/cache_grpc_client.h | 2 + .../engine/datasetops/cache_merge_op.cc | 12 ++-- .../engine/datasetops/cache_merge_op.h | 1 + .../dataset/engine/datasetops/cache_op.cc | 2 +- .../dataset/engine/datasetops/cache_op.h | 1 + .../dataset/engine/datasetops/rename_op.cc | 3 +- .../engine/ir/cache/dataset_cache_impl.cc | 13 +++++ .../engine/ir/cache/dataset_cache_impl.h | 3 + .../ir/cache/pre_built_dataset_cache.cc | 47 --------------- .../engine/ir/cache/pre_built_dataset_cache.h | 24 +++----- .../engine/opt/pre/cache_transform_pass.cc | 10 ++-- .../engine/opt/pre/cache_transform_pass.h | 8 +-- .../dataset/include/dataset/constants.h | 3 + mindspore/dataset/engine/cache_admin.py | 3 +- mindspore/dataset/engine/cache_client.py | 7 ++- mindspore/dataset/transforms/py_transforms.py | 4 ++ mindspore/dataset/vision/py_transforms.py | 4 ++ tests/ut/python/dataset/test_cache_map.py | 2 +- tests/ut/python/dataset/test_cache_nomap.py | 10 ++-- 42 files changed, 240 insertions(+), 228 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/core/config_manager.cc b/mindspore/ccsrc/minddata/dataset/core/config_manager.cc index fedd82260be..dee5a6b276d 100644 --- a/mindspore/ccsrc/minddata/dataset/core/config_manager.cc +++ b/mindspore/ccsrc/minddata/dataset/core/config_manager.cc @@ -28,6 +28,7 @@ #include "mindspore/lite/src/common/log_adapter.h" #endif #include "minddata/dataset/util/system_pool.h" +#include "utils/ms_utils.h" namespace mindspore { namespace dataset { @@ -53,14 +54,14 @@ ConfigManager::ConfigManager() enable_shared_mem_(true) { num_cpu_threads_ = num_cpu_threads_ > 0 ? num_cpu_threads_ : std::numeric_limits::max(); num_parallel_workers_ = num_parallel_workers_ < num_cpu_threads_ ? num_parallel_workers_ : num_cpu_threads_; - auto env_cache_host = std::getenv("MS_CACHE_HOST"); - auto env_cache_port = std::getenv("MS_CACHE_PORT"); - if (env_cache_host != nullptr) { + std::string env_cache_host = common::GetEnv("MS_CACHE_HOST"); + std::string env_cache_port = common::GetEnv("MS_CACHE_PORT"); + if (!env_cache_host.empty()) { cache_host_ = env_cache_host; } - if (env_cache_port != nullptr) { + if (!env_cache_port.empty()) { char *end = nullptr; - cache_port_ = strtol(env_cache_port, &end, 10); + cache_port_ = static_cast(strtol(env_cache_port.c_str(), &end, kDecimal)); if (*end != '\0') { MS_LOG(WARNING) << "Cache port from env variable MS_CACHE_PORT is invalid\n"; cache_port_ = 0; // cause the port range validation to generate an error during the validation checks diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc index 7bfd9c8a3cf..f04dfa229fb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc @@ -33,37 +33,34 @@ namespace mindspore { namespace dataset { -const int32_t CacheAdminArgHandler::kDefaultNumWorkers = std::thread::hardware_concurrency() > 2 - ? std::thread::hardware_concurrency() / 2 - : 1; const char CacheAdminArgHandler::kServerBinary[] = "cache_server"; CacheAdminArgHandler::CacheAdminArgHandler() : port_(kCfgDefaultCachePort), num_workers_(kDefaultNumWorkers), - shm_mem_sz_(kDefaultSharedMemorySizeInGB), + shm_mem_sz_(kDefaultSharedMemorySize), log_level_(kDefaultLogLevel), - memory_cap_ratio_(kMemoryCapRatio), + memory_cap_ratio_(kDefaultMemoryCapRatio), hostname_(kCfgDefaultCacheHost), spill_dir_(""), command_id_(CommandId::kCmdUnknown) { - const char *env_cache_host = std::getenv("MS_CACHE_HOST"); - const char *env_cache_port = std::getenv("MS_CACHE_PORT"); - if (env_cache_host != nullptr) { + std::string env_cache_host = common::GetEnv("MS_CACHE_HOST"); + std::string env_cache_port = common::GetEnv("MS_CACHE_PORT"); + if (!env_cache_host.empty()) { hostname_ = env_cache_host; } - if (env_cache_port != nullptr) { + if (!env_cache_port.empty()) { char *end = nullptr; - port_ = strtol(env_cache_port, &end, 10); + port_ = static_cast(strtol(env_cache_port.c_str(), &end, kDecimal)); if (*end != '\0') { std::cerr << "Cache port from env variable MS_CACHE_PORT is invalid\n"; port_ = 0; // cause the port range validation to generate an error during the validation checks } } - const char *env_log_level = std::getenv("GLOG_v"); - if (env_log_level != nullptr) { + std::string env_log_level = common::GetEnv("GLOG_v"); + if (!env_log_level.empty()) { char *end = nullptr; - log_level_ = strtol(env_log_level, &end, 10); + log_level_ = static_cast(strtol(env_log_level.c_str(), &end, kDecimal)); if (*end != '\0') { std::cerr << "Log level from env variable GLOG_v is invalid\n"; log_level_ = -1; // cause the log level range validation to generate an error during the validation checks @@ -377,15 +374,17 @@ Status CacheAdminArgHandler::Validate() { } // Additional checks here - auto max_num_workers = std::max(std::thread::hardware_concurrency(), 100); + auto max_num_workers = std::max(std::thread::hardware_concurrency(), kMaxNumWorkers); if (used_args_[ArgValue::kArgNumWorkers] && (num_workers_ < 1 || num_workers_ > max_num_workers)) // Check the value of num_workers only if it's provided by users. return Status(StatusCode::kMDSyntaxError, "Number of workers must be in range of 1 and " + std::to_string(max_num_workers) + "."); - if (log_level_ < 0 || log_level_ > 4) return Status(StatusCode::kMDSyntaxError, "Log level must be in range (0..4)."); + if (log_level_ < MsLogLevel::DEBUG || log_level_ > MsLogLevel::EXCEPTION) + return Status(StatusCode::kMDSyntaxError, "Log level must be in range (0..4)."); if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1) return Status(StatusCode::kMDSyntaxError, "Memory cap ratio should be positive and no greater than 1"); - if (port_ < 1025 || port_ > 65535) return Status(StatusCode::kMDSyntaxError, "Port must be in range (1025..65535)."); + if (port_ < kMinLegalPort || port_ > kMaxLegalPort) + return Status(StatusCode::kMDSyntaxError, "Port must be in range (1025..65535)."); return Status::OK(); } @@ -542,7 +541,7 @@ Status CacheAdminArgHandler::StopServer(CommandId command_id) { // The server will send a message back and remove the queue and we will then wake up. But on the safe // side, we will also set up an alarm and kill this process if we hang on // the message queue. - alarm(60); + (void)alarm(kAlarmDeadline); Status dummy_rc; (void)msg.ReceiveStatus(&dummy_rc); std::cout << "Cache server on port " << std::to_string(port_) << " has been stopped successfully." << std::endl; @@ -579,8 +578,7 @@ Status CacheAdminArgHandler::StartServer(CommandId command_id) { } // fork the child process to become the daemon - pid_t pid; - pid = fork(); + pid_t pid = fork(); // failed to fork if (pid < 0) { std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno); @@ -588,7 +586,7 @@ Status CacheAdminArgHandler::StartServer(CommandId command_id) { } else if (pid > 0) { // As a parent, we close the write end. We only listen. close(fd[1]); - dup2(fd[0], 0); + (void)dup2(fd[0], STDIN_FILENO); close(fd[0]); int status; if (waitpid(pid, &status, 0) == -1) { @@ -616,11 +614,11 @@ Status CacheAdminArgHandler::StartServer(CommandId command_id) { } else { // Child here ... // Close all stdin, redirect stdout and stderr to the write end of the pipe. - close(fd[0]); - dup2(fd[1], 1); - dup2(fd[1], 2); - close(0); - close(fd[1]); + (void)close(fd[0]); + (void)dup2(fd[1], STDOUT_FILENO); + (void)dup2(fd[1], STDERR_FILENO); + (void)close(STDIN_FILENO); + (void)close(fd[1]); // exec the cache server binary in this process // If the user did not provide the value of num_workers, we pass -1 to cache server to allow it assign the default. // So that the server knows if the number is provided by users or by default. diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h index b5e83837f23..8a6dcf1365b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h @@ -31,10 +31,8 @@ namespace dataset { class CacheAdminArgHandler { public: - static const int32_t kDefaultNumWorkers; - static constexpr int32_t kDefaultSharedMemorySizeInGB = 4; - static constexpr int32_t kDefaultLogLevel = 1; - static constexpr float kMemoryCapRatio = 0.8; + static constexpr int32_t kAlarmDeadline = 60; + static constexpr int32_t kMaxNumWorkers = 100; static const char kServerBinary[]; // These are the actual command types to execute diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc index 27cb0de8d97..49c724f25cb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc @@ -41,7 +41,8 @@ Status CachedSharedMemory::Init() { // We will create a number of sub pool out of shared memory to reduce latch contention int32_t num_of_pools = num_numa_nodes_; if (num_numa_nodes_ == 1) { - num_of_pools = shared_memory_sz_in_gb_ * 2; + constexpr int32_t kNumPoolMultiplier = 2; + num_of_pools = shared_memory_sz_in_gb_ * kNumPoolMultiplier; } sub_pool_sz_ = shm_mem_sz / num_of_pools; // If each subpool is too small, readjust the number of pools diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc index 7048dffad35..ac045b165c5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc @@ -46,8 +46,8 @@ Status CacheClient::Builder::SanityCheck() { CHECK_FAIL_RETURN_SYNTAX_ERROR(num_connections_ > 0, "number of tcp/ip connections must be positive."); CHECK_FAIL_RETURN_SYNTAX_ERROR(prefetch_size_ > 0, "prefetch size must be positive."); CHECK_FAIL_RETURN_SYNTAX_ERROR(!hostname_.empty(), "hostname must not be empty."); - CHECK_FAIL_RETURN_SYNTAX_ERROR(port_ > 1024, "Port must be in range (1025..65535)."); - CHECK_FAIL_RETURN_SYNTAX_ERROR(port_ <= 65535, "Port must be in range (1025..65535)."); + CHECK_FAIL_RETURN_SYNTAX_ERROR(port_ >= kMinLegalPort, "Port must be in range (1025..65535)."); + CHECK_FAIL_RETURN_SYNTAX_ERROR(port_ <= kMaxLegalPort, "Port must be in range (1025..65535)."); CHECK_FAIL_RETURN_SYNTAX_ERROR(hostname_ == "127.0.0.1", "now cache client has to be on the same host with cache server."); return Status::OK(); @@ -103,6 +103,9 @@ void CacheClient::Print(std::ostream &out) const { << SupportLocalClient(); } +std::string CacheClient::GetHostname() const { return comm_->GetHostname(); } +int32_t CacheClient::GetPort() const { return comm_->GetPort(); } + Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const { auto rq = std::make_shared(this); RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(this, row)); @@ -426,36 +429,35 @@ Status CacheClient::AsyncBufferStream::SyncFlush(AsyncFlushFlag flag) { asyncWriter->rq.reset( new BatchCacheRowsRequest(cc_, offset_addr_ + cur_ * kAsyncBufferSize, asyncWriter->num_ele_)); flush_rc_ = cc_->PushRequest(asyncWriter->rq); - if (flush_rc_.IsOk()) { - // If we are asked to wait, say this is the final flush, just wait for its completion. - bool blocking = (flag & AsyncFlushFlag::kFlushBlocking) == AsyncFlushFlag::kFlushBlocking; - if (blocking) { - // Make sure we are done with all the buffers - for (auto i = 0; i < kNumAsyncBuffer; ++i) { - if (buf_arr_[i].rq) { - Status rc = buf_arr_[i].rq->Wait(); - if (rc.IsError()) { - flush_rc_ = rc; - } - buf_arr_[i].rq.reset(); - } + RETURN_IF_NOT_OK(flush_rc_); + + // If we are asked to wait, say this is the final flush, just wait for its completion. + bool blocking = (flag & AsyncFlushFlag::kFlushBlocking) == AsyncFlushFlag::kFlushBlocking; + if (blocking) { + // Make sure we are done with all the buffers + for (auto i = 0; i < kNumAsyncBuffer; ++i) { + if (buf_arr_[i].rq) { + Status rc = buf_arr_[i].rq->Wait(); + if (rc.IsError()) flush_rc_ = rc; + buf_arr_[i].rq.reset(); } } - // Prepare for the next buffer. - cur_ = (cur_ + 1) % kNumAsyncBuffer; - asyncWriter = &buf_arr_[cur_]; - // Update the cur_ while we have the lock. - // Before we do anything, make sure the cache server has done with this buffer, or we will corrupt its content - // Also we can also pick up any error from previous flush. - if (asyncWriter->rq) { - // Save the result into a common area, so worker can see it and quit. - flush_rc_ = asyncWriter->rq->Wait(); - asyncWriter->rq.reset(); - } - asyncWriter->bytes_avail_ = kAsyncBufferSize; - asyncWriter->num_ele_ = 0; } + // Prepare for the next buffer. + cur_ = (cur_ + 1) % kNumAsyncBuffer; + asyncWriter = &buf_arr_[cur_]; + // Update the cur_ while we have the lock. + // Before we do anything, make sure the cache server has done with this buffer, or we will corrupt its content + // Also we can also pick up any error from previous flush. + if (asyncWriter->rq) { + // Save the result into a common area, so worker can see it and quit. + flush_rc_ = asyncWriter->rq->Wait(); + asyncWriter->rq.reset(); + } + asyncWriter->bytes_avail_ = kAsyncBufferSize; + asyncWriter->num_ele_ = 0; } + return flush_rc_; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h index 7e906ad234a..e7ccd620817 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h @@ -232,6 +232,8 @@ class CacheClient { int32_t GetNumConnections() const { return num_connections_; } int32_t GetPrefetchSize() const { return prefetch_size_; } int32_t GetClientId() const { return client_id_; } + std::string GetHostname() const; + int32_t GetPort() const; /// MergeOp will notify us when the server can't cache any more rows. /// We will stop any attempt to fetch any rows that are most likely diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h index 53954c83266..e1373623ecc 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h @@ -24,6 +24,7 @@ #include #endif #include +#include #ifdef ENABLE_CACHE #include "proto/cache_grpc.grpc.pb.h" #endif @@ -41,6 +42,12 @@ constexpr static int32_t kLocalByPassThreshold = 64 * 1024; constexpr static int32_t kDefaultSharedMemorySize = 4; /// \brief Memory Cap ratio used by the server constexpr static float kDefaultMemoryCapRatio = 0.8; +/// \brief Default log level of the server +constexpr static int32_t kDefaultLogLevel = 1; +/// \brief Set num workers to half of num_cpus as the default +static const int32_t kDefaultNumWorkers = std::thread::hardware_concurrency() > 2 + ? std::thread::hardware_concurrency() / 2 + : 1; /// \brief A flag used by the BatchFetch request (client side) if it can support local bypass constexpr static uint32_t kLocalClientSupport = 1; /// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.cc index 70a8ba56ed8..e1ce544f08c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.cc @@ -26,7 +26,7 @@ CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port // message limit is 4MB which is not big enough. args.SetMaxReceiveMessageSize(-1); MS_LOG(INFO) << "Hostname: " << hostname_ << ", port: " << std::to_string(port_); -#if CACHE_LOCAL_CLIENT +#ifdef CACHE_LOCAL_CLIENT // Try connect locally to the unix_socket first as the first preference // Need to resolve hostname to ip address rather than to do a string compare if (hostname == "127.0.0.1") { @@ -36,7 +36,7 @@ CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port #endif std::string target = hostname + ":" + std::to_string(port); channel_ = grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args); -#if CACHE_LOCAL_CLIENT +#ifdef CACHE_LOCAL_CLIENT } #endif stub_ = CacheServerGreeter::NewStub(channel_); @@ -44,7 +44,7 @@ CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port Status CacheClientGreeter::AttachToSharedMemory(bool *local_bypass) { *local_bypass = false; -#if CACHE_LOCAL_CLIENT +#ifdef CACHE_LOCAL_CLIENT SharedMemory::shm_key_t shm_key; RETURN_IF_NOT_OK(PortToFtok(port_, &shm_key)); // Attach to the shared memory @@ -85,7 +85,7 @@ Status CacheClientGreeter::HandleRequest(std::shared_ptr rq) { auto seqNo = request_cnt_.fetch_add(1); auto tag = std::make_unique(std::move(rq), seqNo); // One minute timeout - auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); + auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(kRequestTimeoutDeadlineInSec); tag->ctx_.set_deadline(deadline); tag->rpc_ = stub_->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, &cq_); tag->rpc_->StartCall(); @@ -108,7 +108,7 @@ Status CacheClientGreeter::WorkerEntry() { do { bool success; void *tag; - auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1); + auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(kWaitForNewEventDeadlineInSec); // Set a timeout for one second. Check for interrupt if we need to do early exit. auto r = cq_.AsyncNext(&tag, &success, deadline); if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.h index 0d4d54e6faa..2521bfa5d1e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.h @@ -59,6 +59,8 @@ class CacheClientGreeter : public Service { friend class CacheClient; public: + constexpr static int32_t kRequestTimeoutDeadlineInSec = 60; + constexpr static int32_t kWaitForNewEventDeadlineInSec = 1; explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_connections); ~CacheClientGreeter(); @@ -86,6 +88,9 @@ class CacheClientGreeter : public Service { /// \return Base address of the shared memory. const void *SharedMemoryBaseAddr() const { return mem_.SharedMemoryBaseAddr(); } + std::string GetHostname() const { return hostname_; } + int32_t GetPort() const { return port_; } + private: std::shared_ptr channel_; std::unique_ptr stub_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc index 663c93e2e8f..c7272d348ad 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc @@ -56,7 +56,7 @@ Status CacheServerGreeterImpl::Run() { // Default message size for gRPC is 4MB. Increase it to 2g-1 builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); int port_tcpip = 0; -#if CACHE_LOCAL_CLIENT +#ifdef CACHE_LOCAL_CLIENT int port_local = 0; // We also optimize on local clients on the same machine using unix socket builder.AddListeningPort("unix://" + unix_socket_, grpc::InsecureServerCredentials(), &port_local); @@ -72,7 +72,7 @@ Status CacheServerGreeterImpl::Run() { if (port_tcpip != port_) { errMsg += "Unable to bind to tcpip port " + std::to_string(port_) + "."; } -#if CACHE_LOCAL_CLIENT +#ifdef CACHE_LOCAL_CLIENT if (port_local == 0) { errMsg += " Unable to create unix socket " + unix_socket_ + "."; } @@ -176,7 +176,7 @@ void CacheServerRequest::Print(std::ostream &out) const { Status CacheServerGreeterImpl::MonitorUnixSocket() { TaskManager::FindMe()->Post(); -#if CACHE_LOCAL_CLIENT +#ifdef CACHE_LOCAL_CLIENT Path p(unix_socket_); do { RETURN_IF_INTERRUPTED(); @@ -197,7 +197,7 @@ Status CacheServerGreeterImpl::MonitorUnixSocket() { MS_LOG(WARNING) << "Unix socket is removed."; TaskManager::WakeUpWatchDog(); } - std::this_thread::sleep_for(std::chrono::seconds(5)); + std::this_thread::sleep_for(std::chrono::seconds(kMonitorIntervalInSec)); } while (true); #endif return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h index 2e79b4fe82c..5d7779b4796 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h @@ -70,6 +70,7 @@ class CacheServerGreeterImpl final { friend class CacheServer; public: + constexpr static int32_t kMonitorIntervalInSec = 5; explicit CacheServerGreeterImpl(int32_t port); virtual ~CacheServerGreeterImpl(); /// \brief Brings up gRPC server diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.cc index 9a483dc09d3..28374842e3c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.cc @@ -101,7 +101,7 @@ Status CacheServerHW::GetNumaNodeInfo() { auto p = it->next(); const std::string entry = p.Basename(); const char *name = entry.data(); - if (strncmp(name, kNodeName, 4) == 0 && isdigit_string(name + strlen(kNodeName))) { + if (strncmp(name, kNodeName, strlen(kNodeName)) == 0 && isdigit_string(name + strlen(kNodeName))) { numa_nodes_.insert(p); } } @@ -116,7 +116,7 @@ Status CacheServerHW::GetNumaNodeInfo() { auto r = std::regex("[0-9]*-[0-9]*"); for (Path p : numa_nodes_) { auto node_dir = p.Basename(); - numa_id_t numa_node = strtol(node_dir.data() + strlen(kNodeName), nullptr, 10); + numa_id_t numa_node = static_cast(strtol(node_dir.data() + strlen(kNodeName), nullptr, kDecimal)); Path f = p / kCpuList; std::ifstream fs(f.toString()); CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "Fail to open file: " + f.toString()); @@ -134,8 +134,8 @@ Status CacheServerHW::GetNumaNodeInfo() { CHECK_FAIL_RETURN_UNEXPECTED(pos != std::string::npos, "Failed to parse numa node file"); std::string min = match.substr(0, pos); std::string max = match.substr(pos + 1); - cpu_id_t cpu_min = strtol(min.data(), nullptr, 10); - cpu_id_t cpu_max = strtol(max.data(), nullptr, 10); + cpu_id_t cpu_min = static_cast(strtol(min.data(), nullptr, kDecimal)); + cpu_id_t cpu_max = static_cast(strtol(max.data(), nullptr, kDecimal)); MS_LOG(DEBUG) << "Numa node " << numa_node << " CPU(s) : " << cpu_min << "-" << cpu_max; for (int i = cpu_min; i <= cpu_max; ++i) { CPU_SET(i, &cpuset); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc index aa68880a377..877d3610372 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc @@ -22,6 +22,7 @@ #include #include "minddata/dataset/engine/cache/cache_common.h" #include "minddata/dataset/engine/cache/cache_ipc.h" +#include "minddata/dataset/include/dataset/constants.h" #include "mindspore/core/utils/log_adapter.h" namespace ms = mindspore; namespace ds = mindspore::dataset; @@ -32,19 +33,30 @@ namespace ds = mindspore::dataset; ms::Status StartServer(int argc, char **argv) { ms::Status rc; ds::CacheServer::Builder builder; - if (argc != 8) { + const int32_t kTotalArgs = 8; + enum { + kProcessNameIdx = 0, + kRootDirArgIdx = 1, + kNumWorkersArgIdx = 2, + kPortArgIdx = 3, + kSharedMemorySizeArgIdx = 4, + kLogLevelArgIdx = 5, + kDemonizeArgIdx = 6, + kMemoryCapRatioArgIdx = 7 + }; + if (argc != kTotalArgs) { return ms::Status(ms::StatusCode::kMDSyntaxError); } - int32_t port = strtol(argv[3], nullptr, 10); - builder.SetRootDirectory(argv[1]) - .SetNumWorkers(strtol(argv[2], nullptr, 10)) + int32_t port = static_cast(strtol(argv[kPortArgIdx], nullptr, ds::kDecimal)); + builder.SetRootDirectory(argv[kRootDirArgIdx]) + .SetNumWorkers(static_cast(strtol(argv[kNumWorkersArgIdx], nullptr, ds::kDecimal))) .SetPort(port) - .SetSharedMemorySizeInGB(strtol(argv[4], nullptr, 10)) - .SetLogLevel(strtol(argv[5], nullptr, 10)) - .SetMemoryCapRatio(strtof(argv[7], nullptr)); + .SetSharedMemorySizeInGB(static_cast(strtol(argv[kSharedMemorySizeArgIdx], nullptr, ds::kDecimal))) + .SetLogLevel(static_cast((strtol(argv[kLogLevelArgIdx], nullptr, ds::kDecimal)))) + .SetMemoryCapRatio(strtof(argv[kMemoryCapRatioArgIdx], nullptr)); - auto daemonize_string = argv[6]; + auto daemonize_string = argv[kDemonizeArgIdx]; bool daemonize = strcmp(daemonize_string, "true") == 0 || strcmp(daemonize_string, "TRUE") == 0 || strcmp(daemonize_string, "t") == 0 || strcmp(daemonize_string, "T") == 0; @@ -68,8 +80,9 @@ ms::Status StartServer(int argc, char **argv) { if (rc.IsError()) { return rc; } - ms::g_ms_submodule_log_levels[SUBMODULE_ID] = strtol(argv[5], nullptr, 10); - google::InitGoogleLogging(argv[0]); + ms::g_ms_submodule_log_levels[SUBMODULE_ID] = + static_cast(strtol(argv[kLogLevelArgIdx], nullptr, ds::kDecimal)); + google::InitGoogleLogging(argv[kProcessNameIdx]); #undef google #endif rc = msg.Create(); @@ -94,9 +107,8 @@ ms::Status StartServer(int argc, char **argv) { } if (child_rc.IsError()) { return child_rc; - } else { - warning_string = child_rc.ToString(); } + warning_string = child_rc.ToString(); std::cout << "Cache server startup completed successfully!\n"; std::cout << "The cache server daemon has been created as process id " << pid << " and listening on port " << port << ".\n"; @@ -116,9 +128,9 @@ ms::Status StartServer(int argc, char **argv) { std::string errMsg = "Failed to setsid(). Errno = " + std::to_string(errno); return ms::Status(ms::StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg); } - close(0); - close(1); - close(2); + (void)close(STDIN_FILENO); + (void)close(STDOUT_FILENO); + (void)close(STDERR_FILENO); } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc index 70f7983f48c..daa2ddbad69 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc @@ -69,8 +69,7 @@ Status CacheRowRequest::SerializeCacheRowRequest(const CacheClient *cc, const Te WritableSlice all(p, sz_); auto offset = fbb->GetSize(); ReadableSlice header(fbb->GetBufferPointer(), fbb->GetSize()); - Status copy_rc; - copy_rc = WritableSlice::Copy(&all, header); + Status copy_rc = WritableSlice::Copy(&all, header); if (copy_rc.IsOk()) { for (const auto &ts : row) { WritableSlice row_data(all, offset, ts->SizeInBytes()); @@ -108,7 +107,7 @@ Status CacheRowRequest::SerializeCacheRowRequest(const CacheClient *cc, const Te Status CacheRowRequest::PostReply() { if (!reply_.result().empty()) { - row_id_from_server_ = strtoll(reply_.result().data(), nullptr, 10); + row_id_from_server_ = strtoll(reply_.result().data(), nullptr, kDecimal); } return Status::OK(); } @@ -116,11 +115,13 @@ Status CacheRowRequest::PostReply() { Status CacheRowRequest::Prepare() { if (BitTest(rq_.flag(), kDataIsInSharedMemory)) { // First one is cookie, followed by address and then size. - CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() == 3, "Incomplete rpc data"); + constexpr int32_t kExpectedBufDataSize = 3; + CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() == kExpectedBufDataSize, "Incomplete rpc data"); } else { // First one is cookie. 2nd one is the google flat buffers followed by a number of buffers. // But we are not going to decode them to verify. - CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() >= 3, "Incomplete rpc data"); + constexpr int32_t kMinBufDataSize = 3; + CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() >= kMinBufDataSize, "Incomplete rpc data"); } return Status::OK(); } @@ -161,7 +162,7 @@ Status BatchFetchRequest::RestoreRows(TensorTable *out, const void *baseAddr, in auto flag = reply_.flag(); bool dataOnSharedMemory = support_local_bypass_ ? (BitTest(flag, kDataIsInSharedMemory)) : false; if (dataOnSharedMemory) { - auto addr = strtoll(reply_.result().data(), nullptr, 10); + auto addr = strtoll(reply_.result().data(), nullptr, kDecimal); ptr = reinterpret_cast(reinterpret_cast(baseAddr) + addr); RETURN_UNEXPECTED_IF_NULL(out); *out_addr = addr; diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h index fa961dff184..7747c7be4a0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h @@ -446,7 +446,7 @@ class AllocateSharedBlockRequest : public BaseRequest { /// the free block is located. /// \return int64_t GetAddr() { - auto addr = strtoll(reply_.result().data(), nullptr, 10); + auto addr = strtoll(reply_.result().data(), nullptr, kDecimal); return addr; } }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc index 23fc7b7da7a..1b26aab4f2e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc @@ -74,7 +74,7 @@ Status CacheServer::DoServiceStart() { } catch (const std::exception &e) { RETURN_STATUS_UNEXPECTED(e.what()); } -#if CACHE_LOCAL_CLIENT +#ifdef CACHE_LOCAL_CLIENT RETURN_IF_NOT_OK(CachedSharedMemory::CreateArena(&shm_, port_, shared_memory_sz_in_gb_)); // Bring up a thread to monitor the unix socket in case it is removed. But it must be done // after we have created the unix socket. @@ -173,7 +173,7 @@ Status CacheServer::GlobalMemoryCheck(uint64_t cache_mem_sz) { } else if (req_mem == 0) { // This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than // 85% of our limit, fail this request. - if (static_cast(max_avail) / static_cast(avail_mem) <= 0.15) { + if (static_cast(max_avail) / static_cast(avail_mem) <= kMemoryBottomLineForNewService) { return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); } } @@ -331,14 +331,17 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) { CacheService *cs = GetService(connection_id); auto *base = SharedMemoryBaseAddr(); // Ensure we got 3 pieces of data coming in - CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() >= 3, "Incomplete data"); + constexpr int32_t kMinBufDataSize = 3; + CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() >= kMinBufDataSize, "Incomplete data"); + // First one is cookie, followed by data address and then size. + enum { kCookieIdx = 0, kAddrIdx = 1, kSizeIdx = 2 }; // First piece of data is the cookie and is required - auto &cookie = rq->buf_data(0); + auto &cookie = rq->buf_data(kCookieIdx); // Second piece of data is the address where we can find the serialized data - auto addr = strtoll(rq->buf_data(1).data(), nullptr, 10); + auto addr = strtoll(rq->buf_data(kAddrIdx).data(), nullptr, kDecimal); auto p = reinterpret_cast(reinterpret_cast(base) + addr); // Third piece of data is the size of the serialized data that we need to transfer - auto sz = strtoll(rq->buf_data(2).data(), nullptr, 10); + auto sz = strtoll(rq->buf_data(kSizeIdx).data(), nullptr, kDecimal); // Successful or not, we need to free the memory on exit. Status rc; if (cs == nullptr) { @@ -380,7 +383,9 @@ Status CacheServer::InternalCacheRow(CacheRequest *rq, CacheReply *reply) { // This is an internal request and is not tied to rpc. But need to post because there // is a thread waiting on the completion of this request. try { - int64_t addr = strtol(rq->buf_data(3).data(), nullptr, 10); + constexpr int32_t kBatchWaitIdx = 3; + // Fourth piece of the data is the address of the BatchWait ptr + int64_t addr = strtol(rq->buf_data(kBatchWaitIdx).data(), nullptr, kDecimal); auto *bw = reinterpret_cast(addr); // Check if the object is still around. auto bwObj = bw->GetBatchWait(); @@ -405,11 +410,13 @@ Status CacheServer::InternalFetchRow(CacheRequest *rq) { std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg); } - rc = cs->InternalFetchRow(flatbuffers::GetRoot(rq->buf_data(0).data())); + // First piece is a flatbuffer containing row fetch information, second piece is the address of the BatchWait ptr + enum { kFetchRowMsgIdx = 0, kBatchWaitIdx = 1 }; + rc = cs->InternalFetchRow(flatbuffers::GetRoot(rq->buf_data(kFetchRowMsgIdx).data())); // This is an internal request and is not tied to rpc. But need to post because there // is a thread waiting on the completion of this request. try { - int64_t addr = strtol(rq->buf_data(1).data(), nullptr, 10); + int64_t addr = strtol(rq->buf_data(kBatchWaitIdx).data(), nullptr, kDecimal); auto *bw = reinterpret_cast(addr); // Check if the object is still around. auto bwObj = bw->GetBatchWait(); @@ -747,17 +754,20 @@ Status CacheServer::ConnectReset(CacheRequest *rq) { } Status CacheServer::BatchCacheRows(CacheRequest *rq) { - CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data().size() == 3, "Expect three pieces of data"); + // First one is cookie, followed by address and then size. + enum { kCookieIdx = 0, kAddrIdx = 1, kSizeIdx = 2 }; + constexpr int32_t kExpectedBufDataSize = 3; + CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data().size() == kExpectedBufDataSize, "Expect three pieces of data"); try { - auto &cookie = rq->buf_data(0); + auto &cookie = rq->buf_data(kCookieIdx); auto connection_id = rq->connection_id(); auto client_id = rq->client_id(); int64_t offset_addr; int32_t num_elem; auto *base = SharedMemoryBaseAddr(); - offset_addr = strtoll(rq->buf_data(1).data(), nullptr, 10); + offset_addr = strtoll(rq->buf_data(kAddrIdx).data(), nullptr, kDecimal); auto p = reinterpret_cast(reinterpret_cast(base) + offset_addr); - num_elem = strtol(rq->buf_data(2).data(), nullptr, 10); + num_elem = static_cast(strtol(rq->buf_data(kSizeIdx).data(), nullptr, kDecimal)); auto batch_wait = std::make_shared(num_elem); // Get a set of free request and push into the queues. for (auto i = 0; i < num_elem; ++i) { @@ -1105,7 +1115,7 @@ Status CacheServer::AllocateSharedMemory(CacheRequest *rq, CacheReply *reply) { auto client_id = rq->client_id(); CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set"); try { - auto requestedSz = strtoll(rq->buf_data(0).data(), nullptr, 10); + auto requestedSz = strtoll(rq->buf_data(0).data(), nullptr, kDecimal); void *p = nullptr; RETURN_IF_NOT_OK(AllocateSharedMemory(client_id, requestedSz, &p)); auto *base = SharedMemoryBaseAddr(); @@ -1124,7 +1134,7 @@ Status CacheServer::FreeSharedMemory(CacheRequest *rq) { CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set"); auto *base = SharedMemoryBaseAddr(); try { - auto addr = strtoll(rq->buf_data(0).data(), nullptr, 10); + auto addr = strtoll(rq->buf_data(0).data(), nullptr, kDecimal); auto p = reinterpret_cast(reinterpret_cast(base) + addr); DeallocateSharedMemory(client_id, p); } catch (const std::exception &e) { @@ -1291,11 +1301,11 @@ int32_t CacheServer::Builder::AdjustNumWorkers(int32_t num_workers) { CacheServer::Builder::Builder() : top_(""), - num_workers_(std::thread::hardware_concurrency() / 2), - port_(50052), + num_workers_(kDefaultNumWorkers), + port_(kCfgDefaultCachePort), shared_memory_sz_in_gb_(kDefaultSharedMemorySize), memory_cap_ratio_(kDefaultMemoryCapRatio), - log_level_(1) { + log_level_(kDefaultLogLevel) { if (num_workers_ == 0) { num_workers_ = 1; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h index 3b251681ae3..cdfcde71a8a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h @@ -57,6 +57,8 @@ class CacheServer : public Service { public: friend class Services; using cache_index = std::map>; + // Only allow new service to be created if left memory is more than 15% of our hard memory cap + constexpr static float kMemoryBottomLineForNewService = 0.15; class Builder { public: diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc index 790832e36ff..b4065677519 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc @@ -90,8 +90,9 @@ Status CacheService::CacheRow(const std::vector &buf, row_id_type if (generate_id_) { *row_id_generated = GetNextRowId(); // Some debug information on how many rows we have generated so far. - if ((*row_id_generated) % 1000 == 0) { - MS_LOG(DEBUG) << "Number of rows cached: " << (*row_id_generated) + 1; + constexpr int32_t kDisplayInterval = 1000; + if ((*row_id_generated) % kDisplayInterval == 0) { + MS_LOG(DEBUG) << "Number of rows cached: " << ((*row_id_generated) + 1); } } else { if (msg->row_id() < 0) { @@ -159,8 +160,9 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_ if (generate_id_) { *row_id_generated = GetNextRowId(); // Some debug information on how many rows we have generated so far. - if ((*row_id_generated) % 1000 == 0) { - MS_LOG(DEBUG) << "Number of rows cached: " << (*row_id_generated) + 1; + constexpr int32_t kDisplayInterval = 1000; + if ((*row_id_generated) % kDisplayInterval == 0) { + MS_LOG(DEBUG) << "Number of rows cached: " << ((*row_id_generated) + 1); } } else { auto msg = GetTensorRowHeaderMsg(src.GetPointer()); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf_run.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf_run.cc index 97edf22cba6..b6219ce3f6b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf_run.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf_run.cc @@ -15,7 +15,6 @@ */ #include "minddata/dataset/engine/cache/perf/cache_perf_run.h" -#include #include #include #include @@ -24,6 +23,7 @@ #include #include #include +#include #include #include #include "minddata/dataset/util/random.h" @@ -344,11 +344,11 @@ void CachePerfRun::PrintEpochSummary() const { << std::setw(14) << "buffer count" << std::setw(18) << "Elapsed time (s)" << std::endl; for (auto &it : epoch_results_) { auto epoch_worker_summary = it.second; - std::cout << std::setw(12) << epoch_worker_summary.pipeline() + 1 << std::setw(10) << epoch_worker_summary.worker() - << std::setw(10) << epoch_worker_summary.min() << std::setw(10) << epoch_worker_summary.max() - << std::setw(10) << epoch_worker_summary.avg() << std::setw(13) << epoch_worker_summary.med() - << std::setw(14) << epoch_worker_summary.cnt() << std::setw(18) << epoch_worker_summary.elapse() - << std::endl; + std::cout << std::setw(12) << (epoch_worker_summary.pipeline() + 1) << std::setw(10) + << epoch_worker_summary.worker() << std::setw(10) << epoch_worker_summary.min() << std::setw(10) + << epoch_worker_summary.max() << std::setw(10) << epoch_worker_summary.avg() << std::setw(13) + << epoch_worker_summary.med() << std::setw(14) << epoch_worker_summary.cnt() << std::setw(18) + << epoch_worker_summary.elapse() << std::endl; } } @@ -463,7 +463,7 @@ Status CachePerfRun::StartPipelines() { // Call _exit instead of exit because we will hang TaskManager destructor for a forked child process. _exit(-1); } else if (pid > 0) { - std::cout << "Pipeline number " << i + 1 << " has been created with process id: " << pid << std::endl; + std::cout << "Pipeline number " << (i + 1) << " has been created with process id: " << pid << std::endl; pid_lists_.push_back(pid); } else { std::string errMsg = "Failed to fork process for cache pipeline: " + std::to_string(errno); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline.cc index 6780e27b038..ffca56103a6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline.cc @@ -15,7 +15,7 @@ */ #include "minddata/dataset/engine/cache/perf/cache_pipeline_run.h" -#include +#include #include "mindspore/core/utils/log_adapter.h" namespace ms = mindspore; diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline_run.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline_run.cc index 245a08dd70b..1316eb5f2ca 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline_run.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline_run.cc @@ -15,10 +15,10 @@ */ #include "minddata/dataset/engine/cache/perf/cache_pipeline_run.h" -#include #include #include #include +#include #include #include #include "minddata/dataset/core/tensor.h" @@ -182,7 +182,7 @@ Status CachePipelineRun::Run() { } // Log a warning level message so we can see it in the log file when it starts. - MS_LOG(WARNING) << "Pipeline number " << my_pipeline_ + 1 << " successfully creating cache service." << std::endl; + MS_LOG(WARNING) << "Pipeline number " << (my_pipeline_ + 1) << " successfully creating cache service." << std::endl; // Spawn a thread to listen to the parent process RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Queue listener", std::bind(&CachePipelineRun::ListenToParent, this))); @@ -213,7 +213,7 @@ Status CachePipelineRun::RunFirstEpoch() { if (my_pipeline_ + 1 == num_pipelines_) { end_row_ = num_rows_ - 1; } - std::cout << "Pipeline number " << my_pipeline_ + 1 << " row id range: [" << start_row_ << "," << end_row_ << "]" + std::cout << "Pipeline number " << (my_pipeline_ + 1) << " row id range: [" << start_row_ << "," << end_row_ << "]" << std::endl; // Spawn the worker threads. @@ -305,7 +305,7 @@ Status CachePipelineRun::WriterWorkerEntry(int32_t worker_id) { auto end_tick = std::chrono::steady_clock::now(); if (rc.IsError()) { if (rc == StatusCode::kMDOutOfMemory || rc == StatusCode::kMDNoSpace) { - MS_LOG(WARNING) << "Pipeline number " << my_pipeline_ + 1 << " worker id " << worker_id << ": " + MS_LOG(WARNING) << "Pipeline number " << (my_pipeline_ + 1) << " worker id " << worker_id << ": " << rc.ToString(); resource_err = true; cc_->ServerRunningOutOfResources(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/storage_manager.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/storage_manager.cc index 941dfb225d7..7f38d4f5d98 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/storage_manager.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/storage_manager.cc @@ -55,7 +55,7 @@ Status StorageManager::AddOneContainer(int replaced_container_pos) { } Status StorageManager::DoServiceStart() { - containers_.reserve(1000); + containers_.reserve(kMaxNumContainers); writable_containers_pool_.reserve(pool_size_); if (root_.IsDirectory()) { // create multiple containers and store their index in a pool diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/storage_manager.h b/mindspore/ccsrc/minddata/dataset/engine/cache/storage_manager.h index a66eef92738..bd9fccb0179 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/storage_manager.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/storage_manager.h @@ -47,6 +47,7 @@ class StorageManager : public Service { using value_type = std::pair>; using storage_index = AutoIndexObj, StorageBPlusTreeTraits>; using key_type = storage_index::key_type; + constexpr static int32_t kMaxNumContainers = 1000; explicit StorageManager(const Path &); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/stub/cache_grpc_client.h b/mindspore/ccsrc/minddata/dataset/engine/cache/stub/cache_grpc_client.h index fcd992797da..250dc6394d8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/stub/cache_grpc_client.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/stub/cache_grpc_client.h @@ -35,6 +35,8 @@ class CacheClientGreeter : public Service { void *SharedMemoryBaseAddr() { return nullptr; } Status HandleRequest(std::shared_ptr rq) { RETURN_STATUS_UNEXPECTED("Not supported"); } Status AttachToSharedMemory(bool *local_bypass) { RETURN_STATUS_UNEXPECTED("Not supported"); } + std::string GetHostname() const { return "Not supported"; } + int32_t GetPort() const { return 0; } protected: private: diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc index 30898ea27f5..45b061c320a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc @@ -115,8 +115,8 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { MS_LOG(DEBUG) << "Ignore eoe"; // However we need to flush any left over from the async write buffer. But any error // we are getting will just to stop caching but the pipeline will continue - Status rc; - if ((rc = cache_client_->FlushAsyncWriteBuffer()).IsError()) { + Status rc = cache_client_->FlushAsyncWriteBuffer(); + if (rc.IsError()) { cache_missing_rows_ = false; if (rc == StatusCode::kMDOutOfMemory || rc == kMDNoSpace) { cache_client_->ServerRunningOutOfResources(); @@ -138,8 +138,7 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { if (rq->GetState() == TensorRowCacheRequest::State::kEmpty) { // We will send the request async. But any error we most // likely ignore and continue. - Status rc; - rc = rq->AsyncSendCacheRequest(cache_client_, new_row); + Status rc = rq->AsyncSendCacheRequest(cache_client_, new_row); if (rc.IsOk()) { RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id)); } else if (rc == StatusCode::kMDOutOfMemory || rc == kMDNoSpace) { @@ -192,7 +191,7 @@ Status CacheMergeOp::Cleaner() { Status CacheMergeOp::PrepareOperator() { // Run any common code from super class first before adding our own // specific logic - CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 2, "Incorrect number of children"); + CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == kNumChildren, "Incorrect number of children"); RETURN_IF_NOT_OK(DatasetOp::PrepareOperator()); // Get the computed check sum from all ops in the cache miss class uint32_t cache_crc = DatasetOp::GenerateCRC(child_[kCacheMissChildIdx]); @@ -287,8 +286,7 @@ Status CacheMergeOp::TensorRowCacheRequest::AsyncSendCacheRequest(const std::sha auto expected = State::kEmpty; if (st_.compare_exchange_strong(expected, State::kDirty)) { // We will do a deep copy but write directly into CacheRequest protobuf or shared memory - Status rc; - rc = cc->AsyncWriteRow(row); + Status rc = cc->AsyncWriteRow(row); if (rc.StatusCode() == StatusCode::kMDNotImplementedYet) { cleaner_copy_ = std::make_shared(cc.get()); rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h index 42f33dd1eca..6054af14ff2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h @@ -68,6 +68,7 @@ class CacheMergeOp : public ParallelOp { std::shared_ptr cleaner_copy_; }; + constexpr static int kNumChildren = 2; // CacheMergeOp has 2 children constexpr static int kCacheHitChildIdx = 0; // Cache hit stream constexpr static int kCacheMissChildIdx = 1; // Cache miss stream diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc index 76ede843ddc..97220e8b84e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc @@ -161,7 +161,7 @@ Status CacheOp::WaitForCachingAllRows() { case CacheServiceState::kBuildPhase: // Do nothing. Continue to wait. BuildPhaseDone = false; - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + std::this_thread::sleep_for(std::chrono::milliseconds(kPhaseCheckIntervalInMilliSec)); break; case CacheServiceState::kFetchPhase: BuildPhaseDone = true; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h index 3017b2ffb5a..3d85fe6ea77 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h @@ -34,6 +34,7 @@ class CacheOp : public CacheBase, public RandomAccessOp { // assigns row id). No read access in the first phase. Once the cache is fully built, // we switch to second phase and fetch requests from the sampler. enum class Phase : uint8_t { kBuildPhase = 0, kFetchPhase = 1 }; + constexpr static int32_t kPhaseCheckIntervalInMilliSec = 100; /// \brief The nested builder class inside of the CacheOp is used to help manage all of /// the arguments for constructing it. Use the builder by setting each argument diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.cc index 00a4c5cf55b..66225df681a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.cc @@ -82,8 +82,7 @@ Status RenameOp::ComputeColMap() { std::string name = pair.first; int32_t id = pair.second; // find name - std::vector::iterator it; - it = std::find(in_columns_.begin(), in_columns_.end(), name); + std::vector::iterator it = std::find(in_columns_.begin(), in_columns_.end(), name); // for c input checks here we have to count the number of times we find the stuff in in_columns_ // because we iterate over the mInputList n times if (it != in_columns_.end()) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc index 86301a03b89..7b8c0203306 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc @@ -71,5 +71,18 @@ Status DatasetCacheImpl::CreateCacheMergeOp(int32_t num_workers, std::shared_ptr return Status::OK(); } + +Status DatasetCacheImpl::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["session_id"] = session_id_; + args["cache_memory_size"] = cache_mem_sz_; + args["spill"] = spill_; + if (hostname_) args["hostname"] = hostname_.value(); + if (port_) args["port"] = port_.value(); + if (num_connections_) args["num_connections"] = num_connections_.value(); + if (prefetch_sz_) args["prefetch_size"] = prefetch_sz_.value(); + *out_json = args; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h index f4f7c7c2450..287d026ef17 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h @@ -32,6 +32,7 @@ namespace dataset { /// DatasetCache is the IR of CacheClient class DatasetCacheImpl : public DatasetCache { public: + friend class PreBuiltDatasetCache; /// /// \brief Constructor /// \param id A user assigned session id for the current pipeline. @@ -66,6 +67,8 @@ class DatasetCacheImpl : public DatasetCache { Status ValidateParams() override { return Status::OK(); } + Status to_json(nlohmann::json *out_json) override; + ~DatasetCacheImpl() = default; private: diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.cc index 87a7eb889e7..26781e9861e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.cc @@ -28,52 +28,5 @@ Status PreBuiltDatasetCache::Build() { // we actually want to keep a reference of the runtime object so it can be shared by different pipelines return Status::OK(); } - -Status PreBuiltDatasetCache::CreateCacheOp(int32_t num_workers, std::shared_ptr *const ds) { - CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); - std::shared_ptr cache_op = nullptr; - RETURN_IF_NOT_OK(CacheOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&cache_op)); - *ds = cache_op; - - return Status::OK(); -} - -Status PreBuiltDatasetCache::to_json(nlohmann::json *out_json) { - nlohmann::json args; - args["session_id"] = cache_client_->session_id(); - args["cache_memory_size"] = cache_client_->GetCacheMemSz(); - args["spill"] = cache_client_->isSpill(); - args["num_connections"] = cache_client_->GetNumConnections(); - args["prefetch_size"] = cache_client_->GetPrefetchSize(); - *out_json = args; - return Status::OK(); -} - -Status PreBuiltDatasetCache::CreateCacheLookupOp(int32_t num_workers, std::shared_ptr *ds, - std::shared_ptr sampler) { - CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); - std::shared_ptr lookup_op = nullptr; - std::shared_ptr sampler_rt = nullptr; - RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt)); - - RETURN_IF_NOT_OK(CacheLookupOp::Builder() - .SetNumWorkers(num_workers) - .SetClient(cache_client_) - .SetSampler(sampler_rt) - .Build(&lookup_op)); - *ds = lookup_op; - - return Status::OK(); -} - -Status PreBuiltDatasetCache::CreateCacheMergeOp(int32_t num_workers, std::shared_ptr *ds) { - CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); - std::shared_ptr merge_op = nullptr; - RETURN_IF_NOT_OK(CacheMergeOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&merge_op)); - *ds = merge_op; - - return Status::OK(); -} - } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h index 83faa7e37c7..bf48efb0001 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h @@ -21,37 +21,27 @@ #include #include "minddata/dataset/engine/cache/cache_client.h" #include "minddata/dataset/engine/datasetops/cache_op.h" -#include "minddata/dataset/engine/ir/cache/dataset_cache.h" +#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" #include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h" namespace mindspore { namespace dataset { /// DatasetCache is the IR of CacheClient -class PreBuiltDatasetCache : public DatasetCache { +class PreBuiltDatasetCache : public DatasetCacheImpl { public: /// \brief Constructor /// \param cc a pre-built cache client - explicit PreBuiltDatasetCache(std::shared_ptr cc) : cache_client_(std::move(cc)) {} + explicit PreBuiltDatasetCache(std::shared_ptr cc) + : DatasetCacheImpl(cc->session_id(), cc->GetCacheMemSz(), cc->isSpill(), StringToChar(cc->GetHostname()), + cc->GetPort(), cc->GetNumConnections(), cc->GetPrefetchSize()) { + cache_client_ = std::move(cc); + } ~PreBuiltDatasetCache() = default; /// Method to initialize the DatasetCache by creating an instance of a CacheClient /// \return Status Error code Status Build() override; - - Status CreateCacheOp(int32_t num_workers, std::shared_ptr *const ds) override; - - Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr *ds, - std::shared_ptr sampler) override; - - Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr *ds) override; - - Status ValidateParams() override { return Status::OK(); } - - Status to_json(nlohmann::json *out_json) override; - - private: - std::shared_ptr cache_client_; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc index 75245fe99fe..0a73e399729 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc @@ -202,12 +202,11 @@ Status CacheTransformPass::RunOnTree(std::shared_ptr root_ir, bool } // Helper function to execute mappable cache transformation. -// Input: +// Input tree: // Sampler // | // LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache) -// -// Transformed: +// Transformed tree: // Sampler --> CacheLookupNode -------------------------> // | | // | CacheMergeNode @@ -232,10 +231,9 @@ Status CacheTransformPass::InjectMappableCacheNode(std::shared_ptr OtherNodes --> CachedNode (cache_ = DatasetCache) -// -// Transformed: +// Transformed tree: // Sampler // | // LeafNode --> OtherNodes --> CachedNode --> CacheNode diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h index a785042fd0e..20707862bfc 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h @@ -133,12 +133,12 @@ class CacheTransformPass : public IRTreePass { private: /// \brief Helper function to execute mappable cache transformation. /// - /// Input: + /// Input tree: /// Sampler /// | /// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache) /// - /// Transformed: + /// Transformed tree: /// Sampler --> CacheLookupNode -------------------------> /// | | /// | CacheMergeNode @@ -153,10 +153,10 @@ class CacheTransformPass : public IRTreePass { /// \brief Helper function to execute non-mappable cache transformation. /// - /// Input: + /// Input tree: /// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache) /// - /// Transformed: + /// Transformed tree: /// Sampler /// | /// LeafNode --> OtherNodes --> CachedNode --> CacheNode diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/constants.h b/mindspore/ccsrc/minddata/dataset/include/dataset/constants.h index 64c7cfcbd0e..a6176769ec5 100644 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/constants.h +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/constants.h @@ -110,6 +110,9 @@ constexpr int32_t kDftPrefetchSize = 20; constexpr int32_t kDftNumConnections = 12; constexpr int32_t kDftAutoNumWorkers = false; constexpr char kDftMetaColumnPrefix[] = "_meta-"; +constexpr int32_t kDecimal = 10; // used in strtol() to convert a string value according to decimal numeral system +constexpr int32_t kMinLegalPort = 1025; +constexpr int32_t kMaxLegalPort = 65535; // Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h) constexpr uint8_t kCVInvalidType = 255; diff --git a/mindspore/dataset/engine/cache_admin.py b/mindspore/dataset/engine/cache_admin.py index b2b8aa51af5..2eac94005dd 100644 --- a/mindspore/dataset/engine/cache_admin.py +++ b/mindspore/dataset/engine/cache_admin.py @@ -30,5 +30,4 @@ def main(): cache_server = os.path.join(cache_admin_dir, "cache_server") os.chmod(cache_admin, stat.S_IRWXU) os.chmod(cache_server, stat.S_IRWXU) - cmd = cache_admin + " " + " ".join(sys.argv[1:]) - sys.exit(subprocess.call(cmd, shell=True)) + sys.exit(subprocess.call([cache_admin] + sys.argv[1:], shell=False)) diff --git a/mindspore/dataset/engine/cache_client.py b/mindspore/dataset/engine/cache_client.py index c5cb54c9f01..5d031376836 100644 --- a/mindspore/dataset/engine/cache_client.py +++ b/mindspore/dataset/engine/cache_client.py @@ -26,8 +26,8 @@ class DatasetCache: """ A client to interface with tensor caching service. - For details, please check `Chinese tutorial `_, - `Chinese programming guide `_. + For details, please check `Tutorial `_, `Programming guide `_. Args: session_id (int): A user assigned session id for the current pipeline. @@ -76,7 +76,8 @@ class DatasetCache: self.num_connections = num_connections self.cache_client = CacheClient(session_id, size, spilling, hostname, port, num_connections, prefetch_size) - def GetStat(self): + def get_stat(self): + """Get the statistics from a cache.""" return self.cache_client.GetStat() def __deepcopy__(self, memodict): diff --git a/mindspore/dataset/transforms/py_transforms.py b/mindspore/dataset/transforms/py_transforms.py index 575ea07fd55..ec96ca3f981 100644 --- a/mindspore/dataset/transforms/py_transforms.py +++ b/mindspore/dataset/transforms/py_transforms.py @@ -22,6 +22,10 @@ from . import py_transforms_util as util from .c_transforms import TensorOperation def not_random(function): + """ + Specify the function as "not random", i.e., it produces deterministic result. + A Python function can only be cached after it is specified as "not random". + """ function.random = False return function diff --git a/mindspore/dataset/vision/py_transforms.py b/mindspore/dataset/vision/py_transforms.py index f50e1116ecb..d5ba44e934c 100644 --- a/mindspore/dataset/vision/py_transforms.py +++ b/mindspore/dataset/vision/py_transforms.py @@ -47,6 +47,10 @@ DE_PY_BORDER_TYPE = {Border.CONSTANT: 'constant', def not_random(function): + """ + Specify the function as "not random", i.e., it produces deterministic result. + A Python function can only be cached after it is specified as "not random". + """ function.random = False return function diff --git a/tests/ut/python/dataset/test_cache_map.py b/tests/ut/python/dataset/test_cache_map.py index 197051b5d71..56b9241bc45 100644 --- a/tests/ut/python/dataset/test_cache_map.py +++ b/tests/ut/python/dataset/test_cache_map.py @@ -2221,7 +2221,7 @@ def test_cache_map_interrupt_and_rerun(): assert num_iter == 10000 epoch_count += 1 - cache_stat = some_cache.GetStat() + cache_stat = some_cache.get_stat() assert cache_stat.num_mem_cached == 10000 logger.info("test_cache_map_interrupt_and_rerun Ended.\n") diff --git a/tests/ut/python/dataset/test_cache_nomap.py b/tests/ut/python/dataset/test_cache_nomap.py index 0e2375ccad9..9096a278f51 100644 --- a/tests/ut/python/dataset/test_cache_nomap.py +++ b/tests/ut/python/dataset/test_cache_nomap.py @@ -153,7 +153,7 @@ def test_cache_nomap_basic3(): assert num_iter == 12 # Contact the server to get the statistics - stat = some_cache.GetStat() + stat = some_cache.get_stat() cache_sz = stat.avg_cache_sz num_mem_cached = stat.num_mem_cached num_disk_cached = stat.num_disk_cached @@ -366,7 +366,7 @@ def test_cache_nomap_basic8(): @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_nomap_basic9(): """ - Testing the GetStat interface for getting some info from server, but this should fail if the cache is not created + Testing the get_stat interface for getting some info from server, but this should fail if the cache is not created in a pipeline. """ @@ -381,7 +381,7 @@ def test_cache_nomap_basic9(): # Contact the server to get the statistics, this should fail because we have not used this cache in any pipeline # so there will not be any cache to get stats on. with pytest.raises(RuntimeError) as e: - stat = some_cache.GetStat() + stat = some_cache.get_stat() cache_sz = stat.avg_cache_sz logger.info("Average row cache size: {}".format(cache_sz)) assert "Unexpected error" in str(e.value) @@ -1239,7 +1239,7 @@ def test_cache_nomap_interrupt_and_rerun(): assert num_iter == 10000 epoch_count += 1 - cache_stat = some_cache.GetStat() + cache_stat = some_cache.get_stat() assert cache_stat.num_mem_cached == 10000 logger.info("test_cache_nomap_interrupt_and_rerun Ended.\n") @@ -2349,7 +2349,7 @@ def test_cache_nomap_all_rows_cached(): logger.info("Number of data in ds1: {} ".format(num_iter)) assert num_iter == num_total_rows - cache_stat = some_cache.GetStat() + cache_stat = some_cache.get_stat() assert cache_stat.num_mem_cached == num_total_rows logger.info("test_cache_nomap_all_rows_cached Ended.\n")