!11559 [MD] Change the default spill location of cache server

From: @lixiachen
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-01-27 03:12:24 +08:00 committed by Gitee
commit 3708624a25
12 changed files with 454 additions and 292 deletions

View File

@ -42,14 +42,6 @@ int main(int argc, char **argv) {
google::InitGoogleLogging(argv[0]);
#endif
// Create default spilling dir
ds::Path spill_dir = ds::Path(ds::DefaultSpillDir());
rc = spill_dir.CreateDirectories();
if (!rc.IsOk()) {
std::cerr << rc.ToString() << std::endl;
return 1;
}
if (argc == 1) {
args.Help();
return 0;

View File

@ -48,7 +48,7 @@ CacheAdminArgHandler::CacheAdminArgHandler()
log_level_(kDefaultLogLevel),
memory_cap_ratio_(kMemoryCapRatio),
hostname_(kCfgDefaultCacheHost),
spill_dir_(DefaultSpillDir()),
spill_dir_(""),
command_id_(CommandId::kCmdUnknown) {
// Initialize the command mappings
arg_map_["-h"] = ArgValue::kArgHost;
@ -334,32 +334,7 @@ Status CacheAdminArgHandler::RunCommand() {
break;
}
case CommandId::kCmdStop: {
CacheClientGreeter comm(hostname_, port_, 1);
RETURN_IF_NOT_OK(comm.ServiceStart());
SharedMessage msg;
RETURN_IF_NOT_OK(msg.Create());
auto rq = std::make_shared<ServerStopRequest>(msg.GetMsgQueueId());
RETURN_IF_NOT_OK(comm.HandleRequest(rq));
Status rc = rq->Wait();
if (rc.IsError()) {
msg.RemoveResourcesOnExit();
if (rc.IsNetWorkError()) {
std::string errMsg = "Server on port " + std::to_string(port_) + " is not up or has been shutdown already.";
return Status(StatusCode::kNetWorkError, errMsg);
}
return rc;
}
// OK return code only means the server acknowledge our request but we still
// have to wait for its complete shutdown because the server will shutdown
// the comm layer as soon as the request is received, and we need to wait
// on the message queue instead.
// The server will send a message back and remove the queue and we will then wake up. But on the safe
// side, we will also set up an alarm and kill this process if we hang on
// the message queue.
alarm(60);
Status dummy_rc;
(void)msg.ReceiveStatus(&dummy_rc);
std::cout << "Cache server on port " << std::to_string(port_) << " has been stopped successfully." << std::endl;
RETURN_IF_NOT_OK(StopServer(command_id_));
break;
}
case CommandId::kCmdGenerateSession: {
@ -430,6 +405,36 @@ Status CacheAdminArgHandler::RunCommand() {
return Status::OK();
}
Status CacheAdminArgHandler::StopServer(CommandId command_id) {
CacheClientGreeter comm(hostname_, port_, 1);
RETURN_IF_NOT_OK(comm.ServiceStart());
SharedMessage msg;
RETURN_IF_NOT_OK(msg.Create());
auto rq = std::make_shared<ServerStopRequest>(msg.GetMsgQueueId());
RETURN_IF_NOT_OK(comm.HandleRequest(rq));
Status rc = rq->Wait();
if (rc.IsError()) {
msg.RemoveResourcesOnExit();
if (rc.IsNetWorkError()) {
std::string errMsg = "Server on port " + std::to_string(port_) + " is not up or has been shutdown already.";
return Status(StatusCode::kNetWorkError, errMsg);
}
return rc;
}
// OK return code only means the server acknowledge our request but we still
// have to wait for its complete shutdown because the server will shutdown
// the comm layer as soon as the request is received, and we need to wait
// on the message queue instead.
// The server will send a message back and remove the queue and we will then wake up. But on the safe
// side, we will also set up an alarm and kill this process if we hang on
// the message queue.
alarm(60);
Status dummy_rc;
(void)msg.ReceiveStatus(&dummy_rc);
std::cout << "Cache server on port " << std::to_string(port_) << " has been stopped successfully." << std::endl;
return Status::OK();
}
Status CacheAdminArgHandler::StartServer(CommandId command_id) {
// There currently does not exist any "install path" or method to identify which path the installed binaries will
// exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the
@ -462,7 +467,6 @@ Status CacheAdminArgHandler::StartServer(CommandId command_id) {
// fork the child process to become the daemon
pid_t pid;
pid = fork();
// failed to fork
if (pid < 0) {
std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno);
@ -538,7 +542,7 @@ void CacheAdminArgHandler::Help() {
std::cerr << " [[-h | --hostname] <hostname>] Default is " << kCfgDefaultCacheHost << ".\n";
std::cerr << " [[-p | --port] <port number>] Default is " << kCfgDefaultCachePort << ".\n";
std::cerr << " [[-w | --workers] <number of workers>] Default is " << kDefaultNumWorkers << ".\n";
std::cerr << " [[-s | --spilldir] <spilling directory>] Default is " << DefaultSpillDir() << ".\n";
std::cerr << " [[-s | --spilldir] <spilling directory>] Default is no spilling.\n";
std::cerr << " [[-l | --loglevel] <log level>] Default is 1 (warning level).\n";
std::cerr << " [--destroy_session | -d] <session id>\n";
std::cerr << " [[-p | --port] <port number>]\n";

View File

@ -79,6 +79,8 @@ class CacheAdminArgHandler {
Status StartServer(CommandId command_id);
Status StopServer(CommandId command_id);
Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream,
CommandId command_id = CommandId::kCmdUnknown);

View File

@ -91,9 +91,6 @@ using worker_id_t = int32_t;
using numa_id_t = int32_t;
using cpu_id_t = int32_t;
/// Return the default spill dir for cache
inline std::string DefaultSpillDir() { return kDefaultPathPrefix; }
/// Return the default log dir for cache
inline std::string DefaultLogDir() { return kDefaultPathPrefix + std::string("/log"); }

View File

@ -125,6 +125,33 @@ class BaseRequest {
/// \return Status object
Status Wait();
/// \brief Return if the request is of row request type
/// \return True if the request is row-related request
bool IsRowRequest() const {
return type_ == RequestType::kBatchCacheRows || type_ == RequestType::kBatchFetchRows ||
type_ == RequestType::kInternalCacheRow || type_ == RequestType::kInternalFetchRow ||
type_ == RequestType::kCacheRow;
}
/// \brief Return if the request is of admin request type
/// \return True if the request is admin-related request
bool IsAdminRequest() const {
return type_ == RequestType::kCreateCache || type_ == RequestType::kDestroyCache ||
type_ == RequestType::kGetStat || type_ == RequestType::kGetCacheState ||
type_ == RequestType::kAllocateSharedBlock || type_ == RequestType::kFreeSharedBlock ||
type_ == RequestType::kCacheSchema || type_ == RequestType::kFetchSchema ||
type_ == RequestType::kBuildPhaseDone || type_ == RequestType::kToggleWriteMode ||
type_ == RequestType::kConnectReset || type_ == RequestType::kStopService ||
type_ == RequestType::kHeartBeat || type_ == RequestType::kGetCacheMissKeys;
}
/// \brief Return if the request is of session request type
/// \return True if the request is session-related request
bool IsSessionRequest() const {
return type_ == RequestType::kGenerateSessionId || type_ == RequestType::kDropSession ||
type_ == RequestType::kListSessions;
}
protected:
CacheRequest rq_; // This is what we send to the server
CacheReply reply_; // This is what the server send back

View File

@ -155,6 +155,42 @@ CacheService *CacheServer::GetService(connection_id_type id) const {
return nullptr;
}
// We would like to protect ourselves from over allocating too much. We will go over existing cache
// and calculate how much we have consumed so far.
Status CacheServer::GlobalMemoryCheck(uint64_t cache_mem_sz) {
auto end = all_caches_.end();
auto it = all_caches_.begin();
auto avail_mem = CacheServerHW::GetTotalSystemMemory() * memory_cap_ratio_;
int64_t max_avail = avail_mem;
while (it != end) {
auto &cs = it->second;
CacheService::ServiceStat stat;
RETURN_IF_NOT_OK(cs->GetStat(&stat));
int64_t mem_consumed = stat.stat_.num_mem_cached * stat.stat_.average_cache_sz;
max_avail -= mem_consumed;
if (max_avail <= 0) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
}
++it;
}
// If we have some cache using some memory already, make a reasonable decision if we should return
// out of memory.
if (max_avail < avail_mem) {
int64_t req_mem = cache_mem_sz * 1048576L; // It is in MB unit.
if (req_mem > max_avail) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
} else if (req_mem == 0) {
// This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than
// 85% of our limit, fail this request.
if (static_cast<float>(max_avail) / static_cast<float>(avail_mem) <= 0.15) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
}
}
}
return Status::OK();
}
Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing connection info");
std::string cookie;
@ -186,55 +222,25 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
if (spill && top_.empty()) {
RETURN_STATUS_UNEXPECTED("Server is not set up with spill support.");
}
flatbuffers::FlatBufferBuilder fbb;
flatbuffers::Offset<flatbuffers::String> off_cookie;
flatbuffers::Offset<flatbuffers::Vector<cpu_id_t>> off_cpu_list;
// Before creating the cache, first check if this is a request for a shared usage of an existing cache
// If two CreateService come in with identical connection_id, we need to serialize the create.
// The first create will be successful and be given a special cookie.
UniqueLock lck(&rwLock_);
bool duplicate = false;
CacheService *curr_cs = GetService(connection_id);
if (curr_cs != nullptr) {
duplicate = true;
client_id = curr_cs->num_clients_.fetch_add(1);
MS_LOG(INFO) << "Duplicate request from client " + std::to_string(client_id) + " for " +
std::to_string(connection_id) + " to create cache service";
}
// Early exit if we are doing global shutdown
if (global_shutdown_) {
return Status::OK();
}
// We would like to protect ourselves from over allocating too much. We will go over existing cache
// and calculate how much we have consumed so far.
auto end = all_caches_.end();
auto it = all_caches_.begin();
bool duplicate = false;
auto avail_mem = CacheServerHW::GetTotalSystemMemory() * memory_cap_ratio_;
int64_t max_avail = avail_mem;
while (it != end) {
if (it->first == connection_id) {
duplicate = true;
break;
} else {
auto &cs = it->second;
CacheService::ServiceStat stat;
RETURN_IF_NOT_OK(cs->GetStat(&stat));
int64_t mem_consumed = stat.stat_.num_mem_cached * stat.stat_.average_cache_sz;
max_avail -= mem_consumed;
if (max_avail <= 0) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
}
}
++it;
}
if (it == end) {
// If we have some cache using some memory already, make a reasonable decision if we should return
// out of memory.
if (max_avail < avail_mem) {
int64_t req_mem = cache_mem_sz * 1048576L; // It is in MB unit.
if (req_mem > max_avail) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
} else if (req_mem == 0) {
// This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than
// 85% of our limit, fail this request.
if (static_cast<float>(max_avail) / static_cast<float>(avail_mem) <= 0.15) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
}
}
}
if (!duplicate) {
RETURN_IF_NOT_OK(GlobalMemoryCheck(cache_mem_sz));
std::unique_ptr<CacheService> cs;
try {
cs = std::make_unique<CacheService>(cache_mem_sz, spill ? top_ : "", generate_id);
@ -245,12 +251,8 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory);
}
} else {
duplicate = true;
client_id = it->second->num_clients_.fetch_add(1);
MS_LOG(INFO) << "Duplicate request from client " + std::to_string(client_id) + " for " +
std::to_string(connection_id) + " to create cache service";
}
// Shuffle the worker threads. But we need to release the locks or we will deadlock when calling
// the following function
lck.Unlock();
@ -258,6 +260,9 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
auto numa_id = client_id % GetNumaNodeCount();
std::vector<cpu_id_t> cpu_list = hw_info_->GetCpuList(numa_id);
// Send back the data
flatbuffers::FlatBufferBuilder fbb;
flatbuffers::Offset<flatbuffers::String> off_cookie;
flatbuffers::Offset<flatbuffers::Vector<cpu_id_t>> off_cpu_list;
off_cookie = fbb.CreateString(cookie);
off_cpu_list = fbb.CreateVector(cpu_list);
CreateCacheReplyMsgBuilder bld(fbb);
@ -376,6 +381,57 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
return rc;
}
Status CacheServer::InternalCacheRow(CacheRequest *rq, CacheReply *reply) {
// Look into the flag to see where we can find the data and call the appropriate method.
auto flag = rq->flag();
Status rc;
if (BitTest(flag, kDataIsInSharedMemory)) {
rc = FastCacheRow(rq, reply);
// This is an internal request and is not tied to rpc. But need to post because there
// is a thread waiting on the completion of this request.
try {
int64_t addr = strtol(rq->buf_data(3).data(), nullptr, 10);
auto *bw = reinterpret_cast<BatchWait *>(addr);
// Check if the object is still around.
auto bwObj = bw->GetBatchWait();
if (bwObj.lock()) {
RETURN_IF_NOT_OK(bw->Set(rc));
}
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
} else {
rc = CacheRow(rq, reply);
}
return rc;
}
Status CacheServer::InternalFetchRow(CacheRequest *rq) {
auto connection_id = rq->connection_id();
SharedLock lck(&rwLock_);
CacheService *cs = GetService(connection_id);
Status rc;
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
}
rc = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq->buf_data(0).data()));
// This is an internal request and is not tied to rpc. But need to post because there
// is a thread waiting on the completion of this request.
try {
int64_t addr = strtol(rq->buf_data(1).data(), nullptr, 10);
auto *bw = reinterpret_cast<BatchWait *>(addr);
// Check if the object is still around.
auto bwObj = bw->GetBatchWait();
if (bwObj.lock()) {
RETURN_IF_NOT_OK(bw->Set(rc));
}
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
return rc;
}
Status CacheServer::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out) {
RETURN_UNEXPECTED_IF_NULL(out);
auto p = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer());
@ -741,40 +797,24 @@ Status CacheServer::BatchCacheRows(CacheRequest *rq) {
return Status::OK();
}
Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
bool internal_request = false;
Status CacheServer::ProcessRowRequest(CacheServerRequest *cache_req, bool *internal_request) {
auto &rq = cache_req->rq_;
auto &reply = cache_req->reply_;
// Except for creating a new session, we expect cs is not null.
switch (cache_req->type_) {
case BaseRequest::RequestType::kCacheRow:
case BaseRequest::RequestType::kInternalCacheRow: {
// Look into the flag to see where we can find the data and
// call the appropriate method.
auto flag = rq.flag();
if (BitTest(flag, kDataIsInSharedMemory)) {
case BaseRequest::RequestType::kCacheRow: {
// Look into the flag to see where we can find the data and call the appropriate method.
if (BitTest(rq.flag(), kDataIsInSharedMemory)) {
cache_req->rc_ = FastCacheRow(&rq, &reply);
internal_request = (cache_req->type_ == BaseRequest::RequestType::kInternalCacheRow);
if (internal_request) {
// This is an internal request and is not tied to rpc. But need to post because there
// is a thread waiting on the completion of this request.
try {
int64_t addr = strtol(rq.buf_data(3).data(), nullptr, 10);
auto *bw = reinterpret_cast<BatchWait *>(addr);
// Check if the object is still around.
auto bwObj = bw->GetBatchWait();
if (bwObj.lock()) {
RETURN_IF_NOT_OK(bw->Set(std::move(cache_req->rc_)));
}
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
}
} else {
cache_req->rc_ = CacheRow(&rq, &reply);
}
break;
}
case BaseRequest::RequestType::kInternalCacheRow: {
*internal_request = true;
cache_req->rc_ = InternalCacheRow(&rq, &reply);
break;
}
case BaseRequest::RequestType::kBatchCacheRows: {
cache_req->rc_ = BatchCacheRows(&rq);
break;
@ -784,31 +824,46 @@ Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
break;
}
case BaseRequest::RequestType::kInternalFetchRow: {
internal_request = true;
auto connection_id = rq.connection_id();
SharedLock lck(&rwLock_);
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
cache_req->rc_ = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq.buf_data(0).data()));
// This is an internal request and is not tied to rpc. But need to post because there
// is a thread waiting on the completion of this request.
try {
int64_t addr = strtol(rq.buf_data(1).data(), nullptr, 10);
auto *bw = reinterpret_cast<BatchWait *>(addr);
// Check if the object is still around.
auto bwObj = bw->GetBatchWait();
if (bwObj.lock()) {
RETURN_IF_NOT_OK(bw->Set(std::move(cache_req->rc_)));
}
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
}
*internal_request = true;
cache_req->rc_ = InternalFetchRow(&rq);
break;
}
default:
std::string errMsg("Internal error, request type is not row request: ");
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
}
return Status::OK();
}
Status CacheServer::ProcessSessionRequest(CacheServerRequest *cache_req) {
auto &rq = cache_req->rq_;
auto &reply = cache_req->reply_;
switch (cache_req->type_) {
case BaseRequest::RequestType::kDropSession: {
cache_req->rc_ = DestroySession(&rq);
break;
}
case BaseRequest::RequestType::kGenerateSessionId: {
cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply);
break;
}
case BaseRequest::RequestType::kListSessions: {
cache_req->rc_ = ListSessions(&reply);
break;
}
default:
std::string errMsg("Internal error, request type is not session request: ");
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
}
return Status::OK();
}
Status CacheServer::ProcessAdminRequest(CacheServerRequest *cache_req) {
auto &rq = cache_req->rq_;
auto &reply = cache_req->reply_;
switch (cache_req->type_) {
case BaseRequest::RequestType::kCreateCache: {
cache_req->rc_ = CreateService(&rq, &reply);
break;
@ -837,14 +892,6 @@ Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
cache_req->rc_ = BuildPhaseDone(&rq);
break;
}
case BaseRequest::RequestType::kDropSession: {
cache_req->rc_ = DestroySession(&rq);
break;
}
case BaseRequest::RequestType::kGenerateSessionId: {
cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply);
break;
}
case BaseRequest::RequestType::kAllocateSharedBlock: {
cache_req->rc_ = AllocateSharedMemory(&rq, &reply);
break;
@ -868,40 +915,45 @@ Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
cache_req->rc_ = ToggleWriteMode(&rq);
break;
}
case BaseRequest::RequestType::kListSessions: {
cache_req->rc_ = ListSessions(&reply);
break;
}
case BaseRequest::RequestType::kConnectReset: {
cache_req->rc_ = ConnectReset(&rq);
break;
}
case BaseRequest::RequestType::kGetCacheState: {
auto connection_id = rq.connection_id();
SharedLock lck(&rwLock_);
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto state = cs->GetState();
reply.set_result(std::to_string(static_cast<int8_t>(state)));
cache_req->rc_ = Status::OK();
}
cache_req->rc_ = GetCacheState(&rq, &reply);
break;
}
default:
std::string errMsg("Unknown request type : ");
std::string errMsg("Internal error, request type is not admin request: ");
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
}
return Status::OK();
}
Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
bool internal_request = false;
// Except for creating a new session, we expect cs is not null.
if (cache_req->IsRowRequest()) {
RETURN_IF_NOT_OK(ProcessRowRequest(cache_req, &internal_request));
} else if (cache_req->IsSessionRequest()) {
RETURN_IF_NOT_OK(ProcessSessionRequest(cache_req));
} else if (cache_req->IsAdminRequest()) {
RETURN_IF_NOT_OK(ProcessAdminRequest(cache_req));
} else {
std::string errMsg("Unknown request type : ");
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
}
// Notify it is done, and move on to the next request.
Status2CacheReply(cache_req->rc_, &reply);
Status2CacheReply(cache_req->rc_, &cache_req->reply_);
cache_req->st_ = CacheServerRequest::STATE::FINISH;
// We will re-tag the request back to the grpc queue. Once it comes back from the client,
// the CacheServerRequest, i.e. the pointer cache_req, will be free
if (!internal_request && !global_shutdown_) {
cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req);
cache_req->responder_.Finish(cache_req->reply_, grpc::Status::OK, cache_req);
} else {
// We can free up the request now.
RETURN_IF_NOT_OK(ReturnRequestTag(cache_req));
@ -1084,6 +1136,20 @@ Status CacheServer::FreeSharedMemory(CacheRequest *rq) {
return Status::OK();
}
Status CacheServer::GetCacheState(CacheRequest *rq, CacheReply *reply) {
auto connection_id = rq->connection_id();
SharedLock lck(&rwLock_);
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto state = cs->GetState();
reply->set_result(std::to_string(static_cast<int8_t>(state)));
return Status::OK();
}
}
Status CacheServer::RpcRequest(worker_id_t worker_id) {
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(comm_layer_->HandleRequest(worker_id));
@ -1213,7 +1279,7 @@ Status CacheServer::Builder::SanityCheck() {
}
CacheServer::Builder::Builder()
: top_(DefaultSpillDir()),
: top_(""),
num_workers_(std::thread::hardware_concurrency() / 2),
port_(50052),
shared_memory_sz_in_gb_(kDefaultSharedMemorySize),

View File

@ -201,6 +201,22 @@ class CacheServer : public Service {
/// \brief Return the memory cap ratio
float GetMemoryCapRatio() const { return memory_cap_ratio_; }
/// \brief Function to handle a row request
/// \param[in] cache_req A row request to handle
/// \param[out] internal_request Indicator if the request is an internal request
/// \return Status object
Status ProcessRowRequest(CacheServerRequest *cache_req, bool *internal_request);
/// \brief Function to handle an admin request
/// \param[in] cache_req An admin request to handle
/// \return Status object
Status ProcessAdminRequest(CacheServerRequest *cache_req);
/// \brief Function to handle a session request
/// \param[in] cache_req A session request to handle
/// \return Status object
Status ProcessSessionRequest(CacheServerRequest *cache_req);
/// \brief How a request is handled.
/// \note that it can be process immediately by a grpc thread or routed to a server thread
/// which is pinned to some numa node core.
@ -256,6 +272,12 @@ class CacheServer : public Service {
/// \return Pointer to cache service. Null if not found
CacheService *GetService(connection_id_type id) const;
/// \brief Going over existing cache service and calculate how much we have consumed so far, a new cache service
/// can only be created if there is still enough avail memory left
/// \param cache_mem_sz Requested memory for a new cache service
/// \return Status object
Status GlobalMemoryCheck(uint64_t cache_mem_sz);
/// \brief Create a cache service. We allow multiple clients to create the same cache service.
/// Subsequent duplicate requests are ignored. The first cache client to create the service will be given
/// a special unique cookie.
@ -314,6 +336,12 @@ class CacheServer : public Service {
/// \return Status object
Status GetStat(CacheRequest *rq, CacheReply *reply);
/// \brief Internal function to get cache state
/// \param rq
/// \param reply
/// \return Status object
Status GetCacheState(CacheRequest *rq, CacheReply *reply);
/// \brief Cache a schema request
/// \param rq
/// \return Status object
@ -411,6 +439,9 @@ class CacheServer : public Service {
/// \return Status object
Status BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out);
Status BatchCacheRows(CacheRequest *rq);
Status InternalFetchRow(CacheRequest *rq);
Status InternalCacheRow(CacheRequest *rq, CacheReply *reply);
};
} // namespace dataset
} // namespace mindspore

View File

@ -15,7 +15,8 @@
# ============================================================================
# source the globals and functions for use with cache testing
SKIP_ADMIN_COUNTER=false
export SKIP_ADMIN_COUNTER=false
declare failed_tests
. cachetest_lib.sh
echo

View File

@ -15,7 +15,8 @@
# ============================================================================
# source the globals and functions for use with cache testing
SKIP_ADMIN_COUNTER=true
export SKIP_ADMIN_COUNTER=true
declare session_id failed_tests
. cachetest_lib.sh
echo
@ -28,8 +29,10 @@ UT_TEST_DIR="${BUILD_PATH}/mindspore/tests/ut/cpp"
DateStamp=$(date +%Y%m%d_%H%M%S);
CPP_TEST_LOG_OUTPUT="/tmp/ut_tests_cache_${DateStamp}.log"
# Start a basic cache server to be used for all tests
StartServer
# start cache server with a spilling path to be used for all tests
cmd="${CACHE_ADMIN} --start -s /tmp"
CacheAdminCmd "${cmd}" 0
sleep 1
HandleRcExit $? 1 1
# Set the environment variable to enable these pytests

View File

@ -15,7 +15,8 @@
# ============================================================================
# source the globals and functions for use with cache testing
SKIP_ADMIN_COUNTER=true
export SKIP_ADMIN_COUNTER=true
declare session_id failed_tests
. cachetest_lib.sh
echo
@ -84,10 +85,6 @@ export SESSION_ID=$session_id
PytestCmd "test_cache_map.py" "test_cache_map_running_twice2"
HandleRcExit $? 0 0
# Set size parameter of DatasetCache to a extra small value
PytestCmd "test_cache_map.py" "test_cache_map_extra_small_size" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_no_image"
HandleRcExit $? 0 0
@ -255,15 +252,6 @@ export SESSION_ID=$session_id
PytestCmd "test_cache_nomap.py" "test_cache_nomap_running_twice2"
HandleRcExit $? 0 0
# Set size parameter of DatasetCache to a extra small value
GetSession
HandleRcExit $? 1 1
export SESSION_ID=$session_id
PytestCmd "test_cache_nomap.py" "test_cache_nomap_extra_small_size" 1
HandleRcExit $? 0 0
DestroySession $session_id
HandleRcExit $? 1 1
# Run two parallel pipelines (sharing cache)
for i in $(seq 1 2)
do
@ -366,7 +354,7 @@ HandleRcExit $? 1 1
export SESSION_ID=$session_id
PytestCmd "test_cache_nomap.py" "test_cache_nomap_session_destroy" &
pid=("$!")
pid=$!
sleep 10
DestroySession $session_id
@ -381,7 +369,7 @@ HandleRcExit $? 1 1
export SESSION_ID=$session_id
PytestCmd "test_cache_nomap.py" "test_cache_nomap_server_stop" &
pid=("$!")
pid=$!
sleep 10
StopServer
@ -417,6 +405,26 @@ HandleRcExit $? 0 0
StopServer
HandleRcExit $? 0 1
# start cache server with a spilling path
cmd="${CACHE_ADMIN} --start -s /tmp"
CacheAdminCmd "${cmd}" 0
sleep 1
HandleRcExit $? 0 0
GetSession
HandleRcExit $? 1 1
export SESSION_ID=$session_id
# Set size parameter of mappable DatasetCache to a extra small value
PytestCmd "test_cache_map.py" "test_cache_map_extra_small_size" 1
HandleRcExit $? 0 0
# Set size parameter of non-mappable DatasetCache to a extra small value
PytestCmd "test_cache_nomap.py" "test_cache_nomap_extra_small_size" 1
HandleRcExit $? 0 0
StopServer
HandleRcExit $? 0 1
unset RUN_CACHE_TEST
unset SESSION_ID

View File

@ -57,7 +57,7 @@ def test_cache_map_basic1():
else:
session_id = 1
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
@ -91,7 +91,7 @@ def test_cache_map_basic2():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
@ -115,7 +115,7 @@ def test_cache_map_basic3():
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
@ -155,7 +155,7 @@ def test_cache_map_basic4():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
@ -189,7 +189,7 @@ def test_cache_map_basic5():
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
@ -225,7 +225,7 @@ def test_cache_map_failure1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
@ -269,7 +269,7 @@ def test_cache_map_failure2():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
@ -310,7 +310,7 @@ def test_cache_map_failure3():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
@ -351,7 +351,7 @@ def test_cache_map_failure4():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
@ -391,7 +391,7 @@ def test_cache_map_failure5():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
@ -432,7 +432,7 @@ def test_cache_map_failure6():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
columns_list = ["id", "file_name", "label_name", "img_data", "label_data"]
num_readers = 1
@ -478,7 +478,7 @@ def test_cache_map_failure7():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
data = ds.GeneratorDataset(generator_1d, ["data"])
data = data.map((lambda x: x), ["data"], cache=some_cache)
@ -514,7 +514,7 @@ def test_cache_map_failure8():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
@ -554,7 +554,7 @@ def test_cache_map_failure9():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
@ -596,7 +596,7 @@ def test_cache_map_failure10():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
@ -616,6 +616,37 @@ def test_cache_map_failure10():
logger.info('test_cache_failure10 Ended.\n')
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_failure11():
"""
Test set spilling=true when cache server is started without spilling support (failure)
Cache(spilling=true)
|
ImageFolder
"""
logger.info("Test cache failure 11")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
with pytest.raises(RuntimeError) as e:
num_iter = 0
for _ in ds1.create_dict_iterator():
num_iter += 1
assert "Unexpected error. Server is not set up with spill support" in str(e.value)
assert num_iter == 0
logger.info('test_cache_failure11 Ended.\n')
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_split1():
"""
@ -641,7 +672,7 @@ def test_cache_map_split1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
@ -692,7 +723,7 @@ def test_cache_map_split2():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 9 records
ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
@ -725,27 +756,27 @@ def test_cache_map_parameter_check():
logger.info("Test cache map parameter check")
with pytest.raises(ValueError) as info:
ds.DatasetCache(session_id=-1, size=0, spilling=True)
ds.DatasetCache(session_id=-1, size=0)
assert "Input is not within the required interval" in str(info.value)
with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id="1", size=0, spilling=True)
ds.DatasetCache(session_id="1", size=0)
assert "Argument session_id with value 1 is not of type (<class 'int'>,)" in str(info.value)
with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=None, size=0, spilling=True)
ds.DatasetCache(session_id=None, size=0)
assert "Argument session_id with value None is not of type (<class 'int'>,)" in str(info.value)
with pytest.raises(ValueError) as info:
ds.DatasetCache(session_id=1, size=-1, spilling=True)
ds.DatasetCache(session_id=1, size=-1)
assert "Input size must be greater than 0" in str(info.value)
with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size="1", spilling=True)
ds.DatasetCache(session_id=1, size="1")
assert "Argument size with value 1 is not of type (<class 'int'>,)" in str(info.value)
with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size=None, spilling=True)
ds.DatasetCache(session_id=1, size=None)
assert "Argument size with value None is not of type (<class 'int'>,)" in str(info.value)
with pytest.raises(TypeError) as info:
@ -753,31 +784,31 @@ def test_cache_map_parameter_check():
assert "Argument spilling with value illegal is not of type (<class 'bool'>,)" in str(info.value)
with pytest.raises(TypeError) as err:
ds.DatasetCache(session_id=1, size=0, spilling=True, hostname=50052)
ds.DatasetCache(session_id=1, size=0, hostname=50052)
assert "Argument hostname with value 50052 is not of type (<class 'str'>,)" in str(err.value)
with pytest.raises(RuntimeError) as err:
ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="illegal")
ds.DatasetCache(session_id=1, size=0, hostname="illegal")
assert "now cache client has to be on the same host with cache server" in str(err.value)
with pytest.raises(RuntimeError) as err:
ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="127.0.0.2")
ds.DatasetCache(session_id=1, size=0, hostname="127.0.0.2")
assert "now cache client has to be on the same host with cache server" in str(err.value)
with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size=0, spilling=True, port="illegal")
ds.DatasetCache(session_id=1, size=0, port="illegal")
assert "Argument port with value illegal is not of type (<class 'int'>,)" in str(info.value)
with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size=0, spilling=True, port="50052")
ds.DatasetCache(session_id=1, size=0, port="50052")
assert "Argument port with value 50052 is not of type (<class 'int'>,)" in str(info.value)
with pytest.raises(ValueError) as err:
ds.DatasetCache(session_id=1, size=0, spilling=True, port=0)
ds.DatasetCache(session_id=1, size=0, port=0)
assert "Input port is not within the required interval of (1025 to 65535)" in str(err.value)
with pytest.raises(ValueError) as err:
ds.DatasetCache(session_id=1, size=0, spilling=True, port=65536)
ds.DatasetCache(session_id=1, size=0, port=65536)
assert "Input port is not within the required interval of (1025 to 65535)" in str(err.value)
with pytest.raises(TypeError) as err:
@ -807,7 +838,7 @@ def test_cache_map_running_twice1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
@ -850,7 +881,7 @@ def test_cache_map_running_twice2():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
@ -998,7 +1029,7 @@ def test_cache_map_parallel_pipeline1(shard):
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard), cache=some_cache)
@ -1035,7 +1066,7 @@ def test_cache_map_parallel_pipeline2(shard):
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard))
@ -1072,7 +1103,7 @@ def test_cache_map_parallel_workers():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_parallel_workers=4)
@ -1109,7 +1140,7 @@ def test_cache_map_server_workers_1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
@ -1146,7 +1177,7 @@ def test_cache_map_server_workers_100():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
@ -1183,7 +1214,7 @@ def test_cache_map_num_connections_1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=1)
some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=1)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
@ -1220,7 +1251,7 @@ def test_cache_map_num_connections_100():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=100)
some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=100)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
@ -1257,7 +1288,7 @@ def test_cache_map_prefetch_size_1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=1)
some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=1)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
@ -1294,7 +1325,7 @@ def test_cache_map_prefetch_size_100():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=100)
some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=100)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
@ -1335,7 +1366,7 @@ def test_cache_map_to_device():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
@ -1366,7 +1397,7 @@ def test_cache_map_epoch_ctrl1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
@ -1406,7 +1437,7 @@ def test_cache_map_epoch_ctrl2():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
@ -1452,7 +1483,7 @@ def test_cache_map_epoch_ctrl3():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
@ -1495,7 +1526,7 @@ def test_cache_map_coco1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 6 records
ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True,
@ -1531,7 +1562,7 @@ def test_cache_map_coco2():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 6 records
ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True)
@ -1566,7 +1597,7 @@ def test_cache_map_mnist1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10, cache=some_cache)
num_epoch = 4
@ -1599,7 +1630,7 @@ def test_cache_map_mnist2():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10)
resize_op = c_vision.Resize((224, 224))
@ -1633,7 +1664,7 @@ def test_cache_map_celeba1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 4 records
ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, cache=some_cache)
@ -1668,7 +1699,7 @@ def test_cache_map_celeba2():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 4 records
ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True)
@ -1703,7 +1734,7 @@ def test_cache_map_manifest1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 4 records
ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True, cache=some_cache)
@ -1738,7 +1769,7 @@ def test_cache_map_manifest2():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 4 records
ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True)
@ -1773,7 +1804,7 @@ def test_cache_map_cifar1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache)
num_epoch = 4
@ -1806,7 +1837,7 @@ def test_cache_map_cifar2():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10)
resize_op = c_vision.Resize((224, 224))
@ -1841,7 +1872,7 @@ def test_cache_map_cifar3():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False)
some_cache = ds.DatasetCache(session_id=session_id, size=1)
ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, cache=some_cache)
@ -1875,7 +1906,7 @@ def test_cache_map_cifar4():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache)
ds1 = ds1.shuffle(10)
@ -1907,7 +1938,7 @@ def test_cache_map_voc1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 9 records
ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True, cache=some_cache)
@ -1942,7 +1973,7 @@ def test_cache_map_voc2():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 9 records
ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
@ -1987,7 +2018,7 @@ def test_cache_map_python_sampler1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler(), cache=some_cache)
@ -2023,7 +2054,7 @@ def test_cache_map_python_sampler2():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler())
@ -2061,7 +2092,7 @@ def test_cache_map_nested_repeat():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This DATA_DIR only has 2 images in it
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)

View File

@ -62,7 +62,7 @@ def test_cache_nomap_basic1():
schema.add_column('label', de_type=mstype.uint8, shape=[1])
# create a cache. arbitrary session_id for now
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# User-created sampler here
ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=4, cache=some_cache)
@ -96,7 +96,7 @@ def test_cache_nomap_basic2():
schema.add_column('label', de_type=mstype.uint8, shape=[1])
# create a cache. arbitrary session_id for now
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# sampler arg not given directly, however any of these args will auto-generate an appropriate sampler:
# num_samples, shuffle, num_shards, shard_id
@ -134,7 +134,7 @@ def test_cache_nomap_basic3():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
decode_op = c_vision.Decode()
ds1 = ds1.map(operations=decode_op, input_columns=["image"])
@ -183,7 +183,7 @@ def test_cache_nomap_basic4():
raise RuntimeError("Testcase requires SESSION_ID environment variable")
# This dataset has 3 records in it only
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# With shuffle not being set, TF defaults to a "global" shuffle when there is no cache
# in the picture. This causes a shuffle-injection over the TF. For clarify, this test will
# explicitly give the global option, even though it's the default in python.
@ -231,7 +231,7 @@ def test_cache_nomap_basic5():
raise RuntimeError("Testcase requires SESSION_ID environment variable")
# This dataset has 3 records in it only
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], cache=some_cache)
decode_op = c_vision.Decode()
ds1 = ds1.map(operations=decode_op, input_columns=["image"])
@ -270,7 +270,7 @@ def test_cache_nomap_basic6():
raise RuntimeError("Testcase requires SESSION_ID environment variable")
# This dataset has 3 records in it only
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# With only 3 records shard into 3, we expect only 1 record returned for this shard
# However, the sharding will be done by the sampler, not by the tf record leaf node
@ -313,7 +313,7 @@ def test_cache_nomap_basic7():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL, cache=some_cache)
@ -344,7 +344,7 @@ def test_cache_nomap_basic8():
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
@ -371,7 +371,7 @@ def test_cache_nomap_basic9():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# Contact the server to get the statistics, this should fail because we have not used this cache in any pipeline
# so there will not be any cache to get stats on.
@ -404,7 +404,7 @@ def test_cache_nomap_allowed_share1():
ds.config.set_seed(1)
# This dataset has 3 records in it only
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=32)
some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=32)
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
ds1 = ds1.repeat(4)
@ -446,7 +446,7 @@ def test_cache_nomap_allowed_share2():
ds.config.set_seed(1)
# This dataset has 3 records in it only
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
decode_op = c_vision.Decode()
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
@ -488,7 +488,7 @@ def test_cache_nomap_allowed_share3():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data"]
ds1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=0, num_samples=3, shuffle=False, cache=some_cache)
@ -529,7 +529,7 @@ def test_cache_nomap_allowed_share4():
raise RuntimeError("Testcase requires SESSION_ID environment variable")
# This dataset has 3 records in it only
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
decode_op = c_vision.Decode()
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
@ -572,7 +572,7 @@ def test_cache_nomap_disallowed_share1():
raise RuntimeError("Testcase requires SESSION_ID environment variable")
# This dataset has 3 records in it only
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
decode_op = c_vision.Decode()
rescale_op = c_vision.Rescale(1.0 / 255.0, -1.0)
@ -615,7 +615,7 @@ def test_cache_nomap_running_twice1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
@ -658,7 +658,7 @@ def test_cache_nomap_running_twice2():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
@ -763,7 +763,7 @@ def test_cache_nomap_parallel_pipeline1(shard):
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard), cache=some_cache)
@ -799,7 +799,7 @@ def test_cache_nomap_parallel_pipeline2(shard):
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard))
@ -835,7 +835,7 @@ def test_cache_nomap_parallel_workers():
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_parallel_workers=4)
@ -872,7 +872,7 @@ def test_cache_nomap_server_workers_1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
@ -909,7 +909,7 @@ def test_cache_nomap_server_workers_100():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
@ -946,7 +946,7 @@ def test_cache_nomap_num_connections_1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=1)
some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=1)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
@ -983,7 +983,7 @@ def test_cache_nomap_num_connections_100():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=100)
some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=100)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
@ -1020,7 +1020,7 @@ def test_cache_nomap_prefetch_size_1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=1)
some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=1)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
@ -1057,7 +1057,7 @@ def test_cache_nomap_prefetch_size_100():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=100)
some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=100)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
@ -1098,7 +1098,7 @@ def test_cache_nomap_to_device():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
@ -1134,7 +1134,7 @@ def test_cache_nomap_session_destroy():
shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
schema.add_column('label', de_type=mstype.uint8, shape=[1])
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# User-created sampler here
ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache)
@ -1172,7 +1172,7 @@ def test_cache_nomap_server_stop():
shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
schema.add_column('label', de_type=mstype.uint8, shape=[1])
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# User-created sampler here
ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache)
@ -1206,7 +1206,7 @@ def test_cache_nomap_epoch_ctrl1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
@ -1246,7 +1246,7 @@ def test_cache_nomap_epoch_ctrl2():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
@ -1292,7 +1292,7 @@ def test_cache_nomap_epoch_ctrl3():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
@ -1339,7 +1339,7 @@ def test_cache_nomap_epoch_ctrl4():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
@ -1381,8 +1381,8 @@ def test_cache_nomap_multiple_cache1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
train_cache = ds.DatasetCache(session_id=session_id, size=0)
eval_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 12 records in it
train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR)
@ -1425,8 +1425,8 @@ def test_cache_nomap_multiple_cache2():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
image_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
text_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
image_cache = ds.DatasetCache(session_id=session_id, size=0)
text_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
image_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
@ -1470,8 +1470,8 @@ def test_cache_nomap_multiple_cache3():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
tf_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
image_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
tf_cache = ds.DatasetCache(session_id=session_id, size=0)
image_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
tf_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
@ -1515,7 +1515,7 @@ def test_cache_nomap_multiple_cache_train():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
train_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 12 records in it
train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR)
@ -1553,7 +1553,7 @@ def test_cache_nomap_multiple_cache_eval():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
eval_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset only has 3 records in it
eval_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
@ -1591,7 +1591,7 @@ def test_cache_nomap_clue1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# With only 3 records shard into 3, we expect only 1 record returned for this shard
# However, the sharding will be done by the sampler, not by the clue leaf node
@ -1630,7 +1630,7 @@ def test_cache_nomap_clue2():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_samples=2)
ds1 = ds1.map((lambda x: x), ["label"], cache=some_cache)
@ -1666,7 +1666,7 @@ def test_cache_nomap_csv1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# With only 3 records shard into 3, we expect only 1 record returned for this shard
# However, the sharding will be done by the sampler, not by the clue leaf node
@ -1706,7 +1706,7 @@ def test_cache_nomap_csv2():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"],
column_names=['col1', 'col2', 'col3', 'col4'], num_samples=2)
@ -1743,7 +1743,7 @@ def test_cache_nomap_textfile1():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# With only 3 records shard into 3, we expect only 1 record returned for this shard
# However, the sharding will be done by the sampler, not by the clue leaf node
@ -1788,7 +1788,7 @@ def test_cache_nomap_textfile2():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
ds1 = ds.TextFileDataset(TEXT_FILE_DATA_DIR, num_samples=2)
tokenizer = text.PythonTokenizer(my_tokenizer)
@ -1828,7 +1828,7 @@ def test_cache_nomap_nested_repeat():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
@ -1867,7 +1867,7 @@ def test_cache_nomap_get_repeat_count():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
@ -1902,7 +1902,7 @@ def test_cache_nomap_long_file_list():
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False)
some_cache = ds.DatasetCache(session_id=session_id, size=1)
ds1 = ds.TFRecordDataset([DATA_DIR[0] for _ in range(0, 1000)], SCHEMA_DIR, columns_list=["image"],
cache=some_cache)