From 72a1df88726834fda7aeb2e5c5103b0b768a72a9 Mon Sep 17 00:00:00 2001 From: chendongsheng Date: Sat, 3 Jul 2021 15:46:14 +0800 Subject: [PATCH] fixed pclint --- .../core/communicator/http_message_handler.cc | 30 ++++++++++--------- .../core/communicator/http_message_handler.h | 22 +++++++------- .../ccsrc/ps/core/communicator/tcp_client.cc | 4 +-- mindspore/ccsrc/ps/core/node.cc | 4 +-- mindspore/ccsrc/ps/core/node_manager.cc | 23 +++++++------- mindspore/ccsrc/ps/core/node_recovery.cc | 1 - mindspore/ccsrc/ps/core/recovery_base.cc | 2 -- mindspore/ccsrc/ps/core/scheduler_node.cc | 23 ++++++++------ mindspore/ccsrc/ps/parameter_server.cc | 20 ++++++++----- .../ccsrc/ps/ps_cache/ps_cache_manager.cc | 24 ++++++++------- mindspore/ccsrc/ps/scheduler.cc | 14 +++++++-- 11 files changed, 93 insertions(+), 74 deletions(-) diff --git a/mindspore/ccsrc/ps/core/communicator/http_message_handler.cc b/mindspore/ccsrc/ps/core/communicator/http_message_handler.cc index c47c4227e64..0346b9aa046 100644 --- a/mindspore/ccsrc/ps/core/communicator/http_message_handler.cc +++ b/mindspore/ccsrc/ps/core/communicator/http_message_handler.cc @@ -45,7 +45,10 @@ void HttpMessageHandler::InitHttpMessage() { const char *query = evhttp_uri_get_query(event_uri_); if (query != nullptr) { MS_LOG(WARNING) << "The query is:" << query; - evhttp_parse_query_str(query, &path_params_); + int result = evhttp_parse_query_str(query, &path_params_); + if (result < 0) { + MS_LOG(ERROR) << "Http parse query:" << query << " failed."; + } } head_params_ = evhttp_request_get_input_headers(event_request_); @@ -58,14 +61,14 @@ void HttpMessageHandler::ParseUrl(const std::string &url) { MS_EXCEPTION_IF_NULL(event_uri_); } -std::string HttpMessageHandler::GetHeadParam(const std::string &key) { +std::string HttpMessageHandler::GetHeadParam(const std::string &key) const { MS_EXCEPTION_IF_NULL(head_params_); const char *val = evhttp_find_header(head_params_, key.c_str()); MS_EXCEPTION_IF_NULL(val); return std::string(val); } -std::string HttpMessageHandler::GetPathParam(const std::string &key) { +std::string HttpMessageHandler::GetPathParam(const std::string &key) const { const char *val = evhttp_find_header(&path_params_, key.c_str()); MS_EXCEPTION_IF_NULL(val); return std::string(val); @@ -134,7 +137,7 @@ std::string HttpMessageHandler::GetPostParam(const std::string &key) { return std::string(val); } -std::string HttpMessageHandler::GetRequestUri() { +std::string HttpMessageHandler::GetRequestUri() const { MS_EXCEPTION_IF_NULL(event_request_); const char *uri = evhttp_request_get_uri(event_request_); MS_EXCEPTION_IF_NULL(uri); @@ -148,14 +151,14 @@ std::string HttpMessageHandler::GetRequestHost() { return std::string(host); } -const char *HttpMessageHandler::GetHostByUri() { +const char *HttpMessageHandler::GetHostByUri() const { MS_EXCEPTION_IF_NULL(event_uri_); const char *host = evhttp_uri_get_host(event_uri_); MS_EXCEPTION_IF_NULL(host); return host; } -int HttpMessageHandler::GetUriPort() { +int HttpMessageHandler::GetUriPort() const { MS_EXCEPTION_IF_NULL(event_uri_); int port = evhttp_uri_get_port(event_uri_); if (port < 0) { @@ -164,7 +167,7 @@ int HttpMessageHandler::GetUriPort() { return port; } -std::string HttpMessageHandler::GetUriPath() { +std::string HttpMessageHandler::GetUriPath() const { MS_EXCEPTION_IF_NULL(event_uri_); const char *path = evhttp_uri_get_path(event_uri_); MS_EXCEPTION_IF_NULL(path); @@ -186,7 +189,7 @@ std::string HttpMessageHandler::GetRequestPath() { return path_res; } -std::string HttpMessageHandler::GetUriQuery() { +std::string HttpMessageHandler::GetUriQuery() const { MS_EXCEPTION_IF_NULL(event_uri_); const char *query = evhttp_uri_get_query(event_uri_); MS_EXCEPTION_IF_NULL(query); @@ -259,8 +262,7 @@ void HttpMessageHandler::SimpleResponse(int code, const HttpHeaders &headers, co MS_EXCEPTION_IF_NULL(resp_buf_); AddRespHeaders(headers); AddRespString(body); - MS_EXCEPTION_IF_NULL(resp_buf_); - evhttp_send_reply(event_request_, resp_code_, nullptr, resp_buf_); + evhttp_send_reply(event_request_, code, nullptr, resp_buf_); } void HttpMessageHandler::ErrorResponse(int code, RequestProcessResult result) { @@ -293,9 +295,9 @@ void HttpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { void HttpMessageHandler::set_content_len(const uint64_t &len) { content_len_ = len; } -uint64_t HttpMessageHandler::content_len() { return content_len_; } +uint64_t HttpMessageHandler::content_len() const { return content_len_; } -const event_base *HttpMessageHandler::http_base() { return event_base_; } +const event_base *HttpMessageHandler::http_base() const { return event_base_; } void HttpMessageHandler::set_http_base(const struct event_base *base) { MS_EXCEPTION_IF_NULL(base); @@ -307,13 +309,13 @@ void HttpMessageHandler::set_request(const struct evhttp_request *req) { event_request_ = const_cast(req); } -const struct evhttp_request *HttpMessageHandler::request() { return event_request_; } +const struct evhttp_request *HttpMessageHandler::request() const { return event_request_; } void HttpMessageHandler::InitBodySize() { body_->resize(content_len()); } std::shared_ptr> HttpMessageHandler::body() { return body_; } -void HttpMessageHandler::set_body(std::shared_ptr> body) { body_ = body; } +void HttpMessageHandler::set_body(const VectorPtr &body) { body_ = body; } const nlohmann::json &HttpMessageHandler::request_message() const { return request_message_; } diff --git a/mindspore/ccsrc/ps/core/communicator/http_message_handler.h b/mindspore/ccsrc/ps/core/communicator/http_message_handler.h index 648d5b66d25..1d4f260ce9f 100644 --- a/mindspore/ccsrc/ps/core/communicator/http_message_handler.h +++ b/mindspore/ccsrc/ps/core/communicator/http_message_handler.h @@ -69,19 +69,19 @@ class HttpMessageHandler { void InitHttpMessage(); void ParseUrl(const std::string &url); - std::string GetRequestUri(); + std::string GetRequestUri() const; std::string GetRequestHost(); - const char *GetHostByUri(); - std::string GetHeadParam(const std::string &key); - std::string GetPathParam(const std::string &key); + const char *GetHostByUri() const; + std::string GetHeadParam(const std::string &key) const; + std::string GetPathParam(const std::string &key) const; std::string GetPostParam(const std::string &key); uint64_t GetPostMsg(unsigned char **buffer); - std::string GetUriPath(); + std::string GetUriPath() const; std::string GetRequestPath(); - std::string GetUriQuery(); + std::string GetUriQuery() const; // It will return -1 if no port set - int GetUriPort(); + int GetUriPort() const; // Useless to get from a request url, fragment is only for browser to locate sth. std::string GetUriFragment(); @@ -104,14 +104,14 @@ class HttpMessageHandler { RequestProcessResult ParsePostMessageToJson(); void ReceiveMessage(const void *buffer, size_t num); void set_content_len(const uint64_t &len); - uint64_t content_len(); - const event_base *http_base(); + uint64_t content_len() const; + const event_base *http_base() const; void set_http_base(const struct event_base *base); void set_request(const struct evhttp_request *req); - const struct evhttp_request *request(); + const struct evhttp_request *request() const; void InitBodySize(); VectorPtr body(); - void set_body(VectorPtr body); + void set_body(const VectorPtr &body); const nlohmann::json &request_message() const; RequestProcessResult ParseValueFromKey(const std::string &key, int32_t *const value); diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_client.cc b/mindspore/ccsrc/ps/core/communicator/tcp_client.cc index d47e2dd7dfa..2c89ca2bea6 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_client.cc +++ b/mindspore/ccsrc/ps/core/communicator/tcp_client.cc @@ -71,8 +71,8 @@ void TcpClient::set_connected_callback(const OnConnected &connected) { connected bool TcpClient::WaitConnected(const uint32_t &connected_timeout) { std::unique_lock lock(connection_mutex_); - bool res = - connection_cond_.wait_for(lock, std::chrono::seconds(connected_timeout), [&] { return is_connected_.load(); }); + bool res = connection_cond_.wait_for(lock, std::chrono::seconds(connected_timeout), + [this] { return this->is_connected_.load(); }); return res; } diff --git a/mindspore/ccsrc/ps/core/node.cc b/mindspore/ccsrc/ps/core/node.cc index d6f7ed457c2..f4c406c8b00 100644 --- a/mindspore/ccsrc/ps/core/node.cc +++ b/mindspore/ccsrc/ps/core/node.cc @@ -31,8 +31,8 @@ std::string Node::BoundIp() const { return node_info_.ip_; } bool Node::WaitForStart(const uint32_t &timeout) { std::unique_lock lock(wait_start_mutex_); - bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { - bool res = is_ready_.load(); + bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [this] { + bool res = this->is_ready_.load(); if (res) { MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success start!"; } diff --git a/mindspore/ccsrc/ps/core/node_manager.cc b/mindspore/ccsrc/ps/core/node_manager.cc index 1e8759d72d4..5eff34a9836 100644 --- a/mindspore/ccsrc/ps/core/node_manager.cc +++ b/mindspore/ccsrc/ps/core/node_manager.cc @@ -50,7 +50,6 @@ uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message) { } return res; }); - if (rank_it == registered_nodes_info_.end()) { rank_id = ++next_server_rank_id_; } else { @@ -76,19 +75,19 @@ uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message) { const std::string &ip = register_message.ip(); uint32_t port = register_message.port(); - auto rank_it = std::find_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [&rank_id](auto item) { - bool res = item.second.is_alive == false && item.second.node_role_ == NodeRole::WORKER; - if (res) { - MS_LOG(INFO) << "The worker node id:" << item.first << " rank id:" << rank_id << " is not alive."; - rank_id = item.second.rank_id_; - } - return res; - }); - - if (rank_it == registered_nodes_info_.end()) { + auto worker_rank_it = + std::find_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [&rank_id](auto item) { + bool res = item.second.is_alive == false && item.second.node_role_ == NodeRole::WORKER; + if (res) { + MS_LOG(INFO) << "The worker node id:" << item.first << " rank id:" << rank_id << " is not alive."; + rank_id = item.second.rank_id_; + } + return res; + }); + if (worker_rank_it == registered_nodes_info_.end()) { rank_id = ++next_worker_rank_id_; } else { - registered_nodes_info_.erase((*rank_it).first); + registered_nodes_info_.erase((*worker_rank_it).first); } if (rank_id >= meta_data_->worker_num) { diff --git a/mindspore/ccsrc/ps/core/node_recovery.cc b/mindspore/ccsrc/ps/core/node_recovery.cc index 51f35e63491..6ff0e62600d 100644 --- a/mindspore/ccsrc/ps/core/node_recovery.cc +++ b/mindspore/ccsrc/ps/core/node_recovery.cc @@ -19,7 +19,6 @@ namespace mindspore { namespace ps { namespace core { - bool NodeRecovery::Recover() { if (recovery_storage_ == nullptr) { return false; diff --git a/mindspore/ccsrc/ps/core/recovery_base.cc b/mindspore/ccsrc/ps/core/recovery_base.cc index b0e15bcdd70..8eec45352fd 100644 --- a/mindspore/ccsrc/ps/core/recovery_base.cc +++ b/mindspore/ccsrc/ps/core/recovery_base.cc @@ -19,7 +19,6 @@ namespace mindspore { namespace ps { namespace core { - void RecoveryBase::Initialize(const std::string &config_json) { nlohmann::json recovery_config; try { @@ -35,7 +34,6 @@ void RecoveryBase::Initialize(const std::string &config_json) { storage_type_ = StorageType::kFileStorage; storage_file_path = recovery_config.at(kStoreFilePath); - if (storage_file_path == "") { MS_LOG(EXCEPTION) << "If the scheduler support recovery, and if the persistent storage is a file, the path of " "the file must be configured"; diff --git a/mindspore/ccsrc/ps/core/scheduler_node.cc b/mindspore/ccsrc/ps/core/scheduler_node.cc index b934b07a101..337c197d2b8 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.cc +++ b/mindspore/ccsrc/ps/core/scheduler_node.cc @@ -21,7 +21,9 @@ namespace ps { namespace core { SchedulerNode::~SchedulerNode() { MS_LOG(INFO) << "Stop scheduler node!"; - Stop(); + if (!Stop()) { + MS_LOG(WARNING) << "Scheduler node stop failed."; + } } bool SchedulerNode::Start(const uint32_t &timeout) { @@ -66,8 +68,10 @@ void SchedulerNode::ProcessHeartbeat(std::shared_ptr server, std::sha heartbeat_resp_message.set_is_worker_or_server0(node_manager_.IsWorkerOrServer0()); - server->SendMessage(conn, meta, Protos::PROTOBUF, heartbeat_resp_message.SerializeAsString().data(), - heartbeat_resp_message.ByteSizeLong()); + if (!server->SendMessage(conn, meta, Protos::PROTOBUF, heartbeat_resp_message.SerializeAsString().data(), + heartbeat_resp_message.ByteSizeLong())) { + MS_LOG(WARNING) << "Send heart beat failed."; + } } void SchedulerNode::Initialize() { @@ -98,7 +102,7 @@ void SchedulerNode::CreateTcpServer() { uint32_t scheduler_port = PSContext::instance()->cluster_config().scheduler_port; server_ = std::make_shared(scheduler_host, scheduler_port); server_->SetMessageCallback([&](std::shared_ptr conn, std::shared_ptr meta, - const Protos &protos, const void *data, size_t size) { + const Protos &, const void *data, size_t size) { if (handlers_.count(meta->cmd()) == 0) { MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!"; } @@ -122,7 +126,9 @@ void SchedulerNode::ProcessRegister(std::shared_ptr server, std::shar MS_EXCEPTION_IF_NULL(data); MS_LOG(INFO) << "The scheduler process a register message!"; RegisterMessage register_message; - register_message.ParseFromArray(data, size); + if (!register_message.ParseFromArray(data, SizeToInt(size))) { + MS_LOG(WARNING) << "Parse data failed."; + } // assign worker node and server node rank id uint32_t rank_id = node_manager_.NextRankId(register_message); @@ -179,7 +185,6 @@ void SchedulerNode::ProcessFinish(std::shared_ptr server, std::shared } return false; }); - if (iter != scale_in_node_ids_.end()) { return; } @@ -450,11 +455,11 @@ bool SchedulerNode::Stop() { bool SchedulerNode::Finish(const uint32_t &) { MS_LOG(INFO) << "[Scheduler finish]: 1. Begin to finish scheduler node!"; std::unique_lock lock(wait_finish_mutex_); - wait_finish_cond_.wait(lock, [&] { - if (is_finish_.load()) { + wait_finish_cond_.wait(lock, [this] { + if (this->is_finish_.load()) { MS_LOG(INFO) << "[Scheduler finish]: 2. Successfully finish scheduler!"; } - return is_finish_.load(); + return this->is_finish_.load(); }); return true; } diff --git a/mindspore/ccsrc/ps/parameter_server.cc b/mindspore/ccsrc/ps/parameter_server.cc index f7671271901..b89e9cce10b 100644 --- a/mindspore/ccsrc/ps/parameter_server.cc +++ b/mindspore/ccsrc/ps/parameter_server.cc @@ -35,7 +35,9 @@ void ParameterServer::Run(const FuncGraphPtr &func_graph) { SyncEmbeddingTables(); MS_LOG(INFO) << "PServer finished updating models, starts finalizing..."; server_node_->Finish(); - server_node_->Stop(); + if (!server_node_->Stop()) { + MS_LOG(WARNING) << "Parameter server stop failed."; + } MS_LOG(INFO) << "PServer finalized successfully."; } @@ -561,7 +563,9 @@ void ParameterServer::ServerHandler::HandleInitWeights(DataPtr data, size_t size std::unique_lock lock(ps_->mutex()); MS_EXCEPTION_IF_NULL(res); KVMessage input; - input.ParseFromArray(data.get(), size); + if (!input.ParseFromArray(data.get(), SizeToInt(size))) { + MS_LOG(WARNING) << "Parse data failed."; + } int key_num = input.keys_size(); const float *data_ptr = input.values().data(); size_t pos = 0; @@ -586,9 +590,11 @@ void ParameterServer::ServerHandler::HandleInitWeightToOptimId(DataPtr data, siz std::unique_lock lock(ps_->mutex()); MS_EXCEPTION_IF_NULL(res); KVMessage input; - input.ParseFromArray(data.get(), size); - size_t key_num = input.keys_size(); - for (size_t i = 0; i < key_num; i++) { + if (!input.ParseFromArray(data.get(), SizeToInt(size))) { + MS_LOG(WARNING) << "Parse data failed."; + } + int key_num = input.keys_size(); + for (int i = 0; i < key_num; i++) { Key key = input.keys()[i]; float val = input.values()[i]; if (init_weight_to_optim_[key]) { @@ -596,7 +602,7 @@ void ParameterServer::ServerHandler::HandleInitWeightToOptimId(DataPtr data, siz } else { init_weight_to_optim_[key] = true; } - ps_->InitWeightKeyToOptims(key, val); + ps_->InitWeightKeyToOptims(key, static_cast(val)); } } @@ -722,7 +728,7 @@ void ParameterServer::ServerHandler::HandleUpdateEmbeddings(DataPtr data, size_t ps_->UpdateEmbeddings(key, lookup_ids, update_vals); } -void ParameterServer::ServerHandler::HandleFinalize(DataPtr data, size_t size, VectorPtr res) { +void ParameterServer::ServerHandler::HandleFinalize(DataPtr, size_t, VectorPtr res) { MS_EXCEPTION_IF_NULL(res); ps_->Finalize(); } diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc index e89b758df8c..8f45835edfc 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc @@ -242,17 +242,19 @@ void PsCacheManager::AllocMemForHashTable() { void PsCacheManager::SetLocalIdRank() { auto worker_num = PSContext::instance()->initial_worker_num(); - auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num)); - vocab_cache_size_diff_ = local_shard_size - SizeToInt(vocab_cache_size_); - emb_table_slice_bounds_.first = local_shard_size * rank_id_; - emb_table_slice_bounds_.second = std::min(emb_table_slice_bounds_.first + local_shard_size, SizeToInt(vocab_size_)); - cache_indices_bounds_.first = SizeToInt(vocab_cache_size_) * rank_id_; - cache_indices_bounds_.second = cache_indices_bounds_.first + SizeToInt(vocab_cache_size_); - MS_LOG(INFO) << "Worker num:" << worker_num << ", rank id:" << rank_id_ - << ", id begin:" << emb_table_slice_bounds_.first << ", id end:" << emb_table_slice_bounds_.second - << ", cache indices begin: " << cache_indices_bounds_.first - << ", cache indices end: " << cache_indices_bounds_.second - << ", vocab_cache_size_diff: " << vocab_cache_size_diff_; + if (worker_num > 0) { + auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num)); + vocab_cache_size_diff_ = local_shard_size - SizeToInt(vocab_cache_size_); + emb_table_slice_bounds_.first = local_shard_size * rank_id_; + emb_table_slice_bounds_.second = std::min(emb_table_slice_bounds_.first + local_shard_size, SizeToInt(vocab_size_)); + cache_indices_bounds_.first = SizeToInt(vocab_cache_size_) * rank_id_; + cache_indices_bounds_.second = cache_indices_bounds_.first + SizeToInt(vocab_cache_size_); + MS_LOG(INFO) << "Worker num:" << worker_num << ", rank id:" << rank_id_ + << ", id begin:" << emb_table_slice_bounds_.first << ", id end:" << emb_table_slice_bounds_.second + << ", cache indices begin: " << cache_indices_bounds_.first + << ", cache indices end: " << cache_indices_bounds_.second + << ", vocab_cache_size_diff: " << vocab_cache_size_diff_; + } } int PsCacheManager::cache_indices_lower_bound() const { return cache_indices_bounds_.first; } diff --git a/mindspore/ccsrc/ps/scheduler.cc b/mindspore/ccsrc/ps/scheduler.cc index dce9a3ce9b5..537c10b37e8 100755 --- a/mindspore/ccsrc/ps/scheduler.cc +++ b/mindspore/ccsrc/ps/scheduler.cc @@ -24,9 +24,17 @@ void Scheduler::Run() { PSContext::instance()->cluster_config().scheduler_port = PSContext::instance()->scheduler_port(); PSContext::instance()->cluster_config().initial_worker_num = PSContext::instance()->initial_worker_num(); PSContext::instance()->cluster_config().initial_server_num = PSContext::instance()->initial_server_num(); - scheduler_node_.Start(); - scheduler_node_.Finish(); - scheduler_node_.Stop(); + if (!scheduler_node_.Start()) { + MS_LOG(WARNING) << "Scheduler start failed."; + } + + if (!scheduler_node_.Finish()) { + MS_LOG(WARNING) << "Scheduler finis failed."; + } + + if (!scheduler_node_.Stop()) { + MS_LOG(WARNING) << "Scheduler stop failed."; + } exit(1); } } // namespace ps