From 35f6d3912717bd178362fa451317a778a4a9c1eb Mon Sep 17 00:00:00 2001 From: chendongsheng Date: Thu, 20 May 2021 10:52:48 +0800 Subject: [PATCH] added cluster metadata --- mindspore/ccsrc/ps/core/abstract_node.cc | 10 +-- mindspore/ccsrc/ps/core/cluster_config.h | 73 +++++++++++++++++ mindspore/ccsrc/ps/core/cluster_metadata.cc | 79 ------------------- mindspore/ccsrc/ps/core/cluster_metadata.h | 54 +++---------- mindspore/ccsrc/ps/core/comm_util.cc | 5 +- mindspore/ccsrc/ps/core/comm_util.h | 2 + .../ccsrc/ps/core/communicator/tcp_client.h | 4 +- .../ps/core/communicator/tcp_communicator.h | 1 + .../ccsrc/ps/core/communicator/tcp_server.h | 1 + mindspore/ccsrc/ps/core/node.h | 4 +- mindspore/ccsrc/ps/core/node_manager.cc | 12 +-- mindspore/ccsrc/ps/core/scheduler_node.cc | 11 +-- mindspore/ccsrc/ps/core/scheduler_node.h | 4 +- mindspore/ccsrc/ps/core/server_node.cc | 1 - mindspore/ccsrc/ps/core/server_node.h | 4 +- mindspore/ccsrc/ps/core/worker_node.h | 4 +- mindspore/ccsrc/ps/ps_context.cc | 4 +- mindspore/ccsrc/ps/ps_context.h | 6 ++ mindspore/ccsrc/ps/scheduler.cc | 3 - .../ps/core/cluster_available_timeout_test.cc | 7 +- tests/ut/cpp/ps/core/cluster_metadata_test.cc | 18 ++--- tests/ut/cpp/ps/core/common_util_test.cc | 2 +- 22 files changed, 142 insertions(+), 167 deletions(-) create mode 100644 mindspore/ccsrc/ps/core/cluster_config.h delete mode 100644 mindspore/ccsrc/ps/core/cluster_metadata.cc diff --git a/mindspore/ccsrc/ps/core/abstract_node.cc b/mindspore/ccsrc/ps/core/abstract_node.cc index 7e68c53aecb..4dbc374b343 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.cc +++ b/mindspore/ccsrc/ps/core/abstract_node.cc @@ -294,7 +294,7 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr &client) } else { UpdateSchedulerTime(); } - std::this_thread::sleep_for(std::chrono::seconds(ClusterMetadata::instance()->heartbeat_interval())); + std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval)); } }); heart_beat_thread_->detach(); @@ -326,7 +326,7 @@ void AbstractNode::UpdateSchedulerTime() { bool AbstractNode::CheckSchedulerTimeout() const { struct timeval current_time {}; (void)gettimeofday(¤t_time, nullptr); - if (scheduler_time_.tv_sec + ClusterMetadata::instance()->scheduler_timeout() < current_time.tv_sec) { + if (scheduler_time_.tv_sec + PSContext::instance()->cluster_config().scheduler_timeout < current_time.tv_sec) { return true; } return false; @@ -411,8 +411,8 @@ bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) { } bool AbstractNode::InitClientToScheduler() { - std::string scheduler_host = ClusterMetadata::instance()->scheduler_host(); - uint16_t scheduler_port = ClusterMetadata::instance()->scheduler_port(); + std::string scheduler_host = PSContext::instance()->cluster_config().scheduler_host; + uint16_t scheduler_port = PSContext::instance()->cluster_config().scheduler_port; client_to_scheduler_ = std::make_shared(scheduler_host, scheduler_port); client_to_scheduler_->SetMessageCallback( [&](std::shared_ptr meta, const Protos &protos, const void *data, size_t size) { @@ -438,7 +438,7 @@ bool AbstractNode::InitClientToScheduler() { client_to_scheduler_thread_->detach(); client_to_scheduler_->set_disconnected_callback([&]() { - std::this_thread::sleep_for(std::chrono::milliseconds(ClusterMetadata::instance()->connect_interval())); + std::this_thread::sleep_for(std::chrono::milliseconds(PSContext::instance()->cluster_config().connect_interval)); if (is_ready_.load() == false) { client_to_scheduler_->Init(); } diff --git a/mindspore/ccsrc/ps/core/cluster_config.h b/mindspore/ccsrc/ps/core/cluster_config.h new file mode 100644 index 00000000000..ab2fc3a1b70 --- /dev/null +++ b/mindspore/ccsrc/ps/core/cluster_config.h @@ -0,0 +1,73 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_CORE_CLUSTER_CONFIG_H_ +#define MINDSPORE_CCSRC_PS_CORE_CLUSTER_CONFIG_H_ + +#include +#include +#include +#include + +#include "utils/log_adapter.h" + +namespace mindspore { +namespace ps { +namespace core { +/* + * Configuration information read through environment variables and configuration files, generally immutable + */ +struct ClusterConfig { + ClusterConfig() + : initial_worker_num(0), + initial_server_num(0), + heartbeat_interval(3), + scheduler_host(""), + scheduler_port(0), + heartbeat_timeout(30), + cluster_available_timeout(300), + connect_interval(100), + scheduler_timeout(30) {} + + void Init(const uint32_t &worker_num, const uint32_t &server_num, std::string host, const uint16_t &port) { + initial_worker_num = worker_num; + initial_server_num = server_num; + scheduler_host = host; + scheduler_port = port; + } + + // Configure through environment variables:MS_WORKER_NUM + uint32_t initial_worker_num; + // Configure through environment variables:MS_SERVER_NUM + uint32_t initial_server_num; + + // The interval for sending heartbeat packets between worker node,server node and scheduler node is 3 seconds. + uint32_t heartbeat_interval; + std::string scheduler_host; + uint16_t scheduler_port; + // The timeout for worker node and server node sending heartbeat packets to scheduler node is 30 seconds. + uint32_t heartbeat_timeout; + // Timeout period for cluster preparation is 300 seconds. + uint32_t cluster_available_timeout; + // The timeout period for the client to connect to the server is 100ms. + uint32_t connect_interval; + // When the scheduler exits, the worker and server can continue to work for 5 hours + uint32_t scheduler_timeout; +}; +} // namespace core +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_CORE_CLUSTER_CONFIG_H_ diff --git a/mindspore/ccsrc/ps/core/cluster_metadata.cc b/mindspore/ccsrc/ps/core/cluster_metadata.cc deleted file mode 100644 index 6d8bea0dc0f..00000000000 --- a/mindspore/ccsrc/ps/core/cluster_metadata.cc +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ps/core/cluster_metadata.h" -#include - -namespace mindspore { -namespace ps { -namespace core { -std::shared_ptr ClusterMetadata::instance() { - static std::shared_ptr metadata_instance = nullptr; - if (metadata_instance == nullptr) { - metadata_instance.reset(new (std::nothrow) ClusterMetadata()); - } - return metadata_instance; -} - -void ClusterMetadata::Init(const uint32_t &worker_num, const uint32_t &server_num, std::string scheduler_host, - const uint16_t &scheduler_port) { - worker_num_ = worker_num; - server_num_ = server_num; - scheduler_host_ = std::make_unique(scheduler_host); - scheduler_port_ = scheduler_port; -} - -uint32_t ClusterMetadata::total_worker_num() { return worker_num_; } - -uint32_t ClusterMetadata::total_server_num() { return server_num_; } - -uint32_t ClusterMetadata::heartbeat_interval() { return heartbeat_interval_; } - -void ClusterMetadata::set_heartbeat_interval(const uint32_t &heartbeat_interval) { - heartbeat_interval_ = heartbeat_interval; -} - -std::string ClusterMetadata::scheduler_host() { - MS_EXCEPTION_IF_NULL(scheduler_host_); - return *scheduler_host_; -} - -uint16_t ClusterMetadata::scheduler_port() { return scheduler_port_; } - -uint32_t ClusterMetadata::heartbeat_timeout() { return heartbeat_timeout_; } - -void ClusterMetadata::set_heartbeat_timeout(const uint32_t &heartbeat_timeout) { - heartbeat_interval_ = heartbeat_timeout; -} - -uint32_t ClusterMetadata::cluster_available_timeout() { return cluster_available_timeout_; } - -void ClusterMetadata::set_cluster_available_timeout(const uint32_t &cluster_available_timeout) { - cluster_available_timeout_ = cluster_available_timeout; -} - -uint32_t ClusterMetadata::connect_interval() { return connect_interval_; } - -void ClusterMetadata::set_connect_interval(const uint32_t &connect_interval) { connect_interval_ = connect_interval; } - -uint32_t ClusterMetadata::scheduler_timeout() { return scheduler_timeout_; } - -void ClusterMetadata::set_scheduler_timeout(const uint32_t &scheduler_timeout) { - scheduler_timeout_ = scheduler_timeout; -} -} // namespace core -} // namespace ps -} // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/cluster_metadata.h b/mindspore/ccsrc/ps/core/cluster_metadata.h index 0fbf039a9db..85213a40a85 100644 --- a/mindspore/ccsrc/ps/core/cluster_metadata.h +++ b/mindspore/ccsrc/ps/core/cluster_metadata.h @@ -27,55 +27,19 @@ namespace mindspore { namespace ps { namespace core { -class ClusterMetadata { - public: - ~ClusterMetadata() = default; - ClusterMetadata(ClusterMetadata const &) = delete; - ClusterMetadata &operator=(const ClusterMetadata &) = delete; - static std::shared_ptr instance(); +/* + * The metadata information of the cluster, stored in the scheduler, is generally used for scale out and scale in. + */ +struct ClusterMetadata { + ClusterMetadata() : worker_num_(0), server_num_(0) {} - void Init(const uint32_t &worker_num, const uint32_t &server_num, std::string scheduler_host, - const uint16_t &scheduler_port); - uint32_t total_worker_num(); - uint32_t total_server_num(); - uint32_t heartbeat_interval(); - void set_heartbeat_interval(const uint32_t &heartbeat_interval); - std::string scheduler_host(); - uint16_t scheduler_port(); - uint32_t heartbeat_timeout(); - void set_heartbeat_timeout(const uint32_t &heartbeat_timeout); - uint32_t cluster_available_timeout(); - void set_cluster_available_timeout(const uint32_t &cluster_available_timeout); - uint32_t connect_interval(); - void set_connect_interval(const uint32_t &connect_interval); - uint32_t scheduler_timeout(); - void set_scheduler_timeout(const uint32_t &scheduler_timeout); + void Init(const uint32_t &worker_num, const uint32_t &server_num) { + worker_num_ = worker_num; + server_num_ = server_num; + } - private: - ClusterMetadata() - : worker_num_(0), - server_num_(0), - heartbeat_interval_(3), - scheduler_host_(nullptr), - scheduler_port_(0), - heartbeat_timeout_(30), - cluster_available_timeout_(300), - connect_interval_(100), - scheduler_timeout_(30) {} uint32_t worker_num_; uint32_t server_num_; - // The interval for sending heartbeat packets between worker node,server node and scheduler node is 3 seconds. - uint32_t heartbeat_interval_; - std::unique_ptr scheduler_host_; - uint16_t scheduler_port_; - // The timeout for worker node and server node sending heartbeat packets to scheduler node is 30 seconds. - uint32_t heartbeat_timeout_; - // Timeout period for cluster preparation is 300 seconds. - uint32_t cluster_available_timeout_; - // The timeout period for the client to connect to the server is 100ms. - uint32_t connect_interval_; - // When the scheduler exits, the worker and server can continue to work for 5 hours - uint32_t scheduler_timeout_; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/comm_util.cc b/mindspore/ccsrc/ps/core/comm_util.cc index 8824d06a88d..7ede9b27a80 100644 --- a/mindspore/ccsrc/ps/core/comm_util.cc +++ b/mindspore/ccsrc/ps/core/comm_util.cc @@ -122,9 +122,10 @@ std::string CommUtil::NodeRoleToString(const NodeRole &role) { } } bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id) { - if (node_role == NodeRole::SERVER && (rank_id > ClusterMetadata::instance()->total_server_num() - 1)) { + if (node_role == NodeRole::SERVER && (rank_id > PSContext::instance()->cluster_config().initial_server_num - 1)) { return false; - } else if (node_role == NodeRole::WORKER && (rank_id > ClusterMetadata::instance()->total_worker_num() - 1)) { + } else if (node_role == NodeRole::WORKER && + (rank_id > PSContext::instance()->cluster_config().initial_worker_num - 1)) { return false; } return true; diff --git a/mindspore/ccsrc/ps/core/comm_util.h b/mindspore/ccsrc/ps/core/comm_util.h index 2a7e6feb229..bdd95839a72 100644 --- a/mindspore/ccsrc/ps/core/comm_util.h +++ b/mindspore/ccsrc/ps/core/comm_util.h @@ -50,6 +50,8 @@ #include "proto/comm.pb.h" #include "proto/ps.pb.h" #include "ps/core/cluster_metadata.h" +#include "ps/core/cluster_config.h" +#include "ps/ps_context.h" #include "utils/log_adapter.h" namespace mindspore { diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_client.h b/mindspore/ccsrc/ps/core/communicator/tcp_client.h index 84704b840cf..2a366badcf1 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_client.h +++ b/mindspore/ccsrc/ps/core/communicator/tcp_client.h @@ -34,6 +34,7 @@ #include #include "ps/core/cluster_metadata.h" +#include "ps/core/cluster_config.h" #include "utils/convert_utils_base.h" #include "ps/core/comm_util.h" #include "ps/core/communicator/ssl_wrapper.h" @@ -58,7 +59,8 @@ class TcpClient { std::string GetServerAddress() const; void set_disconnected_callback(const OnDisconnected &disconnected); void set_connected_callback(const OnConnected &connected); - bool WaitConnected(const uint32_t &connected_timeout = ClusterMetadata::instance()->cluster_available_timeout()); + bool WaitConnected( + const uint32_t &connected_timeout = PSContext::instance()->cluster_config().cluster_available_timeout); void Init(); void StartWithDelay(int seconds); void Stop(); diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h index 286fcbf59b6..6ce4344b3a4 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h +++ b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h @@ -25,6 +25,7 @@ #include "proto/ps.pb.h" #include "ps/core/server_node.h" #include "ps/core/cluster_metadata.h" +#include "ps/core/cluster_config.h" #include "ps/ps_context.h" #include "ps/core/communicator/task_executor.h" #include "ps/core/communicator/communicator_base.h" diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_server.h b/mindspore/ccsrc/ps/core/communicator/tcp_server.h index 0a0edf25da0..9977c89bf46 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_server.h +++ b/mindspore/ccsrc/ps/core/communicator/tcp_server.h @@ -38,6 +38,7 @@ #include "ps/core/communicator/tcp_message_handler.h" #include "ps/core/communicator/ssl_wrapper.h" #include "ps/core/cluster_metadata.h" +#include "ps/core/cluster_config.h" #include "utils/convert_utils_base.h" #include "ps/core/comm_util.h" #include "ps/constants.h" diff --git a/mindspore/ccsrc/ps/core/node.h b/mindspore/ccsrc/ps/core/node.h index d2fe18e0eae..338d95f9180 100644 --- a/mindspore/ccsrc/ps/core/node.h +++ b/mindspore/ccsrc/ps/core/node.h @@ -31,6 +31,8 @@ #include #include "ps/core/cluster_metadata.h" +#include "ps/core/cluster_config.h" +#include "ps/ps_context.h" #include "ps/core/node_info.h" #include "ps/core/communicator/tcp_client.h" #include "ps/core/communicator/tcp_server.h" @@ -54,7 +56,7 @@ class Node { using OnNodeEventMessage = std::function; using MessageCallback = std::function; - virtual bool Start(const uint32_t &timeout = ClusterMetadata::instance()->cluster_available_timeout()) = 0; + virtual bool Start(const uint32_t &timeout = PSContext::instance()->cluster_config().cluster_available_timeout) = 0; virtual bool Stop() = 0; virtual bool Finish(const uint32_t &timeout = kTimeoutInSeconds) = 0; diff --git a/mindspore/ccsrc/ps/core/node_manager.cc b/mindspore/ccsrc/ps/core/node_manager.cc index db4267d2465..08f621a0a25 100644 --- a/mindspore/ccsrc/ps/core/node_manager.cc +++ b/mindspore/ccsrc/ps/core/node_manager.cc @@ -20,7 +20,8 @@ namespace mindspore { namespace ps { namespace core { void NodeManager::InitNodeNum() { - total_node_num_ = ClusterMetadata::instance()->total_server_num() + ClusterMetadata::instance()->total_worker_num(); + total_node_num_ = PSContext::instance()->cluster_config().initial_server_num + + PSContext::instance()->cluster_config().initial_worker_num; } int NodeManager::NextRankId(const RegisterMessage ®ister_message) { @@ -39,7 +40,7 @@ int NodeManager::NextRankId(const RegisterMessage ®ister_message) { uint32_t port = register_message.port(); rank_id = ++next_server_rank_id_; - if (IntToUint(rank_id) >= ClusterMetadata::instance()->total_server_num()) { + if (IntToUint(rank_id) >= PSContext::instance()->cluster_config().initial_server_num) { MS_LOG(WARNING) << "The rank id is greater than the number of servers."; rank_id = -1; --next_server_rank_id_; @@ -55,7 +56,7 @@ int NodeManager::NextRankId(const RegisterMessage ®ister_message) { << " assign rank id:" << rank_id; } else if (register_message.role() == NodeRole::WORKER) { rank_id = ++next_worker_rank_id_; - if (IntToUint(rank_id) >= ClusterMetadata::instance()->total_worker_num()) { + if (IntToUint(rank_id) >= PSContext::instance()->cluster_config().initial_worker_num) { MS_LOG(WARNING) << "The rank id is greater than the number of workers."; rank_id = -1; --next_worker_rank_id_; @@ -104,7 +105,7 @@ void NodeManager::UpdateClusterState() { (void)gettimeofday(¤t_time, nullptr); timeout_nodes_info_.clear(); for (auto it = heartbeats_.begin(); it != heartbeats_.end(); ++it) { - if (it->second.tv_sec + ClusterMetadata::instance()->heartbeat_timeout() < current_time.tv_sec) { + if (it->second.tv_sec + PSContext::instance()->cluster_config().heartbeat_timeout < current_time.tv_sec) { MS_LOG(WARNING) << "The node id:" << it->first << " is timeout!"; timeout_nodes_info_[it->first] = nodes_info_[it->first]; } @@ -130,7 +131,8 @@ void NodeManager::UpdateClusterState() { void NodeManager::CheckClusterTimeout() { if (total_node_num_ != nodes_info_.size()) { - MS_LOG(WARNING) << "The cluster is not ready after " << ClusterMetadata::instance()->cluster_available_timeout() + MS_LOG(WARNING) << "The cluster is not ready after " + << PSContext::instance()->cluster_config().cluster_available_timeout << " seconds,so finish the cluster, and change total node number from " << total_node_num_ << " to " << nodes_info_.size(); current_node_num_ = nodes_info_.size(); diff --git a/mindspore/ccsrc/ps/core/scheduler_node.cc b/mindspore/ccsrc/ps/core/scheduler_node.cc index 926cb474caa..101454214ec 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.cc +++ b/mindspore/ccsrc/ps/core/scheduler_node.cc @@ -87,8 +87,8 @@ void SchedulerNode::InitCommandHandler() { void SchedulerNode::CreateTcpServer() { node_manager_.InitNodeNum(); - std::string scheduler_host = ClusterMetadata::instance()->scheduler_host(); - uint32_t scheduler_port = ClusterMetadata::instance()->scheduler_port(); + std::string scheduler_host = PSContext::instance()->cluster_config().scheduler_host; + uint32_t scheduler_port = PSContext::instance()->cluster_config().scheduler_port; server_ = std::make_shared(scheduler_host, scheduler_port); server_->SetMessageCallback([&](std::shared_ptr conn, std::shared_ptr meta, const Protos &protos, const void *data, size_t size) { @@ -168,19 +168,20 @@ void SchedulerNode::StartUpdateClusterStateTimer() { // 1. update cluster timeout if (!node_manager_.is_cluster_ready() && (std::chrono::steady_clock::now() - start_time > - std::chrono::seconds(ClusterMetadata::instance()->cluster_available_timeout()))) { + std::chrono::seconds(PSContext::instance()->cluster_config().cluster_available_timeout))) { node_manager_.CheckClusterTimeout(); } // 2. update cluster state - std::this_thread::sleep_for(std::chrono::seconds(ClusterMetadata::instance()->heartbeat_interval())); + std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval)); node_manager_.UpdateClusterState(); if (node_manager_.is_cluster_ready()) { is_ready_ = true; wait_start_cond_.notify_all(); } if (node_manager_.is_cluster_finish()) { - std::this_thread::sleep_for(std::chrono::seconds(ClusterMetadata::instance()->heartbeat_interval() * 2)); + std::this_thread::sleep_for( + std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval * 2)); is_finish_ = true; wait_finish_cond_.notify_all(); } diff --git a/mindspore/ccsrc/ps/core/scheduler_node.h b/mindspore/ccsrc/ps/core/scheduler_node.h index ba18549fd55..cf0790db185 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.h +++ b/mindspore/ccsrc/ps/core/scheduler_node.h @@ -28,6 +28,8 @@ #include #include "ps/core/cluster_metadata.h" +#include "ps/core/cluster_config.h" +#include "ps/ps_context.h" #include "ps/core/communicator/tcp_client.h" #include "ps/core/communicator/tcp_server.h" #include "ps/core/node_manager.h" @@ -44,7 +46,7 @@ class SchedulerNode : public Node { typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr server, std::shared_ptr conn, std::shared_ptr meta, const void *data, size_t size); - bool Start(const uint32_t &timeout = ClusterMetadata::instance()->cluster_available_timeout()) override; + bool Start(const uint32_t &timeout = PSContext::instance()->cluster_config().cluster_available_timeout) override; bool Stop() override; bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; diff --git a/mindspore/ccsrc/ps/core/server_node.cc b/mindspore/ccsrc/ps/core/server_node.cc index e83bc371e9b..7d5fecdeb90 100644 --- a/mindspore/ccsrc/ps/core/server_node.cc +++ b/mindspore/ccsrc/ps/core/server_node.cc @@ -147,7 +147,6 @@ std::shared_ptr ServerNode::GetOrCreateTcpComm(const std::stri MS_LOG(INFO) << "Create Tcp communicator."; auto tcp_comm = std::make_shared(task_executor, this); MS_EXCEPTION_IF_NULL(tcp_comm); - ClusterMetadata::instance()->Init(worker_num, server_num, scheduler_ip, scheduler_port); MS_LOG(INFO) << "Initialize cluster metadata for server. Worker number:" << worker_num << ", Server number:" << server_num << ", Scheduler ip:" << scheduler_ip << ", Scheduler port:" << scheduler_port; diff --git a/mindspore/ccsrc/ps/core/server_node.h b/mindspore/ccsrc/ps/core/server_node.h index a1cb8f0f17b..dad0d6ecfa9 100644 --- a/mindspore/ccsrc/ps/core/server_node.h +++ b/mindspore/ccsrc/ps/core/server_node.h @@ -27,6 +27,8 @@ #include #include "ps/core/cluster_metadata.h" +#include "ps/core/cluster_config.h" +#include "ps/ps_context.h" #include "ps/core/communicator/tcp_client.h" #include "ps/core/communicator/tcp_server.h" #include "ps/core/abstract_node.h" @@ -45,7 +47,7 @@ class ServerNode : public AbstractNode { ~ServerNode() override = default; - bool Start(const uint32_t &timeout = ClusterMetadata::instance()->cluster_available_timeout()) override; + bool Start(const uint32_t &timeout = PSContext::instance()->cluster_config().cluster_available_timeout) override; bool Stop() override; bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; diff --git a/mindspore/ccsrc/ps/core/worker_node.h b/mindspore/ccsrc/ps/core/worker_node.h index 19490b79fd1..03b344aa2fd 100644 --- a/mindspore/ccsrc/ps/core/worker_node.h +++ b/mindspore/ccsrc/ps/core/worker_node.h @@ -25,6 +25,8 @@ #include #include "ps/core/cluster_metadata.h" +#include "ps/core/cluster_config.h" +#include "ps/ps_context.h" #include "ps/core/communicator/tcp_client.h" #include "ps/core/communicator/tcp_server.h" #include "ps/core/abstract_node.h" @@ -37,7 +39,7 @@ class WorkerNode : public AbstractNode { WorkerNode() = default; ~WorkerNode() override = default; - bool Start(const uint32_t &timeout = ClusterMetadata::instance()->cluster_available_timeout()) override; + bool Start(const uint32_t &timeout = PSContext::instance()->cluster_config().cluster_available_timeout) override; bool Stop() override; bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index fa45991eccf..89f91eb0201 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -52,7 +52,7 @@ void PSContext::SetPSEnable(bool enabled) { server_num_ = std::strtol(common::GetEnv(kEnvPServerNum).c_str(), nullptr, 10); scheduler_host_ = common::GetEnv(kEnvSchedulerHost); scheduler_port_ = std::strtol(common::GetEnv(kEnvSchedulerPort).c_str(), nullptr, 10); - core::ClusterMetadata::instance()->Init(worker_num_, server_num_, scheduler_host_, scheduler_port_); + cluster_config_.Init(worker_num_, server_num_, scheduler_host_, scheduler_port_); } else { MS_LOG(INFO) << "PS mode is disabled."; is_worker_ = false; @@ -311,5 +311,7 @@ bool PSContext::secure_aggregation() const { return secure_aggregation_; } bool PSContext::enable_ssl() const { return enable_ssl_; } void PSContext::set_enable_ssl(bool enabled) { enable_ssl_ = enabled; } + +core::ClusterConfig &PSContext::cluster_config() { return cluster_config_; } } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_context.h b/mindspore/ccsrc/ps/ps_context.h index 56cf14029bc..9e2cc634992 100644 --- a/mindspore/ccsrc/ps/ps_context.h +++ b/mindspore/ccsrc/ps/ps_context.h @@ -22,6 +22,7 @@ #include #include "ps/constants.h" #include "ps/core/cluster_metadata.h" +#include "ps/core/cluster_config.h" namespace mindspore { namespace ps { @@ -147,6 +148,8 @@ class PSContext { void set_secure_aggregation(bool secure_aggregation); bool secure_aggregation() const; + core::ClusterConfig &cluster_config(); + private: PSContext() : ps_enabled_(false), @@ -229,6 +232,9 @@ class PSContext { // Whether to use secure aggregation algorithm. Used in federated learning for now. bool secure_aggregation_; + + // The cluster config read through environment variables, the value does not change. + core::ClusterConfig cluster_config_; }; } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/scheduler.cc b/mindspore/ccsrc/ps/scheduler.cc index 08726a66e58..acb1ca785cc 100755 --- a/mindspore/ccsrc/ps/scheduler.cc +++ b/mindspore/ccsrc/ps/scheduler.cc @@ -20,9 +20,6 @@ namespace mindspore { namespace ps { void Scheduler::Run() { MS_LOG(INFO) << "Start scheduler."; - core::ClusterMetadata::instance()->Init( - PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(), - PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port()); scheduler_node_.Start(); scheduler_node_.Finish(); scheduler_node_.Stop(); diff --git a/tests/ut/cpp/ps/core/cluster_available_timeout_test.cc b/tests/ut/cpp/ps/core/cluster_available_timeout_test.cc index 4e66bd4c001..880eee9619e 100644 --- a/tests/ut/cpp/ps/core/cluster_available_timeout_test.cc +++ b/tests/ut/cpp/ps/core/cluster_available_timeout_test.cc @@ -31,12 +31,9 @@ class TestClusterAvailableTimeout : public UT::Common { }; TEST_F(TestClusterAvailableTimeout, TestClusterAvailableTimeout) { - ClusterMetadata::instance()->Init(1, 1, "127.0.0.1", 9999); - ClusterMetadata::instance()->set_cluster_available_timeout(3); + PSContext::instance()->cluster_config().Init(1, 1, "127.0.0.1", 9999); + MS_LOG(INFO) << "The timeout is:" << PSContext::instance()->cluster_config().cluster_available_timeout; SchedulerNode node; - node.Start(); - node.Finish(); - node.Stop(); } } // namespace core } // namespace ps diff --git a/tests/ut/cpp/ps/core/cluster_metadata_test.cc b/tests/ut/cpp/ps/core/cluster_metadata_test.cc index 694093e2dff..671899272a8 100644 --- a/tests/ut/cpp/ps/core/cluster_metadata_test.cc +++ b/tests/ut/cpp/ps/core/cluster_metadata_test.cc @@ -19,26 +19,24 @@ #include "common/common_test.h" #include "ps/core/cluster_metadata.h" +#include "ps/core/cluster_config.h" +#include "ps/ps_context.h" namespace mindspore { namespace ps { namespace core { -class TestClusterMetadata : public UT::Common { +class TestClusterConfig : public UT::Common { public: - TestClusterMetadata() = default; - virtual ~TestClusterMetadata() = default; + TestClusterConfig() = default; + virtual ~TestClusterConfig() = default; void SetUp() override {} void TearDown() override {} }; -TEST_F(TestClusterMetadata, HeartbeatInterval) { - ClusterMetadata::instance()->Init(2, 2, "127.0.0.1", 8080); - EXPECT_TRUE(ClusterMetadata::instance()->heartbeat_interval() == 3); - ClusterMetadata::instance()->set_heartbeat_interval(100); - EXPECT_TRUE(ClusterMetadata::instance()->heartbeat_interval() == 100); - EXPECT_STREQ(ClusterMetadata::instance()->scheduler_host().c_str(), "127.0.0.1"); - EXPECT_TRUE(ClusterMetadata::instance()->scheduler_port() == 8080); +TEST_F(TestClusterConfig, HeartbeatInterval) { + PSContext::instance()->cluster_config().Init(2, 2, "127.0.0.1", 8080); + PSContext::instance()->cluster_config().heartbeat_interval = 100; } } // namespace core } // namespace ps diff --git a/tests/ut/cpp/ps/core/common_util_test.cc b/tests/ut/cpp/ps/core/common_util_test.cc index 262cabe53e2..b103aaf978f 100644 --- a/tests/ut/cpp/ps/core/common_util_test.cc +++ b/tests/ut/cpp/ps/core/common_util_test.cc @@ -53,7 +53,7 @@ TEST_F(TestCommUtil, GetAvailableInterfaceAndIP) { } TEST_F(TestCommUtil, ValidateRankId) { - ClusterMetadata::instance()->Init(3, 2, "127.0.0.1", 9999); + PSContext::instance()->cluster_config().Init(3, 2, "127.0.0.1", 9999); EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::WORKER, 2)); EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::WORKER, 3)); EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1));