protobuf change bytes to Any

This commit is contained in:
chendongsheng 2021-01-16 11:18:17 +08:00
parent 3e662805f8
commit 09a15be893
10 changed files with 97 additions and 72 deletions

View File

@ -32,6 +32,7 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(register_message.SerializeAsString());
comm_message.set_user_cmd("");
if (!SendMessageSync(client, comm_message)) {
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " register timeout!";
@ -54,11 +55,12 @@ void AbstractNode::ProcessRegisterResp(const CommMessage &message) {
MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_;
}
bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string &message, const uint32_t &timeout) {
bool AbstractNode::Broadcast(const enum NodeRole &node_role, const CommMessage &message, const uint32_t &timeout) {
if (node_role != NodeRole::SERVER) {
MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes";
}
CommMessage &comm_message = const_cast<CommMessage &>(message);
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0);
@ -69,9 +71,7 @@ bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string &
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(message);
auto client = GetOrCreateTcpClient((*it).first.second);
client->SendMessage(comm_message);
}
@ -84,26 +84,26 @@ void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_me
on_node_event_message_ = on_node_event_message;
}
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message,
const uint32_t &timeout) {
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
CommMessage &comm_message = const_cast<CommMessage &>(message);
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(message);
auto client = GetOrCreateTcpClient(rank_id);
return SendMessageSync(client, comm_message, timeout);
}
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<std::string> &data, const uint32_t &timeout) {
const std::vector<CommMessage> &data, const uint32_t &timeout) {
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(data.size(), 0);
@ -121,9 +121,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);
CommMessage comm_message;
CommMessage &comm_message = const_cast<CommMessage &>(data.at(it));
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(data.at(it));
auto client = GetOrCreateTcpClient(rank_ids.at(it));
client->SendMessage(comm_message);
@ -133,19 +132,21 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
return Wait(request_id, timeout);
}
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
std::string *output, const uint32_t &timeout) {
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message,
CommMessage *output, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(output);
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
CommMessage &comm_message = const_cast<CommMessage &>(message);
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(1, 0);
set_message_callback(request_id, [&]() {
receive_messages_mutex_.lock();
auto res = receive_messages_[request_id];
*output = res[rank_id].data();
*output = res[rank_id];
receive_messages_.erase(request_id);
receive_messages_mutex_.unlock();
});
@ -156,9 +157,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(message);
auto client = GetOrCreateTcpClient(rank_id);
client->SendMessage(comm_message);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
@ -167,7 +166,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
}
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<std::string> &data, std::vector<std::string> *output,
const std::vector<CommMessage> &data, std::vector<CommMessage> *output,
const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(output);
uint64_t request_id = ++next_request_id_;
@ -183,7 +182,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
receive_messages_mutex_.lock();
auto res = receive_messages_[request_id];
for (size_t it = 0; it < len; ++it) {
(*output).push_back(res[rank_ids.at(it)].data());
(*output).push_back(res[rank_ids.at(it)]);
}
receive_messages_.erase(request_id);
receive_messages_mutex_.unlock();
@ -200,9 +199,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);
CommMessage comm_message;
CommMessage &comm_message = const_cast<CommMessage &>(data.at(it));
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(data.at(it));
auto client = GetOrCreateTcpClient(rank_ids.at(it));
client->SendMessage(comm_message);
@ -223,37 +221,37 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) {
}
uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
const std::string &message) {
const CommMessage &message) {
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
CommMessage &comm_message = const_cast<CommMessage &>(message);
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::COLLECTIVE_SEND_DATA);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(message);
auto client = GetOrCreateTcpClient(rank_id);
return SendMessageAsync(client, comm_message);
}
std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role,
const uint32_t &rank_id, std::string *output) {
const uint32_t &rank_id, CommMessage *output) {
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
uint64_t rank_request_id = NextExpectedRankRequestId(rank_id);
if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) {
*output = received_data_[std::make_pair(rank_id, rank_request_id)].data();
*output = received_data_[std::make_pair(rank_id, rank_request_id)];
received_data_.erase(std::make_pair(rank_id, rank_request_id));
} else {
set_receive_callback(rank_id, rank_request_id, [=]() {
receive_callbacks_mutex_.lock();
*output = received_data_[std::make_pair(rank_id, 1)].data();
*output = received_data_[std::make_pair(rank_id, rank_request_id)];
received_data_.erase(std::make_pair(rank_id, rank_request_id));
receive_callbacks_mutex_.unlock();
});
@ -415,21 +413,12 @@ bool AbstractNode::InitClientToScheduler() {
uint16_t scheduler_port = ClusterConfig::scheduler_port();
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port);
client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) {
switch (message.pb_meta().cmd()) {
case NodeCommand::HEARTBEAT:
ProcessHeartbeatResp(message);
break;
case NodeCommand::REGISTER:
ProcessRegisterResp(message);
break;
case NodeCommand::FETCH_SERVER:
ProcessFetchServersResp(message);
break;
case NodeCommand::FINISH:
MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a finish message response!";
break;
default:
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
if (handlers_.count(message.pb_meta().cmd()) == 0) {
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
}
if (handlers_[message.pb_meta().cmd()] != nullptr) {
const auto &handler_ptr = handlers_[message.pb_meta().cmd()];
(this->*handler_ptr)(message);
}
NotifyMessageArrival(message);
});
@ -607,6 +596,13 @@ uint64_t AbstractNode::NextActualRankRequestId(const uint32_t &rank_id) {
}
return rank_request_id;
}
void AbstractNode::InitCommandHandler() {
handlers_[NodeCommand::HEARTBEAT] = &AbstractNode::ProcessHeartbeatResp;
handlers_[NodeCommand::REGISTER] = &AbstractNode::ProcessRegisterResp;
handlers_[NodeCommand::FETCH_SERVER] = &AbstractNode::ProcessFetchServersResp;
handlers_[NodeCommand::FINISH] = nullptr;
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -34,23 +34,25 @@ class AbstractNode : public Node {
AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {}
~AbstractNode() override = default;
bool Broadcast(const enum NodeRole &node_role, const std::string &message,
typedef void (AbstractNode::*ResponseHandler)(const CommMessage &message);
bool Broadcast(const enum NodeRole &node_role, const CommMessage &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
void set_event_callback(const OnNodeEventMessage &on_node_event_message);
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<CommMessage> &data,
const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, std::string *output,
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, CommMessage *output,
const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
std::vector<std::string> *output, const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<CommMessage> &data,
std::vector<CommMessage> *output, const uint32_t &timeout = kCommTimeoutInSeconds);
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message);
uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message);
std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
std::string *output);
CommMessage *output);
bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
protected:
@ -78,6 +80,7 @@ class AbstractNode : public Node {
void RunReceiveCallback(const CommMessage &message);
uint64_t NextExpectedRankRequestId(const uint32_t &rank_id);
uint64_t NextActualRankRequestId(const uint32_t &rank_id);
void InitCommandHandler();
std::unique_ptr<std::thread> heart_beat_thread_;
std::unique_ptr<std::thread> client_to_scheduler_thread_;
@ -115,6 +118,7 @@ class AbstractNode : public Node {
std::unordered_map<uint32_t, uint64_t> actual_rank_request_ids_;
std::mutex rank_request_ids_mutex;
timeval scheduler_time_;
std::unordered_map<NodeCommand, ResponseHandler> handlers_;
};
} // namespace core
} // namespace ps

