forked from mindspore-Ecosystem/mindspore
fixed pclint
This commit is contained in:
parent
45c76beda7
commit
18e8a6e7c9
|
@ -187,6 +187,13 @@ enum class CustomEvent { kIterationRunning = 0, kIterationCompleted };
|
|||
#define ERROR_STATUS(result, code, message) \
|
||||
MS_LOG(ERROR) << message; \
|
||||
result = RequestProcessResult(code, message)
|
||||
|
||||
#define CHECK_RETURN_TYPE(_condition) \
|
||||
do { \
|
||||
if (!(_condition)) { \
|
||||
MS_LOG(ERROR) << "Parse protobuf message failed."; \
|
||||
} \
|
||||
} while (false)
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_CONSTANTS_H_
|
||||
|
|
|
@ -42,11 +42,11 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
|
|||
}
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessRegisterResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
|
||||
void AbstractNode::ProcessRegisterResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
RegisterRespMessage register_resp_message;
|
||||
register_resp_message.ParseFromArray(data, size);
|
||||
CHECK_RETURN_TYPE(register_resp_message.ParseFromArray(data, SizeToInt(size)));
|
||||
if (register_resp_message.node_id() != node_info_.node_id_) {
|
||||
MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id()
|
||||
<< " is not match the current node id:" << node_info_.node_id_;
|
||||
|
@ -59,7 +59,7 @@ void AbstractNode::ProcessRegisterResp(std::shared_ptr<MessageMeta> meta, const
|
|||
MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << " registered scheduler success!";
|
||||
}
|
||||
|
||||
bool AbstractNode::Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command,
|
||||
bool AbstractNode::Broadcast(const NodeRole &node_role, const DataPtr &message, size_t size, int command,
|
||||
const uint32_t &timeout) {
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
if (node_role != NodeRole::SERVER) {
|
||||
|
@ -165,7 +165,7 @@ void AbstractNode::RegisterCustomEventCallback(const uint32_t &event, const Even
|
|||
custom_event_to_callback_.try_emplace(event, event_cb);
|
||||
}
|
||||
|
||||
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len,
|
||||
bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len,
|
||||
int command, const uint32_t &timeout) {
|
||||
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
|
||||
MS_LOG(DEBUG) << "The node is timeout, can not send message.";
|
||||
|
@ -223,7 +223,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
|||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len,
|
||||
bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len,
|
||||
int command, VectorPtr *output, const uint32_t &timeout) {
|
||||
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
|
||||
MS_LOG(DEBUG) << "The node is timeout, can not send message.";
|
||||
|
@ -309,7 +309,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
|||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const void *data,
|
||||
uint64_t AbstractNode::CollectiveSendAsync(const NodeRole &node_role, const uint32_t &rank_id, const void *data,
|
||||
size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
|
||||
|
@ -326,8 +326,8 @@ uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const
|
|||
return SendMessageAsync(client, message_meta, Protos::RAW, data, size);
|
||||
}
|
||||
|
||||
std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role,
|
||||
const uint32_t &rank_id, VectorPtr *output) {
|
||||
std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const NodeRole &node_role, const uint32_t &rank_id,
|
||||
VectorPtr *output) {
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
|
||||
|
@ -359,7 +359,7 @@ std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum No
|
|||
return std::make_pair(rank_id, rank_request_id);
|
||||
}
|
||||
|
||||
bool AbstractNode::CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout) {
|
||||
bool AbstractNode::CollectiveWait(const std::pair<uint32_t, uint64_t> &request_id, const uint32_t &timeout) {
|
||||
std::unique_lock<std::mutex> lock(receive_callbacks_mutex_);
|
||||
bool res =
|
||||
receive_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { return receive_messages_done_[request_id]; });
|
||||
|
@ -473,11 +473,11 @@ bool AbstractNode::CheckSchedulerTimeout() const {
|
|||
return false;
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
|
||||
void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
HeartbeatRespMessage heartbeat_resp_message;
|
||||
heartbeat_resp_message.ParseFromArray(data, size);
|
||||
CHECK_RETURN_TYPE(heartbeat_resp_message.ParseFromArray(data, SizeToInt(size)));
|
||||
|
||||
current_cluster_state_ = heartbeat_resp_message.cluster_state();
|
||||
MS_LOG(DEBUG) << "The current cluster state from heartbeat:"
|
||||
|
@ -529,11 +529,11 @@ void AbstractNode::FetchServers(const std::shared_ptr<TcpClient> &client) {
|
|||
}
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessFetchServersResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
|
||||
void AbstractNode::ProcessFetchServersResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
FetchServersRespMessage fetch_servers_resp_message;
|
||||
fetch_servers_resp_message.ParseFromArray(data, size);
|
||||
CHECK_RETURN_TYPE(fetch_servers_resp_message.ParseFromArray(data, SizeToInt(size)));
|
||||
|
||||
nodes_address_.clear();
|
||||
for (const auto &it : fetch_servers_resp_message.servers_meta()) {
|
||||
|
@ -542,8 +542,9 @@ void AbstractNode::ProcessFetchServersResp(std::shared_ptr<MessageMeta> meta, co
|
|||
}
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
void AbstractNode::ProcessSendMetadata(const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
|
||||
size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
|
@ -589,7 +590,7 @@ void AbstractNode::ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std:
|
|||
connected_nodes_.clear();
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessFinish(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
void AbstractNode::ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
|
@ -599,8 +600,9 @@ void AbstractNode::ProcessFinish(std::shared_ptr<TcpConnection> conn, std::share
|
|||
wait_finish_cond_.notify_all();
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessScaleOutDone(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
void AbstractNode::ProcessScaleOutDone(const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
|
||||
size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
|
@ -609,8 +611,9 @@ void AbstractNode::ProcessScaleOutDone(std::shared_ptr<TcpConnection> conn, std:
|
|||
current_cluster_state_ = ClusterState::CLUSTER_READY;
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessScaleInDone(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
void AbstractNode::ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
|
||||
size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
|
@ -619,7 +622,7 @@ void AbstractNode::ProcessScaleInDone(std::shared_ptr<TcpConnection> conn, std::
|
|||
current_cluster_state_ = ClusterState::CLUSTER_READY;
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessEvent(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
void AbstractNode::ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
|
@ -631,7 +634,7 @@ void AbstractNode::ProcessEvent(std::shared_ptr<TcpConnection> conn, std::shared
|
|||
OnCustomEventCallback(event);
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessScaleOut(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
void AbstractNode::ProcessScaleOut(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
|
@ -649,7 +652,7 @@ void AbstractNode::ProcessScaleOut(std::shared_ptr<TcpConnection> conn, std::sha
|
|||
is_ready_ = false;
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessScaleIn(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
void AbstractNode::ProcessScaleIn(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
|
@ -704,7 +707,7 @@ bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) {
|
|||
bool AbstractNode::InitClientToScheduler() {
|
||||
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_ip_, scheduler_port_);
|
||||
client_to_scheduler_->SetMessageCallback(
|
||||
[&](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
|
||||
[&](const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data, size_t size) {
|
||||
try {
|
||||
if (handlers_.count(meta->cmd()) == 0) {
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
|
||||
|
@ -750,7 +753,7 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const uint3
|
|||
std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first;
|
||||
uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second;
|
||||
auto client = std::make_shared<TcpClient>(ip, port);
|
||||
client->SetMessageCallback([&](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
|
||||
client->SetMessageCallback([&](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
|
||||
size_t size) {
|
||||
switch (meta->cmd()) {
|
||||
case NodeCommand::SEND_DATA:
|
||||
|
@ -771,7 +774,7 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const uint3
|
|||
}
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
|
||||
void AbstractNode::ProcessSendDataResp(const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
|
||||
size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
|
@ -827,7 +830,7 @@ void AbstractNode::set_message_callback(const uint64_t &request_id, const Messag
|
|||
message_callbacks_[request_id] = callback;
|
||||
}
|
||||
|
||||
void AbstractNode::RunReceiveCallback(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
|
||||
void AbstractNode::RunReceiveCallback(const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
|
||||
size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
|
|
|
@ -53,9 +53,11 @@ class AbstractNode : public Node {
|
|||
scheduler_port_(0) {}
|
||||
~AbstractNode() override { is_finish_ = true; }
|
||||
|
||||
typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
typedef void (AbstractNode::*ServerHandler)(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size);
|
||||
typedef void (AbstractNode::*ResponseHandler)(const std::shared_ptr<MessageMeta> &meta, const void *data,
|
||||
size_t size);
|
||||
typedef void (AbstractNode::*ServerHandler)(const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const Protos &protos,
|
||||
const void *data, size_t size);
|
||||
|
||||
using DataPtr = std::shared_ptr<unsigned char[]>;
|
||||
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
|
||||
|
@ -95,7 +97,7 @@ class AbstractNode : public Node {
|
|||
uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size);
|
||||
std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
|
||||
VectorPtr *output);
|
||||
bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
bool CollectiveWait(const std::pair<uint32_t, uint64_t> &request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
|
||||
// Initialize the scaler for server to process before/after scaling operations.
|
||||
bool InitFollowerScaler();
|
||||
|
@ -127,31 +129,31 @@ class AbstractNode : public Node {
|
|||
bool Heartbeat(const std::shared_ptr<TcpClient> &client);
|
||||
void FetchServers(const std::shared_ptr<TcpClient> &client);
|
||||
|
||||
void ProcessRegisterResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void ProcessFetchServersResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void ProcessRegisterResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
|
||||
void ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
|
||||
void ProcessFetchServersResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
|
||||
|
||||
void ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
|
||||
const void *data, size_t size);
|
||||
void ProcessFinish(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
|
||||
const void *data, size_t size);
|
||||
void ProcessSendMetadata(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &protos, const void *data, size_t size);
|
||||
void ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &protos, const void *data, size_t size);
|
||||
|
||||
void ProcessScaleOut(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
|
||||
const void *data, size_t size);
|
||||
void ProcessScaleOut(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &protos, const void *data, size_t size);
|
||||
|
||||
void ProcessScaleIn(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
|
||||
const void *data, size_t size);
|
||||
void ProcessScaleIn(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &protos, const void *data, size_t size);
|
||||
|
||||
// The worker/server processes the scale_out_done message from scheduelr
|
||||
void ProcessScaleOutDone(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
|
||||
const void *data, size_t size);
|
||||
void ProcessScaleOutDone(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &protos, const void *data, size_t size);
|
||||
// The worker/server processes the scale_in_done message from scheduelr
|
||||
void ProcessScaleInDone(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
|
||||
const void *data, size_t size);
|
||||
void ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &protos, const void *data, size_t size);
|
||||
|
||||
// The worker/server processes the SEND_EVENT message from scheduelr
|
||||
void ProcessEvent(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
|
||||
const void *data, size_t size);
|
||||
void ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &protos, const void *data, size_t size);
|
||||
|
||||
void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
|
||||
void UpdateSchedulerTime();
|
||||
|
@ -161,10 +163,12 @@ class AbstractNode : public Node {
|
|||
bool InitClientToScheduler();
|
||||
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const uint32_t &rank_id);
|
||||
|
||||
void ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
|
||||
void ProcessSendDataResp(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
|
||||
size_t size);
|
||||
void RunMessageCallback(const uint64_t &request_id);
|
||||
void set_message_callback(const uint64_t &request_id, const MessageCallback &callback);
|
||||
void RunReceiveCallback(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
|
||||
void RunReceiveCallback(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
|
||||
size_t size);
|
||||
uint64_t NextExpectedRankRequestId(const uint32_t &rank_id);
|
||||
uint64_t NextActualRankRequestId(const uint32_t &rank_id);
|
||||
void InitCommandHandler();
|
||||
|
|
|
@ -34,7 +34,7 @@ void HttpClient::Init() {
|
|||
}
|
||||
|
||||
ResponseCode HttpClient::Post(const std::string &url, const void *body, size_t len,
|
||||
std::shared_ptr<std::vector<char>> output,
|
||||
const std::shared_ptr<std::vector<char>> &output,
|
||||
const std::map<std::string, std::string> &headers) {
|
||||
MS_EXCEPTION_IF_NULL(body);
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
|
@ -65,7 +65,7 @@ ResponseCode HttpClient::Post(const std::string &url, const void *body, size_t l
|
|||
return CreateRequest(handler, connection, request, HttpMethod::HM_POST);
|
||||
}
|
||||
|
||||
ResponseCode HttpClient::Get(const std::string &url, std::shared_ptr<std::vector<char>> output,
|
||||
ResponseCode HttpClient::Get(const std::string &url, const std::shared_ptr<std::vector<char>> &output,
|
||||
const std::map<std::string, std::string> &headers) {
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
auto handler = std::make_shared<HttpMessageHandler>();
|
||||
|
@ -128,7 +128,7 @@ void HttpClient::ReadChunkDataCallback(struct evhttp_request *request, void *arg
|
|||
MS_EXCEPTION_IF_NULL(evbuf);
|
||||
int n = 0;
|
||||
while ((n = evbuffer_remove(evbuf, &buf, sizeof(buf))) > 0) {
|
||||
handler->ReceiveMessage(buf, n);
|
||||
handler->ReceiveMessage(buf, IntToSize(n));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -151,7 +151,7 @@ void HttpClient::ConnectionCloseCallback(struct evhttp_connection *connection, v
|
|||
}
|
||||
|
||||
void HttpClient::AddHeaders(const std::map<std::string, std::string> &headers, const struct evhttp_request *request,
|
||||
std::shared_ptr<HttpMessageHandler> handler) {
|
||||
const std::shared_ptr<HttpMessageHandler> &handler) {
|
||||
MS_EXCEPTION_IF_NULL(request);
|
||||
if (evhttp_add_header(evhttp_request_get_output_headers(const_cast<evhttp_request *>(request)), "Host",
|
||||
handler->GetHostByUri()) != 0) {
|
||||
|
@ -165,7 +165,7 @@ void HttpClient::AddHeaders(const std::map<std::string, std::string> &headers, c
|
|||
}
|
||||
}
|
||||
|
||||
void HttpClient::InitRequest(std::shared_ptr<HttpMessageHandler> handler, const std::string &url,
|
||||
void HttpClient::InitRequest(const std::shared_ptr<HttpMessageHandler> &handler, const std::string &url,
|
||||
const struct evhttp_request *request) {
|
||||
MS_EXCEPTION_IF_NULL(request);
|
||||
MS_EXCEPTION_IF_NULL(handler);
|
||||
|
@ -179,7 +179,7 @@ void HttpClient::InitRequest(std::shared_ptr<HttpMessageHandler> handler, const
|
|||
<< ", The port is:" << handler->GetUriPort() << ", The request_url is:" << handler->GetRequestPath();
|
||||
}
|
||||
|
||||
ResponseCode HttpClient::CreateRequest(std::shared_ptr<HttpMessageHandler> handler,
|
||||
ResponseCode HttpClient::CreateRequest(const std::shared_ptr<HttpMessageHandler> &handler,
|
||||
struct evhttp_connection *connection, struct evhttp_request *request,
|
||||
HttpMethod method) {
|
||||
MS_EXCEPTION_IF_NULL(handler);
|
||||
|
|
|
@ -41,6 +41,7 @@
|
|||
|
||||
#include "ps/core/communicator/http_message_handler.h"
|
||||
#include "ps/core/comm_util.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -62,9 +63,10 @@ class HttpClient {
|
|||
|
||||
virtual ~HttpClient();
|
||||
|
||||
ResponseCode Post(const std::string &url, const void *body, size_t len, std::shared_ptr<std::vector<char>> output,
|
||||
ResponseCode Post(const std::string &url, const void *body, size_t len,
|
||||
const std::shared_ptr<std::vector<char>> &output,
|
||||
const std::map<std::string, std::string> &headers = {});
|
||||
ResponseCode Get(const std::string &url, std::shared_ptr<std::vector<char>> output,
|
||||
ResponseCode Get(const std::string &url, const std::shared_ptr<std::vector<char>> &output,
|
||||
const std::map<std::string, std::string> &headers = {});
|
||||
|
||||
void set_connection_timeout(const int &timeout);
|
||||
|
@ -77,10 +79,10 @@ class HttpClient {
|
|||
static void ConnectionCloseCallback(struct evhttp_connection *connection, void *arg);
|
||||
|
||||
void AddHeaders(const std::map<std::string, std::string> &headers, const struct evhttp_request *request,
|
||||
std::shared_ptr<HttpMessageHandler> handler);
|
||||
void InitRequest(std::shared_ptr<HttpMessageHandler> handler, const std::string &url,
|
||||
const std::shared_ptr<HttpMessageHandler> &handler);
|
||||
void InitRequest(const std::shared_ptr<HttpMessageHandler> &handler, const std::string &url,
|
||||
const struct evhttp_request *request);
|
||||
ResponseCode CreateRequest(std::shared_ptr<HttpMessageHandler> handler, struct evhttp_connection *connection,
|
||||
ResponseCode CreateRequest(const std::shared_ptr<HttpMessageHandler> &handler, struct evhttp_connection *connection,
|
||||
struct evhttp_request *request, HttpMethod method);
|
||||
|
||||
bool Start();
|
||||
|
|
|
@ -196,7 +196,7 @@ std::string HttpMessageHandler::GetUriQuery() const {
|
|||
return std::string(query);
|
||||
}
|
||||
|
||||
std::string HttpMessageHandler::GetUriFragment() {
|
||||
std::string HttpMessageHandler::GetUriFragment() const {
|
||||
MS_EXCEPTION_IF_NULL(event_uri_);
|
||||
const char *fragment = evhttp_uri_get_fragment(event_uri_);
|
||||
MS_EXCEPTION_IF_NULL(fragment);
|
||||
|
|
|
@ -84,7 +84,7 @@ class HttpMessageHandler {
|
|||
int GetUriPort() const;
|
||||
|
||||
// Useless to get from a request url, fragment is only for browser to locate sth.
|
||||
std::string GetUriFragment();
|
||||
std::string GetUriFragment() const;
|
||||
|
||||
void AddRespHeadParam(const std::string &key, const std::string &val);
|
||||
void AddRespHeaders(const HttpHeaders &headers);
|
||||
|
|
|
@ -45,7 +45,7 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port)
|
|||
is_stop_(true),
|
||||
is_connected_(false) {
|
||||
message_handler_.SetCallback(
|
||||
[this](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
|
||||
[this](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) {
|
||||
if (message_callback_) {
|
||||
message_callback_(meta, protos, data, size);
|
||||
}
|
||||
|
@ -308,7 +308,8 @@ bool TcpClient::SendMessage(const CommMessage &message) const {
|
|||
return res;
|
||||
}
|
||||
|
||||
bool TcpClient::SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
|
||||
bool TcpClient::SendMessage(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
|
||||
size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(buffer_event_);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
|
@ -357,7 +358,7 @@ void TcpClient::StartTimer(const uint32_t &time) {
|
|||
|
||||
void TcpClient::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; }
|
||||
|
||||
const event_base &TcpClient::eventbase() { return *event_base_; }
|
||||
const event_base &TcpClient::eventbase() const { return *event_base_; }
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -48,7 +48,8 @@ class TcpClient {
|
|||
using OnDisconnected = std::function<void()>;
|
||||
using OnRead = std::function<void(const void *, size_t)>;
|
||||
using OnTimeout = std::function<void()>;
|
||||
using OnMessage = std::function<void(std::shared_ptr<MessageMeta>, const Protos &, const void *, size_t size)>;
|
||||
using OnMessage =
|
||||
std::function<void(const std::shared_ptr<MessageMeta> &, const Protos &, const void *, size_t size)>;
|
||||
using OnTimer = std::function<void()>;
|
||||
|
||||
explicit TcpClient(const std::string &address, std::uint16_t port);
|
||||
|
@ -66,10 +67,10 @@ class TcpClient {
|
|||
void StartWithNoBlock();
|
||||
void SetMessageCallback(const OnMessage &cb);
|
||||
bool SendMessage(const CommMessage &message) const;
|
||||
bool SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
|
||||
bool SendMessage(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size);
|
||||
void StartTimer(const uint32_t &time);
|
||||
void set_timer_callback(const OnTimer &timer);
|
||||
const event_base &eventbase();
|
||||
const event_base &eventbase() const;
|
||||
|
||||
protected:
|
||||
static void SetTcpNoDelay(const evutil_socket_t &fd);
|
||||
|
|
|
@ -75,7 +75,8 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
|
|||
if (remaining_length_ == 0) {
|
||||
if (message_callback_) {
|
||||
std::shared_ptr<MessageMeta> pb_message = std::make_shared<MessageMeta>();
|
||||
pb_message->ParseFromArray(message_buffer_.get(), message_header_.message_meta_length_);
|
||||
CHECK_RETURN_TYPE(
|
||||
pb_message->ParseFromArray(message_buffer_.get(), UintToInt(message_header_.message_meta_length_)));
|
||||
message_callback_(pb_message, message_header_.message_proto_,
|
||||
message_buffer_.get() + message_header_.message_meta_length_,
|
||||
message_header_.message_length_ - message_header_.message_meta_length_);
|
||||
|
|
|
@ -27,11 +27,14 @@
|
|||
#include "ps/core/communicator/message.h"
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
#include "ps/constants.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
using messageReceive = std::function<void(std::shared_ptr<MessageMeta>, const Protos &, const void *, size_t size)>;
|
||||
using messageReceive =
|
||||
std::function<void(const std::shared_ptr<MessageMeta> &, const Protos &, const void *, size_t size)>;
|
||||
constexpr int kHeaderLen = 16;
|
||||
|
||||
class TcpMessageHandler {
|
||||
|
@ -48,7 +51,7 @@ class TcpMessageHandler {
|
|||
bool is_parsed_;
|
||||
std::unique_ptr<unsigned char[]> message_buffer_;
|
||||
size_t remaining_length_;
|
||||
char header_[16]{0};
|
||||
unsigned char header_[16]{0};
|
||||
int header_index_;
|
||||
size_t last_copy_len_;
|
||||
MessageHeader message_header_;
|
||||
|
|
|
@ -48,7 +48,7 @@ const evutil_socket_t &TcpConnection::GetFd() const { return fd_; }
|
|||
|
||||
void TcpConnection::set_callback(const Callback &callback) { callback_ = callback; }
|
||||
|
||||
bool TcpConnection::SendMessage(std::shared_ptr<CommMessage> message) const {
|
||||
bool TcpConnection::SendMessage(const std::shared_ptr<CommMessage> &message) const {
|
||||
MS_EXCEPTION_IF_NULL(buffer_event_);
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
bufferevent_lock(buffer_event_);
|
||||
|
@ -66,7 +66,7 @@ bool TcpConnection::SendMessage(std::shared_ptr<CommMessage> message) const {
|
|||
return res;
|
||||
}
|
||||
|
||||
bool TcpConnection::SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
|
||||
bool TcpConnection::SendMessage(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
|
||||
size_t size) const {
|
||||
MS_EXCEPTION_IF_NULL(buffer_event_);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
|
@ -325,12 +325,13 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
|
|||
MS_EXCEPTION_IF_NULL(conn);
|
||||
SetTcpNoDelay(fd);
|
||||
server->AddConnection(fd, conn);
|
||||
conn->InitConnection([=](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
|
||||
OnServerReceiveMessage on_server_receive = server->GetServerReceive();
|
||||
if (on_server_receive) {
|
||||
on_server_receive(conn, meta, protos, data, size);
|
||||
}
|
||||
});
|
||||
conn->InitConnection(
|
||||
[=](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) {
|
||||
OnServerReceiveMessage on_server_receive = server->GetServerReceive();
|
||||
if (on_server_receive) {
|
||||
on_server_receive(conn, meta, protos, data, size);
|
||||
}
|
||||
});
|
||||
bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback,
|
||||
reinterpret_cast<void *>(conn.get()));
|
||||
if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) {
|
||||
|
@ -440,13 +441,13 @@ void TcpServer::SetTcpNoDelay(const evutil_socket_t &fd) {
|
|||
}
|
||||
}
|
||||
|
||||
bool TcpServer::SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
|
||||
bool TcpServer::SendMessage(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<CommMessage> &message) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
return conn->SendMessage(message);
|
||||
}
|
||||
|
||||
bool TcpServer::SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
bool TcpServer::SendMessage(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
|
@ -454,7 +455,7 @@ bool TcpServer::SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr
|
|||
return conn->SendMessage(meta, protos, data, size);
|
||||
}
|
||||
|
||||
void TcpServer::SendMessage(std::shared_ptr<CommMessage> message) {
|
||||
void TcpServer::SendMessage(const std::shared_ptr<CommMessage> &message) {
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
std::lock_guard<std::mutex> lock(connection_mutex_);
|
||||
|
||||
|
@ -467,7 +468,7 @@ uint16_t TcpServer::BoundPort() const { return server_port_; }
|
|||
|
||||
std::string TcpServer::BoundIp() const { return server_address_; }
|
||||
|
||||
int TcpServer::ConnectionNum() const { return connections_.size(); }
|
||||
int TcpServer::ConnectionNum() const { return SizeToInt(connections_.size()); }
|
||||
|
||||
const std::map<evutil_socket_t, std::shared_ptr<TcpConnection>> &TcpServer::Connections() const { return connections_; }
|
||||
|
||||
|
|
|
@ -58,8 +58,8 @@ class TcpConnection {
|
|||
|
||||
virtual void InitConnection(const messageReceive &callback);
|
||||
virtual void SendMessage(const void *buffer, size_t num) const;
|
||||
bool SendMessage(std::shared_ptr<CommMessage> message) const;
|
||||
bool SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) const;
|
||||
bool SendMessage(const std::shared_ptr<CommMessage> &message) const;
|
||||
bool SendMessage(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) const;
|
||||
virtual void OnReadHandler(const void *buffer, size_t numBytes);
|
||||
const TcpServer *GetServer() const;
|
||||
const evutil_socket_t &GetFd() const;
|
||||
|
@ -74,8 +74,8 @@ class TcpConnection {
|
|||
};
|
||||
|
||||
using OnServerReceiveMessage =
|
||||
std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
|
||||
const void *data, size_t size)>;
|
||||
std::function<void(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &protos, const void *data, size_t size)>;
|
||||
|
||||
class TcpServer {
|
||||
public:
|
||||
|
@ -105,10 +105,10 @@ class TcpServer {
|
|||
std::shared_ptr<TcpConnection> GetConnectionByFd(const evutil_socket_t &fd);
|
||||
OnServerReceiveMessage GetServerReceive() const;
|
||||
void SetMessageCallback(const OnServerReceiveMessage &cb);
|
||||
bool SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
|
||||
bool SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
|
||||
const void *data, size_t sizee);
|
||||
void SendMessage(std::shared_ptr<CommMessage> message);
|
||||
bool SendMessage(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<CommMessage> &message);
|
||||
bool SendMessage(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &protos, const void *data, size_t sizee);
|
||||
void SendMessage(const std::shared_ptr<CommMessage> &message);
|
||||
uint16_t BoundPort() const;
|
||||
std::string BoundIp() const;
|
||||
int ConnectionNum() const;
|
||||
|
|
|
@ -51,7 +51,7 @@ bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommM
|
|||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
uint64_t Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
|
||||
uint64_t Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
|
@ -99,7 +99,7 @@ bool Node::CheckMessageTrack(const uint64_t &request_id) {
|
|||
return message_tracker_[request_id].first == message_tracker_[request_id].second + 1;
|
||||
}
|
||||
|
||||
void Node::NotifyMessageArrival(std::shared_ptr<MessageMeta> meta) {
|
||||
void Node::NotifyMessageArrival(const std::shared_ptr<MessageMeta> &meta) {
|
||||
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
|
||||
uint64_t request_id = meta->request_id();
|
||||
|
||||
|
|
|
@ -79,12 +79,12 @@ class Node {
|
|||
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
// Send data asynchronously
|
||||
uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
|
||||
uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &protos, const void *data, size_t size);
|
||||
|
||||
uint64_t AddMessageTrack(const uint32_t &expected_response);
|
||||
bool CheckMessageTrack(const uint64_t &request_id);
|
||||
void NotifyMessageArrival(std::shared_ptr<MessageMeta> meta);
|
||||
void NotifyMessageArrival(const std::shared_ptr<MessageMeta> &meta);
|
||||
|
||||
NodeInfo node_info_;
|
||||
std::atomic<bool> is_ready_;
|
||||
|
|
|
@ -201,17 +201,19 @@ void NodeManager::AddScaleOutDoneNode(const std::string &node_id) { scale_out_do
|
|||
|
||||
void NodeManager::AddScaleInDoneNode(const std::string &node_id) { scale_in_done_nodes_id_.insert(node_id); }
|
||||
|
||||
bool NodeManager::IsAllNodesRegistered() {
|
||||
bool NodeManager::IsAllNodesRegistered() const {
|
||||
int32_t num = std::count_if(registered_nodes_info_.begin(), registered_nodes_info_.end(),
|
||||
[](auto item) { return item.second.is_alive == true; });
|
||||
return num == total_node_num_;
|
||||
}
|
||||
|
||||
bool NodeManager::IsAllNodesFinished() { return SizeToInt(finish_nodes_id_.size()) == total_node_num_; }
|
||||
bool NodeManager::IsAllNodesFinished() const { return SizeToInt(finish_nodes_id_.size()) == total_node_num_; }
|
||||
|
||||
bool NodeManager::IsAllNodesScaleOutDone() { return SizeToInt(scale_out_done_nodes_id_.size()) == total_node_num_; }
|
||||
bool NodeManager::IsAllNodesScaleOutDone() const {
|
||||
return SizeToInt(scale_out_done_nodes_id_.size()) == total_node_num_;
|
||||
}
|
||||
|
||||
bool NodeManager::IsAllNodesScaleInDone() { return SizeToInt(scale_in_done_nodes_id_.size()) == total_node_num_; }
|
||||
bool NodeManager::IsAllNodesScaleInDone() const { return SizeToInt(scale_in_done_nodes_id_.size()) == total_node_num_; }
|
||||
|
||||
std::unordered_map<std::string, NodeInfo> &NodeManager::nodes_info() { return nodes_info_; }
|
||||
|
||||
|
|
|
@ -75,17 +75,17 @@ class NodeManager {
|
|||
|
||||
// When workers and servers registered to scheduler, the scheduler will collect the number of registered
|
||||
// nodes and Determine whether the registered number of worker and server is equal to total_node_num_.
|
||||
bool IsAllNodesRegistered();
|
||||
bool IsAllNodesRegistered() const;
|
||||
// When workers and servers send a finish message to the scheduler, the scheduler will collect the number of
|
||||
// finish nodes and Determine whether the finished nodes are equal to total_node_num_.
|
||||
bool IsAllNodesFinished();
|
||||
bool IsAllNodesFinished() const;
|
||||
|
||||
// When workers and servers send a scale_out_done message to the scheduler, the scheduler will collect the number of
|
||||
// nodes and Determine whether the nodes are equal to total_node_num_.
|
||||
bool IsAllNodesScaleOutDone();
|
||||
bool IsAllNodesScaleOutDone() const;
|
||||
// When workers and servers send a scale_in_done message to the scheduler, the scheduler will collect the number of
|
||||
// nodes and Determine whether the nodes are equal to total_node_num_.
|
||||
bool IsAllNodesScaleInDone();
|
||||
bool IsAllNodesScaleInDone() const;
|
||||
|
||||
std::unordered_map<std::string, NodeInfo> &nodes_info();
|
||||
std::unordered_map<std::string, NodeInfo> ®istered_nodes_info();
|
||||
|
|
|
@ -43,9 +43,9 @@ message ParamInitInfoMessage {
|
|||
}
|
||||
|
||||
message KVMessage {
|
||||
repeated int32 keys = 2;
|
||||
repeated uint64 keys = 2;
|
||||
repeated float values = 3;
|
||||
repeated int32 len = 4;
|
||||
repeated uint64 len = 4;
|
||||
}
|
||||
|
||||
message EmbeddingTableMeta {
|
||||
|
|
|
@ -46,14 +46,15 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
|
||||
void SchedulerNode::ProcessHeartbeat(const std::shared_ptr<TcpServer> &server,
|
||||
const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(server);
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
HeartbeatMessage heartbeat_message;
|
||||
heartbeat_message.ParseFromArray(data, size);
|
||||
CHECK_RETURN_TYPE(heartbeat_message.ParseFromArray(data, SizeToInt(size)));
|
||||
|
||||
node_manager_.UpdateHeartbeat(heartbeat_message.node_id());
|
||||
|
||||
|
@ -101,7 +102,7 @@ void SchedulerNode::CreateTcpServer() {
|
|||
std::string scheduler_host = PSContext::instance()->cluster_config().scheduler_host;
|
||||
uint32_t scheduler_port = PSContext::instance()->cluster_config().scheduler_port;
|
||||
server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port);
|
||||
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
server_->SetMessageCallback([&](const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &, const void *data, size_t size) {
|
||||
if (handlers_.count(meta->cmd()) == 0) {
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
|
||||
|
@ -112,23 +113,22 @@ void SchedulerNode::CreateTcpServer() {
|
|||
|
||||
server_->Init();
|
||||
|
||||
scheduler_thread_ = std::make_unique<std::thread>([&]() {
|
||||
scheduler_thread_ = std::make_unique<std::thread>([this]() {
|
||||
MS_LOG(INFO) << "The scheduler node start a tcp server!";
|
||||
server_->Start();
|
||||
this->server_->Start();
|
||||
});
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
|
||||
void SchedulerNode::ProcessRegister(const std::shared_ptr<TcpServer> &server,
|
||||
const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(server);
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
MS_LOG(INFO) << "The scheduler process a register message!";
|
||||
RegisterMessage register_message;
|
||||
if (!register_message.ParseFromArray(data, SizeToInt(size))) {
|
||||
MS_LOG(WARNING) << "Parse data failed.";
|
||||
}
|
||||
CHECK_RETURN_TYPE(register_message.ParseFromArray(data, SizeToInt(size)));
|
||||
|
||||
// assign worker node and server node rank id
|
||||
uint32_t rank_id = node_manager_.NextRankId(register_message, meta);
|
||||
|
@ -167,8 +167,8 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
|
|||
}
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
|
||||
void SchedulerNode::ProcessFinish(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(server);
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
|
@ -202,8 +202,9 @@ void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared
|
|||
}
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessFetchMetadata(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
|
||||
void SchedulerNode::ProcessFetchMetadata(const std::shared_ptr<TcpServer> &server,
|
||||
const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(server);
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
|
@ -217,8 +218,9 @@ void SchedulerNode::ProcessFetchMetadata(std::shared_ptr<TcpServer> server, std:
|
|||
fetch_servers_message.ByteSizeLong());
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessScaleOutDone(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
|
||||
void SchedulerNode::ProcessScaleOutDone(const std::shared_ptr<TcpServer> &server,
|
||||
const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(server);
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
|
@ -242,8 +244,9 @@ void SchedulerNode::ProcessScaleOutDone(std::shared_ptr<TcpServer> server, std::
|
|||
}
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessScaleInDone(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
|
||||
void SchedulerNode::ProcessScaleInDone(const std::shared_ptr<TcpServer> &server,
|
||||
const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(server);
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
|
@ -267,8 +270,9 @@ void SchedulerNode::ProcessScaleInDone(std::shared_ptr<TcpServer> server, std::s
|
|||
}
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessSendEvent(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
|
||||
void SchedulerNode::ProcessSendEvent(const std::shared_ptr<TcpServer> &server,
|
||||
const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(server);
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
|
|
|
@ -58,8 +58,10 @@ class SchedulerNode : public Node {
|
|||
scheduler_recovery_(nullptr) {}
|
||||
~SchedulerNode() override;
|
||||
|
||||
typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
typedef void (SchedulerNode::*ResponseHandler)(const std::shared_ptr<TcpServer> &server,
|
||||
const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data,
|
||||
size_t size);
|
||||
|
||||
bool Start(const uint32_t &timeout = PSContext::instance()->cluster_config().cluster_available_timeout) override;
|
||||
bool Stop() override;
|
||||
|
@ -73,24 +75,24 @@ class SchedulerNode : public Node {
|
|||
void StartUpdateClusterStateTimer();
|
||||
const std::shared_ptr<TcpClient> &GetOrCreateClient(const NodeInfo &node_info);
|
||||
|
||||
void ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void ProcessFetchMetadata(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void ProcessHeartbeat(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
|
||||
void ProcessRegister(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
|
||||
void ProcessFinish(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
|
||||
void ProcessFetchMetadata(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
|
||||
|
||||
// Process scale_out_done messages from workers/servers
|
||||
void ProcessScaleOutDone(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void ProcessScaleOutDone(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
|
||||
// Process scale_in_done messages from workers/servers
|
||||
void ProcessScaleInDone(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void ProcessScaleInDone(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
|
||||
// Process scale_in_done messages from workers/servers
|
||||
void ProcessSendEvent(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void ProcessSendEvent(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
|
||||
|
||||
// After scheduler collects all registered message, it actively sends finish to the node connected by the client.
|
||||
void SendMetadata(const std::shared_ptr<TcpClient> &client, uint32_t rank_id);
|
||||
|
|
|
@ -42,8 +42,8 @@ bool ServerNode::Start(const uint32_t &timeout) {
|
|||
|
||||
void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; }
|
||||
|
||||
void ServerNode::Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const void *data,
|
||||
size_t size) {
|
||||
void ServerNode::Response(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
|
@ -59,7 +59,7 @@ void ServerNode::CreateTcpServer() {
|
|||
std::string server_ip;
|
||||
CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
|
||||
server_ = std::make_shared<TcpServer>(server_ip, 0);
|
||||
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
server_->SetMessageCallback([&](const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
if (server_handler_.count(meta->cmd()) == 0) {
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
|
||||
|
@ -107,7 +107,7 @@ void ServerNode::Initialize() {
|
|||
MS_LOG(INFO) << "[Server start]: 3. Server node crete tcp client to scheduler successful!";
|
||||
}
|
||||
|
||||
void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
void ServerNode::ProcessSendData(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
|
@ -128,8 +128,8 @@ void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::share
|
|||
request_handler_(conn, meta, res, size);
|
||||
}
|
||||
|
||||
void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const void *data, size_t size) {
|
||||
void ServerNode::ProcessCollectiveSendData(const std::shared_ptr<TcpConnection> &conn,
|
||||
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
server_->SendMessage(conn, meta, Protos::RAW, data, size);
|
||||
|
|
|
@ -51,11 +51,13 @@ class ServerNode : public AbstractNode {
|
|||
bool Stop() override;
|
||||
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
|
||||
|
||||
using RequestHandler = std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
DataPtr data, size_t size)>;
|
||||
using RequestHandler =
|
||||
std::function<void(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const DataPtr &data, size_t size)>;
|
||||
|
||||
void set_handler(const RequestHandler &handler);
|
||||
void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void Response(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, const void *data,
|
||||
size_t size);
|
||||
|
||||
std::shared_ptr<CommunicatorBase> GetOrCreateHttpComm(const std::string &ip, uint16_t port,
|
||||
const std::shared_ptr<TaskExecutor> &task_executor);
|
||||
|
@ -66,9 +68,9 @@ class ServerNode : public AbstractNode {
|
|||
private:
|
||||
void CreateTcpServer();
|
||||
void Initialize();
|
||||
void ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
|
||||
const void *data, size_t size);
|
||||
void ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
void ProcessSendData(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const Protos &protos, const void *data, size_t size);
|
||||
void ProcessCollectiveSendData(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||
const void *data, size_t size);
|
||||
|
||||
RequestHandler request_handler_;
|
||||
|
|
|
@ -237,10 +237,11 @@ void ParameterServer::UpdateWeights() {
|
|||
shapes.push_back(indices_shape);
|
||||
|
||||
if (original_optim_inputs_shape_.count(key) != 0) {
|
||||
std::transform(
|
||||
(*(original_optim_inputs_shape_[key])).begin(), (*(original_optim_inputs_shape_[key])).end(),
|
||||
std::back_inserter(shapes),
|
||||
[](std::shared_ptr<std::vector<size_t>> input_shapes) -> std::vector<size_t> { return *input_shapes; });
|
||||
std::transform((*(original_optim_inputs_shape_[key])).begin(), (*(original_optim_inputs_shape_[key])).end(),
|
||||
std::back_inserter(shapes),
|
||||
[](const std::shared_ptr<std::vector<size_t>> &input_shapes) -> std::vector<size_t> {
|
||||
return *input_shapes;
|
||||
});
|
||||
}
|
||||
optimizer->ReInit(shapes);
|
||||
optim_info->ComputeMean(shapes, worker_num_, pserver_num_, server_node_->rank_id());
|
||||
|
@ -377,7 +378,7 @@ void ParameterServer::UpdateEmbeddings(const Key &key, const LookupIds &lookup_i
|
|||
table_lookup_op->UpdateEmbeddings(table_ptr->data(), lookup_ids.data(), vals.data(), lookup_ids.size());
|
||||
}
|
||||
|
||||
inline bool ParameterServer::ReadyForUpdateWeights() {
|
||||
inline bool ParameterServer::ReadyForUpdateWeights() const {
|
||||
return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size();
|
||||
}
|
||||
|
||||
|
@ -387,9 +388,7 @@ inline bool ParameterServer::ReadyForPush(const Key &key) {
|
|||
MS_LOG(EXCEPTION) << "The weights in server is empty. Many reasons could cause this: 1.The Worker didn't send "
|
||||
"kInitWeightsCmd command. 2.The Server failed to initialize weights.";
|
||||
}
|
||||
MS_LOG(INFO) << "The grad_accum_count_:" << grad_accum_count_ << " the weights_:" << weights_.size()
|
||||
<< " the token:" << (tokens_[key] <= 0);
|
||||
return grad_accum_count_ < weights_.size() && tokens_[key] <= 0;
|
||||
return grad_accum_count_ < weights_.size() && tokens_[key] == 0;
|
||||
}
|
||||
|
||||
inline bool ParameterServer::ReadyForPull(const Key &key) {
|
||||
|
@ -504,8 +503,9 @@ void ParameterServer::ServerHandler::Init() {
|
|||
commands_[kPullCmd] = "kPullCmd";
|
||||
}
|
||||
|
||||
void ParameterServer::ServerHandler::operator()(std::shared_ptr<core::TcpConnection> conn,
|
||||
std::shared_ptr<core::MessageMeta> meta, DataPtr data, size_t size) {
|
||||
void ParameterServer::ServerHandler::operator()(const std::shared_ptr<core::TcpConnection> &conn,
|
||||
const std::shared_ptr<core::MessageMeta> &meta, const DataPtr &data,
|
||||
size_t size) {
|
||||
auto output = std::make_shared<std::vector<unsigned char>>();
|
||||
if (commands_.count(meta->user_cmd()) == 0) {
|
||||
MS_LOG(EXCEPTION) << "The command:" << meta->user_cmd() << " is not supported!";
|
||||
|
@ -530,10 +530,10 @@ void ParameterServer::ServerHandler::operator()(std::shared_ptr<core::TcpConnect
|
|||
.count();
|
||||
}
|
||||
|
||||
void ParameterServer::ServerHandler::HandlePushReq(DataPtr data, size_t size, VectorPtr res) {
|
||||
void ParameterServer::ServerHandler::HandlePushReq(const DataPtr &data, size_t size, const VectorPtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
KVMessage input;
|
||||
input.ParseFromArray(data.get(), size);
|
||||
CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
|
||||
Keys keys = {input.keys().begin(), input.keys().end()};
|
||||
Values values = {input.values().begin(), input.values().end()};
|
||||
Lengths lens = {input.len().begin(), input.len().end()};
|
||||
|
@ -541,10 +541,10 @@ void ParameterServer::ServerHandler::HandlePushReq(DataPtr data, size_t size, Ve
|
|||
ps_->AccumGrad(keys, values, lens);
|
||||
}
|
||||
|
||||
void ParameterServer::ServerHandler::HandlePullReq(DataPtr data, size_t size, VectorPtr res) {
|
||||
void ParameterServer::ServerHandler::HandlePullReq(const DataPtr &data, size_t size, const VectorPtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
KVMessage input;
|
||||
input.ParseFromArray(data.get(), size);
|
||||
CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
|
||||
KVMessage res_data;
|
||||
*res_data.mutable_keys() = input.keys();
|
||||
Key key = input.keys()[0];
|
||||
|
@ -559,13 +559,11 @@ void ParameterServer::ServerHandler::HandlePullReq(DataPtr data, size_t size, Ve
|
|||
}
|
||||
}
|
||||
|
||||
void ParameterServer::ServerHandler::HandleInitWeights(DataPtr data, size_t size, VectorPtr res) {
|
||||
void ParameterServer::ServerHandler::HandleInitWeights(const DataPtr &data, size_t size, const VectorPtr &res) {
|
||||
std::unique_lock<std::mutex> lock(ps_->mutex());
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
KVMessage input;
|
||||
if (!input.ParseFromArray(data.get(), SizeToInt(size))) {
|
||||
MS_LOG(WARNING) << "Parse data failed.";
|
||||
}
|
||||
CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
|
||||
int key_num = input.keys_size();
|
||||
const float *data_ptr = input.values().data();
|
||||
size_t pos = 0;
|
||||
|
@ -586,13 +584,11 @@ void ParameterServer::ServerHandler::HandleInitWeights(DataPtr data, size_t size
|
|||
}
|
||||
}
|
||||
|
||||
void ParameterServer::ServerHandler::HandleInitWeightToOptimId(DataPtr data, size_t size, VectorPtr res) {
|
||||
void ParameterServer::ServerHandler::HandleInitWeightToOptimId(const DataPtr &data, size_t size, const VectorPtr &res) {
|
||||
std::unique_lock<std::mutex> lock(ps_->mutex());
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
KVMessage input;
|
||||
if (!input.ParseFromArray(data.get(), SizeToInt(size))) {
|
||||
MS_LOG(WARNING) << "Parse data failed.";
|
||||
}
|
||||
CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
|
||||
int key_num = input.keys_size();
|
||||
for (int i = 0; i < key_num; i++) {
|
||||
Key key = input.keys()[i];
|
||||
|
@ -606,11 +602,11 @@ void ParameterServer::ServerHandler::HandleInitWeightToOptimId(DataPtr data, siz
|
|||
}
|
||||
}
|
||||
|
||||
void ParameterServer::ServerHandler::HandleInitInputsShape(DataPtr data, size_t size, VectorPtr res) {
|
||||
void ParameterServer::ServerHandler::HandleInitInputsShape(const DataPtr &data, size_t size, const VectorPtr &res) {
|
||||
std::unique_lock<std::mutex> lock(ps_->mutex());
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
KVMessage input;
|
||||
input.ParseFromArray(data.get(), size);
|
||||
CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
|
||||
const Key &key = input.keys()[0];
|
||||
if (init_optim_info_[key]) {
|
||||
return;
|
||||
|
@ -623,10 +619,10 @@ void ParameterServer::ServerHandler::HandleInitInputsShape(DataPtr data, size_t
|
|||
ps_->InitOptimInputsShape(keys, values, lens);
|
||||
}
|
||||
|
||||
void ParameterServer::ServerHandler::HandleInitEmbeddings(DataPtr data, size_t size, VectorPtr res) {
|
||||
void ParameterServer::ServerHandler::HandleInitEmbeddings(const DataPtr &data, size_t size, const VectorPtr &) {
|
||||
std::unique_lock<std::mutex> lock(ps_->mutex());
|
||||
EmbeddingTableMeta embedding_table_meta;
|
||||
embedding_table_meta.ParseFromArray(data.get(), size);
|
||||
CHECK_RETURN_TYPE(embedding_table_meta.ParseFromArray(data.get(), SizeToInt(size)));
|
||||
const Key &key = embedding_table_meta.key();
|
||||
MS_LOG(INFO) << "Initializing embedding table for key:" << key;
|
||||
std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> shapes =
|
||||
|
@ -659,10 +655,10 @@ void ParameterServer::ServerHandler::HandleInitEmbeddings(DataPtr data, size_t s
|
|||
ps_->InitEmbeddingTable(key, shapes, param_init_info);
|
||||
}
|
||||
|
||||
void ParameterServer::ServerHandler::HandleCheckReadyForPush(DataPtr data, size_t size, VectorPtr res) {
|
||||
void ParameterServer::ServerHandler::HandleCheckReadyForPush(const DataPtr &data, size_t size, const VectorPtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
KVMessage input;
|
||||
input.ParseFromArray(data.get(), size);
|
||||
CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
|
||||
const Key &key = input.keys()[0];
|
||||
bool ready = ps_->ReadyForPush(key);
|
||||
MS_LOG(INFO) << "The ready is:" << ready;
|
||||
|
@ -678,10 +674,10 @@ void ParameterServer::ServerHandler::HandleCheckReadyForPush(DataPtr data, size_
|
|||
}
|
||||
}
|
||||
|
||||
void ParameterServer::ServerHandler::HandleCheckReadyForPull(DataPtr data, size_t size, VectorPtr res) {
|
||||
void ParameterServer::ServerHandler::HandleCheckReadyForPull(const DataPtr &data, size_t size, const VectorPtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
KVMessage input;
|
||||
input.ParseFromArray(data.get(), size);
|
||||
CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
|
||||
const Key &key = input.keys()[0];
|
||||
bool ready = ps_->ReadyForPull(key);
|
||||
KVMessage res_data;
|
||||
|
@ -696,10 +692,10 @@ void ParameterServer::ServerHandler::HandleCheckReadyForPull(DataPtr data, size_
|
|||
}
|
||||
}
|
||||
|
||||
void ParameterServer::ServerHandler::HandleEmbeddingLookup(DataPtr data, size_t size, VectorPtr res) {
|
||||
void ParameterServer::ServerHandler::HandleEmbeddingLookup(const DataPtr &data, size_t size, const VectorPtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
EmbeddingTableLookup input;
|
||||
input.ParseFromArray(data.get(), size);
|
||||
CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
|
||||
const Key &key = input.key();
|
||||
|
||||
KVMessage res_data;
|
||||
|
@ -717,18 +713,18 @@ void ParameterServer::ServerHandler::HandleEmbeddingLookup(DataPtr data, size_t
|
|||
}
|
||||
}
|
||||
|
||||
void ParameterServer::ServerHandler::HandleUpdateEmbeddings(DataPtr data, size_t size, VectorPtr res) {
|
||||
void ParameterServer::ServerHandler::HandleUpdateEmbeddings(const DataPtr &data, size_t size, const VectorPtr &res) {
|
||||
std::unique_lock<std::mutex> lock(ps_->mutex());
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
KVMessage input;
|
||||
input.ParseFromArray(data.get(), size);
|
||||
CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
|
||||
const Key &key = input.keys()[0];
|
||||
const LookupIds &lookup_ids = {input.keys().begin() + 1, input.keys().end()};
|
||||
const Values &update_vals = {input.values().begin(), input.values().end()};
|
||||
ps_->UpdateEmbeddings(key, lookup_ids, update_vals);
|
||||
}
|
||||
|
||||
void ParameterServer::ServerHandler::HandleFinalize(DataPtr, size_t, VectorPtr res) {
|
||||
void ParameterServer::ServerHandler::HandleFinalize(const DataPtr &, size_t, const VectorPtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
ps_->Finalize();
|
||||
}
|
||||
|
|
|
@ -32,6 +32,8 @@
|
|||
#include <list>
|
||||
#include <map>
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
|
||||
#include "ir/func_graph.h"
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
@ -92,23 +94,23 @@ class ParameterServer {
|
|||
explicit ServerHandler(ParameterServer *ps) : ps_(ps) {}
|
||||
~ServerHandler() = default;
|
||||
void Init();
|
||||
void operator()(std::shared_ptr<core::TcpConnection> conn, std::shared_ptr<core::MessageMeta> meta, DataPtr data,
|
||||
size_t size);
|
||||
void HandlePushReq(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandlePullReq(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleInitWeights(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleInitWeightToOptimId(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleInitInputsShape(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleInitEmbeddings(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleCheckReadyForPush(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleCheckReadyForPull(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleEmbeddingLookup(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleUpdateEmbeddings(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleFinalize(DataPtr data, size_t size, VectorPtr res);
|
||||
void operator()(const std::shared_ptr<core::TcpConnection> &conn, const std::shared_ptr<core::MessageMeta> &meta,
|
||||
const DataPtr &data, size_t size);
|
||||
void HandlePushReq(const DataPtr &data, size_t size, const VectorPtr &res);
|
||||
void HandlePullReq(const DataPtr &data, size_t size, const VectorPtr &res);
|
||||
void HandleInitWeights(const DataPtr &data, size_t size, const VectorPtr &res);
|
||||
void HandleInitWeightToOptimId(const DataPtr &data, size_t size, const VectorPtr &res);
|
||||
void HandleInitInputsShape(const DataPtr &data, size_t size, const VectorPtr &res);
|
||||
void HandleInitEmbeddings(const DataPtr &data, size_t size, const VectorPtr &res);
|
||||
void HandleCheckReadyForPush(const DataPtr &data, size_t size, const VectorPtr &res);
|
||||
void HandleCheckReadyForPull(const DataPtr &data, size_t size, const VectorPtr &res);
|
||||
void HandleEmbeddingLookup(const DataPtr &data, size_t size, const VectorPtr &res);
|
||||
void HandleUpdateEmbeddings(const DataPtr &data, size_t size, const VectorPtr &res);
|
||||
void HandleFinalize(const DataPtr &data, size_t size, const VectorPtr &res);
|
||||
|
||||
private:
|
||||
ParameterServer *ps_;
|
||||
typedef void (ServerHandler::*RequestHandler)(DataPtr data, size_t size, VectorPtr res);
|
||||
typedef void (ServerHandler::*RequestHandler)(const DataPtr &data, size_t size, const VectorPtr &res);
|
||||
std::unordered_map<int, RequestHandler> handlers_;
|
||||
std::unordered_map<int, std::string> commands_;
|
||||
std::unordered_map<Key, bool> init_weights_;
|
||||
|
@ -132,12 +134,12 @@ class ParameterServer {
|
|||
WeightPtr weight(const Key &key);
|
||||
void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, KVMessage *res);
|
||||
void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals);
|
||||
bool ReadyForUpdateWeights();
|
||||
bool ReadyForPush(const Key &key);
|
||||
bool ReadyForPull(const Key &key);
|
||||
void ResetGradAccumCount();
|
||||
inline bool ReadyForUpdateWeights() const;
|
||||
inline bool ReadyForPush(const Key &key);
|
||||
inline bool ReadyForPull(const Key &key);
|
||||
inline void ResetGradAccumCount();
|
||||
const CNodePtr GetCNode(const std::string &name) const;
|
||||
std::mutex &mutex();
|
||||
inline std::mutex &mutex();
|
||||
void GetEmbeddingTableParamPtr();
|
||||
void SyncEmbeddingTables();
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ bool Util::IsRoleOfPServer() { return PSContext::instance()->is_server(); }
|
|||
|
||||
bool Util::IsRoleOfScheduler() { return PSContext::instance()->is_scheduler(); }
|
||||
|
||||
int64_t Util::optimizer_id(std::string name) {
|
||||
int64_t Util::optimizer_id(const std::string &name) {
|
||||
if (optimizer_to_ids.count(name) > 0) {
|
||||
return optimizer_to_ids[name];
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ std::string Util::optimizer_node_name(int64_t id) {
|
|||
return "";
|
||||
}
|
||||
|
||||
bool Util::is_optimizer(std::string name) { return optimizer_to_ids.count(name) > 0; }
|
||||
bool Util::is_optimizer(const std::string &name) { return optimizer_to_ids.count(name) > 0; }
|
||||
|
||||
int64_t Util::LocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num) {
|
||||
std::map<int64_t, int64_t> shard_dims = AllRankLocalShard(first_dim, rank_id, server_num);
|
||||
|
|
|
@ -44,10 +44,10 @@ class Util {
|
|||
public:
|
||||
static bool IsRoleOfPServer();
|
||||
static bool IsRoleOfScheduler();
|
||||
static int64_t optimizer_id(std::string name);
|
||||
static int64_t optimizer_id(const std::string &name);
|
||||
static std::string optimizer_name(int64_t id);
|
||||
static std::string optimizer_node_name(int64_t id);
|
||||
static bool is_optimizer(std::string name);
|
||||
static bool is_optimizer(const std::string &name);
|
||||
static int64_t LocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num);
|
||||
static std::map<int64_t, int64_t> AllRankLocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num);
|
||||
static void ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size,
|
||||
|
|
|
@ -307,7 +307,9 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_
|
|||
}
|
||||
|
||||
std::vector<VectorPtr> resp;
|
||||
worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, cmd, &resp);
|
||||
if (!worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, cmd, &resp)) {
|
||||
MS_LOG(ERROR) << "Worker send failed!";
|
||||
}
|
||||
int64_t single_id_len = SizeToLong(lookup_result->size() / lookup_ids.size());
|
||||
std::unordered_map<Key, std::shared_ptr<std::pair<float *, int64_t>>> id_addr_map;
|
||||
std::shared_ptr<std::vector<float>> values = std::make_shared<std::vector<float>>();
|
||||
|
@ -315,7 +317,7 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_
|
|||
int64_t value_offset = 0;
|
||||
for (size_t i = 0; i < resp.size(); ++i) {
|
||||
KVMessage message;
|
||||
message.ParseFromArray(resp.at(i)->data(), resp.at(i)->size());
|
||||
CHECK_RETURN_TYPE(message.ParseFromArray(resp.at(i)->data(), resp.at(i)->size()));
|
||||
for (auto j = 0; j < message.values_size(); j++) {
|
||||
values->push_back(message.values(j));
|
||||
}
|
||||
|
@ -630,7 +632,7 @@ void Worker::BuildSparseValue(const std::vector<int> &lengths, const size_t grad
|
|||
}
|
||||
|
||||
void Worker::PushData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
|
||||
int cmd, int64_t priority) {
|
||||
int cmd, int64_t) {
|
||||
KVMessage kvs;
|
||||
*kvs.mutable_keys() = {keys.begin(), keys.end()};
|
||||
*kvs.mutable_values() = {vals.begin(), vals.end()};
|
||||
|
@ -682,7 +684,7 @@ void Worker::PullData(const std::vector<Key> &keys, std::vector<float> *const va
|
|||
}
|
||||
|
||||
void Worker::LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition,
|
||||
const std::map<int64_t, int64_t> &attrs) {
|
||||
const std::map<int64_t, int64_t> &) {
|
||||
MS_EXCEPTION_IF_NULL(partition);
|
||||
|
||||
const Key &key = send.key();
|
||||
|
@ -829,7 +831,7 @@ void Worker::SparsePartitioner(const KVMessage &send, PartitionKVMessages *parti
|
|||
}
|
||||
|
||||
void Worker::RoundRobinPartitioner(const KVMessage &send, PartitionKVMessages *partition,
|
||||
const std::map<int64_t, int64_t> &attrs) {
|
||||
const std::map<int64_t, int64_t> &) {
|
||||
MS_EXCEPTION_IF_NULL(partition);
|
||||
partition->resize(server_num_);
|
||||
auto keys = send.keys();
|
||||
|
@ -888,7 +890,7 @@ void Worker::UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessag
|
|||
const std::map<int64_t, int64_t> &attrs) {
|
||||
MS_EXCEPTION_IF_NULL(partition);
|
||||
const float *embedding_vals = send.values().data();
|
||||
const int *lookup_ids = send.len().data();
|
||||
const uint64_t *lookup_ids = send.len().data();
|
||||
size_t val_size = send.values_size();
|
||||
size_t id_size = send.len_size();
|
||||
size_t embedding_dim = val_size / id_size;
|
||||
|
@ -904,7 +906,7 @@ void Worker::UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessag
|
|||
auto &kvs = partition->at(i).second;
|
||||
kvs.add_keys(key);
|
||||
for (size_t j = 0; j < id_size; j++) {
|
||||
auto lookup_id = static_cast<uint64_t>(lookup_ids[j]);
|
||||
auto lookup_id = lookup_ids[j];
|
||||
if (lookup_id >= begin && lookup_id <= end) {
|
||||
kvs.add_keys(lookup_id);
|
||||
for (size_t k = 0; k < embedding_dim; k++) {
|
||||
|
@ -922,7 +924,7 @@ void Worker::UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessag
|
|||
}
|
||||
|
||||
void Worker::BroadcastPartitioner(const KVMessage &send, PartitionKVMessages *partition,
|
||||
const std::map<int64_t, int64_t> &attrs) {
|
||||
const std::map<int64_t, int64_t> &) {
|
||||
MS_EXCEPTION_IF_NULL(partition);
|
||||
partition->resize(server_num_);
|
||||
for (int64_t i = 0; i < server_num_; i++) {
|
||||
|
@ -958,7 +960,7 @@ void Worker::SendForPush(int cmd, const KVMessage &send, const KVPartitioner &pa
|
|||
}
|
||||
|
||||
void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
|
||||
const std::map<int64_t, int64_t> &attrs, std::vector<float> *vals, std::vector<int> *lens) {
|
||||
const std::map<int64_t, int64_t> &, std::vector<float> *vals, std::vector<int> *lens) {
|
||||
MS_EXCEPTION_IF_NULL(vals);
|
||||
PartitionKVMessages messages;
|
||||
partitioner(send, &messages, {});
|
||||
|
@ -986,7 +988,7 @@ void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &pa
|
|||
vals->clear();
|
||||
for (size_t i = 0; i < resp.size(); ++i) {
|
||||
KVMessage message;
|
||||
message.ParseFromArray(resp.at(i)->data(), resp.at(i)->size());
|
||||
CHECK_RETURN_TYPE(message.ParseFromArray(resp.at(i)->data(), SizeToInt(resp.at(i)->size())));
|
||||
std::copy(message.values().begin(), message.values().end(), std::back_inserter(*vals));
|
||||
|
||||
if (lens) {
|
||||
|
|
Loading…
Reference in New Issue