forked from mindspore-Ecosystem/mindspore
commit
5a5b709cc5
|
@ -45,7 +45,10 @@ void HttpMessageHandler::InitHttpMessage() {
|
|||
const char *query = evhttp_uri_get_query(event_uri_);
|
||||
if (query != nullptr) {
|
||||
MS_LOG(WARNING) << "The query is:" << query;
|
||||
evhttp_parse_query_str(query, &path_params_);
|
||||
int result = evhttp_parse_query_str(query, &path_params_);
|
||||
if (result < 0) {
|
||||
MS_LOG(ERROR) << "Http parse query:" << query << " failed.";
|
||||
}
|
||||
}
|
||||
|
||||
head_params_ = evhttp_request_get_input_headers(event_request_);
|
||||
|
@ -58,14 +61,14 @@ void HttpMessageHandler::ParseUrl(const std::string &url) {
|
|||
MS_EXCEPTION_IF_NULL(event_uri_);
|
||||
}
|
||||
|
||||
std::string HttpMessageHandler::GetHeadParam(const std::string &key) {
|
||||
std::string HttpMessageHandler::GetHeadParam(const std::string &key) const {
|
||||
MS_EXCEPTION_IF_NULL(head_params_);
|
||||
const char *val = evhttp_find_header(head_params_, key.c_str());
|
||||
MS_EXCEPTION_IF_NULL(val);
|
||||
return std::string(val);
|
||||
}
|
||||
|
||||
std::string HttpMessageHandler::GetPathParam(const std::string &key) {
|
||||
std::string HttpMessageHandler::GetPathParam(const std::string &key) const {
|
||||
const char *val = evhttp_find_header(&path_params_, key.c_str());
|
||||
MS_EXCEPTION_IF_NULL(val);
|
||||
return std::string(val);
|
||||
|
@ -134,7 +137,7 @@ std::string HttpMessageHandler::GetPostParam(const std::string &key) {
|
|||
return std::string(val);
|
||||
}
|
||||
|
||||
std::string HttpMessageHandler::GetRequestUri() {
|
||||
std::string HttpMessageHandler::GetRequestUri() const {
|
||||
MS_EXCEPTION_IF_NULL(event_request_);
|
||||
const char *uri = evhttp_request_get_uri(event_request_);
|
||||
MS_EXCEPTION_IF_NULL(uri);
|
||||
|
@ -148,14 +151,14 @@ std::string HttpMessageHandler::GetRequestHost() {
|
|||
return std::string(host);
|
||||
}
|
||||
|
||||
const char *HttpMessageHandler::GetHostByUri() {
|
||||
const char *HttpMessageHandler::GetHostByUri() const {
|
||||
MS_EXCEPTION_IF_NULL(event_uri_);
|
||||
const char *host = evhttp_uri_get_host(event_uri_);
|
||||
MS_EXCEPTION_IF_NULL(host);
|
||||
return host;
|
||||
}
|
||||
|
||||
int HttpMessageHandler::GetUriPort() {
|
||||
int HttpMessageHandler::GetUriPort() const {
|
||||
MS_EXCEPTION_IF_NULL(event_uri_);
|
||||
int port = evhttp_uri_get_port(event_uri_);
|
||||
if (port < 0) {
|
||||
|
@ -164,7 +167,7 @@ int HttpMessageHandler::GetUriPort() {
|
|||
return port;
|
||||
}
|
||||
|
||||
std::string HttpMessageHandler::GetUriPath() {
|
||||
std::string HttpMessageHandler::GetUriPath() const {
|
||||
MS_EXCEPTION_IF_NULL(event_uri_);
|
||||
const char *path = evhttp_uri_get_path(event_uri_);
|
||||
MS_EXCEPTION_IF_NULL(path);
|
||||
|
@ -186,7 +189,7 @@ std::string HttpMessageHandler::GetRequestPath() {
|
|||
return path_res;
|
||||
}
|
||||
|
||||
std::string HttpMessageHandler::GetUriQuery() {
|
||||
std::string HttpMessageHandler::GetUriQuery() const {
|
||||
MS_EXCEPTION_IF_NULL(event_uri_);
|
||||
const char *query = evhttp_uri_get_query(event_uri_);
|
||||
MS_EXCEPTION_IF_NULL(query);
|
||||
|
@ -259,8 +262,7 @@ void HttpMessageHandler::SimpleResponse(int code, const HttpHeaders &headers, co
|
|||
MS_EXCEPTION_IF_NULL(resp_buf_);
|
||||
AddRespHeaders(headers);
|
||||
AddRespString(body);
|
||||
MS_EXCEPTION_IF_NULL(resp_buf_);
|
||||
evhttp_send_reply(event_request_, resp_code_, nullptr, resp_buf_);
|
||||
evhttp_send_reply(event_request_, code, nullptr, resp_buf_);
|
||||
}
|
||||
|
||||
void HttpMessageHandler::ErrorResponse(int code, RequestProcessResult result) {
|
||||
|
@ -293,9 +295,9 @@ void HttpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
|
|||
|
||||
void HttpMessageHandler::set_content_len(const uint64_t &len) { content_len_ = len; }
|
||||
|
||||
uint64_t HttpMessageHandler::content_len() { return content_len_; }
|
||||
uint64_t HttpMessageHandler::content_len() const { return content_len_; }
|
||||
|
||||
const event_base *HttpMessageHandler::http_base() { return event_base_; }
|
||||
const event_base *HttpMessageHandler::http_base() const { return event_base_; }
|
||||
|
||||
void HttpMessageHandler::set_http_base(const struct event_base *base) {
|
||||
MS_EXCEPTION_IF_NULL(base);
|
||||
|
@ -307,13 +309,13 @@ void HttpMessageHandler::set_request(const struct evhttp_request *req) {
|
|||
event_request_ = const_cast<evhttp_request *>(req);
|
||||
}
|
||||
|
||||
const struct evhttp_request *HttpMessageHandler::request() { return event_request_; }
|
||||
const struct evhttp_request *HttpMessageHandler::request() const { return event_request_; }
|
||||
|
||||
void HttpMessageHandler::InitBodySize() { body_->resize(content_len()); }
|
||||
|
||||
std::shared_ptr<std::vector<char>> HttpMessageHandler::body() { return body_; }
|
||||
|
||||
void HttpMessageHandler::set_body(std::shared_ptr<std::vector<char>> body) { body_ = body; }
|
||||
void HttpMessageHandler::set_body(const VectorPtr &body) { body_ = body; }
|
||||
|
||||
const nlohmann::json &HttpMessageHandler::request_message() const { return request_message_; }
|
||||
|
||||
|
|
|
@ -69,19 +69,19 @@ class HttpMessageHandler {
|
|||
void InitHttpMessage();
|
||||
void ParseUrl(const std::string &url);
|
||||
|
||||
std::string GetRequestUri();
|
||||
std::string GetRequestUri() const;
|
||||
std::string GetRequestHost();
|
||||
const char *GetHostByUri();
|
||||
std::string GetHeadParam(const std::string &key);
|
||||
std::string GetPathParam(const std::string &key);
|
||||
const char *GetHostByUri() const;
|
||||
std::string GetHeadParam(const std::string &key) const;
|
||||
std::string GetPathParam(const std::string &key) const;
|
||||
std::string GetPostParam(const std::string &key);
|
||||
uint64_t GetPostMsg(unsigned char **buffer);
|
||||
std::string GetUriPath();
|
||||
std::string GetUriPath() const;
|
||||
std::string GetRequestPath();
|
||||
std::string GetUriQuery();
|
||||
std::string GetUriQuery() const;
|
||||
|
||||
// It will return -1 if no port set
|
||||
int GetUriPort();
|
||||
int GetUriPort() const;
|
||||
|
||||
// Useless to get from a request url, fragment is only for browser to locate sth.
|
||||
std::string GetUriFragment();
|
||||
|
@ -104,14 +104,14 @@ class HttpMessageHandler {
|
|||
RequestProcessResult ParsePostMessageToJson();
|
||||
void ReceiveMessage(const void *buffer, size_t num);
|
||||
void set_content_len(const uint64_t &len);
|
||||
uint64_t content_len();
|
||||
const event_base *http_base();
|
||||
uint64_t content_len() const;
|
||||
const event_base *http_base() const;
|
||||
void set_http_base(const struct event_base *base);
|
||||
void set_request(const struct evhttp_request *req);
|
||||
const struct evhttp_request *request();
|
||||
const struct evhttp_request *request() const;
|
||||
void InitBodySize();
|
||||
VectorPtr body();
|
||||
void set_body(VectorPtr body);
|
||||
void set_body(const VectorPtr &body);
|
||||
const nlohmann::json &request_message() const;
|
||||
RequestProcessResult ParseValueFromKey(const std::string &key, int32_t *const value);
|
||||
|
||||
|
|
|
@ -71,8 +71,8 @@ void TcpClient::set_connected_callback(const OnConnected &connected) { connected
|
|||
|
||||
bool TcpClient::WaitConnected(const uint32_t &connected_timeout) {
|
||||
std::unique_lock<std::mutex> lock(connection_mutex_);
|
||||
bool res =
|
||||
connection_cond_.wait_for(lock, std::chrono::seconds(connected_timeout), [&] { return is_connected_.load(); });
|
||||
bool res = connection_cond_.wait_for(lock, std::chrono::seconds(connected_timeout),
|
||||
[this] { return this->is_connected_.load(); });
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
|
@ -31,8 +31,8 @@ std::string Node::BoundIp() const { return node_info_.ip_; }
|
|||
|
||||
bool Node::WaitForStart(const uint32_t &timeout) {
|
||||
std::unique_lock<std::mutex> lock(wait_start_mutex_);
|
||||
bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
|
||||
bool res = is_ready_.load();
|
||||
bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [this] {
|
||||
bool res = this->is_ready_.load();
|
||||
if (res) {
|
||||
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success start!";
|
||||
}
|
||||
|
|
|
@ -50,7 +50,6 @@ uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message) {
|
|||
}
|
||||
return res;
|
||||
});
|
||||
|
||||
if (rank_it == registered_nodes_info_.end()) {
|
||||
rank_id = ++next_server_rank_id_;
|
||||
} else {
|
||||
|
@ -76,19 +75,19 @@ uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message) {
|
|||
const std::string &ip = register_message.ip();
|
||||
uint32_t port = register_message.port();
|
||||
|
||||
auto rank_it = std::find_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [&rank_id](auto item) {
|
||||
bool res = item.second.is_alive == false && item.second.node_role_ == NodeRole::WORKER;
|
||||
if (res) {
|
||||
MS_LOG(INFO) << "The worker node id:" << item.first << " rank id:" << rank_id << " is not alive.";
|
||||
rank_id = item.second.rank_id_;
|
||||
}
|
||||
return res;
|
||||
});
|
||||
|
||||
if (rank_it == registered_nodes_info_.end()) {
|
||||
auto worker_rank_it =
|
||||
std::find_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [&rank_id](auto item) {
|
||||
bool res = item.second.is_alive == false && item.second.node_role_ == NodeRole::WORKER;
|
||||
if (res) {
|
||||
MS_LOG(INFO) << "The worker node id:" << item.first << " rank id:" << rank_id << " is not alive.";
|
||||
rank_id = item.second.rank_id_;
|
||||
}
|
||||
return res;
|
||||
});
|
||||
if (worker_rank_it == registered_nodes_info_.end()) {
|
||||
rank_id = ++next_worker_rank_id_;
|
||||
} else {
|
||||
registered_nodes_info_.erase((*rank_it).first);
|
||||
registered_nodes_info_.erase((*worker_rank_it).first);
|
||||
}
|
||||
|
||||
if (rank_id >= meta_data_->worker_num) {
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
|
||||
bool NodeRecovery::Recover() {
|
||||
if (recovery_storage_ == nullptr) {
|
||||
return false;
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
|
||||
void RecoveryBase::Initialize(const std::string &config_json) {
|
||||
nlohmann::json recovery_config;
|
||||
try {
|
||||
|
@ -35,7 +34,6 @@ void RecoveryBase::Initialize(const std::string &config_json) {
|
|||
storage_type_ = StorageType::kFileStorage;
|
||||
|
||||
storage_file_path = recovery_config.at(kStoreFilePath);
|
||||
|
||||
if (storage_file_path == "") {
|
||||
MS_LOG(EXCEPTION) << "If the scheduler support recovery, and if the persistent storage is a file, the path of "
|
||||
"the file must be configured";
|
||||
|
|
|
@ -21,7 +21,9 @@ namespace ps {
|
|||
namespace core {
|
||||
SchedulerNode::~SchedulerNode() {
|
||||
MS_LOG(INFO) << "Stop scheduler node!";
|
||||
Stop();
|
||||
if (!Stop()) {
|
||||
MS_LOG(WARNING) << "Scheduler node stop failed.";
|
||||
}
|
||||
}
|
||||
|
||||
bool SchedulerNode::Start(const uint32_t &timeout) {
|
||||
|
@ -66,8 +68,10 @@ void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::sha
|
|||
|
||||
heartbeat_resp_message.set_is_worker_or_server0(node_manager_.IsWorkerOrServer0());
|
||||
|
||||
server->SendMessage(conn, meta, Protos::PROTOBUF, heartbeat_resp_message.SerializeAsString().data(),
|
||||
heartbeat_resp_message.ByteSizeLong());
|
||||
if (!server->SendMessage(conn, meta, Protos::PROTOBUF, heartbeat_resp_message.SerializeAsString().data(),
|
||||
heartbeat_resp_message.ByteSizeLong())) {
|
||||
MS_LOG(WARNING) << "Send heart beat failed.";
|
||||
}
|
||||
}
|
||||
|
||||
void SchedulerNode::Initialize() {
|
||||
|
@ -98,7 +102,7 @@ void SchedulerNode::CreateTcpServer() {
|
|||
uint32_t scheduler_port = PSContext::instance()->cluster_config().scheduler_port;
|
||||
server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port);
|
||||
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
const Protos &, const void *data, size_t size) {
|
||||
if (handlers_.count(meta->cmd()) == 0) {
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
|
||||
}
|
||||
|
@ -122,7 +126,9 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
|
|||
MS_EXCEPTION_IF_NULL(data);
|
||||
MS_LOG(INFO) << "The scheduler process a register message!";
|
||||
RegisterMessage register_message;
|
||||
register_message.ParseFromArray(data, size);
|
||||
if (!register_message.ParseFromArray(data, SizeToInt(size))) {
|
||||
MS_LOG(WARNING) << "Parse data failed.";
|
||||
}
|
||||
|
||||
// assign worker node and server node rank id
|
||||
uint32_t rank_id = node_manager_.NextRankId(register_message);
|
||||
|
@ -179,7 +185,6 @@ void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared
|
|||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
if (iter != scale_in_node_ids_.end()) {
|
||||
return;
|
||||
}
|
||||
|
@ -450,11 +455,11 @@ bool SchedulerNode::Stop() {
|
|||
bool SchedulerNode::Finish(const uint32_t &) {
|
||||
MS_LOG(INFO) << "[Scheduler finish]: 1. Begin to finish scheduler node!";
|
||||
std::unique_lock<std::mutex> lock(wait_finish_mutex_);
|
||||
wait_finish_cond_.wait(lock, [&] {
|
||||
if (is_finish_.load()) {
|
||||
wait_finish_cond_.wait(lock, [this] {
|
||||
if (this->is_finish_.load()) {
|
||||
MS_LOG(INFO) << "[Scheduler finish]: 2. Successfully finish scheduler!";
|
||||
}
|
||||
return is_finish_.load();
|
||||
return this->is_finish_.load();
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -35,7 +35,9 @@ void ParameterServer::Run(const FuncGraphPtr &func_graph) {
|
|||
SyncEmbeddingTables();
|
||||
MS_LOG(INFO) << "PServer finished updating models, starts finalizing...";
|
||||
server_node_->Finish();
|
||||
server_node_->Stop();
|
||||
if (!server_node_->Stop()) {
|
||||
MS_LOG(WARNING) << "Parameter server stop failed.";
|
||||
}
|
||||
MS_LOG(INFO) << "PServer finalized successfully.";
|
||||
}
|
||||
|
||||
|
@ -561,7 +563,9 @@ void ParameterServer::ServerHandler::HandleInitWeights(DataPtr data, size_t size
|
|||
std::unique_lock<std::mutex> lock(ps_->mutex());
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
KVMessage input;
|
||||
input.ParseFromArray(data.get(), size);
|
||||
if (!input.ParseFromArray(data.get(), SizeToInt(size))) {
|
||||
MS_LOG(WARNING) << "Parse data failed.";
|
||||
}
|
||||
int key_num = input.keys_size();
|
||||
const float *data_ptr = input.values().data();
|
||||
size_t pos = 0;
|
||||
|
@ -586,9 +590,11 @@ void ParameterServer::ServerHandler::HandleInitWeightToOptimId(DataPtr data, siz
|
|||
std::unique_lock<std::mutex> lock(ps_->mutex());
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
KVMessage input;
|
||||
input.ParseFromArray(data.get(), size);
|
||||
size_t key_num = input.keys_size();
|
||||
for (size_t i = 0; i < key_num; i++) {
|
||||
if (!input.ParseFromArray(data.get(), SizeToInt(size))) {
|
||||
MS_LOG(WARNING) << "Parse data failed.";
|
||||
}
|
||||
int key_num = input.keys_size();
|
||||
for (int i = 0; i < key_num; i++) {
|
||||
Key key = input.keys()[i];
|
||||
float val = input.values()[i];
|
||||
if (init_weight_to_optim_[key]) {
|
||||
|
@ -596,7 +602,7 @@ void ParameterServer::ServerHandler::HandleInitWeightToOptimId(DataPtr data, siz
|
|||
} else {
|
||||
init_weight_to_optim_[key] = true;
|
||||
}
|
||||
ps_->InitWeightKeyToOptims(key, val);
|
||||
ps_->InitWeightKeyToOptims(key, static_cast<int64_t>(val));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -722,7 +728,7 @@ void ParameterServer::ServerHandler::HandleUpdateEmbeddings(DataPtr data, size_t
|
|||
ps_->UpdateEmbeddings(key, lookup_ids, update_vals);
|
||||
}
|
||||
|
||||
void ParameterServer::ServerHandler::HandleFinalize(DataPtr data, size_t size, VectorPtr res) {
|
||||
void ParameterServer::ServerHandler::HandleFinalize(DataPtr, size_t, VectorPtr res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
ps_->Finalize();
|
||||
}
|
||||
|
|
|
@ -242,17 +242,19 @@ void PsCacheManager::AllocMemForHashTable() {
|
|||
|
||||
void PsCacheManager::SetLocalIdRank() {
|
||||
auto worker_num = PSContext::instance()->initial_worker_num();
|
||||
auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num));
|
||||
vocab_cache_size_diff_ = local_shard_size - SizeToInt(vocab_cache_size_);
|
||||
emb_table_slice_bounds_.first = local_shard_size * rank_id_;
|
||||
emb_table_slice_bounds_.second = std::min(emb_table_slice_bounds_.first + local_shard_size, SizeToInt(vocab_size_));
|
||||
cache_indices_bounds_.first = SizeToInt(vocab_cache_size_) * rank_id_;
|
||||
cache_indices_bounds_.second = cache_indices_bounds_.first + SizeToInt(vocab_cache_size_);
|
||||
MS_LOG(INFO) << "Worker num:" << worker_num << ", rank id:" << rank_id_
|
||||
<< ", id begin:" << emb_table_slice_bounds_.first << ", id end:" << emb_table_slice_bounds_.second
|
||||
<< ", cache indices begin: " << cache_indices_bounds_.first
|
||||
<< ", cache indices end: " << cache_indices_bounds_.second
|
||||
<< ", vocab_cache_size_diff: " << vocab_cache_size_diff_;
|
||||
if (worker_num > 0) {
|
||||
auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num));
|
||||
vocab_cache_size_diff_ = local_shard_size - SizeToInt(vocab_cache_size_);
|
||||
emb_table_slice_bounds_.first = local_shard_size * rank_id_;
|
||||
emb_table_slice_bounds_.second = std::min(emb_table_slice_bounds_.first + local_shard_size, SizeToInt(vocab_size_));
|
||||
cache_indices_bounds_.first = SizeToInt(vocab_cache_size_) * rank_id_;
|
||||
cache_indices_bounds_.second = cache_indices_bounds_.first + SizeToInt(vocab_cache_size_);
|
||||
MS_LOG(INFO) << "Worker num:" << worker_num << ", rank id:" << rank_id_
|
||||
<< ", id begin:" << emb_table_slice_bounds_.first << ", id end:" << emb_table_slice_bounds_.second
|
||||
<< ", cache indices begin: " << cache_indices_bounds_.first
|
||||
<< ", cache indices end: " << cache_indices_bounds_.second
|
||||
<< ", vocab_cache_size_diff: " << vocab_cache_size_diff_;
|
||||
}
|
||||
}
|
||||
|
||||
int PsCacheManager::cache_indices_lower_bound() const { return cache_indices_bounds_.first; }
|
||||
|
|
|
@ -24,9 +24,17 @@ void Scheduler::Run() {
|
|||
PSContext::instance()->cluster_config().scheduler_port = PSContext::instance()->scheduler_port();
|
||||
PSContext::instance()->cluster_config().initial_worker_num = PSContext::instance()->initial_worker_num();
|
||||
PSContext::instance()->cluster_config().initial_server_num = PSContext::instance()->initial_server_num();
|
||||
scheduler_node_.Start();
|
||||
scheduler_node_.Finish();
|
||||
scheduler_node_.Stop();
|
||||
if (!scheduler_node_.Start()) {
|
||||
MS_LOG(WARNING) << "Scheduler start failed.";
|
||||
}
|
||||
|
||||
if (!scheduler_node_.Finish()) {
|
||||
MS_LOG(WARNING) << "Scheduler finis failed.";
|
||||
}
|
||||
|
||||
if (!scheduler_node_.Stop()) {
|
||||
MS_LOG(WARNING) << "Scheduler stop failed.";
|
||||
}
|
||||
exit(1);
|
||||
}
|
||||
} // namespace ps
|
||||
|
|
Loading…
Reference in New Issue