View File

@ -95,5 +95,6 @@ message FinishMessage {
message CommMessage {
MessageMeta pb_meta = 1;
bytes data = 2;
// User-defined commands
bytes user_cmd = 3;
}

View File

@ -14,17 +14,42 @@
* limitations under the License.
*/
syntax = "proto3";
package mindspore.ps.core;
package mindspore.ps;
option optimize_for = LITE_RUNTIME;
enum PSCommand {
message Command {
CommandCode cmd = 1;
}
enum CommandCode {
PUSH = 0;
PULL = 1;
INIT_EMBEDDING_TABLE = 2;
INIT_WEIGHT = 3;
INIT_WEIGHT_TO_OPTIM_ID = 4;
INIT_INPUTS_SHAPE = 5;
CHECK_READY_FOR_PUSH = 6;
CHECK_READY_FOR_PULL = 7;
EMBEDDING_LOOKUP = 8;
UPDATE_EMBEDDING = 9;
FINALIZE = 10;
}
message KVMessage {
PSCommand command = 1;
repeated int32 keys = 2;
repeated float values = 3;
repeated int32 len = 4;
}
message EmbeddingTableMeta {
uint64 key = 1;
repeated uint64 input_shape = 2;
repeated uint64 indices_shape = 3;
repeated uint64 output_shape = 4;
}
message EmbeddingTableLookup {
uint64 key = 2;
repeated int32 keys = 3;
repeated float values = 4;
}

View File

@ -67,6 +67,7 @@ void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::sha
}
void SchedulerNode::Initialize() {
InitCommandHandler();
CreateTcpServer();
is_already_stopped_ = false;
node_info_.node_id_ = CommUtil::GenerateUUID();
@ -75,6 +76,13 @@ void SchedulerNode::Initialize() {
<< ", the node id is:" << node_info_.node_id_;
}
void SchedulerNode::InitCommandHandler() {
handlers_[NodeCommand::HEARTBEAT] = &SchedulerNode::ProcessHeartbeat;
handlers_[NodeCommand::REGISTER] = &SchedulerNode::ProcessRegister;
handlers_[NodeCommand::FINISH] = &SchedulerNode::ProcessFinish;
handlers_[NodeCommand::FETCH_SERVER] = &SchedulerNode::ProcessFetchServers;
}
void SchedulerNode::CreateTcpServer() {
node_manager_.InitNodeNum();
@ -82,22 +90,11 @@ void SchedulerNode::CreateTcpServer() {
uint32_t scheduler_port = ClusterConfig::scheduler_port();
server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port);
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
switch (message->pb_meta().cmd()) {
case NodeCommand::HEARTBEAT:
ProcessHeartbeat(server_, conn, message);
break;
case NodeCommand::REGISTER:
ProcessRegister(server_, conn, message);
break;
case NodeCommand::FINISH:
ProcessFinish(server_, conn, message);
break;
case NodeCommand::FETCH_SERVER:
ProcessFetchServers(server_, conn, message);
break;
default:
MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!";
if (handlers_.count(message->pb_meta().cmd()) == 0) {
MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!";
}
const auto &handler_ptr = handlers_[message->pb_meta().cmd()];
(this->*handler_ptr)(server_, conn, message);
});
server_->Init();

View File

@ -25,29 +25,32 @@
#include <vector>
#include <thread>
#include <mutex>
#include <unordered_map>
#include "ps/core/cluster_config.h"
#include "ps/core/tcp_client.h"
#include "ps/core/tcp_server.h"
#include "ps/core/node_manager.h"
#include "ps/core/node.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace ps {
namespace core {
class SchedulerNode : public Node {
public:
SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {}
~SchedulerNode() override;
typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message);
bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override;
bool Stop() override;
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
private:
void Initialize();
void InitCommandHandler();
void CreateTcpServer();
void ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message);
@ -62,6 +65,7 @@ class SchedulerNode : public Node {
std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> scheduler_thread_;
std::unique_ptr<std::thread> update_state_thread_;
std::unordered_map<NodeCommand, ResponseHandler> handlers_;
NodeManager node_manager_;
};

View File

@ -92,6 +92,7 @@ void ServerNode::Initialize() {
node_info_.port_ = server_->BoundPort();
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " is generate uuid is:" << node_info_.node_id_;
InitCommandHandler();
if (!InitClientToScheduler()) {
MS_LOG(EXCEPTION) << "Server node init client timeout!";
}

View File

@ -24,13 +24,10 @@
#include <thread>
#include <utility>
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/cluster_config.h"
#include "ps/core/tcp_client.h"
#include "ps/core/tcp_server.h"
#include "ps/core/abstract_node.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace ps {

View File

@ -50,6 +50,7 @@ void WorkerNode::Initialize() {
node_info_.node_role_ = NodeRole::WORKER;
MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_;
InitCommandHandler();
if (!InitClientToScheduler()) {
MS_LOG(EXCEPTION) << "Worker node init client timeout!";
}

View File

@ -28,7 +28,6 @@
#include "ps/core/tcp_client.h"
#include "ps/core/tcp_server.h"
#include "ps/core/abstract_node.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace ps {