diff --git a/mindspore/ccsrc/ps/core/abstract_node.cc b/mindspore/ccsrc/ps/core/abstract_node.cc index c936d56e821..7c2be33137f 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.cc +++ b/mindspore/ccsrc/ps/core/abstract_node.cc @@ -50,11 +50,6 @@ void AbstractNode::ProcessRegisterResp(std::shared_ptr meta, const << " is not match the current node id:" << node_info_.node_id_; } - if (register_resp_message.rank_id() < 0) { - MS_LOG(EXCEPTION) << "The rank id is wrong."; - } - node_info_.rank_id_ = register_resp_message.rank_id(); - // Receive the Register message, indicating that the scheduler is alive, so update the time point at which the // scheduler is alive UpdateSchedulerTime(); @@ -497,9 +492,13 @@ void AbstractNode::ProcessSendMetadata(std::shared_ptr conn, std: send_meta_message.ParseFromArray(data, size); worker_num_ = send_meta_message.worker_num(); server_num_ = send_meta_message.server_num(); + if (send_meta_message.rank_id() < 0) { + MS_LOG(EXCEPTION) << "The rank id is wrong."; + } + node_info_.rank_id_ = send_meta_message.rank_id(); current_cluster_state_ = send_meta_message.cluster_state(); MS_LOG(INFO) << "The send metadata worker num:" << worker_num_ << ", server num:" << server_num_ - << ", cluster state is:" << current_cluster_state_; + << ", cluster state is:" << current_cluster_state_ << ", the rank id:" << node_info_.rank_id_; nodes_address_.clear(); for (const auto &it : send_meta_message.servers_meta()) { diff --git a/mindspore/ccsrc/ps/core/protos/comm.proto b/mindspore/ccsrc/ps/core/protos/comm.proto index 43763b4a59b..0156290a22c 100644 --- a/mindspore/ccsrc/ps/core/protos/comm.proto +++ b/mindspore/ccsrc/ps/core/protos/comm.proto @@ -73,7 +73,6 @@ message RegisterMessage { message RegisterRespMessage { string node_id = 1; - uint32 rank_id = 2; } message HeartbeatMessage { @@ -123,6 +122,8 @@ message SendMetadataMessage { int32 server_num = 3; // the current cluster state. ClusterState cluster_state = 4; + // The rank id of the node that received this message. + uint32 rank_id = 5; } message FinishMessage { @@ -163,13 +164,13 @@ message ScaleInDoneMessage { string node_id = 1; } -// This message is sent to the scheduler to notify the completion of scale out +// This message is sent by the worker/server to the scheduler, and the scheduler is broadcast the event to all other nodes. message EventMessage { uint32 event = 1; string node_id = 2; } -// schedulerd broadcasts the event to all other nodes through this message +// scheduler broadcasts the event to all other nodes through this message message EventRespMessage { uint32 event = 1; } diff --git a/mindspore/ccsrc/ps/core/scheduler_node.cc b/mindspore/ccsrc/ps/core/scheduler_node.cc index 365cd19a558..4a909514487 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.cc +++ b/mindspore/ccsrc/ps/core/scheduler_node.cc @@ -127,7 +127,6 @@ void SchedulerNode::ProcessRegister(std::shared_ptr server, std::shar RegisterRespMessage register_resp_message; register_resp_message.set_node_id(node_id); - register_resp_message.set_rank_id(rank_id); server->SendMessage(conn, meta, Protos::PROTOBUF, register_resp_message.SerializeAsString().data(), register_resp_message.ByteSizeLong()); @@ -137,8 +136,9 @@ void SchedulerNode::ProcessRegister(std::shared_ptr server, std::shar auto node_infos = node_manager_.nodes_info(); for (const auto &kvs : node_infos) { auto client = GetOrCreateClient(kvs.second); - SendMetadata(client); - MS_LOG(INFO) << "Send meta data to" << kvs.first; + SendMetadata(client, kvs.second.rank_id_); + MS_LOG(INFO) << "Send meta data to node id:" << kvs.first + << ", The rank id of the node that received this message is:" << kvs.second.rank_id_; } wait_start_cond_.notify_all(); } @@ -252,7 +252,7 @@ void SchedulerNode::ProcessSendEvent(std::shared_ptr server, std::sha } } -void SchedulerNode::SendMetadata(const std::shared_ptr &client) { +void SchedulerNode::SendMetadata(const std::shared_ptr &client, uint32_t rank_id) { MS_EXCEPTION_IF_NULL(client); auto message_meta = std::make_shared(); message_meta->set_cmd(NodeCommand::SEND_METADATA); @@ -262,6 +262,7 @@ void SchedulerNode::SendMetadata(const std::shared_ptr &client) { send_metadata_message.set_worker_num(node_manager_.worker_num()); send_metadata_message.set_server_num(node_manager_.server_num()); send_metadata_message.set_cluster_state(node_manager_.GetClusterState()); + send_metadata_message.set_rank_id(rank_id); *send_metadata_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()}; diff --git a/mindspore/ccsrc/ps/core/scheduler_node.h b/mindspore/ccsrc/ps/core/scheduler_node.h index 9593354af5e..c0510632ae9 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.h +++ b/mindspore/ccsrc/ps/core/scheduler_node.h @@ -90,7 +90,7 @@ class SchedulerNode : public Node { std::shared_ptr meta, const void *data, size_t size); // After scheduler collects all registered message, it actively sends finish to the node connected by the client. - void SendMetadata(const std::shared_ptr &client); + void SendMetadata(const std::shared_ptr &client, uint32_t rank_id); // After scheduler collects all finish message, it actively sends finish to the node connected by the client. void SendFinish(const std::shared_ptr &client);