forked from mindspore-Ecosystem/mindspore
protobuf change bytes to Any
This commit is contained in:
parent
3e662805f8
commit
09a15be893
|
@ -32,6 +32,7 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
|
|||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(register_message.SerializeAsString());
|
||||
comm_message.set_user_cmd("");
|
||||
if (!SendMessageSync(client, comm_message)) {
|
||||
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << " register timeout!";
|
||||
|
@ -54,11 +55,12 @@ void AbstractNode::ProcessRegisterResp(const CommMessage &message) {
|
|||
MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_;
|
||||
}
|
||||
|
||||
bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string &message, const uint32_t &timeout) {
|
||||
bool AbstractNode::Broadcast(const enum NodeRole &node_role, const CommMessage &message, const uint32_t &timeout) {
|
||||
if (node_role != NodeRole::SERVER) {
|
||||
MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes";
|
||||
}
|
||||
|
||||
CommMessage &comm_message = const_cast<CommMessage &>(message);
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0);
|
||||
|
||||
|
@ -69,9 +71,7 @@ bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string &
|
|||
message_meta.set_rank_id(node_info_.rank_id_);
|
||||
message_meta.set_role(node_info_.node_role_);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(message);
|
||||
auto client = GetOrCreateTcpClient((*it).first.second);
|
||||
client->SendMessage(comm_message);
|
||||
}
|
||||
|
@ -84,26 +84,26 @@ void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_me
|
|||
on_node_event_message_ = on_node_event_message;
|
||||
}
|
||||
|
||||
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
|
||||
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message,
|
||||
const uint32_t &timeout) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
CommMessage &comm_message = const_cast<CommMessage &>(message);
|
||||
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::SEND_DATA);
|
||||
message_meta.set_rank_id(node_info_.rank_id_);
|
||||
message_meta.set_role(node_info_.node_role_);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(message);
|
||||
auto client = GetOrCreateTcpClient(rank_id);
|
||||
return SendMessageSync(client, comm_message, timeout);
|
||||
}
|
||||
|
||||
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
|
||||
const std::vector<std::string> &data, const uint32_t &timeout) {
|
||||
const std::vector<CommMessage> &data, const uint32_t &timeout) {
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(data.size(), 0);
|
||||
|
||||
|
@ -121,9 +121,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
|||
message_meta.set_rank_id(node_info_.rank_id_);
|
||||
message_meta.set_role(node_info_.node_role_);
|
||||
|
||||
CommMessage comm_message;
|
||||
CommMessage &comm_message = const_cast<CommMessage &>(data.at(it));
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(data.at(it));
|
||||
|
||||
auto client = GetOrCreateTcpClient(rank_ids.at(it));
|
||||
client->SendMessage(comm_message);
|
||||
|
@ -133,19 +132,21 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
|||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
|
||||
std::string *output, const uint32_t &timeout) {
|
||||
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message,
|
||||
CommMessage *output, const uint32_t &timeout) {
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
CommMessage &comm_message = const_cast<CommMessage &>(message);
|
||||
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(1, 0);
|
||||
set_message_callback(request_id, [&]() {
|
||||
receive_messages_mutex_.lock();
|
||||
auto res = receive_messages_[request_id];
|
||||
*output = res[rank_id].data();
|
||||
*output = res[rank_id];
|
||||
receive_messages_.erase(request_id);
|
||||
receive_messages_mutex_.unlock();
|
||||
});
|
||||
|
@ -156,9 +157,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
|
|||
message_meta.set_rank_id(node_info_.rank_id_);
|
||||
message_meta.set_role(node_info_.node_role_);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(message);
|
||||
auto client = GetOrCreateTcpClient(rank_id);
|
||||
client->SendMessage(comm_message);
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
|
@ -167,7 +166,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
|
|||
}
|
||||
|
||||
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
|
||||
const std::vector<std::string> &data, std::vector<std::string> *output,
|
||||
const std::vector<CommMessage> &data, std::vector<CommMessage> *output,
|
||||
const uint32_t &timeout) {
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
|
@ -183,7 +182,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
|||
receive_messages_mutex_.lock();
|
||||
auto res = receive_messages_[request_id];
|
||||
for (size_t it = 0; it < len; ++it) {
|
||||
(*output).push_back(res[rank_ids.at(it)].data());
|
||||
(*output).push_back(res[rank_ids.at(it)]);
|
||||
}
|
||||
receive_messages_.erase(request_id);
|
||||
receive_messages_mutex_.unlock();
|
||||
|
@ -200,9 +199,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
|||
message_meta.set_rank_id(node_info_.rank_id_);
|
||||
message_meta.set_role(node_info_.node_role_);
|
||||
|
||||
CommMessage comm_message;
|
||||
CommMessage &comm_message = const_cast<CommMessage &>(data.at(it));
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(data.at(it));
|
||||
|
||||
auto client = GetOrCreateTcpClient(rank_ids.at(it));
|
||||
client->SendMessage(comm_message);
|
||||
|
@ -223,37 +221,37 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) {
|
|||
}
|
||||
|
||||
uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
|
||||
const std::string &message) {
|
||||
const CommMessage &message) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
CommMessage &comm_message = const_cast<CommMessage &>(message);
|
||||
|
||||
MessageMeta message_meta;
|
||||
message_meta.set_cmd(NodeCommand::COLLECTIVE_SEND_DATA);
|
||||
message_meta.set_rank_id(node_info_.rank_id_);
|
||||
message_meta.set_role(node_info_.node_role_);
|
||||
|
||||
CommMessage comm_message;
|
||||
*comm_message.mutable_pb_meta() = {message_meta};
|
||||
comm_message.set_data(message);
|
||||
auto client = GetOrCreateTcpClient(rank_id);
|
||||
return SendMessageAsync(client, comm_message);
|
||||
}
|
||||
|
||||
std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role,
|
||||
const uint32_t &rank_id, std::string *output) {
|
||||
const uint32_t &rank_id, CommMessage *output) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
uint64_t rank_request_id = NextExpectedRankRequestId(rank_id);
|
||||
if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) {
|
||||
*output = received_data_[std::make_pair(rank_id, rank_request_id)].data();
|
||||
*output = received_data_[std::make_pair(rank_id, rank_request_id)];
|
||||
received_data_.erase(std::make_pair(rank_id, rank_request_id));
|
||||
} else {
|
||||
set_receive_callback(rank_id, rank_request_id, [=]() {
|
||||
receive_callbacks_mutex_.lock();
|
||||
*output = received_data_[std::make_pair(rank_id, 1)].data();
|
||||
*output = received_data_[std::make_pair(rank_id, rank_request_id)];
|
||||
received_data_.erase(std::make_pair(rank_id, rank_request_id));
|
||||
receive_callbacks_mutex_.unlock();
|
||||
});
|
||||
|
@ -415,21 +413,12 @@ bool AbstractNode::InitClientToScheduler() {
|
|||
uint16_t scheduler_port = ClusterConfig::scheduler_port();
|
||||
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port);
|
||||
client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) {
|
||||
switch (message.pb_meta().cmd()) {
|
||||
case NodeCommand::HEARTBEAT:
|
||||
ProcessHeartbeatResp(message);
|
||||
break;
|
||||
case NodeCommand::REGISTER:
|
||||
ProcessRegisterResp(message);
|
||||
break;
|
||||
case NodeCommand::FETCH_SERVER:
|
||||
ProcessFetchServersResp(message);
|
||||
break;
|
||||
case NodeCommand::FINISH:
|
||||
MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a finish message response!";
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
|
||||
if (handlers_.count(message.pb_meta().cmd()) == 0) {
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
|
||||
}
|
||||
if (handlers_[message.pb_meta().cmd()] != nullptr) {
|
||||
const auto &handler_ptr = handlers_[message.pb_meta().cmd()];
|
||||
(this->*handler_ptr)(message);
|
||||
}
|
||||
NotifyMessageArrival(message);
|
||||
});
|
||||
|
@ -607,6 +596,13 @@ uint64_t AbstractNode::NextActualRankRequestId(const uint32_t &rank_id) {
|
|||
}
|
||||
return rank_request_id;
|
||||
}
|
||||
|
||||
void AbstractNode::InitCommandHandler() {
|
||||
handlers_[NodeCommand::HEARTBEAT] = &AbstractNode::ProcessHeartbeatResp;
|
||||
handlers_[NodeCommand::REGISTER] = &AbstractNode::ProcessRegisterResp;
|
||||
handlers_[NodeCommand::FETCH_SERVER] = &AbstractNode::ProcessFetchServersResp;
|
||||
handlers_[NodeCommand::FINISH] = nullptr;
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,23 +34,25 @@ class AbstractNode : public Node {
|
|||
AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {}
|
||||
~AbstractNode() override = default;
|
||||
|
||||
bool Broadcast(const enum NodeRole &node_role, const std::string &message,
|
||||
typedef void (AbstractNode::*ResponseHandler)(const CommMessage &message);
|
||||
|
||||
bool Broadcast(const enum NodeRole &node_role, const CommMessage &message,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
void set_event_callback(const OnNodeEventMessage &on_node_event_message);
|
||||
|
||||
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
|
||||
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
|
||||
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<CommMessage> &data,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, std::string *output,
|
||||
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, CommMessage *output,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
|
||||
std::vector<std::string> *output, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<CommMessage> &data,
|
||||
std::vector<CommMessage> *output, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
|
||||
uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message);
|
||||
uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message);
|
||||
std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
|
||||
std::string *output);
|
||||
CommMessage *output);
|
||||
bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
|
||||
protected:
|
||||
|
@ -78,6 +80,7 @@ class AbstractNode : public Node {
|
|||
void RunReceiveCallback(const CommMessage &message);
|
||||
uint64_t NextExpectedRankRequestId(const uint32_t &rank_id);
|
||||
uint64_t NextActualRankRequestId(const uint32_t &rank_id);
|
||||
void InitCommandHandler();
|
||||
|
||||
std::unique_ptr<std::thread> heart_beat_thread_;
|
||||
std::unique_ptr<std::thread> client_to_scheduler_thread_;
|
||||
|
@ -115,6 +118,7 @@ class AbstractNode : public Node {
|
|||
std::unordered_map<uint32_t, uint64_t> actual_rank_request_ids_;
|
||||
std::mutex rank_request_ids_mutex;
|
||||
timeval scheduler_time_;
|
||||
std::unordered_map<NodeCommand, ResponseHandler> handlers_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -95,5 +95,6 @@ message FinishMessage {
|
|||
message CommMessage {
|
||||
MessageMeta pb_meta = 1;
|
||||
bytes data = 2;
|
||||
// User-defined commands
|
||||
bytes user_cmd = 3;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,17 +14,42 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
syntax = "proto3";
|
||||
package mindspore.ps.core;
|
||||
package mindspore.ps;
|
||||
option optimize_for = LITE_RUNTIME;
|
||||
|
||||
enum PSCommand {
|
||||
message Command {
|
||||
CommandCode cmd = 1;
|
||||
}
|
||||
|
||||
enum CommandCode {
|
||||
PUSH = 0;
|
||||
PULL = 1;
|
||||
INIT_EMBEDDING_TABLE = 2;
|
||||
INIT_WEIGHT = 3;
|
||||
INIT_WEIGHT_TO_OPTIM_ID = 4;
|
||||
INIT_INPUTS_SHAPE = 5;
|
||||
CHECK_READY_FOR_PUSH = 6;
|
||||
CHECK_READY_FOR_PULL = 7;
|
||||
EMBEDDING_LOOKUP = 8;
|
||||
UPDATE_EMBEDDING = 9;
|
||||
FINALIZE = 10;
|
||||
}
|
||||
|
||||
message KVMessage {
|
||||
PSCommand command = 1;
|
||||
repeated int32 keys = 2;
|
||||
repeated float values = 3;
|
||||
repeated int32 len = 4;
|
||||
}
|
||||
|
||||
message EmbeddingTableMeta {
|
||||
uint64 key = 1;
|
||||
repeated uint64 input_shape = 2;
|
||||
repeated uint64 indices_shape = 3;
|
||||
repeated uint64 output_shape = 4;
|
||||
}
|
||||
|
||||
message EmbeddingTableLookup {
|
||||
uint64 key = 2;
|
||||
repeated int32 keys = 3;
|
||||
repeated float values = 4;
|
||||
}
|
|
@ -67,6 +67,7 @@ void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::sha
|
|||
}
|
||||
|
||||
void SchedulerNode::Initialize() {
|
||||
InitCommandHandler();
|
||||
CreateTcpServer();
|
||||
is_already_stopped_ = false;
|
||||
node_info_.node_id_ = CommUtil::GenerateUUID();
|
||||
|
@ -75,6 +76,13 @@ void SchedulerNode::Initialize() {
|
|||
<< ", the node id is:" << node_info_.node_id_;
|
||||
}
|
||||
|
||||
void SchedulerNode::InitCommandHandler() {
|
||||
handlers_[NodeCommand::HEARTBEAT] = &SchedulerNode::ProcessHeartbeat;
|
||||
handlers_[NodeCommand::REGISTER] = &SchedulerNode::ProcessRegister;
|
||||
handlers_[NodeCommand::FINISH] = &SchedulerNode::ProcessFinish;
|
||||
handlers_[NodeCommand::FETCH_SERVER] = &SchedulerNode::ProcessFetchServers;
|
||||
}
|
||||
|
||||
void SchedulerNode::CreateTcpServer() {
|
||||
node_manager_.InitNodeNum();
|
||||
|
||||
|
@ -82,22 +90,11 @@ void SchedulerNode::CreateTcpServer() {
|
|||
uint32_t scheduler_port = ClusterConfig::scheduler_port();
|
||||
server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port);
|
||||
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
|
||||
switch (message->pb_meta().cmd()) {
|
||||
case NodeCommand::HEARTBEAT:
|
||||
ProcessHeartbeat(server_, conn, message);
|
||||
break;
|
||||
case NodeCommand::REGISTER:
|
||||
ProcessRegister(server_, conn, message);
|
||||
break;
|
||||
case NodeCommand::FINISH:
|
||||
ProcessFinish(server_, conn, message);
|
||||
break;
|
||||
case NodeCommand::FETCH_SERVER:
|
||||
ProcessFetchServers(server_, conn, message);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!";
|
||||
if (handlers_.count(message->pb_meta().cmd()) == 0) {
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!";
|
||||
}
|
||||
const auto &handler_ptr = handlers_[message->pb_meta().cmd()];
|
||||
(this->*handler_ptr)(server_, conn, message);
|
||||
});
|
||||
|
||||
server_->Init();
|
||||
|
|
|
@ -25,29 +25,32 @@
|
|||
#include <vector>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "ps/core/tcp_client.h"
|
||||
#include "ps/core/tcp_server.h"
|
||||
#include "ps/core/node_manager.h"
|
||||
#include "ps/core/node.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
|
||||
class SchedulerNode : public Node {
|
||||
public:
|
||||
SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {}
|
||||
~SchedulerNode() override;
|
||||
|
||||
typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<CommMessage> message);
|
||||
|
||||
bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override;
|
||||
bool Stop() override;
|
||||
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
|
||||
|
||||
private:
|
||||
void Initialize();
|
||||
void InitCommandHandler();
|
||||
void CreateTcpServer();
|
||||
void ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<CommMessage> message);
|
||||
|
@ -62,6 +65,7 @@ class SchedulerNode : public Node {
|
|||
std::shared_ptr<TcpServer> server_;
|
||||
std::unique_ptr<std::thread> scheduler_thread_;
|
||||
std::unique_ptr<std::thread> update_state_thread_;
|
||||
std::unordered_map<NodeCommand, ResponseHandler> handlers_;
|
||||
|
||||
NodeManager node_manager_;
|
||||
};
|
||||
|
|
|
@ -92,6 +92,7 @@ void ServerNode::Initialize() {
|
|||
node_info_.port_ = server_->BoundPort();
|
||||
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " is generate uuid is:" << node_info_.node_id_;
|
||||
InitCommandHandler();
|
||||
if (!InitClientToScheduler()) {
|
||||
MS_LOG(EXCEPTION) << "Server node init client timeout!";
|
||||
}
|
||||
|
|
|
@ -24,13 +24,10 @@
|
|||
#include <thread>
|
||||
#include <utility>
|
||||
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "ps/core/tcp_client.h"
|
||||
#include "ps/core/tcp_server.h"
|
||||
#include "ps/core/abstract_node.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
|
|
@ -50,6 +50,7 @@ void WorkerNode::Initialize() {
|
|||
node_info_.node_role_ = NodeRole::WORKER;
|
||||
MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_;
|
||||
InitCommandHandler();
|
||||
if (!InitClientToScheduler()) {
|
||||
MS_LOG(EXCEPTION) << "Worker node init client timeout!";
|
||||
}
|
||||
|
|
|
@ -28,7 +28,6 @@
|
|||
#include "ps/core/tcp_client.h"
|
||||
#include "ps/core/tcp_server.h"
|
||||
#include "ps/core/abstract_node.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
|
Loading…
Reference in New Issue