Merge pull request !19332 from anancds/pclint
This commit is contained in:
i-robot 2021-07-05 06:34:49 +00:00 committed by Gitee
commit 5a5b709cc5
11 changed files with 93 additions and 74 deletions

View File

@ -45,7 +45,10 @@ void HttpMessageHandler::InitHttpMessage() {
const char *query = evhttp_uri_get_query(event_uri_);
if (query != nullptr) {
MS_LOG(WARNING) << "The query is:" << query;
evhttp_parse_query_str(query, &path_params_);
int result = evhttp_parse_query_str(query, &path_params_);
if (result < 0) {
MS_LOG(ERROR) << "Http parse query:" << query << " failed.";
}
}
head_params_ = evhttp_request_get_input_headers(event_request_);
@ -58,14 +61,14 @@ void HttpMessageHandler::ParseUrl(const std::string &url) {
MS_EXCEPTION_IF_NULL(event_uri_);
}
std::string HttpMessageHandler::GetHeadParam(const std::string &key) {
std::string HttpMessageHandler::GetHeadParam(const std::string &key) const {
MS_EXCEPTION_IF_NULL(head_params_);
const char *val = evhttp_find_header(head_params_, key.c_str());
MS_EXCEPTION_IF_NULL(val);
return std::string(val);
}
std::string HttpMessageHandler::GetPathParam(const std::string &key) {
std::string HttpMessageHandler::GetPathParam(const std::string &key) const {
const char *val = evhttp_find_header(&path_params_, key.c_str());
MS_EXCEPTION_IF_NULL(val);
return std::string(val);
@ -134,7 +137,7 @@ std::string HttpMessageHandler::GetPostParam(const std::string &key) {
return std::string(val);
}
std::string HttpMessageHandler::GetRequestUri() {
std::string HttpMessageHandler::GetRequestUri() const {
MS_EXCEPTION_IF_NULL(event_request_);
const char *uri = evhttp_request_get_uri(event_request_);
MS_EXCEPTION_IF_NULL(uri);
@ -148,14 +151,14 @@ std::string HttpMessageHandler::GetRequestHost() {
return std::string(host);
}
const char *HttpMessageHandler::GetHostByUri() {
const char *HttpMessageHandler::GetHostByUri() const {
MS_EXCEPTION_IF_NULL(event_uri_);
const char *host = evhttp_uri_get_host(event_uri_);
MS_EXCEPTION_IF_NULL(host);
return host;
}
int HttpMessageHandler::GetUriPort() {
int HttpMessageHandler::GetUriPort() const {
MS_EXCEPTION_IF_NULL(event_uri_);
int port = evhttp_uri_get_port(event_uri_);
if (port < 0) {
@ -164,7 +167,7 @@ int HttpMessageHandler::GetUriPort() {
return port;
}
std::string HttpMessageHandler::GetUriPath() {
std::string HttpMessageHandler::GetUriPath() const {
MS_EXCEPTION_IF_NULL(event_uri_);
const char *path = evhttp_uri_get_path(event_uri_);
MS_EXCEPTION_IF_NULL(path);
@ -186,7 +189,7 @@ std::string HttpMessageHandler::GetRequestPath() {
return path_res;
}
std::string HttpMessageHandler::GetUriQuery() {
std::string HttpMessageHandler::GetUriQuery() const {
MS_EXCEPTION_IF_NULL(event_uri_);
const char *query = evhttp_uri_get_query(event_uri_);
MS_EXCEPTION_IF_NULL(query);
@ -259,8 +262,7 @@ void HttpMessageHandler::SimpleResponse(int code, const HttpHeaders &headers, co
MS_EXCEPTION_IF_NULL(resp_buf_);
AddRespHeaders(headers);
AddRespString(body);
MS_EXCEPTION_IF_NULL(resp_buf_);
evhttp_send_reply(event_request_, resp_code_, nullptr, resp_buf_);
evhttp_send_reply(event_request_, code, nullptr, resp_buf_);
}
void HttpMessageHandler::ErrorResponse(int code, RequestProcessResult result) {
@ -293,9 +295,9 @@ void HttpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
void HttpMessageHandler::set_content_len(const uint64_t &len) { content_len_ = len; }
uint64_t HttpMessageHandler::content_len() { return content_len_; }
uint64_t HttpMessageHandler::content_len() const { return content_len_; }
const event_base *HttpMessageHandler::http_base() { return event_base_; }
const event_base *HttpMessageHandler::http_base() const { return event_base_; }
void HttpMessageHandler::set_http_base(const struct event_base *base) {
MS_EXCEPTION_IF_NULL(base);
@ -307,13 +309,13 @@ void HttpMessageHandler::set_request(const struct evhttp_request *req) {
event_request_ = const_cast<evhttp_request *>(req);
}
const struct evhttp_request *HttpMessageHandler::request() { return event_request_; }
const struct evhttp_request *HttpMessageHandler::request() const { return event_request_; }
void HttpMessageHandler::InitBodySize() { body_->resize(content_len()); }
std::shared_ptr<std::vector<char>> HttpMessageHandler::body() { return body_; }
void HttpMessageHandler::set_body(std::shared_ptr<std::vector<char>> body) { body_ = body; }
void HttpMessageHandler::set_body(const VectorPtr &body) { body_ = body; }
const nlohmann::json &HttpMessageHandler::request_message() const { return request_message_; }

View File

@ -69,19 +69,19 @@ class HttpMessageHandler {
void InitHttpMessage();
void ParseUrl(const std::string &url);
std::string GetRequestUri();
std::string GetRequestUri() const;
std::string GetRequestHost();
const char *GetHostByUri();
std::string GetHeadParam(const std::string &key);
std::string GetPathParam(const std::string &key);
const char *GetHostByUri() const;
std::string GetHeadParam(const std::string &key) const;
std::string GetPathParam(const std::string &key) const;
std::string GetPostParam(const std::string &key);
uint64_t GetPostMsg(unsigned char **buffer);
std::string GetUriPath();
std::string GetUriPath() const;
std::string GetRequestPath();
std::string GetUriQuery();
std::string GetUriQuery() const;
// It will return -1 if no port set
int GetUriPort();
int GetUriPort() const;
// Useless to get from a request url, fragment is only for browser to locate sth.
std::string GetUriFragment();
@ -104,14 +104,14 @@ class HttpMessageHandler {
RequestProcessResult ParsePostMessageToJson();
void ReceiveMessage(const void *buffer, size_t num);
void set_content_len(const uint64_t &len);
uint64_t content_len();
const event_base *http_base();
uint64_t content_len() const;
const event_base *http_base() const;
void set_http_base(const struct event_base *base);
void set_request(const struct evhttp_request *req);
const struct evhttp_request *request();
const struct evhttp_request *request() const;
void InitBodySize();
VectorPtr body();
void set_body(VectorPtr body);
void set_body(const VectorPtr &body);
const nlohmann::json &request_message() const;
RequestProcessResult ParseValueFromKey(const std::string &key, int32_t *const value);

View File

@ -71,8 +71,8 @@ void TcpClient::set_connected_callback(const OnConnected &connected) { connected
bool TcpClient::WaitConnected(const uint32_t &connected_timeout) {
std::unique_lock<std::mutex> lock(connection_mutex_);
bool res =
connection_cond_.wait_for(lock, std::chrono::seconds(connected_timeout), [&] { return is_connected_.load(); });
bool res = connection_cond_.wait_for(lock, std::chrono::seconds(connected_timeout),
[this] { return this->is_connected_.load(); });
return res;
}

View File

@ -31,8 +31,8 @@ std::string Node::BoundIp() const { return node_info_.ip_; }
bool Node::WaitForStart(const uint32_t &timeout) {
std::unique_lock<std::mutex> lock(wait_start_mutex_);
bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
bool res = is_ready_.load();
bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [this] {
bool res = this->is_ready_.load();
if (res) {
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success start!";
}

View File

@ -50,7 +50,6 @@ uint32_t NodeManager::NextRankId(const RegisterMessage &register_message) {
}
return res;
});
if (rank_it == registered_nodes_info_.end()) {
rank_id = ++next_server_rank_id_;
} else {
@ -76,19 +75,19 @@ uint32_t NodeManager::NextRankId(const RegisterMessage &register_message) {
const std::string &ip = register_message.ip();
uint32_t port = register_message.port();
auto rank_it = std::find_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [&rank_id](auto item) {
bool res = item.second.is_alive == false && item.second.node_role_ == NodeRole::WORKER;
if (res) {
MS_LOG(INFO) << "The worker node id:" << item.first << " rank id:" << rank_id << " is not alive.";
rank_id = item.second.rank_id_;
}
return res;
});
if (rank_it == registered_nodes_info_.end()) {
auto worker_rank_it =
std::find_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [&rank_id](auto item) {
bool res = item.second.is_alive == false && item.second.node_role_ == NodeRole::WORKER;
if (res) {
MS_LOG(INFO) << "The worker node id:" << item.first << " rank id:" << rank_id << " is not alive.";
rank_id = item.second.rank_id_;
}
return res;
});
if (worker_rank_it == registered_nodes_info_.end()) {
rank_id = ++next_worker_rank_id_;
} else {
registered_nodes_info_.erase((*rank_it).first);
registered_nodes_info_.erase((*worker_rank_it).first);
}
if (rank_id >= meta_data_->worker_num) {

View File

@ -19,7 +19,6 @@
namespace mindspore {
namespace ps {
namespace core {
bool NodeRecovery::Recover() {
if (recovery_storage_ == nullptr) {
return false;

View File

@ -19,7 +19,6 @@
namespace mindspore {
namespace ps {
namespace core {
void RecoveryBase::Initialize(const std::string &config_json) {
nlohmann::json recovery_config;
try {
@ -35,7 +34,6 @@ void RecoveryBase::Initialize(const std::string &config_json) {
storage_type_ = StorageType::kFileStorage;
storage_file_path = recovery_config.at(kStoreFilePath);
if (storage_file_path == "") {
MS_LOG(EXCEPTION) << "If the scheduler support recovery, and if the persistent storage is a file, the path of "
"the file must be configured";

View File

@ -21,7 +21,9 @@ namespace ps {
namespace core {
SchedulerNode::~SchedulerNode() {
MS_LOG(INFO) << "Stop scheduler node!";
Stop();
if (!Stop()) {
MS_LOG(WARNING) << "Scheduler node stop failed.";
}
}
bool SchedulerNode::Start(const uint32_t &timeout) {
@ -66,8 +68,10 @@ void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::sha
heartbeat_resp_message.set_is_worker_or_server0(node_manager_.IsWorkerOrServer0());
server->SendMessage(conn, meta, Protos::PROTOBUF, heartbeat_resp_message.SerializeAsString().data(),
heartbeat_resp_message.ByteSizeLong());
if (!server->SendMessage(conn, meta, Protos::PROTOBUF, heartbeat_resp_message.SerializeAsString().data(),
heartbeat_resp_message.ByteSizeLong())) {
MS_LOG(WARNING) << "Send heart beat failed.";
}
}
void SchedulerNode::Initialize() {
@ -98,7 +102,7 @@ void SchedulerNode::CreateTcpServer() {
uint32_t scheduler_port = PSContext::instance()->cluster_config().scheduler_port;
server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port);
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
const Protos &, const void *data, size_t size) {
if (handlers_.count(meta->cmd()) == 0) {
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
}
@ -122,7 +126,9 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
MS_EXCEPTION_IF_NULL(data);
MS_LOG(INFO) << "The scheduler process a register message!";
RegisterMessage register_message;
register_message.ParseFromArray(data, size);
if (!register_message.ParseFromArray(data, SizeToInt(size))) {
MS_LOG(WARNING) << "Parse data failed.";
}
// assign worker node and server node rank id
uint32_t rank_id = node_manager_.NextRankId(register_message);
@ -179,7 +185,6 @@ void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared
}
return false;
});
if (iter != scale_in_node_ids_.end()) {
return;
}
@ -450,11 +455,11 @@ bool SchedulerNode::Stop() {
bool SchedulerNode::Finish(const uint32_t &) {
MS_LOG(INFO) << "[Scheduler finish]: 1. Begin to finish scheduler node!";
std::unique_lock<std::mutex> lock(wait_finish_mutex_);
wait_finish_cond_.wait(lock, [&] {
if (is_finish_.load()) {
wait_finish_cond_.wait(lock, [this] {
if (this->is_finish_.load()) {
MS_LOG(INFO) << "[Scheduler finish]: 2. Successfully finish scheduler!";
}
return is_finish_.load();
return this->is_finish_.load();
});
return true;
}

View File

@ -35,7 +35,9 @@ void ParameterServer::Run(const FuncGraphPtr &func_graph) {
SyncEmbeddingTables();
MS_LOG(INFO) << "PServer finished updating models, starts finalizing...";
server_node_->Finish();
server_node_->Stop();
if (!server_node_->Stop()) {
MS_LOG(WARNING) << "Parameter server stop failed.";
}
MS_LOG(INFO) << "PServer finalized successfully.";
}
@ -561,7 +563,9 @@ void ParameterServer::ServerHandler::HandleInitWeights(DataPtr data, size_t size
std::unique_lock<std::mutex> lock(ps_->mutex());
MS_EXCEPTION_IF_NULL(res);
KVMessage input;
input.ParseFromArray(data.get(), size);
if (!input.ParseFromArray(data.get(), SizeToInt(size))) {
MS_LOG(WARNING) << "Parse data failed.";
}
int key_num = input.keys_size();
const float *data_ptr = input.values().data();
size_t pos = 0;
@ -586,9 +590,11 @@ void ParameterServer::ServerHandler::HandleInitWeightToOptimId(DataPtr data, siz
std::unique_lock<std::mutex> lock(ps_->mutex());
MS_EXCEPTION_IF_NULL(res);
KVMessage input;
input.ParseFromArray(data.get(), size);
size_t key_num = input.keys_size();
for (size_t i = 0; i < key_num; i++) {
if (!input.ParseFromArray(data.get(), SizeToInt(size))) {
MS_LOG(WARNING) << "Parse data failed.";
}
int key_num = input.keys_size();
for (int i = 0; i < key_num; i++) {
Key key = input.keys()[i];
float val = input.values()[i];
if (init_weight_to_optim_[key]) {
@ -596,7 +602,7 @@ void ParameterServer::ServerHandler::HandleInitWeightToOptimId(DataPtr data, siz
} else {
init_weight_to_optim_[key] = true;
}
ps_->InitWeightKeyToOptims(key, val);
ps_->InitWeightKeyToOptims(key, static_cast<int64_t>(val));
}
}
@ -722,7 +728,7 @@ void ParameterServer::ServerHandler::HandleUpdateEmbeddings(DataPtr data, size_t
ps_->UpdateEmbeddings(key, lookup_ids, update_vals);
}
void ParameterServer::ServerHandler::HandleFinalize(DataPtr data, size_t size, VectorPtr res) {
void ParameterServer::ServerHandler::HandleFinalize(DataPtr, size_t, VectorPtr res) {
MS_EXCEPTION_IF_NULL(res);
ps_->Finalize();
}

View File

@ -242,17 +242,19 @@ void PsCacheManager::AllocMemForHashTable() {
void PsCacheManager::SetLocalIdRank() {
auto worker_num = PSContext::instance()->initial_worker_num();
auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num));
vocab_cache_size_diff_ = local_shard_size - SizeToInt(vocab_cache_size_);
emb_table_slice_bounds_.first = local_shard_size * rank_id_;
emb_table_slice_bounds_.second = std::min(emb_table_slice_bounds_.first + local_shard_size, SizeToInt(vocab_size_));
cache_indices_bounds_.first = SizeToInt(vocab_cache_size_) * rank_id_;
cache_indices_bounds_.second = cache_indices_bounds_.first + SizeToInt(vocab_cache_size_);
MS_LOG(INFO) << "Worker num:" << worker_num << ", rank id:" << rank_id_
<< ", id begin:" << emb_table_slice_bounds_.first << ", id end:" << emb_table_slice_bounds_.second
<< ", cache indices begin: " << cache_indices_bounds_.first
<< ", cache indices end: " << cache_indices_bounds_.second
<< ", vocab_cache_size_diff: " << vocab_cache_size_diff_;
if (worker_num > 0) {
auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num));
vocab_cache_size_diff_ = local_shard_size - SizeToInt(vocab_cache_size_);
emb_table_slice_bounds_.first = local_shard_size * rank_id_;
emb_table_slice_bounds_.second = std::min(emb_table_slice_bounds_.first + local_shard_size, SizeToInt(vocab_size_));
cache_indices_bounds_.first = SizeToInt(vocab_cache_size_) * rank_id_;
cache_indices_bounds_.second = cache_indices_bounds_.first + SizeToInt(vocab_cache_size_);
MS_LOG(INFO) << "Worker num:" << worker_num << ", rank id:" << rank_id_
<< ", id begin:" << emb_table_slice_bounds_.first << ", id end:" << emb_table_slice_bounds_.second
<< ", cache indices begin: " << cache_indices_bounds_.first
<< ", cache indices end: " << cache_indices_bounds_.second
<< ", vocab_cache_size_diff: " << vocab_cache_size_diff_;
}
}
int PsCacheManager::cache_indices_lower_bound() const { return cache_indices_bounds_.first; }

View File

@ -24,9 +24,17 @@ void Scheduler::Run() {
PSContext::instance()->cluster_config().scheduler_port = PSContext::instance()->scheduler_port();
PSContext::instance()->cluster_config().initial_worker_num = PSContext::instance()->initial_worker_num();
PSContext::instance()->cluster_config().initial_server_num = PSContext::instance()->initial_server_num();
scheduler_node_.Start();
scheduler_node_.Finish();
scheduler_node_.Stop();
if (!scheduler_node_.Start()) {
MS_LOG(WARNING) << "Scheduler start failed.";
}
if (!scheduler_node_.Finish()) {
MS_LOG(WARNING) << "Scheduler finis failed.";
}
if (!scheduler_node_.Stop()) {
MS_LOG(WARNING) << "Scheduler stop failed.";
}
exit(1);
}
} // namespace ps