forked from mindspore-Ecosystem/mindspore
!17657 Fix a bug in cache multi-session destroy and fix random check for python functions
From: @lixiachen Reviewed-by: @john_tzanakakis,@robingrosman Signed-off-by: @robingrosman
This commit is contained in:
commit
a0822767a1
|
@ -125,13 +125,12 @@ Status CacheAdminArgHandler::AssignArg(std::string option, std::vector<uint32_t>
|
|||
|
||||
uint32_t value_as_uint;
|
||||
while (arg_stream->rdbuf()->in_avail() != 0) {
|
||||
std::stringstream::pos_type pos = arg_stream->tellg();
|
||||
*arg_stream >> value_as_uint;
|
||||
if (arg_stream->fail()) {
|
||||
arg_stream->clear();
|
||||
std::string value_as_string;
|
||||
*arg_stream >> value_as_string;
|
||||
std::string err_msg = "Invalid numeric value: " + value_as_string;
|
||||
return Status(StatusCode::kMDSyntaxError, err_msg);
|
||||
arg_stream->seekg(pos, arg_stream->beg);
|
||||
break;
|
||||
} else {
|
||||
out_arg->push_back(value_as_uint);
|
||||
}
|
||||
|
@ -492,18 +491,18 @@ Status CacheAdminArgHandler::ShowServerInfo() {
|
|||
if (spill_dir.empty()) spill_dir = "None";
|
||||
|
||||
int name_w = 20;
|
||||
int value_w = 15;
|
||||
int value_w = 50;
|
||||
std::cout << "Cache Server Configuration: " << std::endl;
|
||||
std::cout << "----------------------------------------" << std::endl;
|
||||
std::cout << std::string(name_w + value_w, '-') << std::endl;
|
||||
std::cout << std::setw(name_w) << "config name" << std::setw(value_w) << "value" << std::endl;
|
||||
std::cout << "----------------------------------------" << std::endl;
|
||||
std::cout << std::string(name_w + value_w, '-') << std::endl;
|
||||
std::cout << std::setw(name_w) << "hostname" << std::setw(value_w) << hostname_ << std::endl;
|
||||
std::cout << std::setw(name_w) << "port" << std::setw(value_w) << port_ << std::endl;
|
||||
std::cout << std::setw(name_w) << "number of workers" << std::setw(value_w) << std::to_string(num_workers)
|
||||
<< std::endl;
|
||||
std::cout << std::setw(name_w) << "log level" << std::setw(value_w) << std::to_string(log_level) << std::endl;
|
||||
std::cout << std::setw(name_w) << "spill dir" << std::setw(value_w) << spill_dir << std::endl;
|
||||
std::cout << "----------------------------------------" << std::endl;
|
||||
std::cout << std::string(name_w + value_w, '-') << std::endl;
|
||||
|
||||
std::cout << "Active sessions: " << std::endl;
|
||||
if (!session_ids.empty()) {
|
||||
|
|
|
@ -27,36 +27,39 @@
|
|||
namespace ms = mindspore;
|
||||
namespace ds = mindspore::dataset;
|
||||
|
||||
namespace {
|
||||
const int32_t kTotalArgs = 8;
|
||||
enum ArgIndex : uint8_t {
|
||||
kProcessName = 0,
|
||||
kRootDir = 1,
|
||||
kNumWorkers = 2,
|
||||
kPort = 3,
|
||||
kSharedMemorySize = 4,
|
||||
kLogLevel = 5,
|
||||
kDemonize = 6,
|
||||
kMemoryCapRatio = 7
|
||||
};
|
||||
} // namespace
|
||||
|
||||
/// Start the server
|
||||
/// \param argv
|
||||
/// \return Status object
|
||||
ms::Status StartServer(int argc, char **argv) {
|
||||
ms::Status rc;
|
||||
ds::CacheServer::Builder builder;
|
||||
const int32_t kTotalArgs = 8;
|
||||
enum {
|
||||
kProcessNameIdx = 0,
|
||||
kRootDirArgIdx = 1,
|
||||
kNumWorkersArgIdx = 2,
|
||||
kPortArgIdx = 3,
|
||||
kSharedMemorySizeArgIdx = 4,
|
||||
kLogLevelArgIdx = 5,
|
||||
kDemonizeArgIdx = 6,
|
||||
kMemoryCapRatioArgIdx = 7
|
||||
};
|
||||
if (argc != kTotalArgs) {
|
||||
return ms::Status(ms::StatusCode::kMDSyntaxError);
|
||||
}
|
||||
|
||||
int32_t port = static_cast<int32_t>(strtol(argv[kPortArgIdx], nullptr, ds::kDecimal));
|
||||
builder.SetRootDirectory(argv[kRootDirArgIdx])
|
||||
.SetNumWorkers(static_cast<int32_t>(strtol(argv[kNumWorkersArgIdx], nullptr, ds::kDecimal)))
|
||||
int32_t port = static_cast<int32_t>(strtol(argv[ArgIndex::kPort], nullptr, ds::kDecimal));
|
||||
builder.SetRootDirectory(argv[ArgIndex::kRootDir])
|
||||
.SetNumWorkers(static_cast<int32_t>(strtol(argv[ArgIndex::kNumWorkers], nullptr, ds::kDecimal)))
|
||||
.SetPort(port)
|
||||
.SetSharedMemorySizeInGB(static_cast<int32_t>(strtol(argv[kSharedMemorySizeArgIdx], nullptr, ds::kDecimal)))
|
||||
.SetLogLevel(static_cast<int8_t>((strtol(argv[kLogLevelArgIdx], nullptr, ds::kDecimal))))
|
||||
.SetMemoryCapRatio(strtof(argv[kMemoryCapRatioArgIdx], nullptr));
|
||||
.SetSharedMemorySizeInGB(static_cast<int32_t>(strtol(argv[ArgIndex::kSharedMemorySize], nullptr, ds::kDecimal)))
|
||||
.SetLogLevel(static_cast<int8_t>((strtol(argv[ArgIndex::kLogLevel], nullptr, ds::kDecimal))))
|
||||
.SetMemoryCapRatio(strtof(argv[ArgIndex::kMemoryCapRatio], nullptr));
|
||||
|
||||
auto daemonize_string = argv[kDemonizeArgIdx];
|
||||
auto daemonize_string = argv[ArgIndex::kDemonize];
|
||||
bool daemonize = strcmp(daemonize_string, "true") == 0 || strcmp(daemonize_string, "TRUE") == 0 ||
|
||||
strcmp(daemonize_string, "t") == 0 || strcmp(daemonize_string, "T") == 0;
|
||||
|
||||
|
@ -81,8 +84,8 @@ ms::Status StartServer(int argc, char **argv) {
|
|||
return rc;
|
||||
}
|
||||
ms::g_ms_submodule_log_levels[SUBMODULE_ID] =
|
||||
static_cast<int>(strtol(argv[kLogLevelArgIdx], nullptr, ds::kDecimal));
|
||||
google::InitGoogleLogging(argv[kProcessNameIdx]);
|
||||
static_cast<int>(strtol(argv[ArgIndex::kLogLevel], nullptr, ds::kDecimal));
|
||||
google::InitGoogleLogging(argv[ArgIndex::kProcessName]);
|
||||
#undef google
|
||||
#endif
|
||||
rc = msg.Create();
|
||||
|
|
|
@ -334,14 +334,14 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
|
|||
constexpr int32_t kMinBufDataSize = 3;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() >= kMinBufDataSize, "Incomplete data");
|
||||
// First one is cookie, followed by data address and then size.
|
||||
enum { kCookieIdx = 0, kAddrIdx = 1, kSizeIdx = 2 };
|
||||
enum BufDataIndex : uint8_t { kCookie = 0, kAddr = 1, kSize = 2 };
|
||||
// First piece of data is the cookie and is required
|
||||
auto &cookie = rq->buf_data(kCookieIdx);
|
||||
auto &cookie = rq->buf_data(BufDataIndex::kCookie);
|
||||
// Second piece of data is the address where we can find the serialized data
|
||||
auto addr = strtoll(rq->buf_data(kAddrIdx).data(), nullptr, kDecimal);
|
||||
auto addr = strtoll(rq->buf_data(BufDataIndex::kAddr).data(), nullptr, kDecimal);
|
||||
auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr);
|
||||
// Third piece of data is the size of the serialized data that we need to transfer
|
||||
auto sz = strtoll(rq->buf_data(kSizeIdx).data(), nullptr, kDecimal);
|
||||
auto sz = strtoll(rq->buf_data(BufDataIndex::kSize).data(), nullptr, kDecimal);
|
||||
// Successful or not, we need to free the memory on exit.
|
||||
Status rc;
|
||||
if (cs == nullptr) {
|
||||
|
@ -411,12 +411,12 @@ Status CacheServer::InternalFetchRow(CacheRequest *rq) {
|
|||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
}
|
||||
// First piece is a flatbuffer containing row fetch information, second piece is the address of the BatchWait ptr
|
||||
enum { kFetchRowMsgIdx = 0, kBatchWaitIdx = 1 };
|
||||
rc = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq->buf_data(kFetchRowMsgIdx).data()));
|
||||
enum BufDataIndex : uint8_t { kFetchRowMsg = 0, kBatchWait = 1 };
|
||||
rc = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq->buf_data(BufDataIndex::kFetchRowMsg).data()));
|
||||
// This is an internal request and is not tied to rpc. But need to post because there
|
||||
// is a thread waiting on the completion of this request.
|
||||
try {
|
||||
int64_t addr = strtol(rq->buf_data(kBatchWaitIdx).data(), nullptr, kDecimal);
|
||||
int64_t addr = strtol(rq->buf_data(BufDataIndex::kBatchWait).data(), nullptr, kDecimal);
|
||||
auto *bw = reinterpret_cast<BatchWait *>(addr);
|
||||
// Check if the object is still around.
|
||||
auto bwObj = bw->GetBatchWait();
|
||||
|
@ -755,19 +755,19 @@ Status CacheServer::ConnectReset(CacheRequest *rq) {
|
|||
|
||||
Status CacheServer::BatchCacheRows(CacheRequest *rq) {
|
||||
// First one is cookie, followed by address and then size.
|
||||
enum { kCookieIdx = 0, kAddrIdx = 1, kSizeIdx = 2 };
|
||||
enum BufDataIndex : uint8_t { kCookie = 0, kAddr = 1, kSize = 2 };
|
||||
constexpr int32_t kExpectedBufDataSize = 3;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data().size() == kExpectedBufDataSize, "Expect three pieces of data");
|
||||
try {
|
||||
auto &cookie = rq->buf_data(kCookieIdx);
|
||||
auto &cookie = rq->buf_data(BufDataIndex::kCookie);
|
||||
auto connection_id = rq->connection_id();
|
||||
auto client_id = rq->client_id();
|
||||
int64_t offset_addr;
|
||||
int32_t num_elem;
|
||||
auto *base = SharedMemoryBaseAddr();
|
||||
offset_addr = strtoll(rq->buf_data(kAddrIdx).data(), nullptr, kDecimal);
|
||||
offset_addr = strtoll(rq->buf_data(BufDataIndex::kAddr).data(), nullptr, kDecimal);
|
||||
auto p = reinterpret_cast<char *>(reinterpret_cast<int64_t>(base) + offset_addr);
|
||||
num_elem = static_cast<int32_t>(strtol(rq->buf_data(kSizeIdx).data(), nullptr, kDecimal));
|
||||
num_elem = static_cast<int32_t>(strtol(rq->buf_data(BufDataIndex::kSize).data(), nullptr, kDecimal));
|
||||
auto batch_wait = std::make_shared<BatchWait>(num_elem);
|
||||
// Get a set of free request and push into the queues.
|
||||
for (auto i = 0; i < num_elem; ++i) {
|
||||
|
|
|
@ -103,7 +103,7 @@ class CacheServer : public Service {
|
|||
|
||||
void Print(std::ostream &out) const {
|
||||
out << "Summary of the cache server configuration\n"
|
||||
<< "Spill directory: " << GetTop() << "\n"
|
||||
<< "Spill directory: " << (GetTop().empty() ? "None" : GetTop()) << "\n"
|
||||
<< "Number of parallel workers: " << GetNumWorkers() << "\n"
|
||||
<< "Tcp/ip port: " << GetPort() << "\n"
|
||||
<< "Shared memory size (in GB): " << GetSharedMemorySzInGb() << "\n"
|
||||
|
|
|
@ -147,6 +147,11 @@ class FuncWrapper:
|
|||
if not callable(transform):
|
||||
raise ValueError("FuncWrapper only support warping callable python function.")
|
||||
self.transform = transform
|
||||
try:
|
||||
if hasattr(self.transform, "random") and not self.transform.random:
|
||||
self.random = False
|
||||
except Exception:
|
||||
self.random = True
|
||||
|
||||
def __call__(self, *args):
|
||||
result = None
|
||||
|
|
Loading…
Reference in New Issue