diff --git a/mindspore/ccsrc/ps/core/abstract_node.cc b/mindspore/ccsrc/ps/core/abstract_node.cc index 03059b9ba99..c936d56e821 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.cc +++ b/mindspore/ccsrc/ps/core/abstract_node.cc @@ -139,10 +139,32 @@ void AbstractNode::set_scale_in_done() { << " the node id:" << node_info_.node_id_ << "is send scale_in_done to scheduler!"; } +void AbstractNode::BroadcastEvent(const uint32_t &event) { + auto message_meta = std::make_shared(); + message_meta->set_cmd(NodeCommand::SEND_EVENT); + + EventMessage event_message; + event_message.set_event(event); + event_message.set_node_id(node_info_.node_id_); + + if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF, event_message.SerializeAsString().data(), + event_message.ByteSizeLong())) { + MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << " the node id:" << node_info_.node_id_ << " scale_in_done timeout!"; + } + + MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << " the node id:" << node_info_.node_id_ << "is send event to scheduler!"; +} + void AbstractNode::RegisterEventCallback(const core::ClusterEvent &event, const EventCallback &event_cb) { event_to_callback_.try_emplace(event, event_cb); } +void AbstractNode::RegisterCustomEventCallback(const uint32_t &event, const EventCallback &event_cb) { + custom_event_to_callback_.try_emplace(event, event_cb); +} + bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command, const uint32_t &timeout) { MS_EXCEPTION_IF_NULL(data); @@ -531,6 +553,18 @@ void AbstractNode::ProcessScaleInDone(std::shared_ptr conn, std:: current_cluster_state_ = ClusterState::CLUSTER_READY; } +void AbstractNode::ProcessEvent(std::shared_ptr conn, std::shared_ptr meta, + const Protos &protos, const void *data, size_t size) { + MS_EXCEPTION_IF_NULL(conn); + MS_EXCEPTION_IF_NULL(meta); + MS_EXCEPTION_IF_NULL(data); + EventRespMessage event_resp_message; + event_resp_message.ParseFromArray(data, size); + uint32_t event = event_resp_message.event(); + server_->SendMessage(conn, meta, Protos::RAW, data, size); + OnCustomEventCallback(event); +} + void AbstractNode::ProcessScaleOut(std::shared_ptr conn, std::shared_ptr meta, const Protos &protos, const void *data, size_t size) { MS_EXCEPTION_IF_NULL(conn); @@ -790,6 +824,7 @@ void AbstractNode::InitCommandHandler() { handlers_[NodeCommand::FINISH] = nullptr; handlers_[NodeCommand::SCALE_OUT_DONE] = nullptr; handlers_[NodeCommand::SCALE_IN_DONE] = nullptr; + handlers_[NodeCommand::SEND_EVENT] = nullptr; } void AbstractNode::InitServerHandler() { @@ -801,6 +836,7 @@ void AbstractNode::InitServerHandler() { server_handler_[NodeCommand::SCALE_IN] = &AbstractNode::ProcessScaleIn; server_handler_[NodeCommand::SCALE_OUT_DONE] = &AbstractNode::ProcessScaleOutDone; server_handler_[NodeCommand::SCALE_IN_DONE] = &AbstractNode::ProcessScaleInDone; + server_handler_[NodeCommand::SEND_EVENT] = &AbstractNode::ProcessEvent; } void AbstractNode::InitNodeInfo(const NodeRole &role) { @@ -826,6 +862,15 @@ void AbstractNode::OnEventCallback(const ClusterEvent &event) { event_to_callback_[event](); } } + +void AbstractNode::OnCustomEventCallback(const uint32_t &event) { + if (!custom_event_to_callback_.count(event)) { + MS_LOG(WARNING) << "The event callback of " << event << " is not set."; + } else { + MS_LOG(INFO) << "Trigger the event:" << event; + custom_event_to_callback_[event](); + } +} } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/abstract_node.h b/mindspore/ccsrc/ps/core/abstract_node.h index 05ed110f438..cd96f2b8e5f 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.h +++ b/mindspore/ccsrc/ps/core/abstract_node.h @@ -68,8 +68,13 @@ class AbstractNode : public Node { // Send scale_in_done instructions to the scheduler. void set_scale_in_done(); + // The worker/server sends the event to the scheduler, and then the scheduler broadcasts this event to all nodes. + void BroadcastEvent(const uint32_t &event); + // Set the callback corresponding to the event. void RegisterEventCallback(const ClusterEvent &event, const EventCallback &event_cb); + // Set the callback corresponding to the custom event. + void RegisterCustomEventCallback(const uint32_t &event, const EventCallback &event_cb); bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command, const uint32_t &timeout = kCommTimeoutInSeconds); @@ -129,6 +134,10 @@ class AbstractNode : public Node { void ProcessScaleInDone(std::shared_ptr conn, std::shared_ptr meta, const Protos &protos, const void *data, size_t size); + // The worker/server processes the SEND_EVENT message from scheduelr + void ProcessEvent(std::shared_ptr conn, std::shared_ptr meta, const Protos &protos, + const void *data, size_t size); + void StartHeartbeatTimer(const std::shared_ptr &client); void UpdateSchedulerTime(); bool CheckSchedulerTimeout() const; @@ -153,6 +162,8 @@ class AbstractNode : public Node { // Trigger the callback corresponding to the event. void OnEventCallback(const ClusterEvent &event); + // Trigger the callback corresponding to the custom event. + void OnCustomEventCallback(const uint32_t &event); std::unique_ptr heart_beat_thread_; std::unique_ptr client_to_scheduler_thread_; @@ -201,6 +212,13 @@ class AbstractNode : public Node { // Each ClusterEvent corresponds to a EventCallback to process the event. std::map event_to_callback_; + // Each custom event corresponds to a EventCallback to process the event. + // This event is sent to the scheduler, and then the scheduler broadcasts this event to all nodes. + // for example: + // In order to ensure the consistency of the cluster, the server broadcasts an iteration_end event to notify all other + // nodes to modify the iteration status + std::map custom_event_to_callback_; + // Scaler for worker/server node. std::unique_ptr follower_scaler_; }; diff --git a/mindspore/ccsrc/ps/core/protos/comm.proto b/mindspore/ccsrc/ps/core/protos/comm.proto index 86cfa1a15da..43763b4a59b 100644 --- a/mindspore/ccsrc/ps/core/protos/comm.proto +++ b/mindspore/ccsrc/ps/core/protos/comm.proto @@ -37,6 +37,8 @@ enum NodeCommand { SCALE_OUT_DONE = 10; // This command is used to synchronize the scale in status of the cluster SCALE_IN_DONE = 11; + // This command is used to send user defined event. + SEND_EVENT = 12; } enum NodeRole { @@ -160,3 +162,14 @@ message ScaleOutDoneMessage { message ScaleInDoneMessage { string node_id = 1; } + +// This message is sent to the scheduler to notify the completion of scale out +message EventMessage { + uint32 event = 1; + string node_id = 2; +} + +// schedulerd broadcasts the event to all other nodes through this message +message EventRespMessage { + uint32 event = 1; +} diff --git a/mindspore/ccsrc/ps/core/scheduler_node.cc b/mindspore/ccsrc/ps/core/scheduler_node.cc index 61acd28cf38..365cd19a558 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.cc +++ b/mindspore/ccsrc/ps/core/scheduler_node.cc @@ -81,6 +81,7 @@ void SchedulerNode::InitCommandHandler() { handlers_[NodeCommand::FETCH_METADATA] = &SchedulerNode::ProcessFetchMetadata; handlers_[NodeCommand::SCALE_OUT_DONE] = &SchedulerNode::ProcessScaleOutDone; handlers_[NodeCommand::SCALE_IN_DONE] = &SchedulerNode::ProcessScaleInDone; + handlers_[NodeCommand::SEND_EVENT] = &SchedulerNode::ProcessSendEvent; } void SchedulerNode::CreateTcpServer() { @@ -230,6 +231,27 @@ void SchedulerNode::ProcessScaleInDone(std::shared_ptr server, std::s } } +void SchedulerNode::ProcessSendEvent(std::shared_ptr server, std::shared_ptr conn, + std::shared_ptr meta, const void *data, size_t size) { + MS_EXCEPTION_IF_NULL(server); + MS_EXCEPTION_IF_NULL(conn); + MS_EXCEPTION_IF_NULL(meta); + MS_EXCEPTION_IF_NULL(data); + EventMessage event_message; + event_message.ParseFromArray(data, size); + std::string node_id = event_message.node_id(); + uint32_t event = event_message.event(); + MS_LOG(INFO) << "The scheduler process a event message from node id:" << node_id; + + server->SendMessage(conn, meta, Protos::PROTOBUF, data, size); + + auto node_infos = node_manager_.nodes_info(); + for (const auto &kvs : node_infos) { + auto client = GetOrCreateClient(kvs.second); + SendEvent(client, event); + } +} + void SchedulerNode::SendMetadata(const std::shared_ptr &client) { MS_EXCEPTION_IF_NULL(client); auto message_meta = std::make_shared(); @@ -304,6 +326,24 @@ void SchedulerNode::SendScaleInDone(const std::shared_ptr &client) { << " the node id:" << node_info_.node_id_ << "is sending scale_in_done to workers and servers!"; } +void SchedulerNode::SendEvent(const std::shared_ptr &client, const uint32_t &event) { + MS_EXCEPTION_IF_NULL(client); + auto message_meta = std::make_shared(); + message_meta->set_cmd(NodeCommand::SEND_EVENT); + + EventRespMessage event_resp_message; + event_resp_message.set_event(event); + + if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, event_resp_message.SerializeAsString().data(), + event_resp_message.ByteSizeLong())) { + MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << " the node id:" << node_info_.node_id_ << " send event resp timeout!"; + } + + MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << " the node id:" << node_info_.node_id_ << "is sending event resp to workers and servers!"; +} + void SchedulerNode::StartUpdateClusterStateTimer() { MS_LOG(WARNING) << "The scheduler start a heartbeat timer!"; update_state_thread_ = std::make_unique([&]() { diff --git a/mindspore/ccsrc/ps/core/scheduler_node.h b/mindspore/ccsrc/ps/core/scheduler_node.h index 1a567f93cd2..9593354af5e 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.h +++ b/mindspore/ccsrc/ps/core/scheduler_node.h @@ -85,6 +85,9 @@ class SchedulerNode : public Node { // Process scale_in_done messages from workers/servers void ProcessScaleInDone(std::shared_ptr server, std::shared_ptr conn, std::shared_ptr meta, const void *data, size_t size); + // Process scale_in_done messages from workers/servers + void ProcessSendEvent(std::shared_ptr server, std::shared_ptr conn, + std::shared_ptr meta, const void *data, size_t size); // After scheduler collects all registered message, it actively sends finish to the node connected by the client. void SendMetadata(const std::shared_ptr &client); @@ -98,6 +101,8 @@ class SchedulerNode : public Node { // After scheduler collects all scale_in_done message, it actively sends scale_out_done to the node connected by the // client. void SendScaleInDone(const std::shared_ptr &client); + // After scheduler receive SEND_EVENT message, it will broadcast the event to all other nodes. + void SendEvent(const std::shared_ptr &client, const uint32_t &event); // Handle the scale out http request, then delegate to the leader scaler to process scale out asynchronously. void ProcessScaleOut(std::shared_ptr resp);