forked from mindspore-Ecosystem/mindspore
added cluster metadata
This commit is contained in:
parent
b281711030
commit
35f6d39127
|
@ -294,7 +294,7 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client)
|
|||
} else {
|
||||
UpdateSchedulerTime();
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::seconds(ClusterMetadata::instance()->heartbeat_interval()));
|
||||
std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval));
|
||||
}
|
||||
});
|
||||
heart_beat_thread_->detach();
|
||||
|
@ -326,7 +326,7 @@ void AbstractNode::UpdateSchedulerTime() {
|
|||
bool AbstractNode::CheckSchedulerTimeout() const {
|
||||
struct timeval current_time {};
|
||||
(void)gettimeofday(¤t_time, nullptr);
|
||||
if (scheduler_time_.tv_sec + ClusterMetadata::instance()->scheduler_timeout() < current_time.tv_sec) {
|
||||
if (scheduler_time_.tv_sec + PSContext::instance()->cluster_config().scheduler_timeout < current_time.tv_sec) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
@ -411,8 +411,8 @@ bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) {
|
|||
}
|
||||
|
||||
bool AbstractNode::InitClientToScheduler() {
|
||||
std::string scheduler_host = ClusterMetadata::instance()->scheduler_host();
|
||||
uint16_t scheduler_port = ClusterMetadata::instance()->scheduler_port();
|
||||
std::string scheduler_host = PSContext::instance()->cluster_config().scheduler_host;
|
||||
uint16_t scheduler_port = PSContext::instance()->cluster_config().scheduler_port;
|
||||
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port);
|
||||
client_to_scheduler_->SetMessageCallback(
|
||||
[&](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
|
||||
|
@ -438,7 +438,7 @@ bool AbstractNode::InitClientToScheduler() {
|
|||
client_to_scheduler_thread_->detach();
|
||||
|
||||
client_to_scheduler_->set_disconnected_callback([&]() {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(ClusterMetadata::instance()->connect_interval()));
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(PSContext::instance()->cluster_config().connect_interval));
|
||||
if (is_ready_.load() == false) {
|
||||
client_to_scheduler_->Init();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_CORE_CLUSTER_CONFIG_H_
|
||||
#define MINDSPORE_CCSRC_PS_CORE_CLUSTER_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
/*
|
||||
* Configuration information read through environment variables and configuration files, generally immutable
|
||||
*/
|
||||
struct ClusterConfig {
|
||||
ClusterConfig()
|
||||
: initial_worker_num(0),
|
||||
initial_server_num(0),
|
||||
heartbeat_interval(3),
|
||||
scheduler_host(""),
|
||||
scheduler_port(0),
|
||||
heartbeat_timeout(30),
|
||||
cluster_available_timeout(300),
|
||||
connect_interval(100),
|
||||
scheduler_timeout(30) {}
|
||||
|
||||
void Init(const uint32_t &worker_num, const uint32_t &server_num, std::string host, const uint16_t &port) {
|
||||
initial_worker_num = worker_num;
|
||||
initial_server_num = server_num;
|
||||
scheduler_host = host;
|
||||
scheduler_port = port;
|
||||
}
|
||||
|
||||
// Configure through environment variables:MS_WORKER_NUM
|
||||
uint32_t initial_worker_num;
|
||||
// Configure through environment variables:MS_SERVER_NUM
|
||||
uint32_t initial_server_num;
|
||||
|
||||
// The interval for sending heartbeat packets between worker node,server node and scheduler node is 3 seconds.
|
||||
uint32_t heartbeat_interval;
|
||||
std::string scheduler_host;
|
||||
uint16_t scheduler_port;
|
||||
// The timeout for worker node and server node sending heartbeat packets to scheduler node is 30 seconds.
|
||||
uint32_t heartbeat_timeout;
|
||||
// Timeout period for cluster preparation is 300 seconds.
|
||||
uint32_t cluster_available_timeout;
|
||||
// The timeout period for the client to connect to the server is 100ms.
|
||||
uint32_t connect_interval;
|
||||
// When the scheduler exits, the worker and server can continue to work for 5 hours
|
||||
uint32_t scheduler_timeout;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_CORE_CLUSTER_CONFIG_H_
|
|
@ -1,79 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
std::shared_ptr<ClusterMetadata> ClusterMetadata::instance() {
|
||||
static std::shared_ptr<ClusterMetadata> metadata_instance = nullptr;
|
||||
if (metadata_instance == nullptr) {
|
||||
metadata_instance.reset(new (std::nothrow) ClusterMetadata());
|
||||
}
|
||||
return metadata_instance;
|
||||
}
|
||||
|
||||
void ClusterMetadata::Init(const uint32_t &worker_num, const uint32_t &server_num, std::string scheduler_host,
|
||||
const uint16_t &scheduler_port) {
|
||||
worker_num_ = worker_num;
|
||||
server_num_ = server_num;
|
||||
scheduler_host_ = std::make_unique<std::string>(scheduler_host);
|
||||
scheduler_port_ = scheduler_port;
|
||||
}
|
||||
|
||||
uint32_t ClusterMetadata::total_worker_num() { return worker_num_; }
|
||||
|
||||
uint32_t ClusterMetadata::total_server_num() { return server_num_; }
|
||||
|
||||
uint32_t ClusterMetadata::heartbeat_interval() { return heartbeat_interval_; }
|
||||
|
||||
void ClusterMetadata::set_heartbeat_interval(const uint32_t &heartbeat_interval) {
|
||||
heartbeat_interval_ = heartbeat_interval;
|
||||
}
|
||||
|
||||
std::string ClusterMetadata::scheduler_host() {
|
||||
MS_EXCEPTION_IF_NULL(scheduler_host_);
|
||||
return *scheduler_host_;
|
||||
}
|
||||
|
||||
uint16_t ClusterMetadata::scheduler_port() { return scheduler_port_; }
|
||||
|
||||
uint32_t ClusterMetadata::heartbeat_timeout() { return heartbeat_timeout_; }
|
||||
|
||||
void ClusterMetadata::set_heartbeat_timeout(const uint32_t &heartbeat_timeout) {
|
||||
heartbeat_interval_ = heartbeat_timeout;
|
||||
}
|
||||
|
||||
uint32_t ClusterMetadata::cluster_available_timeout() { return cluster_available_timeout_; }
|
||||
|
||||
void ClusterMetadata::set_cluster_available_timeout(const uint32_t &cluster_available_timeout) {
|
||||
cluster_available_timeout_ = cluster_available_timeout;
|
||||
}
|
||||
|
||||
uint32_t ClusterMetadata::connect_interval() { return connect_interval_; }
|
||||
|
||||
void ClusterMetadata::set_connect_interval(const uint32_t &connect_interval) { connect_interval_ = connect_interval; }
|
||||
|
||||
uint32_t ClusterMetadata::scheduler_timeout() { return scheduler_timeout_; }
|
||||
|
||||
void ClusterMetadata::set_scheduler_timeout(const uint32_t &scheduler_timeout) {
|
||||
scheduler_timeout_ = scheduler_timeout;
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -27,55 +27,19 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
class ClusterMetadata {
|
||||
public:
|
||||
~ClusterMetadata() = default;
|
||||
ClusterMetadata(ClusterMetadata const &) = delete;
|
||||
ClusterMetadata &operator=(const ClusterMetadata &) = delete;
|
||||
static std::shared_ptr<ClusterMetadata> instance();
|
||||
/*
|
||||
* The metadata information of the cluster, stored in the scheduler, is generally used for scale out and scale in.
|
||||
*/
|
||||
struct ClusterMetadata {
|
||||
ClusterMetadata() : worker_num_(0), server_num_(0) {}
|
||||
|
||||
void Init(const uint32_t &worker_num, const uint32_t &server_num, std::string scheduler_host,
|
||||
const uint16_t &scheduler_port);
|
||||
uint32_t total_worker_num();
|
||||
uint32_t total_server_num();
|
||||
uint32_t heartbeat_interval();
|
||||
void set_heartbeat_interval(const uint32_t &heartbeat_interval);
|
||||
std::string scheduler_host();
|
||||
uint16_t scheduler_port();
|
||||
uint32_t heartbeat_timeout();
|
||||
void set_heartbeat_timeout(const uint32_t &heartbeat_timeout);
|
||||
uint32_t cluster_available_timeout();
|
||||
void set_cluster_available_timeout(const uint32_t &cluster_available_timeout);
|
||||
uint32_t connect_interval();
|
||||
void set_connect_interval(const uint32_t &connect_interval);
|
||||
uint32_t scheduler_timeout();
|
||||
void set_scheduler_timeout(const uint32_t &scheduler_timeout);
|
||||
void Init(const uint32_t &worker_num, const uint32_t &server_num) {
|
||||
worker_num_ = worker_num;
|
||||
server_num_ = server_num;
|
||||
}
|
||||
|
||||
private:
|
||||
ClusterMetadata()
|
||||
: worker_num_(0),
|
||||
server_num_(0),
|
||||
heartbeat_interval_(3),
|
||||
scheduler_host_(nullptr),
|
||||
scheduler_port_(0),
|
||||
heartbeat_timeout_(30),
|
||||
cluster_available_timeout_(300),
|
||||
connect_interval_(100),
|
||||
scheduler_timeout_(30) {}
|
||||
uint32_t worker_num_;
|
||||
uint32_t server_num_;
|
||||
// The interval for sending heartbeat packets between worker node,server node and scheduler node is 3 seconds.
|
||||
uint32_t heartbeat_interval_;
|
||||
std::unique_ptr<std::string> scheduler_host_;
|
||||
uint16_t scheduler_port_;
|
||||
// The timeout for worker node and server node sending heartbeat packets to scheduler node is 30 seconds.
|
||||
uint32_t heartbeat_timeout_;
|
||||
// Timeout period for cluster preparation is 300 seconds.
|
||||
uint32_t cluster_available_timeout_;
|
||||
// The timeout period for the client to connect to the server is 100ms.
|
||||
uint32_t connect_interval_;
|
||||
// When the scheduler exits, the worker and server can continue to work for 5 hours
|
||||
uint32_t scheduler_timeout_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -122,9 +122,10 @@ std::string CommUtil::NodeRoleToString(const NodeRole &role) {
|
|||
}
|
||||
}
|
||||
bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id) {
|
||||
if (node_role == NodeRole::SERVER && (rank_id > ClusterMetadata::instance()->total_server_num() - 1)) {
|
||||
if (node_role == NodeRole::SERVER && (rank_id > PSContext::instance()->cluster_config().initial_server_num - 1)) {
|
||||
return false;
|
||||
} else if (node_role == NodeRole::WORKER && (rank_id > ClusterMetadata::instance()->total_worker_num() - 1)) {
|
||||
} else if (node_role == NodeRole::WORKER &&
|
||||
(rank_id > PSContext::instance()->cluster_config().initial_worker_num - 1)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -50,6 +50,8 @@
|
|||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
#include <condition_variable>
|
||||
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
#include "ps/core/comm_util.h"
|
||||
#include "ps/core/communicator/ssl_wrapper.h"
|
||||
|
@ -58,7 +59,8 @@ class TcpClient {
|
|||
std::string GetServerAddress() const;
|
||||
void set_disconnected_callback(const OnDisconnected &disconnected);
|
||||
void set_connected_callback(const OnConnected &connected);
|
||||
bool WaitConnected(const uint32_t &connected_timeout = ClusterMetadata::instance()->cluster_available_timeout());
|
||||
bool WaitConnected(
|
||||
const uint32_t &connected_timeout = PSContext::instance()->cluster_config().cluster_available_timeout);
|
||||
void Init();
|
||||
void StartWithDelay(int seconds);
|
||||
void Stop();
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "proto/ps.pb.h"
|
||||
#include "ps/core/server_node.h"
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "ps/core/communicator/task_executor.h"
|
||||
#include "ps/core/communicator/communicator_base.h"
|
||||
|
|
|
@ -38,6 +38,7 @@
|
|||
#include "ps/core/communicator/tcp_message_handler.h"
|
||||
#include "ps/core/communicator/ssl_wrapper.h"
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
#include "ps/core/comm_util.h"
|
||||
#include "ps/constants.h"
|
||||
|
|
|
@ -31,6 +31,8 @@
|
|||
#include <tuple>
|
||||
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "ps/core/node_info.h"
|
||||
#include "ps/core/communicator/tcp_client.h"
|
||||
#include "ps/core/communicator/tcp_server.h"
|
||||
|
@ -54,7 +56,7 @@ class Node {
|
|||
using OnNodeEventMessage = std::function<void(const NodeEvent &event)>;
|
||||
using MessageCallback = std::function<void()>;
|
||||
|
||||
virtual bool Start(const uint32_t &timeout = ClusterMetadata::instance()->cluster_available_timeout()) = 0;
|
||||
virtual bool Start(const uint32_t &timeout = PSContext::instance()->cluster_config().cluster_available_timeout) = 0;
|
||||
virtual bool Stop() = 0;
|
||||
virtual bool Finish(const uint32_t &timeout = kTimeoutInSeconds) = 0;
|
||||
|
||||
|
|
|
@ -20,7 +20,8 @@ namespace mindspore {
|
|||
namespace ps {
|
||||
namespace core {
|
||||
void NodeManager::InitNodeNum() {
|
||||
total_node_num_ = ClusterMetadata::instance()->total_server_num() + ClusterMetadata::instance()->total_worker_num();
|
||||
total_node_num_ = PSContext::instance()->cluster_config().initial_server_num +
|
||||
PSContext::instance()->cluster_config().initial_worker_num;
|
||||
}
|
||||
|
||||
int NodeManager::NextRankId(const RegisterMessage ®ister_message) {
|
||||
|
@ -39,7 +40,7 @@ int NodeManager::NextRankId(const RegisterMessage ®ister_message) {
|
|||
uint32_t port = register_message.port();
|
||||
|
||||
rank_id = ++next_server_rank_id_;
|
||||
if (IntToUint(rank_id) >= ClusterMetadata::instance()->total_server_num()) {
|
||||
if (IntToUint(rank_id) >= PSContext::instance()->cluster_config().initial_server_num) {
|
||||
MS_LOG(WARNING) << "The rank id is greater than the number of servers.";
|
||||
rank_id = -1;
|
||||
--next_server_rank_id_;
|
||||
|
@ -55,7 +56,7 @@ int NodeManager::NextRankId(const RegisterMessage ®ister_message) {
|
|||
<< " assign rank id:" << rank_id;
|
||||
} else if (register_message.role() == NodeRole::WORKER) {
|
||||
rank_id = ++next_worker_rank_id_;
|
||||
if (IntToUint(rank_id) >= ClusterMetadata::instance()->total_worker_num()) {
|
||||
if (IntToUint(rank_id) >= PSContext::instance()->cluster_config().initial_worker_num) {
|
||||
MS_LOG(WARNING) << "The rank id is greater than the number of workers.";
|
||||
rank_id = -1;
|
||||
--next_worker_rank_id_;
|
||||
|
@ -104,7 +105,7 @@ void NodeManager::UpdateClusterState() {
|
|||
(void)gettimeofday(¤t_time, nullptr);
|
||||
timeout_nodes_info_.clear();
|
||||
for (auto it = heartbeats_.begin(); it != heartbeats_.end(); ++it) {
|
||||
if (it->second.tv_sec + ClusterMetadata::instance()->heartbeat_timeout() < current_time.tv_sec) {
|
||||
if (it->second.tv_sec + PSContext::instance()->cluster_config().heartbeat_timeout < current_time.tv_sec) {
|
||||
MS_LOG(WARNING) << "The node id:" << it->first << " is timeout!";
|
||||
timeout_nodes_info_[it->first] = nodes_info_[it->first];
|
||||
}
|
||||
|
@ -130,7 +131,8 @@ void NodeManager::UpdateClusterState() {
|
|||
|
||||
void NodeManager::CheckClusterTimeout() {
|
||||
if (total_node_num_ != nodes_info_.size()) {
|
||||
MS_LOG(WARNING) << "The cluster is not ready after " << ClusterMetadata::instance()->cluster_available_timeout()
|
||||
MS_LOG(WARNING) << "The cluster is not ready after "
|
||||
<< PSContext::instance()->cluster_config().cluster_available_timeout
|
||||
<< " seconds,so finish the cluster, and change total node number from " << total_node_num_ << " to "
|
||||
<< nodes_info_.size();
|
||||
current_node_num_ = nodes_info_.size();
|
||||
|
|
|
@ -87,8 +87,8 @@ void SchedulerNode::InitCommandHandler() {
|
|||
void SchedulerNode::CreateTcpServer() {
|
||||
node_manager_.InitNodeNum();
|
||||
|
||||
std::string scheduler_host = ClusterMetadata::instance()->scheduler_host();
|
||||
uint32_t scheduler_port = ClusterMetadata::instance()->scheduler_port();
|
||||
std::string scheduler_host = PSContext::instance()->cluster_config().scheduler_host;
|
||||
uint32_t scheduler_port = PSContext::instance()->cluster_config().scheduler_port;
|
||||
server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port);
|
||||
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
|
@ -168,19 +168,20 @@ void SchedulerNode::StartUpdateClusterStateTimer() {
|
|||
// 1. update cluster timeout
|
||||
if (!node_manager_.is_cluster_ready() &&
|
||||
(std::chrono::steady_clock::now() - start_time >
|
||||
std::chrono::seconds(ClusterMetadata::instance()->cluster_available_timeout()))) {
|
||||
std::chrono::seconds(PSContext::instance()->cluster_config().cluster_available_timeout))) {
|
||||
node_manager_.CheckClusterTimeout();
|
||||
}
|
||||
|
||||
// 2. update cluster state
|
||||
std::this_thread::sleep_for(std::chrono::seconds(ClusterMetadata::instance()->heartbeat_interval()));
|
||||
std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval));
|
||||
node_manager_.UpdateClusterState();
|
||||
if (node_manager_.is_cluster_ready()) {
|
||||
is_ready_ = true;
|
||||
wait_start_cond_.notify_all();
|
||||
}
|
||||
if (node_manager_.is_cluster_finish()) {
|
||||
std::this_thread::sleep_for(std::chrono::seconds(ClusterMetadata::instance()->heartbeat_interval() * 2));
|
||||
std::this_thread::sleep_for(
|
||||
std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval * 2));
|
||||
is_finish_ = true;
|
||||
wait_finish_cond_.notify_all();
|
||||
}
|
||||
|
|
|
@ -28,6 +28,8 @@
|
|||
#include <unordered_map>
|
||||
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "ps/core/communicator/tcp_client.h"
|
||||
#include "ps/core/communicator/tcp_server.h"
|
||||
#include "ps/core/node_manager.h"
|
||||
|
@ -44,7 +46,7 @@ class SchedulerNode : public Node {
|
|||
typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
|
||||
bool Start(const uint32_t &timeout = ClusterMetadata::instance()->cluster_available_timeout()) override;
|
||||
bool Start(const uint32_t &timeout = PSContext::instance()->cluster_config().cluster_available_timeout) override;
|
||||
bool Stop() override;
|
||||
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
|
||||
|
||||
|
|
|
@ -147,7 +147,6 @@ std::shared_ptr<CommunicatorBase> ServerNode::GetOrCreateTcpComm(const std::stri
|
|||
MS_LOG(INFO) << "Create Tcp communicator.";
|
||||
auto tcp_comm = std::make_shared<TcpCommunicator>(task_executor, this);
|
||||
MS_EXCEPTION_IF_NULL(tcp_comm);
|
||||
ClusterMetadata::instance()->Init(worker_num, server_num, scheduler_ip, scheduler_port);
|
||||
MS_LOG(INFO) << "Initialize cluster metadata for server. Worker number:" << worker_num
|
||||
<< ", Server number:" << server_num << ", Scheduler ip:" << scheduler_ip
|
||||
<< ", Scheduler port:" << scheduler_port;
|
||||
|
|
|
@ -27,6 +27,8 @@
|
|||
#include <unordered_map>
|
||||
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "ps/core/communicator/tcp_client.h"
|
||||
#include "ps/core/communicator/tcp_server.h"
|
||||
#include "ps/core/abstract_node.h"
|
||||
|
@ -45,7 +47,7 @@ class ServerNode : public AbstractNode {
|
|||
|
||||
~ServerNode() override = default;
|
||||
|
||||
bool Start(const uint32_t &timeout = ClusterMetadata::instance()->cluster_available_timeout()) override;
|
||||
bool Start(const uint32_t &timeout = PSContext::instance()->cluster_config().cluster_available_timeout) override;
|
||||
bool Stop() override;
|
||||
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
|
||||
|
||||
|
|
|
@ -25,6 +25,8 @@
|
|||
#include <algorithm>
|
||||
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "ps/core/communicator/tcp_client.h"
|
||||
#include "ps/core/communicator/tcp_server.h"
|
||||
#include "ps/core/abstract_node.h"
|
||||
|
@ -37,7 +39,7 @@ class WorkerNode : public AbstractNode {
|
|||
WorkerNode() = default;
|
||||
~WorkerNode() override = default;
|
||||
|
||||
bool Start(const uint32_t &timeout = ClusterMetadata::instance()->cluster_available_timeout()) override;
|
||||
bool Start(const uint32_t &timeout = PSContext::instance()->cluster_config().cluster_available_timeout) override;
|
||||
bool Stop() override;
|
||||
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ void PSContext::SetPSEnable(bool enabled) {
|
|||
server_num_ = std::strtol(common::GetEnv(kEnvPServerNum).c_str(), nullptr, 10);
|
||||
scheduler_host_ = common::GetEnv(kEnvSchedulerHost);
|
||||
scheduler_port_ = std::strtol(common::GetEnv(kEnvSchedulerPort).c_str(), nullptr, 10);
|
||||
core::ClusterMetadata::instance()->Init(worker_num_, server_num_, scheduler_host_, scheduler_port_);
|
||||
cluster_config_.Init(worker_num_, server_num_, scheduler_host_, scheduler_port_);
|
||||
} else {
|
||||
MS_LOG(INFO) << "PS mode is disabled.";
|
||||
is_worker_ = false;
|
||||
|
@ -311,5 +311,7 @@ bool PSContext::secure_aggregation() const { return secure_aggregation_; }
|
|||
bool PSContext::enable_ssl() const { return enable_ssl_; }
|
||||
|
||||
void PSContext::set_enable_ssl(bool enabled) { enable_ssl_ = enabled; }
|
||||
|
||||
core::ClusterConfig &PSContext::cluster_config() { return cluster_config_; }
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <memory>
|
||||
#include "ps/constants.h"
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -147,6 +148,8 @@ class PSContext {
|
|||
void set_secure_aggregation(bool secure_aggregation);
|
||||
bool secure_aggregation() const;
|
||||
|
||||
core::ClusterConfig &cluster_config();
|
||||
|
||||
private:
|
||||
PSContext()
|
||||
: ps_enabled_(false),
|
||||
|
@ -229,6 +232,9 @@ class PSContext {
|
|||
|
||||
// Whether to use secure aggregation algorithm. Used in federated learning for now.
|
||||
bool secure_aggregation_;
|
||||
|
||||
// The cluster config read through environment variables, the value does not change.
|
||||
core::ClusterConfig cluster_config_;
|
||||
};
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,9 +20,6 @@ namespace mindspore {
|
|||
namespace ps {
|
||||
void Scheduler::Run() {
|
||||
MS_LOG(INFO) << "Start scheduler.";
|
||||
core::ClusterMetadata::instance()->Init(
|
||||
PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(),
|
||||
PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port());
|
||||
scheduler_node_.Start();
|
||||
scheduler_node_.Finish();
|
||||
scheduler_node_.Stop();
|
||||
|
|
|
@ -31,12 +31,9 @@ class TestClusterAvailableTimeout : public UT::Common {
|
|||
};
|
||||
|
||||
TEST_F(TestClusterAvailableTimeout, TestClusterAvailableTimeout) {
|
||||
ClusterMetadata::instance()->Init(1, 1, "127.0.0.1", 9999);
|
||||
ClusterMetadata::instance()->set_cluster_available_timeout(3);
|
||||
PSContext::instance()->cluster_config().Init(1, 1, "127.0.0.1", 9999);
|
||||
MS_LOG(INFO) << "The timeout is:" << PSContext::instance()->cluster_config().cluster_available_timeout;
|
||||
SchedulerNode node;
|
||||
node.Start();
|
||||
node.Finish();
|
||||
node.Stop();
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -19,26 +19,24 @@
|
|||
|
||||
#include "common/common_test.h"
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "ps/ps_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
class TestClusterMetadata : public UT::Common {
|
||||
class TestClusterConfig : public UT::Common {
|
||||
public:
|
||||
TestClusterMetadata() = default;
|
||||
virtual ~TestClusterMetadata() = default;
|
||||
TestClusterConfig() = default;
|
||||
virtual ~TestClusterConfig() = default;
|
||||
|
||||
void SetUp() override {}
|
||||
void TearDown() override {}
|
||||
};
|
||||
|
||||
TEST_F(TestClusterMetadata, HeartbeatInterval) {
|
||||
ClusterMetadata::instance()->Init(2, 2, "127.0.0.1", 8080);
|
||||
EXPECT_TRUE(ClusterMetadata::instance()->heartbeat_interval() == 3);
|
||||
ClusterMetadata::instance()->set_heartbeat_interval(100);
|
||||
EXPECT_TRUE(ClusterMetadata::instance()->heartbeat_interval() == 100);
|
||||
EXPECT_STREQ(ClusterMetadata::instance()->scheduler_host().c_str(), "127.0.0.1");
|
||||
EXPECT_TRUE(ClusterMetadata::instance()->scheduler_port() == 8080);
|
||||
TEST_F(TestClusterConfig, HeartbeatInterval) {
|
||||
PSContext::instance()->cluster_config().Init(2, 2, "127.0.0.1", 8080);
|
||||
PSContext::instance()->cluster_config().heartbeat_interval = 100;
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -53,7 +53,7 @@ TEST_F(TestCommUtil, GetAvailableInterfaceAndIP) {
|
|||
}
|
||||
|
||||
TEST_F(TestCommUtil, ValidateRankId) {
|
||||
ClusterMetadata::instance()->Init(3, 2, "127.0.0.1", 9999);
|
||||
PSContext::instance()->cluster_config().Init(3, 2, "127.0.0.1", 9999);
|
||||
EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::WORKER, 2));
|
||||
EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::WORKER, 3));
|
||||
EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1));
|
||||
|
|
Loading…
Reference in New Issue