forked from mindspore-Ecosystem/mindspore
!11559 [MD] Change the default spill location of cache server
From: @lixiachen Reviewed-by: Signed-off-by:
This commit is contained in:
commit
3708624a25
|
@ -42,14 +42,6 @@ int main(int argc, char **argv) {
|
|||
google::InitGoogleLogging(argv[0]);
|
||||
#endif
|
||||
|
||||
// Create default spilling dir
|
||||
ds::Path spill_dir = ds::Path(ds::DefaultSpillDir());
|
||||
rc = spill_dir.CreateDirectories();
|
||||
if (!rc.IsOk()) {
|
||||
std::cerr << rc.ToString() << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (argc == 1) {
|
||||
args.Help();
|
||||
return 0;
|
||||
|
|
|
@ -48,7 +48,7 @@ CacheAdminArgHandler::CacheAdminArgHandler()
|
|||
log_level_(kDefaultLogLevel),
|
||||
memory_cap_ratio_(kMemoryCapRatio),
|
||||
hostname_(kCfgDefaultCacheHost),
|
||||
spill_dir_(DefaultSpillDir()),
|
||||
spill_dir_(""),
|
||||
command_id_(CommandId::kCmdUnknown) {
|
||||
// Initialize the command mappings
|
||||
arg_map_["-h"] = ArgValue::kArgHost;
|
||||
|
@ -334,32 +334,7 @@ Status CacheAdminArgHandler::RunCommand() {
|
|||
break;
|
||||
}
|
||||
case CommandId::kCmdStop: {
|
||||
CacheClientGreeter comm(hostname_, port_, 1);
|
||||
RETURN_IF_NOT_OK(comm.ServiceStart());
|
||||
SharedMessage msg;
|
||||
RETURN_IF_NOT_OK(msg.Create());
|
||||
auto rq = std::make_shared<ServerStopRequest>(msg.GetMsgQueueId());
|
||||
RETURN_IF_NOT_OK(comm.HandleRequest(rq));
|
||||
Status rc = rq->Wait();
|
||||
if (rc.IsError()) {
|
||||
msg.RemoveResourcesOnExit();
|
||||
if (rc.IsNetWorkError()) {
|
||||
std::string errMsg = "Server on port " + std::to_string(port_) + " is not up or has been shutdown already.";
|
||||
return Status(StatusCode::kNetWorkError, errMsg);
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
// OK return code only means the server acknowledge our request but we still
|
||||
// have to wait for its complete shutdown because the server will shutdown
|
||||
// the comm layer as soon as the request is received, and we need to wait
|
||||
// on the message queue instead.
|
||||
// The server will send a message back and remove the queue and we will then wake up. But on the safe
|
||||
// side, we will also set up an alarm and kill this process if we hang on
|
||||
// the message queue.
|
||||
alarm(60);
|
||||
Status dummy_rc;
|
||||
(void)msg.ReceiveStatus(&dummy_rc);
|
||||
std::cout << "Cache server on port " << std::to_string(port_) << " has been stopped successfully." << std::endl;
|
||||
RETURN_IF_NOT_OK(StopServer(command_id_));
|
||||
break;
|
||||
}
|
||||
case CommandId::kCmdGenerateSession: {
|
||||
|
@ -430,6 +405,36 @@ Status CacheAdminArgHandler::RunCommand() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheAdminArgHandler::StopServer(CommandId command_id) {
|
||||
CacheClientGreeter comm(hostname_, port_, 1);
|
||||
RETURN_IF_NOT_OK(comm.ServiceStart());
|
||||
SharedMessage msg;
|
||||
RETURN_IF_NOT_OK(msg.Create());
|
||||
auto rq = std::make_shared<ServerStopRequest>(msg.GetMsgQueueId());
|
||||
RETURN_IF_NOT_OK(comm.HandleRequest(rq));
|
||||
Status rc = rq->Wait();
|
||||
if (rc.IsError()) {
|
||||
msg.RemoveResourcesOnExit();
|
||||
if (rc.IsNetWorkError()) {
|
||||
std::string errMsg = "Server on port " + std::to_string(port_) + " is not up or has been shutdown already.";
|
||||
return Status(StatusCode::kNetWorkError, errMsg);
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
// OK return code only means the server acknowledge our request but we still
|
||||
// have to wait for its complete shutdown because the server will shutdown
|
||||
// the comm layer as soon as the request is received, and we need to wait
|
||||
// on the message queue instead.
|
||||
// The server will send a message back and remove the queue and we will then wake up. But on the safe
|
||||
// side, we will also set up an alarm and kill this process if we hang on
|
||||
// the message queue.
|
||||
alarm(60);
|
||||
Status dummy_rc;
|
||||
(void)msg.ReceiveStatus(&dummy_rc);
|
||||
std::cout << "Cache server on port " << std::to_string(port_) << " has been stopped successfully." << std::endl;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheAdminArgHandler::StartServer(CommandId command_id) {
|
||||
// There currently does not exist any "install path" or method to identify which path the installed binaries will
|
||||
// exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the
|
||||
|
@ -462,7 +467,6 @@ Status CacheAdminArgHandler::StartServer(CommandId command_id) {
|
|||
// fork the child process to become the daemon
|
||||
pid_t pid;
|
||||
pid = fork();
|
||||
|
||||
// failed to fork
|
||||
if (pid < 0) {
|
||||
std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno);
|
||||
|
@ -538,7 +542,7 @@ void CacheAdminArgHandler::Help() {
|
|||
std::cerr << " [[-h | --hostname] <hostname>] Default is " << kCfgDefaultCacheHost << ".\n";
|
||||
std::cerr << " [[-p | --port] <port number>] Default is " << kCfgDefaultCachePort << ".\n";
|
||||
std::cerr << " [[-w | --workers] <number of workers>] Default is " << kDefaultNumWorkers << ".\n";
|
||||
std::cerr << " [[-s | --spilldir] <spilling directory>] Default is " << DefaultSpillDir() << ".\n";
|
||||
std::cerr << " [[-s | --spilldir] <spilling directory>] Default is no spilling.\n";
|
||||
std::cerr << " [[-l | --loglevel] <log level>] Default is 1 (warning level).\n";
|
||||
std::cerr << " [--destroy_session | -d] <session id>\n";
|
||||
std::cerr << " [[-p | --port] <port number>]\n";
|
||||
|
|
|
@ -79,6 +79,8 @@ class CacheAdminArgHandler {
|
|||
|
||||
Status StartServer(CommandId command_id);
|
||||
|
||||
Status StopServer(CommandId command_id);
|
||||
|
||||
Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream,
|
||||
CommandId command_id = CommandId::kCmdUnknown);
|
||||
|
||||
|
|
|
@ -91,9 +91,6 @@ using worker_id_t = int32_t;
|
|||
using numa_id_t = int32_t;
|
||||
using cpu_id_t = int32_t;
|
||||
|
||||
/// Return the default spill dir for cache
|
||||
inline std::string DefaultSpillDir() { return kDefaultPathPrefix; }
|
||||
|
||||
/// Return the default log dir for cache
|
||||
inline std::string DefaultLogDir() { return kDefaultPathPrefix + std::string("/log"); }
|
||||
|
||||
|
|
|
@ -125,6 +125,33 @@ class BaseRequest {
|
|||
/// \return Status object
|
||||
Status Wait();
|
||||
|
||||
/// \brief Return if the request is of row request type
|
||||
/// \return True if the request is row-related request
|
||||
bool IsRowRequest() const {
|
||||
return type_ == RequestType::kBatchCacheRows || type_ == RequestType::kBatchFetchRows ||
|
||||
type_ == RequestType::kInternalCacheRow || type_ == RequestType::kInternalFetchRow ||
|
||||
type_ == RequestType::kCacheRow;
|
||||
}
|
||||
|
||||
/// \brief Return if the request is of admin request type
|
||||
/// \return True if the request is admin-related request
|
||||
bool IsAdminRequest() const {
|
||||
return type_ == RequestType::kCreateCache || type_ == RequestType::kDestroyCache ||
|
||||
type_ == RequestType::kGetStat || type_ == RequestType::kGetCacheState ||
|
||||
type_ == RequestType::kAllocateSharedBlock || type_ == RequestType::kFreeSharedBlock ||
|
||||
type_ == RequestType::kCacheSchema || type_ == RequestType::kFetchSchema ||
|
||||
type_ == RequestType::kBuildPhaseDone || type_ == RequestType::kToggleWriteMode ||
|
||||
type_ == RequestType::kConnectReset || type_ == RequestType::kStopService ||
|
||||
type_ == RequestType::kHeartBeat || type_ == RequestType::kGetCacheMissKeys;
|
||||
}
|
||||
|
||||
/// \brief Return if the request is of session request type
|
||||
/// \return True if the request is session-related request
|
||||
bool IsSessionRequest() const {
|
||||
return type_ == RequestType::kGenerateSessionId || type_ == RequestType::kDropSession ||
|
||||
type_ == RequestType::kListSessions;
|
||||
}
|
||||
|
||||
protected:
|
||||
CacheRequest rq_; // This is what we send to the server
|
||||
CacheReply reply_; // This is what the server send back
|
||||
|
|
|
@ -155,6 +155,42 @@ CacheService *CacheServer::GetService(connection_id_type id) const {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// We would like to protect ourselves from over allocating too much. We will go over existing cache
|
||||
// and calculate how much we have consumed so far.
|
||||
Status CacheServer::GlobalMemoryCheck(uint64_t cache_mem_sz) {
|
||||
auto end = all_caches_.end();
|
||||
auto it = all_caches_.begin();
|
||||
auto avail_mem = CacheServerHW::GetTotalSystemMemory() * memory_cap_ratio_;
|
||||
int64_t max_avail = avail_mem;
|
||||
while (it != end) {
|
||||
auto &cs = it->second;
|
||||
CacheService::ServiceStat stat;
|
||||
RETURN_IF_NOT_OK(cs->GetStat(&stat));
|
||||
int64_t mem_consumed = stat.stat_.num_mem_cached * stat.stat_.average_cache_sz;
|
||||
max_avail -= mem_consumed;
|
||||
if (max_avail <= 0) {
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
|
||||
}
|
||||
++it;
|
||||
}
|
||||
|
||||
// If we have some cache using some memory already, make a reasonable decision if we should return
|
||||
// out of memory.
|
||||
if (max_avail < avail_mem) {
|
||||
int64_t req_mem = cache_mem_sz * 1048576L; // It is in MB unit.
|
||||
if (req_mem > max_avail) {
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
|
||||
} else if (req_mem == 0) {
|
||||
// This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than
|
||||
// 85% of our limit, fail this request.
|
||||
if (static_cast<float>(max_avail) / static_cast<float>(avail_mem) <= 0.15) {
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing connection info");
|
||||
std::string cookie;
|
||||
|
@ -186,55 +222,25 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
|
|||
if (spill && top_.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("Server is not set up with spill support.");
|
||||
}
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
flatbuffers::Offset<flatbuffers::String> off_cookie;
|
||||
flatbuffers::Offset<flatbuffers::Vector<cpu_id_t>> off_cpu_list;
|
||||
// Before creating the cache, first check if this is a request for a shared usage of an existing cache
|
||||
// If two CreateService come in with identical connection_id, we need to serialize the create.
|
||||
// The first create will be successful and be given a special cookie.
|
||||
UniqueLock lck(&rwLock_);
|
||||
bool duplicate = false;
|
||||
CacheService *curr_cs = GetService(connection_id);
|
||||
if (curr_cs != nullptr) {
|
||||
duplicate = true;
|
||||
client_id = curr_cs->num_clients_.fetch_add(1);
|
||||
MS_LOG(INFO) << "Duplicate request from client " + std::to_string(client_id) + " for " +
|
||||
std::to_string(connection_id) + " to create cache service";
|
||||
}
|
||||
// Early exit if we are doing global shutdown
|
||||
if (global_shutdown_) {
|
||||
return Status::OK();
|
||||
}
|
||||
// We would like to protect ourselves from over allocating too much. We will go over existing cache
|
||||
// and calculate how much we have consumed so far.
|
||||
auto end = all_caches_.end();
|
||||
auto it = all_caches_.begin();
|
||||
bool duplicate = false;
|
||||
auto avail_mem = CacheServerHW::GetTotalSystemMemory() * memory_cap_ratio_;
|
||||
int64_t max_avail = avail_mem;
|
||||
while (it != end) {
|
||||
if (it->first == connection_id) {
|
||||
duplicate = true;
|
||||
break;
|
||||
} else {
|
||||
auto &cs = it->second;
|
||||
CacheService::ServiceStat stat;
|
||||
RETURN_IF_NOT_OK(cs->GetStat(&stat));
|
||||
int64_t mem_consumed = stat.stat_.num_mem_cached * stat.stat_.average_cache_sz;
|
||||
max_avail -= mem_consumed;
|
||||
if (max_avail <= 0) {
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
|
||||
}
|
||||
}
|
||||
++it;
|
||||
}
|
||||
if (it == end) {
|
||||
// If we have some cache using some memory already, make a reasonable decision if we should return
|
||||
// out of memory.
|
||||
if (max_avail < avail_mem) {
|
||||
int64_t req_mem = cache_mem_sz * 1048576L; // It is in MB unit.
|
||||
if (req_mem > max_avail) {
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
|
||||
} else if (req_mem == 0) {
|
||||
// This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than
|
||||
// 85% of our limit, fail this request.
|
||||
if (static_cast<float>(max_avail) / static_cast<float>(avail_mem) <= 0.15) {
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!duplicate) {
|
||||
RETURN_IF_NOT_OK(GlobalMemoryCheck(cache_mem_sz));
|
||||
std::unique_ptr<CacheService> cs;
|
||||
try {
|
||||
cs = std::make_unique<CacheService>(cache_mem_sz, spill ? top_ : "", generate_id);
|
||||
|
@ -245,12 +251,8 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
|
|||
} catch (const std::bad_alloc &e) {
|
||||
return Status(StatusCode::kOutOfMemory);
|
||||
}
|
||||
} else {
|
||||
duplicate = true;
|
||||
client_id = it->second->num_clients_.fetch_add(1);
|
||||
MS_LOG(INFO) << "Duplicate request from client " + std::to_string(client_id) + " for " +
|
||||
std::to_string(connection_id) + " to create cache service";
|
||||
}
|
||||
|
||||
// Shuffle the worker threads. But we need to release the locks or we will deadlock when calling
|
||||
// the following function
|
||||
lck.Unlock();
|
||||
|
@ -258,6 +260,9 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
|
|||
auto numa_id = client_id % GetNumaNodeCount();
|
||||
std::vector<cpu_id_t> cpu_list = hw_info_->GetCpuList(numa_id);
|
||||
// Send back the data
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
flatbuffers::Offset<flatbuffers::String> off_cookie;
|
||||
flatbuffers::Offset<flatbuffers::Vector<cpu_id_t>> off_cpu_list;
|
||||
off_cookie = fbb.CreateString(cookie);
|
||||
off_cpu_list = fbb.CreateVector(cpu_list);
|
||||
CreateCacheReplyMsgBuilder bld(fbb);
|
||||
|
@ -376,6 +381,57 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
|
|||
return rc;
|
||||
}
|
||||
|
||||
Status CacheServer::InternalCacheRow(CacheRequest *rq, CacheReply *reply) {
|
||||
// Look into the flag to see where we can find the data and call the appropriate method.
|
||||
auto flag = rq->flag();
|
||||
Status rc;
|
||||
if (BitTest(flag, kDataIsInSharedMemory)) {
|
||||
rc = FastCacheRow(rq, reply);
|
||||
// This is an internal request and is not tied to rpc. But need to post because there
|
||||
// is a thread waiting on the completion of this request.
|
||||
try {
|
||||
int64_t addr = strtol(rq->buf_data(3).data(), nullptr, 10);
|
||||
auto *bw = reinterpret_cast<BatchWait *>(addr);
|
||||
// Check if the object is still around.
|
||||
auto bwObj = bw->GetBatchWait();
|
||||
if (bwObj.lock()) {
|
||||
RETURN_IF_NOT_OK(bw->Set(rc));
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED(e.what());
|
||||
}
|
||||
} else {
|
||||
rc = CacheRow(rq, reply);
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
Status CacheServer::InternalFetchRow(CacheRequest *rq) {
|
||||
auto connection_id = rq->connection_id();
|
||||
SharedLock lck(&rwLock_);
|
||||
CacheService *cs = GetService(connection_id);
|
||||
Status rc;
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
}
|
||||
rc = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq->buf_data(0).data()));
|
||||
// This is an internal request and is not tied to rpc. But need to post because there
|
||||
// is a thread waiting on the completion of this request.
|
||||
try {
|
||||
int64_t addr = strtol(rq->buf_data(1).data(), nullptr, 10);
|
||||
auto *bw = reinterpret_cast<BatchWait *>(addr);
|
||||
// Check if the object is still around.
|
||||
auto bwObj = bw->GetBatchWait();
|
||||
if (bwObj.lock()) {
|
||||
RETURN_IF_NOT_OK(bw->Set(rc));
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED(e.what());
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
Status CacheServer::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
auto p = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer());
|
||||
|
@ -741,40 +797,24 @@ Status CacheServer::BatchCacheRows(CacheRequest *rq) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
|
||||
bool internal_request = false;
|
||||
Status CacheServer::ProcessRowRequest(CacheServerRequest *cache_req, bool *internal_request) {
|
||||
auto &rq = cache_req->rq_;
|
||||
auto &reply = cache_req->reply_;
|
||||
// Except for creating a new session, we expect cs is not null.
|
||||
switch (cache_req->type_) {
|
||||
case BaseRequest::RequestType::kCacheRow:
|
||||
case BaseRequest::RequestType::kInternalCacheRow: {
|
||||
// Look into the flag to see where we can find the data and
|
||||
// call the appropriate method.
|
||||
auto flag = rq.flag();
|
||||
if (BitTest(flag, kDataIsInSharedMemory)) {
|
||||
case BaseRequest::RequestType::kCacheRow: {
|
||||
// Look into the flag to see where we can find the data and call the appropriate method.
|
||||
if (BitTest(rq.flag(), kDataIsInSharedMemory)) {
|
||||
cache_req->rc_ = FastCacheRow(&rq, &reply);
|
||||
internal_request = (cache_req->type_ == BaseRequest::RequestType::kInternalCacheRow);
|
||||
if (internal_request) {
|
||||
// This is an internal request and is not tied to rpc. But need to post because there
|
||||
// is a thread waiting on the completion of this request.
|
||||
try {
|
||||
int64_t addr = strtol(rq.buf_data(3).data(), nullptr, 10);
|
||||
auto *bw = reinterpret_cast<BatchWait *>(addr);
|
||||
// Check if the object is still around.
|
||||
auto bwObj = bw->GetBatchWait();
|
||||
if (bwObj.lock()) {
|
||||
RETURN_IF_NOT_OK(bw->Set(std::move(cache_req->rc_)));
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED(e.what());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cache_req->rc_ = CacheRow(&rq, &reply);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kInternalCacheRow: {
|
||||
*internal_request = true;
|
||||
cache_req->rc_ = InternalCacheRow(&rq, &reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kBatchCacheRows: {
|
||||
cache_req->rc_ = BatchCacheRows(&rq);
|
||||
break;
|
||||
|
@ -784,31 +824,46 @@ Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
|
|||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kInternalFetchRow: {
|
||||
internal_request = true;
|
||||
auto connection_id = rq.connection_id();
|
||||
SharedLock lck(&rwLock_);
|
||||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else {
|
||||
cache_req->rc_ = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq.buf_data(0).data()));
|
||||
// This is an internal request and is not tied to rpc. But need to post because there
|
||||
// is a thread waiting on the completion of this request.
|
||||
try {
|
||||
int64_t addr = strtol(rq.buf_data(1).data(), nullptr, 10);
|
||||
auto *bw = reinterpret_cast<BatchWait *>(addr);
|
||||
// Check if the object is still around.
|
||||
auto bwObj = bw->GetBatchWait();
|
||||
if (bwObj.lock()) {
|
||||
RETURN_IF_NOT_OK(bw->Set(std::move(cache_req->rc_)));
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED(e.what());
|
||||
}
|
||||
}
|
||||
*internal_request = true;
|
||||
cache_req->rc_ = InternalFetchRow(&rq);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
std::string errMsg("Internal error, request type is not row request: ");
|
||||
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
|
||||
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheServer::ProcessSessionRequest(CacheServerRequest *cache_req) {
|
||||
auto &rq = cache_req->rq_;
|
||||
auto &reply = cache_req->reply_;
|
||||
switch (cache_req->type_) {
|
||||
case BaseRequest::RequestType::kDropSession: {
|
||||
cache_req->rc_ = DestroySession(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kGenerateSessionId: {
|
||||
cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kListSessions: {
|
||||
cache_req->rc_ = ListSessions(&reply);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
std::string errMsg("Internal error, request type is not session request: ");
|
||||
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
|
||||
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheServer::ProcessAdminRequest(CacheServerRequest *cache_req) {
|
||||
auto &rq = cache_req->rq_;
|
||||
auto &reply = cache_req->reply_;
|
||||
switch (cache_req->type_) {
|
||||
case BaseRequest::RequestType::kCreateCache: {
|
||||
cache_req->rc_ = CreateService(&rq, &reply);
|
||||
break;
|
||||
|
@ -837,14 +892,6 @@ Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
|
|||
cache_req->rc_ = BuildPhaseDone(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kDropSession: {
|
||||
cache_req->rc_ = DestroySession(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kGenerateSessionId: {
|
||||
cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kAllocateSharedBlock: {
|
||||
cache_req->rc_ = AllocateSharedMemory(&rq, &reply);
|
||||
break;
|
||||
|
@ -868,40 +915,45 @@ Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
|
|||
cache_req->rc_ = ToggleWriteMode(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kListSessions: {
|
||||
cache_req->rc_ = ListSessions(&reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kConnectReset: {
|
||||
cache_req->rc_ = ConnectReset(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kGetCacheState: {
|
||||
auto connection_id = rq.connection_id();
|
||||
SharedLock lck(&rwLock_);
|
||||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else {
|
||||
auto state = cs->GetState();
|
||||
reply.set_result(std::to_string(static_cast<int8_t>(state)));
|
||||
cache_req->rc_ = Status::OK();
|
||||
}
|
||||
cache_req->rc_ = GetCacheState(&rq, &reply);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
std::string errMsg("Unknown request type : ");
|
||||
std::string errMsg("Internal error, request type is not admin request: ");
|
||||
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
|
||||
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
|
||||
bool internal_request = false;
|
||||
|
||||
// Except for creating a new session, we expect cs is not null.
|
||||
if (cache_req->IsRowRequest()) {
|
||||
RETURN_IF_NOT_OK(ProcessRowRequest(cache_req, &internal_request));
|
||||
} else if (cache_req->IsSessionRequest()) {
|
||||
RETURN_IF_NOT_OK(ProcessSessionRequest(cache_req));
|
||||
} else if (cache_req->IsAdminRequest()) {
|
||||
RETURN_IF_NOT_OK(ProcessAdminRequest(cache_req));
|
||||
} else {
|
||||
std::string errMsg("Unknown request type : ");
|
||||
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
|
||||
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
}
|
||||
|
||||
// Notify it is done, and move on to the next request.
|
||||
Status2CacheReply(cache_req->rc_, &reply);
|
||||
Status2CacheReply(cache_req->rc_, &cache_req->reply_);
|
||||
cache_req->st_ = CacheServerRequest::STATE::FINISH;
|
||||
// We will re-tag the request back to the grpc queue. Once it comes back from the client,
|
||||
// the CacheServerRequest, i.e. the pointer cache_req, will be free
|
||||
if (!internal_request && !global_shutdown_) {
|
||||
cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req);
|
||||
cache_req->responder_.Finish(cache_req->reply_, grpc::Status::OK, cache_req);
|
||||
} else {
|
||||
// We can free up the request now.
|
||||
RETURN_IF_NOT_OK(ReturnRequestTag(cache_req));
|
||||
|
@ -1084,6 +1136,20 @@ Status CacheServer::FreeSharedMemory(CacheRequest *rq) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheServer::GetCacheState(CacheRequest *rq, CacheReply *reply) {
|
||||
auto connection_id = rq->connection_id();
|
||||
SharedLock lck(&rwLock_);
|
||||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else {
|
||||
auto state = cs->GetState();
|
||||
reply->set_result(std::to_string(static_cast<int8_t>(state)));
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
Status CacheServer::RpcRequest(worker_id_t worker_id) {
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(comm_layer_->HandleRequest(worker_id));
|
||||
|
@ -1213,7 +1279,7 @@ Status CacheServer::Builder::SanityCheck() {
|
|||
}
|
||||
|
||||
CacheServer::Builder::Builder()
|
||||
: top_(DefaultSpillDir()),
|
||||
: top_(""),
|
||||
num_workers_(std::thread::hardware_concurrency() / 2),
|
||||
port_(50052),
|
||||
shared_memory_sz_in_gb_(kDefaultSharedMemorySize),
|
||||
|
|
|
@ -201,6 +201,22 @@ class CacheServer : public Service {
|
|||
/// \brief Return the memory cap ratio
|
||||
float GetMemoryCapRatio() const { return memory_cap_ratio_; }
|
||||
|
||||
/// \brief Function to handle a row request
|
||||
/// \param[in] cache_req A row request to handle
|
||||
/// \param[out] internal_request Indicator if the request is an internal request
|
||||
/// \return Status object
|
||||
Status ProcessRowRequest(CacheServerRequest *cache_req, bool *internal_request);
|
||||
|
||||
/// \brief Function to handle an admin request
|
||||
/// \param[in] cache_req An admin request to handle
|
||||
/// \return Status object
|
||||
Status ProcessAdminRequest(CacheServerRequest *cache_req);
|
||||
|
||||
/// \brief Function to handle a session request
|
||||
/// \param[in] cache_req A session request to handle
|
||||
/// \return Status object
|
||||
Status ProcessSessionRequest(CacheServerRequest *cache_req);
|
||||
|
||||
/// \brief How a request is handled.
|
||||
/// \note that it can be process immediately by a grpc thread or routed to a server thread
|
||||
/// which is pinned to some numa node core.
|
||||
|
@ -256,6 +272,12 @@ class CacheServer : public Service {
|
|||
/// \return Pointer to cache service. Null if not found
|
||||
CacheService *GetService(connection_id_type id) const;
|
||||
|
||||
/// \brief Going over existing cache service and calculate how much we have consumed so far, a new cache service
|
||||
/// can only be created if there is still enough avail memory left
|
||||
/// \param cache_mem_sz Requested memory for a new cache service
|
||||
/// \return Status object
|
||||
Status GlobalMemoryCheck(uint64_t cache_mem_sz);
|
||||
|
||||
/// \brief Create a cache service. We allow multiple clients to create the same cache service.
|
||||
/// Subsequent duplicate requests are ignored. The first cache client to create the service will be given
|
||||
/// a special unique cookie.
|
||||
|
@ -314,6 +336,12 @@ class CacheServer : public Service {
|
|||
/// \return Status object
|
||||
Status GetStat(CacheRequest *rq, CacheReply *reply);
|
||||
|
||||
/// \brief Internal function to get cache state
|
||||
/// \param rq
|
||||
/// \param reply
|
||||
/// \return Status object
|
||||
Status GetCacheState(CacheRequest *rq, CacheReply *reply);
|
||||
|
||||
/// \brief Cache a schema request
|
||||
/// \param rq
|
||||
/// \return Status object
|
||||
|
@ -411,6 +439,9 @@ class CacheServer : public Service {
|
|||
/// \return Status object
|
||||
Status BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out);
|
||||
Status BatchCacheRows(CacheRequest *rq);
|
||||
|
||||
Status InternalFetchRow(CacheRequest *rq);
|
||||
Status InternalCacheRow(CacheRequest *rq, CacheReply *reply);
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,7 +15,8 @@
|
|||
# ============================================================================
|
||||
|
||||
# source the globals and functions for use with cache testing
|
||||
SKIP_ADMIN_COUNTER=false
|
||||
export SKIP_ADMIN_COUNTER=false
|
||||
declare failed_tests
|
||||
. cachetest_lib.sh
|
||||
echo
|
||||
|
||||
|
|
|
@ -15,7 +15,8 @@
|
|||
# ============================================================================
|
||||
|
||||
# source the globals and functions for use with cache testing
|
||||
SKIP_ADMIN_COUNTER=true
|
||||
export SKIP_ADMIN_COUNTER=true
|
||||
declare session_id failed_tests
|
||||
. cachetest_lib.sh
|
||||
echo
|
||||
|
||||
|
@ -28,8 +29,10 @@ UT_TEST_DIR="${BUILD_PATH}/mindspore/tests/ut/cpp"
|
|||
DateStamp=$(date +%Y%m%d_%H%M%S);
|
||||
CPP_TEST_LOG_OUTPUT="/tmp/ut_tests_cache_${DateStamp}.log"
|
||||
|
||||
# Start a basic cache server to be used for all tests
|
||||
StartServer
|
||||
# start cache server with a spilling path to be used for all tests
|
||||
cmd="${CACHE_ADMIN} --start -s /tmp"
|
||||
CacheAdminCmd "${cmd}" 0
|
||||
sleep 1
|
||||
HandleRcExit $? 1 1
|
||||
|
||||
# Set the environment variable to enable these pytests
|
||||
|
|
|
@ -15,7 +15,8 @@
|
|||
# ============================================================================
|
||||
|
||||
# source the globals and functions for use with cache testing
|
||||
SKIP_ADMIN_COUNTER=true
|
||||
export SKIP_ADMIN_COUNTER=true
|
||||
declare session_id failed_tests
|
||||
. cachetest_lib.sh
|
||||
echo
|
||||
|
||||
|
@ -84,10 +85,6 @@ export SESSION_ID=$session_id
|
|||
PytestCmd "test_cache_map.py" "test_cache_map_running_twice2"
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
# Set size parameter of DatasetCache to a extra small value
|
||||
PytestCmd "test_cache_map.py" "test_cache_map_extra_small_size" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
PytestCmd "test_cache_map.py" "test_cache_map_no_image"
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
|
@ -255,15 +252,6 @@ export SESSION_ID=$session_id
|
|||
PytestCmd "test_cache_nomap.py" "test_cache_nomap_running_twice2"
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
# Set size parameter of DatasetCache to a extra small value
|
||||
GetSession
|
||||
HandleRcExit $? 1 1
|
||||
export SESSION_ID=$session_id
|
||||
PytestCmd "test_cache_nomap.py" "test_cache_nomap_extra_small_size" 1
|
||||
HandleRcExit $? 0 0
|
||||
DestroySession $session_id
|
||||
HandleRcExit $? 1 1
|
||||
|
||||
# Run two parallel pipelines (sharing cache)
|
||||
for i in $(seq 1 2)
|
||||
do
|
||||
|
@ -366,7 +354,7 @@ HandleRcExit $? 1 1
|
|||
export SESSION_ID=$session_id
|
||||
|
||||
PytestCmd "test_cache_nomap.py" "test_cache_nomap_session_destroy" &
|
||||
pid=("$!")
|
||||
pid=$!
|
||||
|
||||
sleep 10
|
||||
DestroySession $session_id
|
||||
|
@ -381,7 +369,7 @@ HandleRcExit $? 1 1
|
|||
export SESSION_ID=$session_id
|
||||
|
||||
PytestCmd "test_cache_nomap.py" "test_cache_nomap_server_stop" &
|
||||
pid=("$!")
|
||||
pid=$!
|
||||
|
||||
sleep 10
|
||||
StopServer
|
||||
|
@ -417,6 +405,26 @@ HandleRcExit $? 0 0
|
|||
StopServer
|
||||
HandleRcExit $? 0 1
|
||||
|
||||
# start cache server with a spilling path
|
||||
cmd="${CACHE_ADMIN} --start -s /tmp"
|
||||
CacheAdminCmd "${cmd}" 0
|
||||
sleep 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
GetSession
|
||||
HandleRcExit $? 1 1
|
||||
export SESSION_ID=$session_id
|
||||
|
||||
# Set size parameter of mappable DatasetCache to a extra small value
|
||||
PytestCmd "test_cache_map.py" "test_cache_map_extra_small_size" 1
|
||||
HandleRcExit $? 0 0
|
||||
# Set size parameter of non-mappable DatasetCache to a extra small value
|
||||
PytestCmd "test_cache_nomap.py" "test_cache_nomap_extra_small_size" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
StopServer
|
||||
HandleRcExit $? 0 1
|
||||
|
||||
unset RUN_CACHE_TEST
|
||||
unset SESSION_ID
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ def test_cache_map_basic1():
|
|||
else:
|
||||
session_id = 1
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
|
||||
|
@ -91,7 +91,7 @@ def test_cache_map_basic2():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
|
||||
|
@ -115,7 +115,7 @@ def test_cache_map_basic3():
|
|||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
|
||||
|
@ -155,7 +155,7 @@ def test_cache_map_basic4():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
|
||||
|
@ -189,7 +189,7 @@ def test_cache_map_basic5():
|
|||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
|
||||
|
@ -225,7 +225,7 @@ def test_cache_map_failure1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
|
||||
|
@ -269,7 +269,7 @@ def test_cache_map_failure2():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
|
||||
|
@ -310,7 +310,7 @@ def test_cache_map_failure3():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
|
||||
|
@ -351,7 +351,7 @@ def test_cache_map_failure4():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
|
||||
|
@ -391,7 +391,7 @@ def test_cache_map_failure5():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
data = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
|
||||
|
@ -432,7 +432,7 @@ def test_cache_map_failure6():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
columns_list = ["id", "file_name", "label_name", "img_data", "label_data"]
|
||||
num_readers = 1
|
||||
|
@ -478,7 +478,7 @@ def test_cache_map_failure7():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
data = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
data = data.map((lambda x: x), ["data"], cache=some_cache)
|
||||
|
@ -514,7 +514,7 @@ def test_cache_map_failure8():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
|
||||
|
@ -554,7 +554,7 @@ def test_cache_map_failure9():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
|
||||
|
@ -596,7 +596,7 @@ def test_cache_map_failure10():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
|
||||
|
@ -616,6 +616,37 @@ def test_cache_map_failure10():
|
|||
logger.info('test_cache_failure10 Ended.\n')
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_map_failure11():
|
||||
"""
|
||||
Test set spilling=true when cache server is started without spilling support (failure)
|
||||
|
||||
Cache(spilling=true)
|
||||
|
|
||||
ImageFolder
|
||||
|
||||
"""
|
||||
logger.info("Test cache failure 11")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
|
||||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
num_iter = 0
|
||||
for _ in ds1.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert "Unexpected error. Server is not set up with spill support" in str(e.value)
|
||||
|
||||
assert num_iter == 0
|
||||
logger.info('test_cache_failure11 Ended.\n')
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_map_split1():
|
||||
"""
|
||||
|
@ -641,7 +672,7 @@ def test_cache_map_split1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
|
||||
|
@ -692,7 +723,7 @@ def test_cache_map_split2():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 9 records
|
||||
ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
@ -725,27 +756,27 @@ def test_cache_map_parameter_check():
|
|||
logger.info("Test cache map parameter check")
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
ds.DatasetCache(session_id=-1, size=0, spilling=True)
|
||||
ds.DatasetCache(session_id=-1, size=0)
|
||||
assert "Input is not within the required interval" in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
ds.DatasetCache(session_id="1", size=0, spilling=True)
|
||||
ds.DatasetCache(session_id="1", size=0)
|
||||
assert "Argument session_id with value 1 is not of type (<class 'int'>,)" in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
ds.DatasetCache(session_id=None, size=0, spilling=True)
|
||||
ds.DatasetCache(session_id=None, size=0)
|
||||
assert "Argument session_id with value None is not of type (<class 'int'>,)" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
ds.DatasetCache(session_id=1, size=-1, spilling=True)
|
||||
ds.DatasetCache(session_id=1, size=-1)
|
||||
assert "Input size must be greater than 0" in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
ds.DatasetCache(session_id=1, size="1", spilling=True)
|
||||
ds.DatasetCache(session_id=1, size="1")
|
||||
assert "Argument size with value 1 is not of type (<class 'int'>,)" in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
ds.DatasetCache(session_id=1, size=None, spilling=True)
|
||||
ds.DatasetCache(session_id=1, size=None)
|
||||
assert "Argument size with value None is not of type (<class 'int'>,)" in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
|
@ -753,31 +784,31 @@ def test_cache_map_parameter_check():
|
|||
assert "Argument spilling with value illegal is not of type (<class 'bool'>,)" in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as err:
|
||||
ds.DatasetCache(session_id=1, size=0, spilling=True, hostname=50052)
|
||||
ds.DatasetCache(session_id=1, size=0, hostname=50052)
|
||||
assert "Argument hostname with value 50052 is not of type (<class 'str'>,)" in str(err.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="illegal")
|
||||
ds.DatasetCache(session_id=1, size=0, hostname="illegal")
|
||||
assert "now cache client has to be on the same host with cache server" in str(err.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="127.0.0.2")
|
||||
ds.DatasetCache(session_id=1, size=0, hostname="127.0.0.2")
|
||||
assert "now cache client has to be on the same host with cache server" in str(err.value)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
ds.DatasetCache(session_id=1, size=0, spilling=True, port="illegal")
|
||||
ds.DatasetCache(session_id=1, size=0, port="illegal")
|
||||
assert "Argument port with value illegal is not of type (<class 'int'>,)" in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
ds.DatasetCache(session_id=1, size=0, spilling=True, port="50052")
|
||||
ds.DatasetCache(session_id=1, size=0, port="50052")
|
||||
assert "Argument port with value 50052 is not of type (<class 'int'>,)" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as err:
|
||||
ds.DatasetCache(session_id=1, size=0, spilling=True, port=0)
|
||||
ds.DatasetCache(session_id=1, size=0, port=0)
|
||||
assert "Input port is not within the required interval of (1025 to 65535)" in str(err.value)
|
||||
|
||||
with pytest.raises(ValueError) as err:
|
||||
ds.DatasetCache(session_id=1, size=0, spilling=True, port=65536)
|
||||
ds.DatasetCache(session_id=1, size=0, port=65536)
|
||||
assert "Input port is not within the required interval of (1025 to 65535)" in str(err.value)
|
||||
|
||||
with pytest.raises(TypeError) as err:
|
||||
|
@ -807,7 +838,7 @@ def test_cache_map_running_twice1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
|
||||
|
@ -850,7 +881,7 @@ def test_cache_map_running_twice2():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
|
||||
|
@ -998,7 +1029,7 @@ def test_cache_map_parallel_pipeline1(shard):
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard), cache=some_cache)
|
||||
|
@ -1035,7 +1066,7 @@ def test_cache_map_parallel_pipeline2(shard):
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard))
|
||||
|
@ -1072,7 +1103,7 @@ def test_cache_map_parallel_workers():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_parallel_workers=4)
|
||||
|
@ -1109,7 +1140,7 @@ def test_cache_map_server_workers_1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
|
||||
|
@ -1146,7 +1177,7 @@ def test_cache_map_server_workers_100():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
|
||||
|
@ -1183,7 +1214,7 @@ def test_cache_map_num_connections_1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=1)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=1)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
|
||||
|
@ -1220,7 +1251,7 @@ def test_cache_map_num_connections_100():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=100)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=100)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
|
||||
|
@ -1257,7 +1288,7 @@ def test_cache_map_prefetch_size_1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=1)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=1)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
|
||||
|
@ -1294,7 +1325,7 @@ def test_cache_map_prefetch_size_100():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=100)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=100)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
|
||||
|
@ -1335,7 +1366,7 @@ def test_cache_map_to_device():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
|
||||
|
@ -1366,7 +1397,7 @@ def test_cache_map_epoch_ctrl1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
|
||||
|
@ -1406,7 +1437,7 @@ def test_cache_map_epoch_ctrl2():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
|
||||
|
@ -1452,7 +1483,7 @@ def test_cache_map_epoch_ctrl3():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
|
||||
|
@ -1495,7 +1526,7 @@ def test_cache_map_coco1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 6 records
|
||||
ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True,
|
||||
|
@ -1531,7 +1562,7 @@ def test_cache_map_coco2():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 6 records
|
||||
ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True)
|
||||
|
@ -1566,7 +1597,7 @@ def test_cache_map_mnist1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10, cache=some_cache)
|
||||
|
||||
num_epoch = 4
|
||||
|
@ -1599,7 +1630,7 @@ def test_cache_map_mnist2():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10)
|
||||
|
||||
resize_op = c_vision.Resize((224, 224))
|
||||
|
@ -1633,7 +1664,7 @@ def test_cache_map_celeba1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 4 records
|
||||
ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, cache=some_cache)
|
||||
|
@ -1668,7 +1699,7 @@ def test_cache_map_celeba2():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 4 records
|
||||
ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True)
|
||||
|
@ -1703,7 +1734,7 @@ def test_cache_map_manifest1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 4 records
|
||||
ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True, cache=some_cache)
|
||||
|
@ -1738,7 +1769,7 @@ def test_cache_map_manifest2():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 4 records
|
||||
ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True)
|
||||
|
@ -1773,7 +1804,7 @@ def test_cache_map_cifar1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache)
|
||||
|
||||
num_epoch = 4
|
||||
|
@ -1806,7 +1837,7 @@ def test_cache_map_cifar2():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10)
|
||||
resize_op = c_vision.Resize((224, 224))
|
||||
|
@ -1841,7 +1872,7 @@ def test_cache_map_cifar3():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=1)
|
||||
|
||||
ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, cache=some_cache)
|
||||
|
||||
|
@ -1875,7 +1906,7 @@ def test_cache_map_cifar4():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache)
|
||||
ds1 = ds1.shuffle(10)
|
||||
|
||||
|
@ -1907,7 +1938,7 @@ def test_cache_map_voc1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 9 records
|
||||
ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True, cache=some_cache)
|
||||
|
@ -1942,7 +1973,7 @@ def test_cache_map_voc2():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 9 records
|
||||
ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
|
@ -1987,7 +2018,7 @@ def test_cache_map_python_sampler1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler(), cache=some_cache)
|
||||
|
@ -2023,7 +2054,7 @@ def test_cache_map_python_sampler2():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler())
|
||||
|
@ -2061,7 +2092,7 @@ def test_cache_map_nested_repeat():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This DATA_DIR only has 2 images in it
|
||||
ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
|
||||
|
|
|
@ -62,7 +62,7 @@ def test_cache_nomap_basic1():
|
|||
schema.add_column('label', de_type=mstype.uint8, shape=[1])
|
||||
|
||||
# create a cache. arbitrary session_id for now
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# User-created sampler here
|
||||
ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=4, cache=some_cache)
|
||||
|
@ -96,7 +96,7 @@ def test_cache_nomap_basic2():
|
|||
schema.add_column('label', de_type=mstype.uint8, shape=[1])
|
||||
|
||||
# create a cache. arbitrary session_id for now
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# sampler arg not given directly, however any of these args will auto-generate an appropriate sampler:
|
||||
# num_samples, shuffle, num_shards, shard_id
|
||||
|
@ -134,7 +134,7 @@ def test_cache_nomap_basic3():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
|
||||
decode_op = c_vision.Decode()
|
||||
ds1 = ds1.map(operations=decode_op, input_columns=["image"])
|
||||
|
@ -183,7 +183,7 @@ def test_cache_nomap_basic4():
|
|||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
# With shuffle not being set, TF defaults to a "global" shuffle when there is no cache
|
||||
# in the picture. This causes a shuffle-injection over the TF. For clarify, this test will
|
||||
# explicitly give the global option, even though it's the default in python.
|
||||
|
@ -231,7 +231,7 @@ def test_cache_nomap_basic5():
|
|||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], cache=some_cache)
|
||||
decode_op = c_vision.Decode()
|
||||
ds1 = ds1.map(operations=decode_op, input_columns=["image"])
|
||||
|
@ -270,7 +270,7 @@ def test_cache_nomap_basic6():
|
|||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# With only 3 records shard into 3, we expect only 1 record returned for this shard
|
||||
# However, the sharding will be done by the sampler, not by the tf record leaf node
|
||||
|
@ -313,7 +313,7 @@ def test_cache_nomap_basic7():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL, cache=some_cache)
|
||||
|
@ -344,7 +344,7 @@ def test_cache_nomap_basic8():
|
|||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
|
||||
|
@ -371,7 +371,7 @@ def test_cache_nomap_basic9():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# Contact the server to get the statistics, this should fail because we have not used this cache in any pipeline
|
||||
# so there will not be any cache to get stats on.
|
||||
|
@ -404,7 +404,7 @@ def test_cache_nomap_allowed_share1():
|
|||
|
||||
ds.config.set_seed(1)
|
||||
# This dataset has 3 records in it only
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=32)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=32)
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
|
||||
ds1 = ds1.repeat(4)
|
||||
|
||||
|
@ -446,7 +446,7 @@ def test_cache_nomap_allowed_share2():
|
|||
|
||||
ds.config.set_seed(1)
|
||||
# This dataset has 3 records in it only
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
decode_op = c_vision.Decode()
|
||||
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
|
@ -488,7 +488,7 @@ def test_cache_nomap_allowed_share3():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data"]
|
||||
ds1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=0, num_samples=3, shuffle=False, cache=some_cache)
|
||||
|
@ -529,7 +529,7 @@ def test_cache_nomap_allowed_share4():
|
|||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
decode_op = c_vision.Decode()
|
||||
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
|
@ -572,7 +572,7 @@ def test_cache_nomap_disallowed_share1():
|
|||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
decode_op = c_vision.Decode()
|
||||
rescale_op = c_vision.Rescale(1.0 / 255.0, -1.0)
|
||||
|
||||
|
@ -615,7 +615,7 @@ def test_cache_nomap_running_twice1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
|
||||
|
@ -658,7 +658,7 @@ def test_cache_nomap_running_twice2():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
|
||||
|
@ -763,7 +763,7 @@ def test_cache_nomap_parallel_pipeline1(shard):
|
|||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard), cache=some_cache)
|
||||
|
@ -799,7 +799,7 @@ def test_cache_nomap_parallel_pipeline2(shard):
|
|||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard))
|
||||
|
@ -835,7 +835,7 @@ def test_cache_nomap_parallel_workers():
|
|||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_parallel_workers=4)
|
||||
|
@ -872,7 +872,7 @@ def test_cache_nomap_server_workers_1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
|
||||
|
@ -909,7 +909,7 @@ def test_cache_nomap_server_workers_100():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
|
||||
|
@ -946,7 +946,7 @@ def test_cache_nomap_num_connections_1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=1)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=1)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
|
||||
|
@ -983,7 +983,7 @@ def test_cache_nomap_num_connections_100():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=100)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=100)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
|
||||
|
@ -1020,7 +1020,7 @@ def test_cache_nomap_prefetch_size_1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=1)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=1)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
|
||||
|
@ -1057,7 +1057,7 @@ def test_cache_nomap_prefetch_size_100():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=100)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=100)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
|
||||
|
@ -1098,7 +1098,7 @@ def test_cache_nomap_to_device():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
|
||||
|
@ -1134,7 +1134,7 @@ def test_cache_nomap_session_destroy():
|
|||
shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
|
||||
schema.add_column('label', de_type=mstype.uint8, shape=[1])
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# User-created sampler here
|
||||
ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache)
|
||||
|
@ -1172,7 +1172,7 @@ def test_cache_nomap_server_stop():
|
|||
shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
|
||||
schema.add_column('label', de_type=mstype.uint8, shape=[1])
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# User-created sampler here
|
||||
ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache)
|
||||
|
@ -1206,7 +1206,7 @@ def test_cache_nomap_epoch_ctrl1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
|
||||
|
@ -1246,7 +1246,7 @@ def test_cache_nomap_epoch_ctrl2():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
|
||||
|
@ -1292,7 +1292,7 @@ def test_cache_nomap_epoch_ctrl3():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
|
||||
|
@ -1339,7 +1339,7 @@ def test_cache_nomap_epoch_ctrl4():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
|
||||
|
@ -1381,8 +1381,8 @@ def test_cache_nomap_multiple_cache1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
train_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
eval_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 12 records in it
|
||||
train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR)
|
||||
|
@ -1425,8 +1425,8 @@ def test_cache_nomap_multiple_cache2():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
image_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
text_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
image_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
text_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
image_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
|
||||
|
@ -1470,8 +1470,8 @@ def test_cache_nomap_multiple_cache3():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
tf_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
image_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
tf_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
image_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
tf_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
|
||||
|
@ -1515,7 +1515,7 @@ def test_cache_nomap_multiple_cache_train():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
train_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 12 records in it
|
||||
train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR)
|
||||
|
@ -1553,7 +1553,7 @@ def test_cache_nomap_multiple_cache_eval():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
eval_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset only has 3 records in it
|
||||
eval_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
|
||||
|
@ -1591,7 +1591,7 @@ def test_cache_nomap_clue1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# With only 3 records shard into 3, we expect only 1 record returned for this shard
|
||||
# However, the sharding will be done by the sampler, not by the clue leaf node
|
||||
|
@ -1630,7 +1630,7 @@ def test_cache_nomap_clue2():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_samples=2)
|
||||
ds1 = ds1.map((lambda x: x), ["label"], cache=some_cache)
|
||||
|
@ -1666,7 +1666,7 @@ def test_cache_nomap_csv1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# With only 3 records shard into 3, we expect only 1 record returned for this shard
|
||||
# However, the sharding will be done by the sampler, not by the clue leaf node
|
||||
|
@ -1706,7 +1706,7 @@ def test_cache_nomap_csv2():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"],
|
||||
column_names=['col1', 'col2', 'col3', 'col4'], num_samples=2)
|
||||
|
@ -1743,7 +1743,7 @@ def test_cache_nomap_textfile1():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# With only 3 records shard into 3, we expect only 1 record returned for this shard
|
||||
# However, the sharding will be done by the sampler, not by the clue leaf node
|
||||
|
@ -1788,7 +1788,7 @@ def test_cache_nomap_textfile2():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
ds1 = ds.TextFileDataset(TEXT_FILE_DATA_DIR, num_samples=2)
|
||||
tokenizer = text.PythonTokenizer(my_tokenizer)
|
||||
|
@ -1828,7 +1828,7 @@ def test_cache_nomap_nested_repeat():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
|
||||
|
@ -1867,7 +1867,7 @@ def test_cache_nomap_get_repeat_count():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# This dataset has 3 records in it only
|
||||
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
|
@ -1902,7 +1902,7 @@ def test_cache_nomap_long_file_list():
|
|||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False)
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=1)
|
||||
|
||||
ds1 = ds.TFRecordDataset([DATA_DIR[0] for _ in range(0, 1000)], SCHEMA_DIR, columns_list=["image"],
|
||||
cache=some_cache)
|
||||
|
|
Loading…
Reference in New Issue