!17657 Fix a bug in cache multi-session destroy and fix random check for python functions

From: @lixiachen
Reviewed-by: @john_tzanakakis,@robingrosman
Signed-off-by: @robingrosman
This commit is contained in:
mindspore-ci-bot 2021-06-05 03:16:23 +08:00 committed by Gitee
commit a0822767a1
5 changed files with 47 additions and 40 deletions

View File

@ -125,13 +125,12 @@ Status CacheAdminArgHandler::AssignArg(std::string option, std::vector<uint32_t>
uint32_t value_as_uint;
while (arg_stream->rdbuf()->in_avail() != 0) {
std::stringstream::pos_type pos = arg_stream->tellg();
*arg_stream >> value_as_uint;
if (arg_stream->fail()) {
arg_stream->clear();
std::string value_as_string;
*arg_stream >> value_as_string;
std::string err_msg = "Invalid numeric value: " + value_as_string;
return Status(StatusCode::kMDSyntaxError, err_msg);
arg_stream->seekg(pos, arg_stream->beg);
break;
} else {
out_arg->push_back(value_as_uint);
}
@ -492,18 +491,18 @@ Status CacheAdminArgHandler::ShowServerInfo() {
if (spill_dir.empty()) spill_dir = "None";
int name_w = 20;
int value_w = 15;
int value_w = 50;
std::cout << "Cache Server Configuration: " << std::endl;
std::cout << "----------------------------------------" << std::endl;
std::cout << std::string(name_w + value_w, '-') << std::endl;
std::cout << std::setw(name_w) << "config name" << std::setw(value_w) << "value" << std::endl;
std::cout << "----------------------------------------" << std::endl;
std::cout << std::string(name_w + value_w, '-') << std::endl;
std::cout << std::setw(name_w) << "hostname" << std::setw(value_w) << hostname_ << std::endl;
std::cout << std::setw(name_w) << "port" << std::setw(value_w) << port_ << std::endl;
std::cout << std::setw(name_w) << "number of workers" << std::setw(value_w) << std::to_string(num_workers)
<< std::endl;
std::cout << std::setw(name_w) << "log level" << std::setw(value_w) << std::to_string(log_level) << std::endl;
std::cout << std::setw(name_w) << "spill dir" << std::setw(value_w) << spill_dir << std::endl;
std::cout << "----------------------------------------" << std::endl;
std::cout << std::string(name_w + value_w, '-') << std::endl;
std::cout << "Active sessions: " << std::endl;
if (!session_ids.empty()) {

View File

@ -27,36 +27,39 @@
namespace ms = mindspore;
namespace ds = mindspore::dataset;
namespace {
const int32_t kTotalArgs = 8;
enum ArgIndex : uint8_t {
kProcessName = 0,
kRootDir = 1,
kNumWorkers = 2,
kPort = 3,
kSharedMemorySize = 4,
kLogLevel = 5,
kDemonize = 6,
kMemoryCapRatio = 7
};
} // namespace
/// Start the server
/// \param argv
/// \return Status object
ms::Status StartServer(int argc, char **argv) {
ms::Status rc;
ds::CacheServer::Builder builder;
const int32_t kTotalArgs = 8;
enum {
kProcessNameIdx = 0,
kRootDirArgIdx = 1,
kNumWorkersArgIdx = 2,
kPortArgIdx = 3,
kSharedMemorySizeArgIdx = 4,
kLogLevelArgIdx = 5,
kDemonizeArgIdx = 6,
kMemoryCapRatioArgIdx = 7
};
if (argc != kTotalArgs) {
return ms::Status(ms::StatusCode::kMDSyntaxError);
}
int32_t port = static_cast<int32_t>(strtol(argv[kPortArgIdx], nullptr, ds::kDecimal));
builder.SetRootDirectory(argv[kRootDirArgIdx])
.SetNumWorkers(static_cast<int32_t>(strtol(argv[kNumWorkersArgIdx], nullptr, ds::kDecimal)))
int32_t port = static_cast<int32_t>(strtol(argv[ArgIndex::kPort], nullptr, ds::kDecimal));
builder.SetRootDirectory(argv[ArgIndex::kRootDir])
.SetNumWorkers(static_cast<int32_t>(strtol(argv[ArgIndex::kNumWorkers], nullptr, ds::kDecimal)))
.SetPort(port)
.SetSharedMemorySizeInGB(static_cast<int32_t>(strtol(argv[kSharedMemorySizeArgIdx], nullptr, ds::kDecimal)))
.SetLogLevel(static_cast<int8_t>((strtol(argv[kLogLevelArgIdx], nullptr, ds::kDecimal))))
.SetMemoryCapRatio(strtof(argv[kMemoryCapRatioArgIdx], nullptr));
.SetSharedMemorySizeInGB(static_cast<int32_t>(strtol(argv[ArgIndex::kSharedMemorySize], nullptr, ds::kDecimal)))
.SetLogLevel(static_cast<int8_t>((strtol(argv[ArgIndex::kLogLevel], nullptr, ds::kDecimal))))
.SetMemoryCapRatio(strtof(argv[ArgIndex::kMemoryCapRatio], nullptr));
auto daemonize_string = argv[kDemonizeArgIdx];
auto daemonize_string = argv[ArgIndex::kDemonize];
bool daemonize = strcmp(daemonize_string, "true") == 0 || strcmp(daemonize_string, "TRUE") == 0 ||
strcmp(daemonize_string, "t") == 0 || strcmp(daemonize_string, "T") == 0;
@ -81,8 +84,8 @@ ms::Status StartServer(int argc, char **argv) {
return rc;
}
ms::g_ms_submodule_log_levels[SUBMODULE_ID] =
static_cast<int>(strtol(argv[kLogLevelArgIdx], nullptr, ds::kDecimal));
google::InitGoogleLogging(argv[kProcessNameIdx]);
static_cast<int>(strtol(argv[ArgIndex::kLogLevel], nullptr, ds::kDecimal));
google::InitGoogleLogging(argv[ArgIndex::kProcessName]);
#undef google
#endif
rc = msg.Create();

View File

@ -334,14 +334,14 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
constexpr int32_t kMinBufDataSize = 3;
CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() >= kMinBufDataSize, "Incomplete data");
// First one is cookie, followed by data address and then size.
enum { kCookieIdx = 0, kAddrIdx = 1, kSizeIdx = 2 };
enum BufDataIndex : uint8_t { kCookie = 0, kAddr = 1, kSize = 2 };
// First piece of data is the cookie and is required
auto &cookie = rq->buf_data(kCookieIdx);
auto &cookie = rq->buf_data(BufDataIndex::kCookie);
// Second piece of data is the address where we can find the serialized data
auto addr = strtoll(rq->buf_data(kAddrIdx).data(), nullptr, kDecimal);
auto addr = strtoll(rq->buf_data(BufDataIndex::kAddr).data(), nullptr, kDecimal);
auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr);
// Third piece of data is the size of the serialized data that we need to transfer
auto sz = strtoll(rq->buf_data(kSizeIdx).data(), nullptr, kDecimal);
auto sz = strtoll(rq->buf_data(BufDataIndex::kSize).data(), nullptr, kDecimal);
// Successful or not, we need to free the memory on exit.
Status rc;
if (cs == nullptr) {
@ -411,12 +411,12 @@ Status CacheServer::InternalFetchRow(CacheRequest *rq) {
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
}
// First piece is a flatbuffer containing row fetch information, second piece is the address of the BatchWait ptr
enum { kFetchRowMsgIdx = 0, kBatchWaitIdx = 1 };
rc = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq->buf_data(kFetchRowMsgIdx).data()));
enum BufDataIndex : uint8_t { kFetchRowMsg = 0, kBatchWait = 1 };
rc = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq->buf_data(BufDataIndex::kFetchRowMsg).data()));
// This is an internal request and is not tied to rpc. But need to post because there
// is a thread waiting on the completion of this request.
try {
int64_t addr = strtol(rq->buf_data(kBatchWaitIdx).data(), nullptr, kDecimal);
int64_t addr = strtol(rq->buf_data(BufDataIndex::kBatchWait).data(), nullptr, kDecimal);
auto *bw = reinterpret_cast<BatchWait *>(addr);
// Check if the object is still around.
auto bwObj = bw->GetBatchWait();
@ -755,19 +755,19 @@ Status CacheServer::ConnectReset(CacheRequest *rq) {
Status CacheServer::BatchCacheRows(CacheRequest *rq) {
// First one is cookie, followed by address and then size.
enum { kCookieIdx = 0, kAddrIdx = 1, kSizeIdx = 2 };
enum BufDataIndex : uint8_t { kCookie = 0, kAddr = 1, kSize = 2 };
constexpr int32_t kExpectedBufDataSize = 3;
CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data().size() == kExpectedBufDataSize, "Expect three pieces of data");
try {
auto &cookie = rq->buf_data(kCookieIdx);
auto &cookie = rq->buf_data(BufDataIndex::kCookie);
auto connection_id = rq->connection_id();
auto client_id = rq->client_id();
int64_t offset_addr;
int32_t num_elem;
auto *base = SharedMemoryBaseAddr();
offset_addr = strtoll(rq->buf_data(kAddrIdx).data(), nullptr, kDecimal);
offset_addr = strtoll(rq->buf_data(BufDataIndex::kAddr).data(), nullptr, kDecimal);
auto p = reinterpret_cast<char *>(reinterpret_cast<int64_t>(base) + offset_addr);
num_elem = static_cast<int32_t>(strtol(rq->buf_data(kSizeIdx).data(), nullptr, kDecimal));
num_elem = static_cast<int32_t>(strtol(rq->buf_data(BufDataIndex::kSize).data(), nullptr, kDecimal));
auto batch_wait = std::make_shared<BatchWait>(num_elem);
// Get a set of free request and push into the queues.
for (auto i = 0; i < num_elem; ++i) {

View File

@ -103,7 +103,7 @@ class CacheServer : public Service {
void Print(std::ostream &out) const {
out << "Summary of the cache server configuration\n"
<< "Spill directory: " << GetTop() << "\n"
<< "Spill directory: " << (GetTop().empty() ? "None" : GetTop()) << "\n"
<< "Number of parallel workers: " << GetNumWorkers() << "\n"
<< "Tcp/ip port: " << GetPort() << "\n"
<< "Shared memory size (in GB): " << GetSharedMemorySzInGb() << "\n"

View File

@ -147,6 +147,11 @@ class FuncWrapper:
if not callable(transform):
raise ValueError("FuncWrapper only support warping callable python function.")
self.transform = transform
try:
if hasattr(self.transform, "random") and not self.transform.random:
self.random = False
except Exception:
self.random = True
def __call__(self, *args):
result = None