forked from mindspore-Ecosystem/mindspore
commit
5a5b709cc5
|
@ -45,7 +45,10 @@ void HttpMessageHandler::InitHttpMessage() {
|
||||||
const char *query = evhttp_uri_get_query(event_uri_);
|
const char *query = evhttp_uri_get_query(event_uri_);
|
||||||
if (query != nullptr) {
|
if (query != nullptr) {
|
||||||
MS_LOG(WARNING) << "The query is:" << query;
|
MS_LOG(WARNING) << "The query is:" << query;
|
||||||
evhttp_parse_query_str(query, &path_params_);
|
int result = evhttp_parse_query_str(query, &path_params_);
|
||||||
|
if (result < 0) {
|
||||||
|
MS_LOG(ERROR) << "Http parse query:" << query << " failed.";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
head_params_ = evhttp_request_get_input_headers(event_request_);
|
head_params_ = evhttp_request_get_input_headers(event_request_);
|
||||||
|
@ -58,14 +61,14 @@ void HttpMessageHandler::ParseUrl(const std::string &url) {
|
||||||
MS_EXCEPTION_IF_NULL(event_uri_);
|
MS_EXCEPTION_IF_NULL(event_uri_);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string HttpMessageHandler::GetHeadParam(const std::string &key) {
|
std::string HttpMessageHandler::GetHeadParam(const std::string &key) const {
|
||||||
MS_EXCEPTION_IF_NULL(head_params_);
|
MS_EXCEPTION_IF_NULL(head_params_);
|
||||||
const char *val = evhttp_find_header(head_params_, key.c_str());
|
const char *val = evhttp_find_header(head_params_, key.c_str());
|
||||||
MS_EXCEPTION_IF_NULL(val);
|
MS_EXCEPTION_IF_NULL(val);
|
||||||
return std::string(val);
|
return std::string(val);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string HttpMessageHandler::GetPathParam(const std::string &key) {
|
std::string HttpMessageHandler::GetPathParam(const std::string &key) const {
|
||||||
const char *val = evhttp_find_header(&path_params_, key.c_str());
|
const char *val = evhttp_find_header(&path_params_, key.c_str());
|
||||||
MS_EXCEPTION_IF_NULL(val);
|
MS_EXCEPTION_IF_NULL(val);
|
||||||
return std::string(val);
|
return std::string(val);
|
||||||
|
@ -134,7 +137,7 @@ std::string HttpMessageHandler::GetPostParam(const std::string &key) {
|
||||||
return std::string(val);
|
return std::string(val);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string HttpMessageHandler::GetRequestUri() {
|
std::string HttpMessageHandler::GetRequestUri() const {
|
||||||
MS_EXCEPTION_IF_NULL(event_request_);
|
MS_EXCEPTION_IF_NULL(event_request_);
|
||||||
const char *uri = evhttp_request_get_uri(event_request_);
|
const char *uri = evhttp_request_get_uri(event_request_);
|
||||||
MS_EXCEPTION_IF_NULL(uri);
|
MS_EXCEPTION_IF_NULL(uri);
|
||||||
|
@ -148,14 +151,14 @@ std::string HttpMessageHandler::GetRequestHost() {
|
||||||
return std::string(host);
|
return std::string(host);
|
||||||
}
|
}
|
||||||
|
|
||||||
const char *HttpMessageHandler::GetHostByUri() {
|
const char *HttpMessageHandler::GetHostByUri() const {
|
||||||
MS_EXCEPTION_IF_NULL(event_uri_);
|
MS_EXCEPTION_IF_NULL(event_uri_);
|
||||||
const char *host = evhttp_uri_get_host(event_uri_);
|
const char *host = evhttp_uri_get_host(event_uri_);
|
||||||
MS_EXCEPTION_IF_NULL(host);
|
MS_EXCEPTION_IF_NULL(host);
|
||||||
return host;
|
return host;
|
||||||
}
|
}
|
||||||
|
|
||||||
int HttpMessageHandler::GetUriPort() {
|
int HttpMessageHandler::GetUriPort() const {
|
||||||
MS_EXCEPTION_IF_NULL(event_uri_);
|
MS_EXCEPTION_IF_NULL(event_uri_);
|
||||||
int port = evhttp_uri_get_port(event_uri_);
|
int port = evhttp_uri_get_port(event_uri_);
|
||||||
if (port < 0) {
|
if (port < 0) {
|
||||||
|
@ -164,7 +167,7 @@ int HttpMessageHandler::GetUriPort() {
|
||||||
return port;
|
return port;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string HttpMessageHandler::GetUriPath() {
|
std::string HttpMessageHandler::GetUriPath() const {
|
||||||
MS_EXCEPTION_IF_NULL(event_uri_);
|
MS_EXCEPTION_IF_NULL(event_uri_);
|
||||||
const char *path = evhttp_uri_get_path(event_uri_);
|
const char *path = evhttp_uri_get_path(event_uri_);
|
||||||
MS_EXCEPTION_IF_NULL(path);
|
MS_EXCEPTION_IF_NULL(path);
|
||||||
|
@ -186,7 +189,7 @@ std::string HttpMessageHandler::GetRequestPath() {
|
||||||
return path_res;
|
return path_res;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string HttpMessageHandler::GetUriQuery() {
|
std::string HttpMessageHandler::GetUriQuery() const {
|
||||||
MS_EXCEPTION_IF_NULL(event_uri_);
|
MS_EXCEPTION_IF_NULL(event_uri_);
|
||||||
const char *query = evhttp_uri_get_query(event_uri_);
|
const char *query = evhttp_uri_get_query(event_uri_);
|
||||||
MS_EXCEPTION_IF_NULL(query);
|
MS_EXCEPTION_IF_NULL(query);
|
||||||
|
@ -259,8 +262,7 @@ void HttpMessageHandler::SimpleResponse(int code, const HttpHeaders &headers, co
|
||||||
MS_EXCEPTION_IF_NULL(resp_buf_);
|
MS_EXCEPTION_IF_NULL(resp_buf_);
|
||||||
AddRespHeaders(headers);
|
AddRespHeaders(headers);
|
||||||
AddRespString(body);
|
AddRespString(body);
|
||||||
MS_EXCEPTION_IF_NULL(resp_buf_);
|
evhttp_send_reply(event_request_, code, nullptr, resp_buf_);
|
||||||
evhttp_send_reply(event_request_, resp_code_, nullptr, resp_buf_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void HttpMessageHandler::ErrorResponse(int code, RequestProcessResult result) {
|
void HttpMessageHandler::ErrorResponse(int code, RequestProcessResult result) {
|
||||||
|
@ -293,9 +295,9 @@ void HttpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
|
||||||
|
|
||||||
void HttpMessageHandler::set_content_len(const uint64_t &len) { content_len_ = len; }
|
void HttpMessageHandler::set_content_len(const uint64_t &len) { content_len_ = len; }
|
||||||
|
|
||||||
uint64_t HttpMessageHandler::content_len() { return content_len_; }
|
uint64_t HttpMessageHandler::content_len() const { return content_len_; }
|
||||||
|
|
||||||
const event_base *HttpMessageHandler::http_base() { return event_base_; }
|
const event_base *HttpMessageHandler::http_base() const { return event_base_; }
|
||||||
|
|
||||||
void HttpMessageHandler::set_http_base(const struct event_base *base) {
|
void HttpMessageHandler::set_http_base(const struct event_base *base) {
|
||||||
MS_EXCEPTION_IF_NULL(base);
|
MS_EXCEPTION_IF_NULL(base);
|
||||||
|
@ -307,13 +309,13 @@ void HttpMessageHandler::set_request(const struct evhttp_request *req) {
|
||||||
event_request_ = const_cast<evhttp_request *>(req);
|
event_request_ = const_cast<evhttp_request *>(req);
|
||||||
}
|
}
|
||||||
|
|
||||||
const struct evhttp_request *HttpMessageHandler::request() { return event_request_; }
|
const struct evhttp_request *HttpMessageHandler::request() const { return event_request_; }
|
||||||
|
|
||||||
void HttpMessageHandler::InitBodySize() { body_->resize(content_len()); }
|
void HttpMessageHandler::InitBodySize() { body_->resize(content_len()); }
|
||||||
|
|
||||||
std::shared_ptr<std::vector<char>> HttpMessageHandler::body() { return body_; }
|
std::shared_ptr<std::vector<char>> HttpMessageHandler::body() { return body_; }
|
||||||
|
|
||||||
void HttpMessageHandler::set_body(std::shared_ptr<std::vector<char>> body) { body_ = body; }
|
void HttpMessageHandler::set_body(const VectorPtr &body) { body_ = body; }
|
||||||
|
|
||||||
const nlohmann::json &HttpMessageHandler::request_message() const { return request_message_; }
|
const nlohmann::json &HttpMessageHandler::request_message() const { return request_message_; }
|
||||||
|
|
||||||
|
|
|
@ -69,19 +69,19 @@ class HttpMessageHandler {
|
||||||
void InitHttpMessage();
|
void InitHttpMessage();
|
||||||
void ParseUrl(const std::string &url);
|
void ParseUrl(const std::string &url);
|
||||||
|
|
||||||
std::string GetRequestUri();
|
std::string GetRequestUri() const;
|
||||||
std::string GetRequestHost();
|
std::string GetRequestHost();
|
||||||
const char *GetHostByUri();
|
const char *GetHostByUri() const;
|
||||||
std::string GetHeadParam(const std::string &key);
|
std::string GetHeadParam(const std::string &key) const;
|
||||||
std::string GetPathParam(const std::string &key);
|
std::string GetPathParam(const std::string &key) const;
|
||||||
std::string GetPostParam(const std::string &key);
|
std::string GetPostParam(const std::string &key);
|
||||||
uint64_t GetPostMsg(unsigned char **buffer);
|
uint64_t GetPostMsg(unsigned char **buffer);
|
||||||
std::string GetUriPath();
|
std::string GetUriPath() const;
|
||||||
std::string GetRequestPath();
|
std::string GetRequestPath();
|
||||||
std::string GetUriQuery();
|
std::string GetUriQuery() const;
|
||||||
|
|
||||||
// It will return -1 if no port set
|
// It will return -1 if no port set
|
||||||
int GetUriPort();
|
int GetUriPort() const;
|
||||||
|
|
||||||
// Useless to get from a request url, fragment is only for browser to locate sth.
|
// Useless to get from a request url, fragment is only for browser to locate sth.
|
||||||
std::string GetUriFragment();
|
std::string GetUriFragment();
|
||||||
|
@ -104,14 +104,14 @@ class HttpMessageHandler {
|
||||||
RequestProcessResult ParsePostMessageToJson();
|
RequestProcessResult ParsePostMessageToJson();
|
||||||
void ReceiveMessage(const void *buffer, size_t num);
|
void ReceiveMessage(const void *buffer, size_t num);
|
||||||
void set_content_len(const uint64_t &len);
|
void set_content_len(const uint64_t &len);
|
||||||
uint64_t content_len();
|
uint64_t content_len() const;
|
||||||
const event_base *http_base();
|
const event_base *http_base() const;
|
||||||
void set_http_base(const struct event_base *base);
|
void set_http_base(const struct event_base *base);
|
||||||
void set_request(const struct evhttp_request *req);
|
void set_request(const struct evhttp_request *req);
|
||||||
const struct evhttp_request *request();
|
const struct evhttp_request *request() const;
|
||||||
void InitBodySize();
|
void InitBodySize();
|
||||||
VectorPtr body();
|
VectorPtr body();
|
||||||
void set_body(VectorPtr body);
|
void set_body(const VectorPtr &body);
|
||||||
const nlohmann::json &request_message() const;
|
const nlohmann::json &request_message() const;
|
||||||
RequestProcessResult ParseValueFromKey(const std::string &key, int32_t *const value);
|
RequestProcessResult ParseValueFromKey(const std::string &key, int32_t *const value);
|
||||||
|
|
||||||
|
|
|
@ -71,8 +71,8 @@ void TcpClient::set_connected_callback(const OnConnected &connected) { connected
|
||||||
|
|
||||||
bool TcpClient::WaitConnected(const uint32_t &connected_timeout) {
|
bool TcpClient::WaitConnected(const uint32_t &connected_timeout) {
|
||||||
std::unique_lock<std::mutex> lock(connection_mutex_);
|
std::unique_lock<std::mutex> lock(connection_mutex_);
|
||||||
bool res =
|
bool res = connection_cond_.wait_for(lock, std::chrono::seconds(connected_timeout),
|
||||||
connection_cond_.wait_for(lock, std::chrono::seconds(connected_timeout), [&] { return is_connected_.load(); });
|
[this] { return this->is_connected_.load(); });
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,8 +31,8 @@ std::string Node::BoundIp() const { return node_info_.ip_; }
|
||||||
|
|
||||||
bool Node::WaitForStart(const uint32_t &timeout) {
|
bool Node::WaitForStart(const uint32_t &timeout) {
|
||||||
std::unique_lock<std::mutex> lock(wait_start_mutex_);
|
std::unique_lock<std::mutex> lock(wait_start_mutex_);
|
||||||
bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
|
bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [this] {
|
||||||
bool res = is_ready_.load();
|
bool res = this->is_ready_.load();
|
||||||
if (res) {
|
if (res) {
|
||||||
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success start!";
|
MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success start!";
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,7 +50,6 @@ uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message) {
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
});
|
});
|
||||||
|
|
||||||
if (rank_it == registered_nodes_info_.end()) {
|
if (rank_it == registered_nodes_info_.end()) {
|
||||||
rank_id = ++next_server_rank_id_;
|
rank_id = ++next_server_rank_id_;
|
||||||
} else {
|
} else {
|
||||||
|
@ -76,19 +75,19 @@ uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message) {
|
||||||
const std::string &ip = register_message.ip();
|
const std::string &ip = register_message.ip();
|
||||||
uint32_t port = register_message.port();
|
uint32_t port = register_message.port();
|
||||||
|
|
||||||
auto rank_it = std::find_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [&rank_id](auto item) {
|
auto worker_rank_it =
|
||||||
bool res = item.second.is_alive == false && item.second.node_role_ == NodeRole::WORKER;
|
std::find_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [&rank_id](auto item) {
|
||||||
if (res) {
|
bool res = item.second.is_alive == false && item.second.node_role_ == NodeRole::WORKER;
|
||||||
MS_LOG(INFO) << "The worker node id:" << item.first << " rank id:" << rank_id << " is not alive.";
|
if (res) {
|
||||||
rank_id = item.second.rank_id_;
|
MS_LOG(INFO) << "The worker node id:" << item.first << " rank id:" << rank_id << " is not alive.";
|
||||||
}
|
rank_id = item.second.rank_id_;
|
||||||
return res;
|
}
|
||||||
});
|
return res;
|
||||||
|
});
|
||||||
if (rank_it == registered_nodes_info_.end()) {
|
if (worker_rank_it == registered_nodes_info_.end()) {
|
||||||
rank_id = ++next_worker_rank_id_;
|
rank_id = ++next_worker_rank_id_;
|
||||||
} else {
|
} else {
|
||||||
registered_nodes_info_.erase((*rank_it).first);
|
registered_nodes_info_.erase((*worker_rank_it).first);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (rank_id >= meta_data_->worker_num) {
|
if (rank_id >= meta_data_->worker_num) {
|
||||||
|
|
|
@ -19,7 +19,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace ps {
|
||||||
namespace core {
|
namespace core {
|
||||||
|
|
||||||
bool NodeRecovery::Recover() {
|
bool NodeRecovery::Recover() {
|
||||||
if (recovery_storage_ == nullptr) {
|
if (recovery_storage_ == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -19,7 +19,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace ps {
|
||||||
namespace core {
|
namespace core {
|
||||||
|
|
||||||
void RecoveryBase::Initialize(const std::string &config_json) {
|
void RecoveryBase::Initialize(const std::string &config_json) {
|
||||||
nlohmann::json recovery_config;
|
nlohmann::json recovery_config;
|
||||||
try {
|
try {
|
||||||
|
@ -35,7 +34,6 @@ void RecoveryBase::Initialize(const std::string &config_json) {
|
||||||
storage_type_ = StorageType::kFileStorage;
|
storage_type_ = StorageType::kFileStorage;
|
||||||
|
|
||||||
storage_file_path = recovery_config.at(kStoreFilePath);
|
storage_file_path = recovery_config.at(kStoreFilePath);
|
||||||
|
|
||||||
if (storage_file_path == "") {
|
if (storage_file_path == "") {
|
||||||
MS_LOG(EXCEPTION) << "If the scheduler support recovery, and if the persistent storage is a file, the path of "
|
MS_LOG(EXCEPTION) << "If the scheduler support recovery, and if the persistent storage is a file, the path of "
|
||||||
"the file must be configured";
|
"the file must be configured";
|
||||||
|
|
|
@ -21,7 +21,9 @@ namespace ps {
|
||||||
namespace core {
|
namespace core {
|
||||||
SchedulerNode::~SchedulerNode() {
|
SchedulerNode::~SchedulerNode() {
|
||||||
MS_LOG(INFO) << "Stop scheduler node!";
|
MS_LOG(INFO) << "Stop scheduler node!";
|
||||||
Stop();
|
if (!Stop()) {
|
||||||
|
MS_LOG(WARNING) << "Scheduler node stop failed.";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SchedulerNode::Start(const uint32_t &timeout) {
|
bool SchedulerNode::Start(const uint32_t &timeout) {
|
||||||
|
@ -66,8 +68,10 @@ void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::sha
|
||||||
|
|
||||||
heartbeat_resp_message.set_is_worker_or_server0(node_manager_.IsWorkerOrServer0());
|
heartbeat_resp_message.set_is_worker_or_server0(node_manager_.IsWorkerOrServer0());
|
||||||
|
|
||||||
server->SendMessage(conn, meta, Protos::PROTOBUF, heartbeat_resp_message.SerializeAsString().data(),
|
if (!server->SendMessage(conn, meta, Protos::PROTOBUF, heartbeat_resp_message.SerializeAsString().data(),
|
||||||
heartbeat_resp_message.ByteSizeLong());
|
heartbeat_resp_message.ByteSizeLong())) {
|
||||||
|
MS_LOG(WARNING) << "Send heart beat failed.";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SchedulerNode::Initialize() {
|
void SchedulerNode::Initialize() {
|
||||||
|
@ -98,7 +102,7 @@ void SchedulerNode::CreateTcpServer() {
|
||||||
uint32_t scheduler_port = PSContext::instance()->cluster_config().scheduler_port;
|
uint32_t scheduler_port = PSContext::instance()->cluster_config().scheduler_port;
|
||||||
server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port);
|
server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port);
|
||||||
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||||
const Protos &protos, const void *data, size_t size) {
|
const Protos &, const void *data, size_t size) {
|
||||||
if (handlers_.count(meta->cmd()) == 0) {
|
if (handlers_.count(meta->cmd()) == 0) {
|
||||||
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
|
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
|
||||||
}
|
}
|
||||||
|
@ -122,7 +126,9 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
|
||||||
MS_EXCEPTION_IF_NULL(data);
|
MS_EXCEPTION_IF_NULL(data);
|
||||||
MS_LOG(INFO) << "The scheduler process a register message!";
|
MS_LOG(INFO) << "The scheduler process a register message!";
|
||||||
RegisterMessage register_message;
|
RegisterMessage register_message;
|
||||||
register_message.ParseFromArray(data, size);
|
if (!register_message.ParseFromArray(data, SizeToInt(size))) {
|
||||||
|
MS_LOG(WARNING) << "Parse data failed.";
|
||||||
|
}
|
||||||
|
|
||||||
// assign worker node and server node rank id
|
// assign worker node and server node rank id
|
||||||
uint32_t rank_id = node_manager_.NextRankId(register_message);
|
uint32_t rank_id = node_manager_.NextRankId(register_message);
|
||||||
|
@ -179,7 +185,6 @@ void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
|
|
||||||
if (iter != scale_in_node_ids_.end()) {
|
if (iter != scale_in_node_ids_.end()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -450,11 +455,11 @@ bool SchedulerNode::Stop() {
|
||||||
bool SchedulerNode::Finish(const uint32_t &) {
|
bool SchedulerNode::Finish(const uint32_t &) {
|
||||||
MS_LOG(INFO) << "[Scheduler finish]: 1. Begin to finish scheduler node!";
|
MS_LOG(INFO) << "[Scheduler finish]: 1. Begin to finish scheduler node!";
|
||||||
std::unique_lock<std::mutex> lock(wait_finish_mutex_);
|
std::unique_lock<std::mutex> lock(wait_finish_mutex_);
|
||||||
wait_finish_cond_.wait(lock, [&] {
|
wait_finish_cond_.wait(lock, [this] {
|
||||||
if (is_finish_.load()) {
|
if (this->is_finish_.load()) {
|
||||||
MS_LOG(INFO) << "[Scheduler finish]: 2. Successfully finish scheduler!";
|
MS_LOG(INFO) << "[Scheduler finish]: 2. Successfully finish scheduler!";
|
||||||
}
|
}
|
||||||
return is_finish_.load();
|
return this->is_finish_.load();
|
||||||
});
|
});
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,7 +35,9 @@ void ParameterServer::Run(const FuncGraphPtr &func_graph) {
|
||||||
SyncEmbeddingTables();
|
SyncEmbeddingTables();
|
||||||
MS_LOG(INFO) << "PServer finished updating models, starts finalizing...";
|
MS_LOG(INFO) << "PServer finished updating models, starts finalizing...";
|
||||||
server_node_->Finish();
|
server_node_->Finish();
|
||||||
server_node_->Stop();
|
if (!server_node_->Stop()) {
|
||||||
|
MS_LOG(WARNING) << "Parameter server stop failed.";
|
||||||
|
}
|
||||||
MS_LOG(INFO) << "PServer finalized successfully.";
|
MS_LOG(INFO) << "PServer finalized successfully.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -561,7 +563,9 @@ void ParameterServer::ServerHandler::HandleInitWeights(DataPtr data, size_t size
|
||||||
std::unique_lock<std::mutex> lock(ps_->mutex());
|
std::unique_lock<std::mutex> lock(ps_->mutex());
|
||||||
MS_EXCEPTION_IF_NULL(res);
|
MS_EXCEPTION_IF_NULL(res);
|
||||||
KVMessage input;
|
KVMessage input;
|
||||||
input.ParseFromArray(data.get(), size);
|
if (!input.ParseFromArray(data.get(), SizeToInt(size))) {
|
||||||
|
MS_LOG(WARNING) << "Parse data failed.";
|
||||||
|
}
|
||||||
int key_num = input.keys_size();
|
int key_num = input.keys_size();
|
||||||
const float *data_ptr = input.values().data();
|
const float *data_ptr = input.values().data();
|
||||||
size_t pos = 0;
|
size_t pos = 0;
|
||||||
|
@ -586,9 +590,11 @@ void ParameterServer::ServerHandler::HandleInitWeightToOptimId(DataPtr data, siz
|
||||||
std::unique_lock<std::mutex> lock(ps_->mutex());
|
std::unique_lock<std::mutex> lock(ps_->mutex());
|
||||||
MS_EXCEPTION_IF_NULL(res);
|
MS_EXCEPTION_IF_NULL(res);
|
||||||
KVMessage input;
|
KVMessage input;
|
||||||
input.ParseFromArray(data.get(), size);
|
if (!input.ParseFromArray(data.get(), SizeToInt(size))) {
|
||||||
size_t key_num = input.keys_size();
|
MS_LOG(WARNING) << "Parse data failed.";
|
||||||
for (size_t i = 0; i < key_num; i++) {
|
}
|
||||||
|
int key_num = input.keys_size();
|
||||||
|
for (int i = 0; i < key_num; i++) {
|
||||||
Key key = input.keys()[i];
|
Key key = input.keys()[i];
|
||||||
float val = input.values()[i];
|
float val = input.values()[i];
|
||||||
if (init_weight_to_optim_[key]) {
|
if (init_weight_to_optim_[key]) {
|
||||||
|
@ -596,7 +602,7 @@ void ParameterServer::ServerHandler::HandleInitWeightToOptimId(DataPtr data, siz
|
||||||
} else {
|
} else {
|
||||||
init_weight_to_optim_[key] = true;
|
init_weight_to_optim_[key] = true;
|
||||||
}
|
}
|
||||||
ps_->InitWeightKeyToOptims(key, val);
|
ps_->InitWeightKeyToOptims(key, static_cast<int64_t>(val));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -722,7 +728,7 @@ void ParameterServer::ServerHandler::HandleUpdateEmbeddings(DataPtr data, size_t
|
||||||
ps_->UpdateEmbeddings(key, lookup_ids, update_vals);
|
ps_->UpdateEmbeddings(key, lookup_ids, update_vals);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ParameterServer::ServerHandler::HandleFinalize(DataPtr data, size_t size, VectorPtr res) {
|
void ParameterServer::ServerHandler::HandleFinalize(DataPtr, size_t, VectorPtr res) {
|
||||||
MS_EXCEPTION_IF_NULL(res);
|
MS_EXCEPTION_IF_NULL(res);
|
||||||
ps_->Finalize();
|
ps_->Finalize();
|
||||||
}
|
}
|
||||||
|
|
|
@ -242,17 +242,19 @@ void PsCacheManager::AllocMemForHashTable() {
|
||||||
|
|
||||||
void PsCacheManager::SetLocalIdRank() {
|
void PsCacheManager::SetLocalIdRank() {
|
||||||
auto worker_num = PSContext::instance()->initial_worker_num();
|
auto worker_num = PSContext::instance()->initial_worker_num();
|
||||||
auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num));
|
if (worker_num > 0) {
|
||||||
vocab_cache_size_diff_ = local_shard_size - SizeToInt(vocab_cache_size_);
|
auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num));
|
||||||
emb_table_slice_bounds_.first = local_shard_size * rank_id_;
|
vocab_cache_size_diff_ = local_shard_size - SizeToInt(vocab_cache_size_);
|
||||||
emb_table_slice_bounds_.second = std::min(emb_table_slice_bounds_.first + local_shard_size, SizeToInt(vocab_size_));
|
emb_table_slice_bounds_.first = local_shard_size * rank_id_;
|
||||||
cache_indices_bounds_.first = SizeToInt(vocab_cache_size_) * rank_id_;
|
emb_table_slice_bounds_.second = std::min(emb_table_slice_bounds_.first + local_shard_size, SizeToInt(vocab_size_));
|
||||||
cache_indices_bounds_.second = cache_indices_bounds_.first + SizeToInt(vocab_cache_size_);
|
cache_indices_bounds_.first = SizeToInt(vocab_cache_size_) * rank_id_;
|
||||||
MS_LOG(INFO) << "Worker num:" << worker_num << ", rank id:" << rank_id_
|
cache_indices_bounds_.second = cache_indices_bounds_.first + SizeToInt(vocab_cache_size_);
|
||||||
<< ", id begin:" << emb_table_slice_bounds_.first << ", id end:" << emb_table_slice_bounds_.second
|
MS_LOG(INFO) << "Worker num:" << worker_num << ", rank id:" << rank_id_
|
||||||
<< ", cache indices begin: " << cache_indices_bounds_.first
|
<< ", id begin:" << emb_table_slice_bounds_.first << ", id end:" << emb_table_slice_bounds_.second
|
||||||
<< ", cache indices end: " << cache_indices_bounds_.second
|
<< ", cache indices begin: " << cache_indices_bounds_.first
|
||||||
<< ", vocab_cache_size_diff: " << vocab_cache_size_diff_;
|
<< ", cache indices end: " << cache_indices_bounds_.second
|
||||||
|
<< ", vocab_cache_size_diff: " << vocab_cache_size_diff_;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int PsCacheManager::cache_indices_lower_bound() const { return cache_indices_bounds_.first; }
|
int PsCacheManager::cache_indices_lower_bound() const { return cache_indices_bounds_.first; }
|
||||||
|
|
|
@ -24,9 +24,17 @@ void Scheduler::Run() {
|
||||||
PSContext::instance()->cluster_config().scheduler_port = PSContext::instance()->scheduler_port();
|
PSContext::instance()->cluster_config().scheduler_port = PSContext::instance()->scheduler_port();
|
||||||
PSContext::instance()->cluster_config().initial_worker_num = PSContext::instance()->initial_worker_num();
|
PSContext::instance()->cluster_config().initial_worker_num = PSContext::instance()->initial_worker_num();
|
||||||
PSContext::instance()->cluster_config().initial_server_num = PSContext::instance()->initial_server_num();
|
PSContext::instance()->cluster_config().initial_server_num = PSContext::instance()->initial_server_num();
|
||||||
scheduler_node_.Start();
|
if (!scheduler_node_.Start()) {
|
||||||
scheduler_node_.Finish();
|
MS_LOG(WARNING) << "Scheduler start failed.";
|
||||||
scheduler_node_.Stop();
|
}
|
||||||
|
|
||||||
|
if (!scheduler_node_.Finish()) {
|
||||||
|
MS_LOG(WARNING) << "Scheduler finis failed.";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!scheduler_node_.Stop()) {
|
||||||
|
MS_LOG(WARNING) << "Scheduler stop failed.";
|
||||||
|
}
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
} // namespace ps
|
} // namespace ps
|
||||||
|
|
Loading…
Reference in New Issue