diff --git a/mindspore/ccsrc/ps/core/abstract_node.cc b/mindspore/ccsrc/ps/core/abstract_node.cc index 293a1fe1490..daac34850e0 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.cc +++ b/mindspore/ccsrc/ps/core/abstract_node.cc @@ -32,6 +32,7 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) { CommMessage comm_message; *comm_message.mutable_pb_meta() = {message_meta}; comm_message.set_data(register_message.SerializeAsString()); + comm_message.set_user_cmd(""); if (!SendMessageSync(client, comm_message)) { MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) << " the node id:" << node_info_.node_id_ << " register timeout!"; @@ -54,11 +55,12 @@ void AbstractNode::ProcessRegisterResp(const CommMessage &message) { MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_; } -bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string &message, const uint32_t &timeout) { +bool AbstractNode::Broadcast(const enum NodeRole &node_role, const CommMessage &message, const uint32_t &timeout) { if (node_role != NodeRole::SERVER) { MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes"; } + CommMessage &comm_message = const_cast<CommMessage &>(message); uint64_t request_id = ++next_request_id_; message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0); @@ -69,9 +71,7 @@ bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string & message_meta.set_rank_id(node_info_.rank_id_); message_meta.set_role(node_info_.node_role_); - CommMessage comm_message; *comm_message.mutable_pb_meta() = {message_meta}; - comm_message.set_data(message); auto client = GetOrCreateTcpClient((*it).first.second); client->SendMessage(comm_message); } @@ -84,26 +84,26 @@ void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_me on_node_event_message_ = on_node_event_message; } -bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, +bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, const uint32_t &timeout) { if (!CommUtil::ValidateRankId(node_role, rank_id)) { MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; } + CommMessage &comm_message = const_cast<CommMessage &>(message); + MessageMeta message_meta; message_meta.set_cmd(NodeCommand::SEND_DATA); message_meta.set_rank_id(node_info_.rank_id_); message_meta.set_role(node_info_.node_role_); - CommMessage comm_message; *comm_message.mutable_pb_meta() = {message_meta}; - comm_message.set_data(message); auto client = GetOrCreateTcpClient(rank_id); return SendMessageSync(client, comm_message, timeout); } bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, - const std::vector<std::string> &data, const uint32_t &timeout) { + const std::vector<CommMessage> &data, const uint32_t &timeout) { uint64_t request_id = ++next_request_id_; message_tracker_[request_id] = std::make_pair(data.size(), 0); @@ -121,9 +121,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & message_meta.set_rank_id(node_info_.rank_id_); message_meta.set_role(node_info_.node_role_); - CommMessage comm_message; + CommMessage &comm_message = const_cast<CommMessage &>(data.at(it)); *comm_message.mutable_pb_meta() = {message_meta}; - comm_message.set_data(data.at(it)); auto client = GetOrCreateTcpClient(rank_ids.at(it)); client->SendMessage(comm_message); @@ -133,19 +132,21 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & return Wait(request_id, timeout); } -bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, - std::string *output, const uint32_t &timeout) { +bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, + CommMessage *output, const uint32_t &timeout) { MS_EXCEPTION_IF_NULL(output); if (!CommUtil::ValidateRankId(node_role, rank_id)) { MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; } + CommMessage &comm_message = const_cast<CommMessage &>(message); + uint64_t request_id = ++next_request_id_; message_tracker_[request_id] = std::make_pair(1, 0); set_message_callback(request_id, [&]() { receive_messages_mutex_.lock(); auto res = receive_messages_[request_id]; - *output = res[rank_id].data(); + *output = res[rank_id]; receive_messages_.erase(request_id); receive_messages_mutex_.unlock(); }); @@ -156,9 +157,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, message_meta.set_rank_id(node_info_.rank_id_); message_meta.set_role(node_info_.node_role_); - CommMessage comm_message; *comm_message.mutable_pb_meta() = {message_meta}; - comm_message.set_data(message); auto client = GetOrCreateTcpClient(rank_id); client->SendMessage(comm_message); MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) @@ -167,7 +166,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, } bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, - const std::vector<std::string> &data, std::vector<std::string> *output, + const std::vector<CommMessage> &data, std::vector<CommMessage> *output, const uint32_t &timeout) { MS_EXCEPTION_IF_NULL(output); uint64_t request_id = ++next_request_id_; @@ -183,7 +182,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & receive_messages_mutex_.lock(); auto res = receive_messages_[request_id]; for (size_t it = 0; it < len; ++it) { - (*output).push_back(res[rank_ids.at(it)].data()); + (*output).push_back(res[rank_ids.at(it)]); } receive_messages_.erase(request_id); receive_messages_mutex_.unlock(); @@ -200,9 +199,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & message_meta.set_rank_id(node_info_.rank_id_); message_meta.set_role(node_info_.node_role_); - CommMessage comm_message; + CommMessage &comm_message = const_cast<CommMessage &>(data.at(it)); *comm_message.mutable_pb_meta() = {message_meta}; - comm_message.set_data(data.at(it)); auto client = GetOrCreateTcpClient(rank_ids.at(it)); client->SendMessage(comm_message); @@ -223,37 +221,37 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) { } uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, - const std::string &message) { + const CommMessage &message) { if (!CommUtil::ValidateRankId(node_role, rank_id)) { MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; } + CommMessage &comm_message = const_cast<CommMessage &>(message); + MessageMeta message_meta; message_meta.set_cmd(NodeCommand::COLLECTIVE_SEND_DATA); message_meta.set_rank_id(node_info_.rank_id_); message_meta.set_role(node_info_.node_role_); - CommMessage comm_message; *comm_message.mutable_pb_meta() = {message_meta}; - comm_message.set_data(message); auto client = GetOrCreateTcpClient(rank_id); return SendMessageAsync(client, comm_message); } std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role, - const uint32_t &rank_id, std::string *output) { + const uint32_t &rank_id, CommMessage *output) { if (!CommUtil::ValidateRankId(node_role, rank_id)) { MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; } uint64_t rank_request_id = NextExpectedRankRequestId(rank_id); if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) { - *output = received_data_[std::make_pair(rank_id, rank_request_id)].data(); + *output = received_data_[std::make_pair(rank_id, rank_request_id)]; received_data_.erase(std::make_pair(rank_id, rank_request_id)); } else { set_receive_callback(rank_id, rank_request_id, [=]() { receive_callbacks_mutex_.lock(); - *output = received_data_[std::make_pair(rank_id, 1)].data(); + *output = received_data_[std::make_pair(rank_id, rank_request_id)]; received_data_.erase(std::make_pair(rank_id, rank_request_id)); receive_callbacks_mutex_.unlock(); }); @@ -415,21 +413,12 @@ bool AbstractNode::InitClientToScheduler() { uint16_t scheduler_port = ClusterConfig::scheduler_port(); client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port); client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { - switch (message.pb_meta().cmd()) { - case NodeCommand::HEARTBEAT: - ProcessHeartbeatResp(message); - break; - case NodeCommand::REGISTER: - ProcessRegisterResp(message); - break; - case NodeCommand::FETCH_SERVER: - ProcessFetchServersResp(message); - break; - case NodeCommand::FINISH: - MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a finish message response!"; - break; - default: - MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; + if (handlers_.count(message.pb_meta().cmd()) == 0) { + MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; + } + if (handlers_[message.pb_meta().cmd()] != nullptr) { + const auto &handler_ptr = handlers_[message.pb_meta().cmd()]; + (this->*handler_ptr)(message); } NotifyMessageArrival(message); }); @@ -607,6 +596,13 @@ uint64_t AbstractNode::NextActualRankRequestId(const uint32_t &rank_id) { } return rank_request_id; } + +void AbstractNode::InitCommandHandler() { + handlers_[NodeCommand::HEARTBEAT] = &AbstractNode::ProcessHeartbeatResp; + handlers_[NodeCommand::REGISTER] = &AbstractNode::ProcessRegisterResp; + handlers_[NodeCommand::FETCH_SERVER] = &AbstractNode::ProcessFetchServersResp; + handlers_[NodeCommand::FINISH] = nullptr; +} } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/abstract_node.h b/mindspore/ccsrc/ps/core/abstract_node.h index eea8eb773da..448e36489ac 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.h +++ b/mindspore/ccsrc/ps/core/abstract_node.h @@ -34,23 +34,25 @@ class AbstractNode : public Node { AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {} ~AbstractNode() override = default; - bool Broadcast(const enum NodeRole &node_role, const std::string &message, + typedef void (AbstractNode::*ResponseHandler)(const CommMessage &message); + + bool Broadcast(const enum NodeRole &node_role, const CommMessage &message, const uint32_t &timeout = kCommTimeoutInSeconds); void set_event_callback(const OnNodeEventMessage &on_node_event_message); - bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, + bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, const uint32_t &timeout = kCommTimeoutInSeconds); - bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, + bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<CommMessage> &data, const uint32_t &timeout = kCommTimeoutInSeconds); - bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, std::string *output, + bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, CommMessage *output, const uint32_t &timeout = kCommTimeoutInSeconds); - bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data, - std::vector<std::string> *output, const uint32_t &timeout = kCommTimeoutInSeconds); + bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<CommMessage> &data, + std::vector<CommMessage> *output, const uint32_t &timeout = kCommTimeoutInSeconds); bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); - uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message); + uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message); std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id, - std::string *output); + CommMessage *output); bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds); protected: @@ -78,6 +80,7 @@ class AbstractNode : public Node { void RunReceiveCallback(const CommMessage &message); uint64_t NextExpectedRankRequestId(const uint32_t &rank_id); uint64_t NextActualRankRequestId(const uint32_t &rank_id); + void InitCommandHandler(); std::unique_ptr<std::thread> heart_beat_thread_; std::unique_ptr<std::thread> client_to_scheduler_thread_; @@ -115,6 +118,7 @@ class AbstractNode : public Node { std::unordered_map<uint32_t, uint64_t> actual_rank_request_ids_; std::mutex rank_request_ids_mutex; timeval scheduler_time_; + std::unordered_map<NodeCommand, ResponseHandler> handlers_; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/protos/comm.proto b/mindspore/ccsrc/ps/core/protos/comm.proto index 4e24de8c580..81d10137120 100644 --- a/mindspore/ccsrc/ps/core/protos/comm.proto +++ b/mindspore/ccsrc/ps/core/protos/comm.proto @@ -95,5 +95,6 @@ message FinishMessage { message CommMessage { MessageMeta pb_meta = 1; bytes data = 2; + // User-defined commands + bytes user_cmd = 3; } - diff --git a/mindspore/ccsrc/ps/core/protos/ps.proto b/mindspore/ccsrc/ps/core/protos/ps.proto index 9ae31a94c13..7f293663a12 100644 --- a/mindspore/ccsrc/ps/core/protos/ps.proto +++ b/mindspore/ccsrc/ps/core/protos/ps.proto @@ -14,17 +14,42 @@ * limitations under the License. */ syntax = "proto3"; -package mindspore.ps.core; +package mindspore.ps; option optimize_for = LITE_RUNTIME; -enum PSCommand { +message Command { + CommandCode cmd = 1; +} + +enum CommandCode { PUSH = 0; PULL = 1; INIT_EMBEDDING_TABLE = 2; + INIT_WEIGHT = 3; + INIT_WEIGHT_TO_OPTIM_ID = 4; + INIT_INPUTS_SHAPE = 5; + CHECK_READY_FOR_PUSH = 6; + CHECK_READY_FOR_PULL = 7; + EMBEDDING_LOOKUP = 8; + UPDATE_EMBEDDING = 9; + FINALIZE = 10; } message KVMessage { - PSCommand command = 1; repeated int32 keys = 2; repeated float values = 3; + repeated int32 len = 4; +} + +message EmbeddingTableMeta { + uint64 key = 1; + repeated uint64 input_shape = 2; + repeated uint64 indices_shape = 3; + repeated uint64 output_shape = 4; +} + +message EmbeddingTableLookup { + uint64 key = 2; + repeated int32 keys = 3; + repeated float values = 4; } \ No newline at end of file diff --git a/mindspore/ccsrc/ps/core/scheduler_node.cc b/mindspore/ccsrc/ps/core/scheduler_node.cc index d84fc77dc47..a3a38519fbd 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.cc +++ b/mindspore/ccsrc/ps/core/scheduler_node.cc @@ -67,6 +67,7 @@ void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::sha } void SchedulerNode::Initialize() { + InitCommandHandler(); CreateTcpServer(); is_already_stopped_ = false; node_info_.node_id_ = CommUtil::GenerateUUID(); @@ -75,6 +76,13 @@ void SchedulerNode::Initialize() { << ", the node id is:" << node_info_.node_id_; } +void SchedulerNode::InitCommandHandler() { + handlers_[NodeCommand::HEARTBEAT] = &SchedulerNode::ProcessHeartbeat; + handlers_[NodeCommand::REGISTER] = &SchedulerNode::ProcessRegister; + handlers_[NodeCommand::FINISH] = &SchedulerNode::ProcessFinish; + handlers_[NodeCommand::FETCH_SERVER] = &SchedulerNode::ProcessFetchServers; +} + void SchedulerNode::CreateTcpServer() { node_manager_.InitNodeNum(); @@ -82,22 +90,11 @@ void SchedulerNode::CreateTcpServer() { uint32_t scheduler_port = ClusterConfig::scheduler_port(); server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port); server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) { - switch (message->pb_meta().cmd()) { - case NodeCommand::HEARTBEAT: - ProcessHeartbeat(server_, conn, message); - break; - case NodeCommand::REGISTER: - ProcessRegister(server_, conn, message); - break; - case NodeCommand::FINISH: - ProcessFinish(server_, conn, message); - break; - case NodeCommand::FETCH_SERVER: - ProcessFetchServers(server_, conn, message); - break; - default: - MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!"; + if (handlers_.count(message->pb_meta().cmd()) == 0) { + MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!"; } + const auto &handler_ptr = handlers_[message->pb_meta().cmd()]; + (this->*handler_ptr)(server_, conn, message); }); server_->Init(); diff --git a/mindspore/ccsrc/ps/core/scheduler_node.h b/mindspore/ccsrc/ps/core/scheduler_node.h index a476caae53e..1c89d2398dd 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.h +++ b/mindspore/ccsrc/ps/core/scheduler_node.h @@ -25,29 +25,32 @@ #include <vector> #include <thread> #include <mutex> +#include <unordered_map> #include "ps/core/cluster_config.h" #include "ps/core/tcp_client.h" #include "ps/core/tcp_server.h" #include "ps/core/node_manager.h" #include "ps/core/node.h" -#include "utils/log_adapter.h" namespace mindspore { namespace ps { namespace core { - class SchedulerNode : public Node { public: SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {} ~SchedulerNode() override; + typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, + std::shared_ptr<CommMessage> message); + bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override; bool Stop() override; bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; private: void Initialize(); + void InitCommandHandler(); void CreateTcpServer(); void ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message); @@ -62,6 +65,7 @@ class SchedulerNode : public Node { std::shared_ptr<TcpServer> server_; std::unique_ptr<std::thread> scheduler_thread_; std::unique_ptr<std::thread> update_state_thread_; + std::unordered_map<NodeCommand, ResponseHandler> handlers_; NodeManager node_manager_; }; diff --git a/mindspore/ccsrc/ps/core/server_node.cc b/mindspore/ccsrc/ps/core/server_node.cc index 08d0b280b80..28d09570678 100644 --- a/mindspore/ccsrc/ps/core/server_node.cc +++ b/mindspore/ccsrc/ps/core/server_node.cc @@ -92,6 +92,7 @@ void ServerNode::Initialize() { node_info_.port_ = server_->BoundPort(); MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) << " is generate uuid is:" << node_info_.node_id_; + InitCommandHandler(); if (!InitClientToScheduler()) { MS_LOG(EXCEPTION) << "Server node init client timeout!"; } diff --git a/mindspore/ccsrc/ps/core/server_node.h b/mindspore/ccsrc/ps/core/server_node.h index 2a0d70e82b6..086358f56e5 100644 --- a/mindspore/ccsrc/ps/core/server_node.h +++ b/mindspore/ccsrc/ps/core/server_node.h @@ -24,13 +24,10 @@ #include <thread> #include <utility> -#include "proto/comm.pb.h" -#include "proto/ps.pb.h" #include "ps/core/cluster_config.h" #include "ps/core/tcp_client.h" #include "ps/core/tcp_server.h" #include "ps/core/abstract_node.h" -#include "utils/log_adapter.h" namespace mindspore { namespace ps { diff --git a/mindspore/ccsrc/ps/core/worker_node.cc b/mindspore/ccsrc/ps/core/worker_node.cc index ee162e070b4..1870a499241 100644 --- a/mindspore/ccsrc/ps/core/worker_node.cc +++ b/mindspore/ccsrc/ps/core/worker_node.cc @@ -50,6 +50,7 @@ void WorkerNode::Initialize() { node_info_.node_role_ = NodeRole::WORKER; MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) << ", the node id is:" << node_info_.node_id_; + InitCommandHandler(); if (!InitClientToScheduler()) { MS_LOG(EXCEPTION) << "Worker node init client timeout!"; } diff --git a/mindspore/ccsrc/ps/core/worker_node.h b/mindspore/ccsrc/ps/core/worker_node.h index a1343aa3623..8608ae430a9 100644 --- a/mindspore/ccsrc/ps/core/worker_node.h +++ b/mindspore/ccsrc/ps/core/worker_node.h @@ -28,7 +28,6 @@ #include "ps/core/tcp_client.h" #include "ps/core/tcp_server.h" #include "ps/core/abstract_node.h" -#include "utils/log_adapter.h" namespace mindspore { namespace ps {