fixed pclint

This commit is contained in:
chendongsheng 2021-07-13 16:57:22 +08:00
parent 45c76beda7
commit 18e8a6e7c9
27 changed files with 256 additions and 223 deletions

View File

@ -187,6 +187,13 @@ enum class CustomEvent { kIterationRunning = 0, kIterationCompleted };
#define ERROR_STATUS(result, code, message) \
MS_LOG(ERROR) << message; \
result = RequestProcessResult(code, message)
#define CHECK_RETURN_TYPE(_condition) \
do { \
if (!(_condition)) { \
MS_LOG(ERROR) << "Parse protobuf message failed."; \
} \
} while (false)
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_CONSTANTS_H_

View File

@ -42,11 +42,11 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
}
}
void AbstractNode::ProcessRegisterResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
void AbstractNode::ProcessRegisterResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
RegisterRespMessage register_resp_message;
register_resp_message.ParseFromArray(data, size);
CHECK_RETURN_TYPE(register_resp_message.ParseFromArray(data, SizeToInt(size)));
if (register_resp_message.node_id() != node_info_.node_id_) {
MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id()
<< " is not match the current node id:" << node_info_.node_id_;
@ -59,7 +59,7 @@ void AbstractNode::ProcessRegisterResp(std::shared_ptr<MessageMeta> meta, const
MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << " registered scheduler success!";
}
bool AbstractNode::Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command,
bool AbstractNode::Broadcast(const NodeRole &node_role, const DataPtr &message, size_t size, int command,
const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(message);
if (node_role != NodeRole::SERVER) {
@ -165,7 +165,7 @@ void AbstractNode::RegisterCustomEventCallback(const uint32_t &event, const Even
custom_event_to_callback_.try_emplace(event, event_cb);
}
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len,
bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len,
int command, const uint32_t &timeout) {
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
MS_LOG(DEBUG) << "The node is timeout, can not send message.";
@ -223,7 +223,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
return Wait(request_id, timeout);
}
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len,
bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len,
int command, VectorPtr *output, const uint32_t &timeout) {
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
MS_LOG(DEBUG) << "The node is timeout, can not send message.";
@ -309,7 +309,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
return Wait(request_id, timeout);
}
uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const void *data,
uint64_t AbstractNode::CollectiveSendAsync(const NodeRole &node_role, const uint32_t &rank_id, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(data);
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
@ -326,8 +326,8 @@ uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const
return SendMessageAsync(client, message_meta, Protos::RAW, data, size);
}
std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role,
const uint32_t &rank_id, VectorPtr *output) {
std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const NodeRole &node_role, const uint32_t &rank_id,
VectorPtr *output) {
MS_EXCEPTION_IF_NULL(output);
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
@ -359,7 +359,7 @@ std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum No
return std::make_pair(rank_id, rank_request_id);
}
bool AbstractNode::CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout) {
bool AbstractNode::CollectiveWait(const std::pair<uint32_t, uint64_t> &request_id, const uint32_t &timeout) {
std::unique_lock<std::mutex> lock(receive_callbacks_mutex_);
bool res =
receive_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { return receive_messages_done_[request_id]; });
@ -473,11 +473,11 @@ bool AbstractNode::CheckSchedulerTimeout() const {
return false;
}
void AbstractNode::ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
HeartbeatRespMessage heartbeat_resp_message;
heartbeat_resp_message.ParseFromArray(data, size);
CHECK_RETURN_TYPE(heartbeat_resp_message.ParseFromArray(data, SizeToInt(size)));
current_cluster_state_ = heartbeat_resp_message.cluster_state();
MS_LOG(DEBUG) << "The current cluster state from heartbeat:"
@ -529,11 +529,11 @@ void AbstractNode::FetchServers(const std::shared_ptr<TcpClient> &client) {
}
}
void AbstractNode::ProcessFetchServersResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
void AbstractNode::ProcessFetchServersResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
FetchServersRespMessage fetch_servers_resp_message;
fetch_servers_resp_message.ParseFromArray(data, size);
CHECK_RETURN_TYPE(fetch_servers_resp_message.ParseFromArray(data, SizeToInt(size)));
nodes_address_.clear();
for (const auto &it : fetch_servers_resp_message.servers_meta()) {
@ -542,8 +542,9 @@ void AbstractNode::ProcessFetchServersResp(std::shared_ptr<MessageMeta> meta, co
}
}
void AbstractNode::ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
void AbstractNode::ProcessSendMetadata(const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
@ -589,7 +590,7 @@ void AbstractNode::ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std:
connected_nodes_.clear();
}
void AbstractNode::ProcessFinish(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
void AbstractNode::ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
@ -599,8 +600,9 @@ void AbstractNode::ProcessFinish(std::shared_ptr<TcpConnection> conn, std::share
wait_finish_cond_.notify_all();
}
void AbstractNode::ProcessScaleOutDone(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
void AbstractNode::ProcessScaleOutDone(const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
@ -609,8 +611,9 @@ void AbstractNode::ProcessScaleOutDone(std::shared_ptr<TcpConnection> conn, std:
current_cluster_state_ = ClusterState::CLUSTER_READY;
}
void AbstractNode::ProcessScaleInDone(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
void AbstractNode::ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
@ -619,7 +622,7 @@ void AbstractNode::ProcessScaleInDone(std::shared_ptr<TcpConnection> conn, std::
current_cluster_state_ = ClusterState::CLUSTER_READY;
}
void AbstractNode::ProcessEvent(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
void AbstractNode::ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
@ -631,7 +634,7 @@ void AbstractNode::ProcessEvent(std::shared_ptr<TcpConnection> conn, std::shared
OnCustomEventCallback(event);
}
void AbstractNode::ProcessScaleOut(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
void AbstractNode::ProcessScaleOut(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
@ -649,7 +652,7 @@ void AbstractNode::ProcessScaleOut(std::shared_ptr<TcpConnection> conn, std::sha
is_ready_ = false;
}
void AbstractNode::ProcessScaleIn(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
void AbstractNode::ProcessScaleIn(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
@ -704,7 +707,7 @@ bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) {
bool AbstractNode::InitClientToScheduler() {
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_ip_, scheduler_port_);
client_to_scheduler_->SetMessageCallback(
[&](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
[&](const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data, size_t size) {
try {
if (handlers_.count(meta->cmd()) == 0) {
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
@ -750,7 +753,7 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const uint3
std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first;
uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second;
auto client = std::make_shared<TcpClient>(ip, port);
client->SetMessageCallback([&](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
client->SetMessageCallback([&](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
size_t size) {
switch (meta->cmd()) {
case NodeCommand::SEND_DATA:
@ -771,7 +774,7 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const uint3
}
}
void AbstractNode::ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
void AbstractNode::ProcessSendDataResp(const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
@ -827,7 +830,7 @@ void AbstractNode::set_message_callback(const uint64_t &request_id, const Messag
message_callbacks_[request_id] = callback;
}
void AbstractNode::RunReceiveCallback(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
void AbstractNode::RunReceiveCallback(const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);

View File

@ -53,9 +53,11 @@ class AbstractNode : public Node {
scheduler_port_(0) {}
~AbstractNode() override { is_finish_ = true; }
typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
typedef void (AbstractNode::*ServerHandler)(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size);
typedef void (AbstractNode::*ResponseHandler)(const std::shared_ptr<MessageMeta> &meta, const void *data,
size_t size);
typedef void (AbstractNode::*ServerHandler)(const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const Protos &protos,
const void *data, size_t size);
using DataPtr = std::shared_ptr<unsigned char[]>;
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
@ -95,7 +97,7 @@ class AbstractNode : public Node {
uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size);
std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
VectorPtr *output);
bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
bool CollectiveWait(const std::pair<uint32_t, uint64_t> &request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
// Initialize the scaler for server to process before/after scaling operations.
bool InitFollowerScaler();
@ -127,31 +129,31 @@ class AbstractNode : public Node {
bool Heartbeat(const std::shared_ptr<TcpClient> &client);
void FetchServers(const std::shared_ptr<TcpClient> &client);
void ProcessRegisterResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessFetchServersResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessRegisterResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
void ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
void ProcessFetchServersResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
void ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size);
void ProcessFinish(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size);
void ProcessSendMetadata(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size);
void ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size);
void ProcessScaleOut(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size);
void ProcessScaleOut(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size);
void ProcessScaleIn(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size);
void ProcessScaleIn(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size);
// The worker/server processes the scale_out_done message from scheduelr
void ProcessScaleOutDone(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size);
void ProcessScaleOutDone(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size);
// The worker/server processes the scale_in_done message from scheduelr
void ProcessScaleInDone(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size);
void ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size);
// The worker/server processes the SEND_EVENT message from scheduelr
void ProcessEvent(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size);
void ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size);
void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
void UpdateSchedulerTime();
@ -161,10 +163,12 @@ class AbstractNode : public Node {
bool InitClientToScheduler();
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const uint32_t &rank_id);
void ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
void ProcessSendDataResp(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
size_t size);
void RunMessageCallback(const uint64_t &request_id);
void set_message_callback(const uint64_t &request_id, const MessageCallback &callback);
void RunReceiveCallback(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
void RunReceiveCallback(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
size_t size);
uint64_t NextExpectedRankRequestId(const uint32_t &rank_id);
uint64_t NextActualRankRequestId(const uint32_t &rank_id);
void InitCommandHandler();

View File

@ -34,7 +34,7 @@ void HttpClient::Init() {
}
ResponseCode HttpClient::Post(const std::string &url, const void *body, size_t len,
std::shared_ptr<std::vector<char>> output,
const std::shared_ptr<std::vector<char>> &output,
const std::map<std::string, std::string> &headers) {
MS_EXCEPTION_IF_NULL(body);
MS_EXCEPTION_IF_NULL(output);
@ -65,7 +65,7 @@ ResponseCode HttpClient::Post(const std::string &url, const void *body, size_t l
return CreateRequest(handler, connection, request, HttpMethod::HM_POST);
}
ResponseCode HttpClient::Get(const std::string &url, std::shared_ptr<std::vector<char>> output,
ResponseCode HttpClient::Get(const std::string &url, const std::shared_ptr<std::vector<char>> &output,
const std::map<std::string, std::string> &headers) {
MS_EXCEPTION_IF_NULL(output);
auto handler = std::make_shared<HttpMessageHandler>();
@ -128,7 +128,7 @@ void HttpClient::ReadChunkDataCallback(struct evhttp_request *request, void *arg
MS_EXCEPTION_IF_NULL(evbuf);
int n = 0;
while ((n = evbuffer_remove(evbuf, &buf, sizeof(buf))) > 0) {
handler->ReceiveMessage(buf, n);
handler->ReceiveMessage(buf, IntToSize(n));
}
}
@ -151,7 +151,7 @@ void HttpClient::ConnectionCloseCallback(struct evhttp_connection *connection, v
}
void HttpClient::AddHeaders(const std::map<std::string, std::string> &headers, const struct evhttp_request *request,
std::shared_ptr<HttpMessageHandler> handler) {
const std::shared_ptr<HttpMessageHandler> &handler) {
MS_EXCEPTION_IF_NULL(request);
if (evhttp_add_header(evhttp_request_get_output_headers(const_cast<evhttp_request *>(request)), "Host",
handler->GetHostByUri()) != 0) {
@ -165,7 +165,7 @@ void HttpClient::AddHeaders(const std::map<std::string, std::string> &headers, c
}
}
void HttpClient::InitRequest(std::shared_ptr<HttpMessageHandler> handler, const std::string &url,
void HttpClient::InitRequest(const std::shared_ptr<HttpMessageHandler> &handler, const std::string &url,
const struct evhttp_request *request) {
MS_EXCEPTION_IF_NULL(request);
MS_EXCEPTION_IF_NULL(handler);
@ -179,7 +179,7 @@ void HttpClient::InitRequest(std::shared_ptr<HttpMessageHandler> handler, const
<< ", The port is:" << handler->GetUriPort() << ", The request_url is:" << handler->GetRequestPath();
}
ResponseCode HttpClient::CreateRequest(std::shared_ptr<HttpMessageHandler> handler,
ResponseCode HttpClient::CreateRequest(const std::shared_ptr<HttpMessageHandler> &handler,
struct evhttp_connection *connection, struct evhttp_request *request,
HttpMethod method) {
MS_EXCEPTION_IF_NULL(handler);

View File

@ -41,6 +41,7 @@
#include "ps/core/communicator/http_message_handler.h"
#include "ps/core/comm_util.h"
#include "utils/convert_utils_base.h"
namespace mindspore {
namespace ps {
@ -62,9 +63,10 @@ class HttpClient {
virtual ~HttpClient();
ResponseCode Post(const std::string &url, const void *body, size_t len, std::shared_ptr<std::vector<char>> output,
ResponseCode Post(const std::string &url, const void *body, size_t len,
const std::shared_ptr<std::vector<char>> &output,
const std::map<std::string, std::string> &headers = {});
ResponseCode Get(const std::string &url, std::shared_ptr<std::vector<char>> output,
ResponseCode Get(const std::string &url, const std::shared_ptr<std::vector<char>> &output,
const std::map<std::string, std::string> &headers = {});
void set_connection_timeout(const int &timeout);
@ -77,10 +79,10 @@ class HttpClient {
static void ConnectionCloseCallback(struct evhttp_connection *connection, void *arg);
void AddHeaders(const std::map<std::string, std::string> &headers, const struct evhttp_request *request,
std::shared_ptr<HttpMessageHandler> handler);
void InitRequest(std::shared_ptr<HttpMessageHandler> handler, const std::string &url,
const std::shared_ptr<HttpMessageHandler> &handler);
void InitRequest(const std::shared_ptr<HttpMessageHandler> &handler, const std::string &url,
const struct evhttp_request *request);
ResponseCode CreateRequest(std::shared_ptr<HttpMessageHandler> handler, struct evhttp_connection *connection,
ResponseCode CreateRequest(const std::shared_ptr<HttpMessageHandler> &handler, struct evhttp_connection *connection,
struct evhttp_request *request, HttpMethod method);
bool Start();

View File

@ -196,7 +196,7 @@ std::string HttpMessageHandler::GetUriQuery() const {
return std::string(query);
}
std::string HttpMessageHandler::GetUriFragment() {
std::string HttpMessageHandler::GetUriFragment() const {
MS_EXCEPTION_IF_NULL(event_uri_);
const char *fragment = evhttp_uri_get_fragment(event_uri_);
MS_EXCEPTION_IF_NULL(fragment);

View File

@ -84,7 +84,7 @@ class HttpMessageHandler {
int GetUriPort() const;
// Useless to get from a request url, fragment is only for browser to locate sth.
std::string GetUriFragment();
std::string GetUriFragment() const;
void AddRespHeadParam(const std::string &key, const std::string &val);
void AddRespHeaders(const HttpHeaders &headers);

View File

@ -45,7 +45,7 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port)
is_stop_(true),
is_connected_(false) {
message_handler_.SetCallback(
[this](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
[this](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) {
if (message_callback_) {
message_callback_(meta, protos, data, size);
}
@ -308,7 +308,8 @@ bool TcpClient::SendMessage(const CommMessage &message) const {
return res;
}
bool TcpClient::SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
bool TcpClient::SendMessage(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(buffer_event_);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
@ -357,7 +358,7 @@ void TcpClient::StartTimer(const uint32_t &time) {
void TcpClient::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; }
const event_base &TcpClient::eventbase() { return *event_base_; }
const event_base &TcpClient::eventbase() const { return *event_base_; }
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -48,7 +48,8 @@ class TcpClient {
using OnDisconnected = std::function<void()>;
using OnRead = std::function<void(const void *, size_t)>;
using OnTimeout = std::function<void()>;
using OnMessage = std::function<void(std::shared_ptr<MessageMeta>, const Protos &, const void *, size_t size)>;
using OnMessage =
std::function<void(const std::shared_ptr<MessageMeta> &, const Protos &, const void *, size_t size)>;
using OnTimer = std::function<void()>;
explicit TcpClient(const std::string &address, std::uint16_t port);
@ -66,10 +67,10 @@ class TcpClient {
void StartWithNoBlock();
void SetMessageCallback(const OnMessage &cb);
bool SendMessage(const CommMessage &message) const;
bool SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
bool SendMessage(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size);
void StartTimer(const uint32_t &time);
void set_timer_callback(const OnTimer &timer);
const event_base &eventbase();
const event_base &eventbase() const;
protected:
static void SetTcpNoDelay(const evutil_socket_t &fd);

View File

@ -75,7 +75,8 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
if (remaining_length_ == 0) {
if (message_callback_) {
std::shared_ptr<MessageMeta> pb_message = std::make_shared<MessageMeta>();
pb_message->ParseFromArray(message_buffer_.get(), message_header_.message_meta_length_);
CHECK_RETURN_TYPE(
pb_message->ParseFromArray(message_buffer_.get(), UintToInt(message_header_.message_meta_length_)));
message_callback_(pb_message, message_header_.message_proto_,
message_buffer_.get() + message_header_.message_meta_length_,
message_header_.message_length_ - message_header_.message_meta_length_);

View File

@ -27,11 +27,14 @@
#include "ps/core/communicator/message.h"
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "utils/convert_utils_base.h"
#include "ps/constants.h"
namespace mindspore {
namespace ps {
namespace core {
using messageReceive = std::function<void(std::shared_ptr<MessageMeta>, const Protos &, const void *, size_t size)>;
using messageReceive =
std::function<void(const std::shared_ptr<MessageMeta> &, const Protos &, const void *, size_t size)>;
constexpr int kHeaderLen = 16;
class TcpMessageHandler {
@ -48,7 +51,7 @@ class TcpMessageHandler {
bool is_parsed_;
std::unique_ptr<unsigned char[]> message_buffer_;
size_t remaining_length_;
char header_[16]{0};
unsigned char header_[16]{0};
int header_index_;
size_t last_copy_len_;
MessageHeader message_header_;

View File

@ -48,7 +48,7 @@ const evutil_socket_t &TcpConnection::GetFd() const { return fd_; }
void TcpConnection::set_callback(const Callback &callback) { callback_ = callback; }
bool TcpConnection::SendMessage(std::shared_ptr<CommMessage> message) const {
bool TcpConnection::SendMessage(const std::shared_ptr<CommMessage> &message) const {
MS_EXCEPTION_IF_NULL(buffer_event_);
MS_EXCEPTION_IF_NULL(message);
bufferevent_lock(buffer_event_);
@ -66,7 +66,7 @@ bool TcpConnection::SendMessage(std::shared_ptr<CommMessage> message) const {
return res;
}
bool TcpConnection::SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
bool TcpConnection::SendMessage(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
size_t size) const {
MS_EXCEPTION_IF_NULL(buffer_event_);
MS_EXCEPTION_IF_NULL(meta);
@ -325,12 +325,13 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
MS_EXCEPTION_IF_NULL(conn);
SetTcpNoDelay(fd);
server->AddConnection(fd, conn);
conn->InitConnection([=](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
OnServerReceiveMessage on_server_receive = server->GetServerReceive();
if (on_server_receive) {
on_server_receive(conn, meta, protos, data, size);
}
});
conn->InitConnection(
[=](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) {
OnServerReceiveMessage on_server_receive = server->GetServerReceive();
if (on_server_receive) {
on_server_receive(conn, meta, protos, data, size);
}
});
bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback,
reinterpret_cast<void *>(conn.get()));
if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) {
@ -440,13 +441,13 @@ void TcpServer::SetTcpNoDelay(const evutil_socket_t &fd) {
}
}
bool TcpServer::SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
bool TcpServer::SendMessage(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<CommMessage> &message) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(message);
return conn->SendMessage(message);
}
bool TcpServer::SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
bool TcpServer::SendMessage(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
@ -454,7 +455,7 @@ bool TcpServer::SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr
return conn->SendMessage(meta, protos, data, size);
}
void TcpServer::SendMessage(std::shared_ptr<CommMessage> message) {
void TcpServer::SendMessage(const std::shared_ptr<CommMessage> &message) {
MS_EXCEPTION_IF_NULL(message);
std::lock_guard<std::mutex> lock(connection_mutex_);
@ -467,7 +468,7 @@ uint16_t TcpServer::BoundPort() const { return server_port_; }
std::string TcpServer::BoundIp() const { return server_address_; }
int TcpServer::ConnectionNum() const { return connections_.size(); }
int TcpServer::ConnectionNum() const { return SizeToInt(connections_.size()); }
const std::map<evutil_socket_t, std::shared_ptr<TcpConnection>> &TcpServer::Connections() const { return connections_; }

View File

@ -58,8 +58,8 @@ class TcpConnection {
virtual void InitConnection(const messageReceive &callback);
virtual void SendMessage(const void *buffer, size_t num) const;
bool SendMessage(std::shared_ptr<CommMessage> message) const;
bool SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) const;
bool SendMessage(const std::shared_ptr<CommMessage> &message) const;
bool SendMessage(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) const;
virtual void OnReadHandler(const void *buffer, size_t numBytes);
const TcpServer *GetServer() const;
const evutil_socket_t &GetFd() const;
@ -74,8 +74,8 @@ class TcpConnection {
};
using OnServerReceiveMessage =
std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size)>;
std::function<void(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size)>;
class TcpServer {
public:
@ -105,10 +105,10 @@ class TcpServer {
std::shared_ptr<TcpConnection> GetConnectionByFd(const evutil_socket_t &fd);
OnServerReceiveMessage GetServerReceive() const;
void SetMessageCallback(const OnServerReceiveMessage &cb);
bool SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
bool SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t sizee);
void SendMessage(std::shared_ptr<CommMessage> message);
bool SendMessage(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<CommMessage> &message);
bool SendMessage(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t sizee);
void SendMessage(const std::shared_ptr<CommMessage> &message);
uint16_t BoundPort() const;
std::string BoundIp() const;
int ConnectionNum() const;

View File

@ -51,7 +51,7 @@ bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommM
return Wait(request_id, timeout);
}
uint64_t Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
uint64_t Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(client);
MS_EXCEPTION_IF_NULL(meta);
@ -99,7 +99,7 @@ bool Node::CheckMessageTrack(const uint64_t &request_id) {
return message_tracker_[request_id].first == message_tracker_[request_id].second + 1;
}
void Node::NotifyMessageArrival(std::shared_ptr<MessageMeta> meta) {
void Node::NotifyMessageArrival(const std::shared_ptr<MessageMeta> &meta) {
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
uint64_t request_id = meta->request_id();

View File

@ -79,12 +79,12 @@ class Node {
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
// Send data asynchronously
uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size);
uint64_t AddMessageTrack(const uint32_t &expected_response);
bool CheckMessageTrack(const uint64_t &request_id);
void NotifyMessageArrival(std::shared_ptr<MessageMeta> meta);
void NotifyMessageArrival(const std::shared_ptr<MessageMeta> &meta);
NodeInfo node_info_;
std::atomic<bool> is_ready_;

View File

@ -201,17 +201,19 @@ void NodeManager::AddScaleOutDoneNode(const std::string &node_id) { scale_out_do
void NodeManager::AddScaleInDoneNode(const std::string &node_id) { scale_in_done_nodes_id_.insert(node_id); }
bool NodeManager::IsAllNodesRegistered() {
bool NodeManager::IsAllNodesRegistered() const {
int32_t num = std::count_if(registered_nodes_info_.begin(), registered_nodes_info_.end(),
[](auto item) { return item.second.is_alive == true; });
return num == total_node_num_;
}
bool NodeManager::IsAllNodesFinished() { return SizeToInt(finish_nodes_id_.size()) == total_node_num_; }
bool NodeManager::IsAllNodesFinished() const { return SizeToInt(finish_nodes_id_.size()) == total_node_num_; }
bool NodeManager::IsAllNodesScaleOutDone() { return SizeToInt(scale_out_done_nodes_id_.size()) == total_node_num_; }
bool NodeManager::IsAllNodesScaleOutDone() const {
return SizeToInt(scale_out_done_nodes_id_.size()) == total_node_num_;
}
bool NodeManager::IsAllNodesScaleInDone() { return SizeToInt(scale_in_done_nodes_id_.size()) == total_node_num_; }
bool NodeManager::IsAllNodesScaleInDone() const { return SizeToInt(scale_in_done_nodes_id_.size()) == total_node_num_; }
std::unordered_map<std::string, NodeInfo> &NodeManager::nodes_info() { return nodes_info_; }

View File

@ -75,17 +75,17 @@ class NodeManager {
// When workers and servers registered to scheduler, the scheduler will collect the number of registered
// nodes and Determine whether the registered number of worker and server is equal to total_node_num_.
bool IsAllNodesRegistered();
bool IsAllNodesRegistered() const;
// When workers and servers send a finish message to the scheduler, the scheduler will collect the number of
// finish nodes and Determine whether the finished nodes are equal to total_node_num_.
bool IsAllNodesFinished();
bool IsAllNodesFinished() const;
// When workers and servers send a scale_out_done message to the scheduler, the scheduler will collect the number of
// nodes and Determine whether the nodes are equal to total_node_num_.
bool IsAllNodesScaleOutDone();
bool IsAllNodesScaleOutDone() const;
// When workers and servers send a scale_in_done message to the scheduler, the scheduler will collect the number of
// nodes and Determine whether the nodes are equal to total_node_num_.
bool IsAllNodesScaleInDone();
bool IsAllNodesScaleInDone() const;
std::unordered_map<std::string, NodeInfo> &nodes_info();
std::unordered_map<std::string, NodeInfo> &registered_nodes_info();

View File

@ -43,9 +43,9 @@ message ParamInitInfoMessage {
}
message KVMessage {
repeated int32 keys = 2;
repeated uint64 keys = 2;
repeated float values = 3;
repeated int32 len = 4;
repeated uint64 len = 4;
}
message EmbeddingTableMeta {

View File

@ -46,14 +46,15 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
return true;
}
void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
void SchedulerNode::ProcessHeartbeat(const std::shared_ptr<TcpServer> &server,
const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(server);
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
HeartbeatMessage heartbeat_message;
heartbeat_message.ParseFromArray(data, size);
CHECK_RETURN_TYPE(heartbeat_message.ParseFromArray(data, SizeToInt(size)));
node_manager_.UpdateHeartbeat(heartbeat_message.node_id());
@ -101,7 +102,7 @@ void SchedulerNode::CreateTcpServer() {
std::string scheduler_host = PSContext::instance()->cluster_config().scheduler_host;
uint32_t scheduler_port = PSContext::instance()->cluster_config().scheduler_port;
server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port);
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
server_->SetMessageCallback([&](const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &, const void *data, size_t size) {
if (handlers_.count(meta->cmd()) == 0) {
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
@ -112,23 +113,22 @@ void SchedulerNode::CreateTcpServer() {
server_->Init();
scheduler_thread_ = std::make_unique<std::thread>([&]() {
scheduler_thread_ = std::make_unique<std::thread>([this]() {
MS_LOG(INFO) << "The scheduler node start a tcp server!";
server_->Start();
this->server_->Start();
});
}
void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
void SchedulerNode::ProcessRegister(const std::shared_ptr<TcpServer> &server,
const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(server);
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
MS_LOG(INFO) << "The scheduler process a register message!";
RegisterMessage register_message;
if (!register_message.ParseFromArray(data, SizeToInt(size))) {
MS_LOG(WARNING) << "Parse data failed.";
}
CHECK_RETURN_TYPE(register_message.ParseFromArray(data, SizeToInt(size)));
// assign worker node and server node rank id
uint32_t rank_id = node_manager_.NextRankId(register_message, meta);
@ -167,8 +167,8 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
}
}
void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
void SchedulerNode::ProcessFinish(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(server);
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
@ -202,8 +202,9 @@ void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared
}
}
void SchedulerNode::ProcessFetchMetadata(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
void SchedulerNode::ProcessFetchMetadata(const std::shared_ptr<TcpServer> &server,
const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(server);
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
@ -217,8 +218,9 @@ void SchedulerNode::ProcessFetchMetadata(std::shared_ptr<TcpServer> server, std:
fetch_servers_message.ByteSizeLong());
}
void SchedulerNode::ProcessScaleOutDone(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
void SchedulerNode::ProcessScaleOutDone(const std::shared_ptr<TcpServer> &server,
const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(server);
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
@ -242,8 +244,9 @@ void SchedulerNode::ProcessScaleOutDone(std::shared_ptr<TcpServer> server, std::
}
}
void SchedulerNode::ProcessScaleInDone(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
void SchedulerNode::ProcessScaleInDone(const std::shared_ptr<TcpServer> &server,
const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(server);
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
@ -267,8 +270,9 @@ void SchedulerNode::ProcessScaleInDone(std::shared_ptr<TcpServer> server, std::s
}
}
void SchedulerNode::ProcessSendEvent(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
void SchedulerNode::ProcessSendEvent(const std::shared_ptr<TcpServer> &server,
const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(server);
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);

View File

@ -58,8 +58,10 @@ class SchedulerNode : public Node {
scheduler_recovery_(nullptr) {}
~SchedulerNode() override;
typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
typedef void (SchedulerNode::*ResponseHandler)(const std::shared_ptr<TcpServer> &server,
const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data,
size_t size);
bool Start(const uint32_t &timeout = PSContext::instance()->cluster_config().cluster_available_timeout) override;
bool Stop() override;
@ -73,24 +75,24 @@ class SchedulerNode : public Node {
void StartUpdateClusterStateTimer();
const std::shared_ptr<TcpClient> &GetOrCreateClient(const NodeInfo &node_info);
void ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessFetchMetadata(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessHeartbeat(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
void ProcessRegister(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
void ProcessFinish(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
void ProcessFetchMetadata(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
// Process scale_out_done messages from workers/servers
void ProcessScaleOutDone(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessScaleOutDone(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
// Process scale_in_done messages from workers/servers
void ProcessScaleInDone(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessScaleInDone(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
// Process scale_in_done messages from workers/servers
void ProcessSendEvent(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessSendEvent(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
// After scheduler collects all registered message, it actively sends finish to the node connected by the client.
void SendMetadata(const std::shared_ptr<TcpClient> &client, uint32_t rank_id);

View File

@ -42,8 +42,8 @@ bool ServerNode::Start(const uint32_t &timeout) {
void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; }
void ServerNode::Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const void *data,
size_t size) {
void ServerNode::Response(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
@ -59,7 +59,7 @@ void ServerNode::CreateTcpServer() {
std::string server_ip;
CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
server_ = std::make_shared<TcpServer>(server_ip, 0);
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
server_->SetMessageCallback([&](const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size) {
if (server_handler_.count(meta->cmd()) == 0) {
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
@ -107,7 +107,7 @@ void ServerNode::Initialize() {
MS_LOG(INFO) << "[Server start]: 3. Server node crete tcp client to scheduler successful!";
}
void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
void ServerNode::ProcessSendData(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
@ -128,8 +128,8 @@ void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::share
request_handler_(conn, meta, res, size);
}
void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const void *data, size_t size) {
void ServerNode::ProcessCollectiveSendData(const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
server_->SendMessage(conn, meta, Protos::RAW, data, size);

View File

@ -51,11 +51,13 @@ class ServerNode : public AbstractNode {
bool Stop() override;
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
using RequestHandler = std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
DataPtr data, size_t size)>;
using RequestHandler =
std::function<void(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const DataPtr &data, size_t size)>;
void set_handler(const RequestHandler &handler);
void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void Response(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, const void *data,
size_t size);
std::shared_ptr<CommunicatorBase> GetOrCreateHttpComm(const std::string &ip, uint16_t port,
const std::shared_ptr<TaskExecutor> &task_executor);
@ -66,9 +68,9 @@ class ServerNode : public AbstractNode {
private:
void CreateTcpServer();
void Initialize();
void ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size);
void ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
void ProcessSendData(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size);
void ProcessCollectiveSendData(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const void *data, size_t size);
RequestHandler request_handler_;

View File

@ -237,10 +237,11 @@ void ParameterServer::UpdateWeights() {
shapes.push_back(indices_shape);
if (original_optim_inputs_shape_.count(key) != 0) {
std::transform(
(*(original_optim_inputs_shape_[key])).begin(), (*(original_optim_inputs_shape_[key])).end(),
std::back_inserter(shapes),
[](std::shared_ptr<std::vector<size_t>> input_shapes) -> std::vector<size_t> { return *input_shapes; });
std::transform((*(original_optim_inputs_shape_[key])).begin(), (*(original_optim_inputs_shape_[key])).end(),
std::back_inserter(shapes),
[](const std::shared_ptr<std::vector<size_t>> &input_shapes) -> std::vector<size_t> {
return *input_shapes;
});
}
optimizer->ReInit(shapes);
optim_info->ComputeMean(shapes, worker_num_, pserver_num_, server_node_->rank_id());
@ -377,7 +378,7 @@ void ParameterServer::UpdateEmbeddings(const Key &key, const LookupIds &lookup_i
table_lookup_op->UpdateEmbeddings(table_ptr->data(), lookup_ids.data(), vals.data(), lookup_ids.size());
}
inline bool ParameterServer::ReadyForUpdateWeights() {
inline bool ParameterServer::ReadyForUpdateWeights() const {
return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size();
}
@ -387,9 +388,7 @@ inline bool ParameterServer::ReadyForPush(const Key &key) {
MS_LOG(EXCEPTION) << "The weights in server is empty. Many reasons could cause this: 1.The Worker didn't send "
"kInitWeightsCmd command. 2.The Server failed to initialize weights.";
}
MS_LOG(INFO) << "The grad_accum_count_:" << grad_accum_count_ << " the weights_:" << weights_.size()
<< " the token:" << (tokens_[key] <= 0);
return grad_accum_count_ < weights_.size() && tokens_[key] <= 0;
return grad_accum_count_ < weights_.size() && tokens_[key] == 0;
}
inline bool ParameterServer::ReadyForPull(const Key &key) {
@ -504,8 +503,9 @@ void ParameterServer::ServerHandler::Init() {
commands_[kPullCmd] = "kPullCmd";
}
void ParameterServer::ServerHandler::operator()(std::shared_ptr<core::TcpConnection> conn,
std::shared_ptr<core::MessageMeta> meta, DataPtr data, size_t size) {
void ParameterServer::ServerHandler::operator()(const std::shared_ptr<core::TcpConnection> &conn,
const std::shared_ptr<core::MessageMeta> &meta, const DataPtr &data,
size_t size) {
auto output = std::make_shared<std::vector<unsigned char>>();
if (commands_.count(meta->user_cmd()) == 0) {
MS_LOG(EXCEPTION) << "The command:" << meta->user_cmd() << " is not supported!";
@ -530,10 +530,10 @@ void ParameterServer::ServerHandler::operator()(std::shared_ptr<core::TcpConnect
.count();
}
void ParameterServer::ServerHandler::HandlePushReq(DataPtr data, size_t size, VectorPtr res) {
void ParameterServer::ServerHandler::HandlePushReq(const DataPtr &data, size_t size, const VectorPtr &res) {
MS_EXCEPTION_IF_NULL(res);
KVMessage input;
input.ParseFromArray(data.get(), size);
CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
Keys keys = {input.keys().begin(), input.keys().end()};
Values values = {input.values().begin(), input.values().end()};
Lengths lens = {input.len().begin(), input.len().end()};
@ -541,10 +541,10 @@ void ParameterServer::ServerHandler::HandlePushReq(DataPtr data, size_t size, Ve
ps_->AccumGrad(keys, values, lens);
}
void ParameterServer::ServerHandler::HandlePullReq(DataPtr data, size_t size, VectorPtr res) {
void ParameterServer::ServerHandler::HandlePullReq(const DataPtr &data, size_t size, const VectorPtr &res) {
MS_EXCEPTION_IF_NULL(res);
KVMessage input;
input.ParseFromArray(data.get(), size);
CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
KVMessage res_data;
*res_data.mutable_keys() = input.keys();
Key key = input.keys()[0];
@ -559,13 +559,11 @@ void ParameterServer::ServerHandler::HandlePullReq(DataPtr data, size_t size, Ve
}
}
void ParameterServer::ServerHandler::HandleInitWeights(DataPtr data, size_t size, VectorPtr res) {
void ParameterServer::ServerHandler::HandleInitWeights(const DataPtr &data, size_t size, const VectorPtr &res) {
std::unique_lock<std::mutex> lock(ps_->mutex());
MS_EXCEPTION_IF_NULL(res);
KVMessage input;
if (!input.ParseFromArray(data.get(), SizeToInt(size))) {
MS_LOG(WARNING) << "Parse data failed.";
}
CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
int key_num = input.keys_size();
const float *data_ptr = input.values().data();
size_t pos = 0;
@ -586,13 +584,11 @@ void ParameterServer::ServerHandler::HandleInitWeights(DataPtr data, size_t size
}
}
void ParameterServer::ServerHandler::HandleInitWeightToOptimId(DataPtr data, size_t size, VectorPtr res) {
void ParameterServer::ServerHandler::HandleInitWeightToOptimId(const DataPtr &data, size_t size, const VectorPtr &res) {
std::unique_lock<std::mutex> lock(ps_->mutex());
MS_EXCEPTION_IF_NULL(res);
KVMessage input;
if (!input.ParseFromArray(data.get(), SizeToInt(size))) {
MS_LOG(WARNING) << "Parse data failed.";
}
CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
int key_num = input.keys_size();
for (int i = 0; i < key_num; i++) {
Key key = input.keys()[i];
@ -606,11 +602,11 @@ void ParameterServer::ServerHandler::HandleInitWeightToOptimId(DataPtr data, siz
}
}
void ParameterServer::ServerHandler::HandleInitInputsShape(DataPtr data, size_t size, VectorPtr res) {
void ParameterServer::ServerHandler::HandleInitInputsShape(const DataPtr &data, size_t size, const VectorPtr &res) {
std::unique_lock<std::mutex> lock(ps_->mutex());
MS_EXCEPTION_IF_NULL(res);
KVMessage input;
input.ParseFromArray(data.get(), size);
CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
const Key &key = input.keys()[0];
if (init_optim_info_[key]) {
return;
@ -623,10 +619,10 @@ void ParameterServer::ServerHandler::HandleInitInputsShape(DataPtr data, size_t
ps_->InitOptimInputsShape(keys, values, lens);
}
void ParameterServer::ServerHandler::HandleInitEmbeddings(DataPtr data, size_t size, VectorPtr res) {
void ParameterServer::ServerHandler::HandleInitEmbeddings(const DataPtr &data, size_t size, const VectorPtr &) {
std::unique_lock<std::mutex> lock(ps_->mutex());
EmbeddingTableMeta embedding_table_meta;
embedding_table_meta.ParseFromArray(data.get(), size);
CHECK_RETURN_TYPE(embedding_table_meta.ParseFromArray(data.get(), SizeToInt(size)));
const Key &key = embedding_table_meta.key();
MS_LOG(INFO) << "Initializing embedding table for key:" << key;
std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> shapes =
@ -659,10 +655,10 @@ void ParameterServer::ServerHandler::HandleInitEmbeddings(DataPtr data, size_t s
ps_->InitEmbeddingTable(key, shapes, param_init_info);
}
void ParameterServer::ServerHandler::HandleCheckReadyForPush(DataPtr data, size_t size, VectorPtr res) {
void ParameterServer::ServerHandler::HandleCheckReadyForPush(const DataPtr &data, size_t size, const VectorPtr &res) {
MS_EXCEPTION_IF_NULL(res);
KVMessage input;
input.ParseFromArray(data.get(), size);
CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
const Key &key = input.keys()[0];
bool ready = ps_->ReadyForPush(key);
MS_LOG(INFO) << "The ready is:" << ready;
@ -678,10 +674,10 @@ void ParameterServer::ServerHandler::HandleCheckReadyForPush(DataPtr data, size_
}
}
void ParameterServer::ServerHandler::HandleCheckReadyForPull(DataPtr data, size_t size, VectorPtr res) {
void ParameterServer::ServerHandler::HandleCheckReadyForPull(const DataPtr &data, size_t size, const VectorPtr &res) {
MS_EXCEPTION_IF_NULL(res);
KVMessage input;
input.ParseFromArray(data.get(), size);
CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
const Key &key = input.keys()[0];
bool ready = ps_->ReadyForPull(key);
KVMessage res_data;
@ -696,10 +692,10 @@ void ParameterServer::ServerHandler::HandleCheckReadyForPull(DataPtr data, size_
}
}
void ParameterServer::ServerHandler::HandleEmbeddingLookup(DataPtr data, size_t size, VectorPtr res) {
void ParameterServer::ServerHandler::HandleEmbeddingLookup(const DataPtr &data, size_t size, const VectorPtr &res) {
MS_EXCEPTION_IF_NULL(res);
EmbeddingTableLookup input;
input.ParseFromArray(data.get(), size);
CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
const Key &key = input.key();
KVMessage res_data;
@ -717,18 +713,18 @@ void ParameterServer::ServerHandler::HandleEmbeddingLookup(DataPtr data, size_t
}
}
void ParameterServer::ServerHandler::HandleUpdateEmbeddings(DataPtr data, size_t size, VectorPtr res) {
void ParameterServer::ServerHandler::HandleUpdateEmbeddings(const DataPtr &data, size_t size, const VectorPtr &res) {
std::unique_lock<std::mutex> lock(ps_->mutex());
MS_EXCEPTION_IF_NULL(res);
KVMessage input;
input.ParseFromArray(data.get(), size);
CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
const Key &key = input.keys()[0];
const LookupIds &lookup_ids = {input.keys().begin() + 1, input.keys().end()};
const Values &update_vals = {input.values().begin(), input.values().end()};
ps_->UpdateEmbeddings(key, lookup_ids, update_vals);
}
void ParameterServer::ServerHandler::HandleFinalize(DataPtr, size_t, VectorPtr res) {
void ParameterServer::ServerHandler::HandleFinalize(const DataPtr &, size_t, const VectorPtr &res) {
MS_EXCEPTION_IF_NULL(res);
ps_->Finalize();
}

View File

@ -32,6 +32,8 @@
#include <list>
#include <map>
#include <functional>
#include <algorithm>
#include "ir/func_graph.h"
#include "backend/session/session_basic.h"
#include "backend/session/anf_runtime_algorithm.h"
@ -92,23 +94,23 @@ class ParameterServer {
explicit ServerHandler(ParameterServer *ps) : ps_(ps) {}
~ServerHandler() = default;
void Init();
void operator()(std::shared_ptr<core::TcpConnection> conn, std::shared_ptr<core::MessageMeta> meta, DataPtr data,
size_t size);
void HandlePushReq(DataPtr data, size_t size, VectorPtr res);
void HandlePullReq(DataPtr data, size_t size, VectorPtr res);
void HandleInitWeights(DataPtr data, size_t size, VectorPtr res);
void HandleInitWeightToOptimId(DataPtr data, size_t size, VectorPtr res);
void HandleInitInputsShape(DataPtr data, size_t size, VectorPtr res);
void HandleInitEmbeddings(DataPtr data, size_t size, VectorPtr res);
void HandleCheckReadyForPush(DataPtr data, size_t size, VectorPtr res);
void HandleCheckReadyForPull(DataPtr data, size_t size, VectorPtr res);
void HandleEmbeddingLookup(DataPtr data, size_t size, VectorPtr res);
void HandleUpdateEmbeddings(DataPtr data, size_t size, VectorPtr res);
void HandleFinalize(DataPtr data, size_t size, VectorPtr res);
void operator()(const std::shared_ptr<core::TcpConnection> &conn, const std::shared_ptr<core::MessageMeta> &meta,
const DataPtr &data, size_t size);
void HandlePushReq(const DataPtr &data, size_t size, const VectorPtr &res);
void HandlePullReq(const DataPtr &data, size_t size, const VectorPtr &res);
void HandleInitWeights(const DataPtr &data, size_t size, const VectorPtr &res);
void HandleInitWeightToOptimId(const DataPtr &data, size_t size, const VectorPtr &res);
void HandleInitInputsShape(const DataPtr &data, size_t size, const VectorPtr &res);
void HandleInitEmbeddings(const DataPtr &data, size_t size, const VectorPtr &res);
void HandleCheckReadyForPush(const DataPtr &data, size_t size, const VectorPtr &res);
void HandleCheckReadyForPull(const DataPtr &data, size_t size, const VectorPtr &res);
void HandleEmbeddingLookup(const DataPtr &data, size_t size, const VectorPtr &res);
void HandleUpdateEmbeddings(const DataPtr &data, size_t size, const VectorPtr &res);
void HandleFinalize(const DataPtr &data, size_t size, const VectorPtr &res);
private:
ParameterServer *ps_;
typedef void (ServerHandler::*RequestHandler)(DataPtr data, size_t size, VectorPtr res);
typedef void (ServerHandler::*RequestHandler)(const DataPtr &data, size_t size, const VectorPtr &res);
std::unordered_map<int, RequestHandler> handlers_;
std::unordered_map<int, std::string> commands_;
std::unordered_map<Key, bool> init_weights_;
@ -132,12 +134,12 @@ class ParameterServer {
WeightPtr weight(const Key &key);
void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, KVMessage *res);
void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals);
bool ReadyForUpdateWeights();
bool ReadyForPush(const Key &key);
bool ReadyForPull(const Key &key);
void ResetGradAccumCount();
inline bool ReadyForUpdateWeights() const;
inline bool ReadyForPush(const Key &key);
inline bool ReadyForPull(const Key &key);
inline void ResetGradAccumCount();
const CNodePtr GetCNode(const std::string &name) const;
std::mutex &mutex();
inline std::mutex &mutex();
void GetEmbeddingTableParamPtr();
void SyncEmbeddingTables();

View File

@ -49,7 +49,7 @@ bool Util::IsRoleOfPServer() { return PSContext::instance()->is_server(); }
bool Util::IsRoleOfScheduler() { return PSContext::instance()->is_scheduler(); }
int64_t Util::optimizer_id(std::string name) {
int64_t Util::optimizer_id(const std::string &name) {
if (optimizer_to_ids.count(name) > 0) {
return optimizer_to_ids[name];
}
@ -70,7 +70,7 @@ std::string Util::optimizer_node_name(int64_t id) {
return "";
}
bool Util::is_optimizer(std::string name) { return optimizer_to_ids.count(name) > 0; }
bool Util::is_optimizer(const std::string &name) { return optimizer_to_ids.count(name) > 0; }
int64_t Util::LocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num) {
std::map<int64_t, int64_t> shard_dims = AllRankLocalShard(first_dim, rank_id, server_num);

View File

@ -44,10 +44,10 @@ class Util {
public:
static bool IsRoleOfPServer();
static bool IsRoleOfScheduler();
static int64_t optimizer_id(std::string name);
static int64_t optimizer_id(const std::string &name);
static std::string optimizer_name(int64_t id);
static std::string optimizer_node_name(int64_t id);
static bool is_optimizer(std::string name);
static bool is_optimizer(const std::string &name);
static int64_t LocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num);
static std::map<int64_t, int64_t> AllRankLocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num);
static void ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size,

View File

@ -307,7 +307,9 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_
}
std::vector<VectorPtr> resp;
worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, cmd, &resp);
if (!worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, cmd, &resp)) {
MS_LOG(ERROR) << "Worker send failed!";
}
int64_t single_id_len = SizeToLong(lookup_result->size() / lookup_ids.size());
std::unordered_map<Key, std::shared_ptr<std::pair<float *, int64_t>>> id_addr_map;
std::shared_ptr<std::vector<float>> values = std::make_shared<std::vector<float>>();
@ -315,7 +317,7 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_
int64_t value_offset = 0;
for (size_t i = 0; i < resp.size(); ++i) {
KVMessage message;
message.ParseFromArray(resp.at(i)->data(), resp.at(i)->size());
CHECK_RETURN_TYPE(message.ParseFromArray(resp.at(i)->data(), resp.at(i)->size()));
for (auto j = 0; j < message.values_size(); j++) {
values->push_back(message.values(j));
}
@ -630,7 +632,7 @@ void Worker::BuildSparseValue(const std::vector<int> &lengths, const size_t grad
}
void Worker::PushData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
int cmd, int64_t priority) {
int cmd, int64_t) {
KVMessage kvs;
*kvs.mutable_keys() = {keys.begin(), keys.end()};
*kvs.mutable_values() = {vals.begin(), vals.end()};
@ -682,7 +684,7 @@ void Worker::PullData(const std::vector<Key> &keys, std::vector<float> *const va
}
void Worker::LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition,
const std::map<int64_t, int64_t> &attrs) {
const std::map<int64_t, int64_t> &) {
MS_EXCEPTION_IF_NULL(partition);
const Key &key = send.key();
@ -829,7 +831,7 @@ void Worker::SparsePartitioner(const KVMessage &send, PartitionKVMessages *parti
}
void Worker::RoundRobinPartitioner(const KVMessage &send, PartitionKVMessages *partition,
const std::map<int64_t, int64_t> &attrs) {
const std::map<int64_t, int64_t> &) {
MS_EXCEPTION_IF_NULL(partition);
partition->resize(server_num_);
auto keys = send.keys();
@ -888,7 +890,7 @@ void Worker::UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessag
const std::map<int64_t, int64_t> &attrs) {
MS_EXCEPTION_IF_NULL(partition);
const float *embedding_vals = send.values().data();
const int *lookup_ids = send.len().data();
const uint64_t *lookup_ids = send.len().data();
size_t val_size = send.values_size();
size_t id_size = send.len_size();
size_t embedding_dim = val_size / id_size;
@ -904,7 +906,7 @@ void Worker::UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessag
auto &kvs = partition->at(i).second;
kvs.add_keys(key);
for (size_t j = 0; j < id_size; j++) {
auto lookup_id = static_cast<uint64_t>(lookup_ids[j]);
auto lookup_id = lookup_ids[j];
if (lookup_id >= begin && lookup_id <= end) {
kvs.add_keys(lookup_id);
for (size_t k = 0; k < embedding_dim; k++) {
@ -922,7 +924,7 @@ void Worker::UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessag
}
void Worker::BroadcastPartitioner(const KVMessage &send, PartitionKVMessages *partition,
const std::map<int64_t, int64_t> &attrs) {
const std::map<int64_t, int64_t> &) {
MS_EXCEPTION_IF_NULL(partition);
partition->resize(server_num_);
for (int64_t i = 0; i < server_num_; i++) {
@ -958,7 +960,7 @@ void Worker::SendForPush(int cmd, const KVMessage &send, const KVPartitioner &pa
}
void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
const std::map<int64_t, int64_t> &attrs, std::vector<float> *vals, std::vector<int> *lens) {
const std::map<int64_t, int64_t> &, std::vector<float> *vals, std::vector<int> *lens) {
MS_EXCEPTION_IF_NULL(vals);
PartitionKVMessages messages;
partitioner(send, &messages, {});
@ -986,7 +988,7 @@ void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &pa
vals->clear();
for (size_t i = 0; i < resp.size(); ++i) {
KVMessage message;
message.ParseFromArray(resp.at(i)->data(), resp.at(i)->size());
CHECK_RETURN_TYPE(message.ParseFromArray(resp.at(i)->data(), SizeToInt(resp.at(i)->size())));
std::copy(message.values().begin(), message.values().end(), std::back_inserter(*vals));
if (lens) {