!17052 added scale out

From: @anancds
Reviewed-by: @cristoval,@limingqi107
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-05-28 15:05:39 +08:00 committed by Gitee
commit 40ca285ab3
15 changed files with 145 additions and 38 deletions

View File

@ -359,6 +359,8 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_client_batch_size", &PSContext::set_client_batch_size, "Set federated learning client batch size.")
.def("set_client_learning_rate", &PSContext::set_client_learning_rate,
"Set federated learning client learning rate.")
.def("set_scheduler_manage_port", &PSContext::set_scheduler_manage_port,
"Set scheduler manage port used to scale out/in.")
.def("set_enable_ssl", &PSContext::enable_ssl, "Set PS SSL mode enabled or disabled.");
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")

View File

@ -35,6 +35,7 @@ constexpr char kEnvPServerNum[] = "MS_SERVER_NUM";
constexpr char kEnvWorkerNum[] = "MS_WORKER_NUM";
constexpr char kEnvSchedulerHost[] = "MS_SCHED_HOST";
constexpr char kEnvSchedulerPort[] = "MS_SCHED_PORT";
constexpr char kEnvSchedulerManagePort[] = "MS_SCHED_MANAGE_PORT";
constexpr char kCommTypeOfIBVerbs[] = "ibverbs";
constexpr char kRoleOfPServer[] = "server";

View File

@ -90,6 +90,16 @@ bool AbstractNode::Broadcast(const enum NodeRole &node_role, const DataPtr &mess
void AbstractNode::set_event_callback(const OnNodeEventMessage &event) { on_node_event_message_ = event; }
void AbstractNode::set_ready_for_scale_out() {
Register(client_to_scheduler_);
connected_nodes_.clear();
}
void AbstractNode::set_ready_for_scale_in() {
Register(client_to_scheduler_);
connected_nodes_.clear();
}
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len,
int command, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(data);
@ -267,6 +277,10 @@ bool AbstractNode::CollectiveWait(std::pair<uint32_t, uint64_t> request_id, cons
return res;
}
int32_t AbstractNode::worker_num() const { return worker_num_; }
int32_t AbstractNode::server_num() const { return server_num_; }
void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) {
MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_
@ -375,7 +389,6 @@ void AbstractNode::ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std:
SendMetadataMessage send_meta_message;
send_meta_message.ParseFromArray(data, size);
nodes_address_.clear();
MS_LOG(ERROR) << "send metadata size:" << send_meta_message.servers_meta().size();
for (const auto &it : send_meta_message.servers_meta()) {
nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port());
MS_LOG(INFO) << "The server ip is:" << it.ip() << ", the port is:" << it.port();
@ -383,6 +396,16 @@ void AbstractNode::ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std:
server_->SendMessage(conn, meta, Protos::RAW, data, size);
is_ready_ = true;
wait_start_cond_.notify_all();
if (current_cluster_state_ == ClusterState::CLUSTER_SCALE_OUT) {
MS_LOG(WARNING) << "Trigger cluster scale out done event.";
on_node_event_message_(ClusterEvent::CLUSTER_SCALE_OUT_DONE);
}
if (current_cluster_state_ == ClusterState::CLUSTER_SCALE_IN) {
MS_LOG(WARNING) << "Trigger cluster scale in done event.";
on_node_event_message_(ClusterEvent::CLUSTER_SCALE_IN_DONE);
}
current_cluster_state_ = ClusterState::CLUSTER_READY;
}
void AbstractNode::ProcessFinish(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
@ -395,6 +418,28 @@ void AbstractNode::ProcessFinish(std::shared_ptr<TcpConnection> conn, std::share
wait_finish_cond_.notify_all();
}
void AbstractNode::ProcessScaleOut(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
server_->SendMessage(conn, meta, Protos::RAW, data, size);
on_node_event_message_(ClusterEvent::READY_FOR_SCALE_OUT);
current_cluster_state_ = ClusterState::CLUSTER_SCALE_OUT;
is_ready_ = false;
}
void AbstractNode::ProcessScaleIn(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
server_->SendMessage(conn, meta, Protos::RAW, data, size);
on_node_event_message_(ClusterEvent::READY_FOR_SCALE_IN);
current_cluster_state_ = ClusterState::CLUSTER_SCALE_IN;
is_ready_ = false;
}
bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) {
auto meta = std::make_shared<MessageMeta>();
meta->set_cmd(NodeCommand::FINISH);
@ -613,9 +658,10 @@ void AbstractNode::InitServerHandler() {
server_handler_[NodeCommand::FINISH] = &AbstractNode::ProcessFinish;
server_handler_[NodeCommand::SEND_DATA] = nullptr;
server_handler_[NodeCommand::COLLECTIVE_SEND_DATA] = nullptr;
server_handler_[NodeCommand::SCALE_OUT] = &AbstractNode::ProcessScaleOut;
}
void AbstractNode::InitNode(const NodeRole &role) {
void AbstractNode::InitNodeInfo(const NodeRole &role) {
node_info_.node_id_ = CommUtil::GenerateUUID();
node_info_.node_role_ = role;
node_info_.ip_ = server_->BoundIp();
@ -624,6 +670,11 @@ void AbstractNode::InitNode(const NodeRole &role) {
<< " is generate uuid is:" << node_info_.node_id_ << ", the ip:" << server_->BoundIp()
<< ", the port:" << server_->BoundPort();
}
void AbstractNode::InitNodeNum() {
worker_num_ = PSContext::instance()->cluster_config().initial_worker_num;
server_num_ = PSContext::instance()->cluster_config().initial_server_num;
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -38,7 +38,9 @@ class AbstractNode : public Node {
client_to_scheduler_thread_(nullptr),
client_to_scheduler_(nullptr),
server_(nullptr),
server_thread_(nullptr) {}
server_thread_(nullptr),
worker_num_(-1),
server_num_(-1) {}
~AbstractNode() override = default;
typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
@ -52,6 +54,10 @@ class AbstractNode : public Node {
const uint32_t &timeout = kCommTimeoutInSeconds);
void set_event_callback(const OnNodeEventMessage &event);
// When the business layer finish scale out, it should call this function
void set_ready_for_scale_out();
// When the business layer finish scale in, it should call this function
void set_ready_for_scale_in();
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command,
const uint32_t &timeout = kCommTimeoutInSeconds);
@ -68,6 +74,9 @@ class AbstractNode : public Node {
VectorPtr *output);
bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
int32_t worker_num() const;
int32_t server_num() const;
protected:
void Register(const std::shared_ptr<TcpClient> &client);
bool Heartbeat(const std::shared_ptr<TcpClient> &client);
@ -82,6 +91,12 @@ class AbstractNode : public Node {
void ProcessFinish(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size);
void ProcessScaleOut(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size);
void ProcessScaleIn(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size);
void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
void UpdateSchedulerTime();
bool CheckSchedulerTimeout() const;
@ -98,7 +113,11 @@ class AbstractNode : public Node {
uint64_t NextActualRankRequestId(const uint32_t &rank_id);
void InitCommandHandler();
void InitServerHandler();
void InitNode(const NodeRole &role);
// when initializing the node, should initializing the node info.
void InitNodeInfo(const NodeRole &role);
// Initialize worker num and server num by cluster config.
void InitNodeNum();
std::unique_ptr<std::thread> heart_beat_thread_;
std::unique_ptr<std::thread> client_to_scheduler_thread_;
@ -136,11 +155,12 @@ class AbstractNode : public Node {
std::unordered_map<NodeCommand, ResponseHandler> handlers_;
std::unordered_map<NodeCommand, ServerHandler> server_handler_;
std::unordered_map<ClusterEvent, bool> is_event_send_;
std::mutex is_event_send_mutex_;
// Workers and servers launch the server to process command: FINISH,SCALE_OUT,SCALE_IN,SEND_METADATA
std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> server_thread_;
int32_t worker_num_;
int32_t server_num_;
};
} // namespace core
} // namespace ps

View File

@ -47,11 +47,6 @@ class Node {
Node()
: is_ready_(false),
is_finish_(false),
is_node_ready_scale_out_(false),
is_node_ready_scale_in_(false),
is_cluster_ready_scale_out_(false),
is_cluster_ready_scale_in_(false),
update_local_servers_(false),
is_already_stopped_(true),
is_already_finished_(false),
next_request_id_(0),
@ -93,15 +88,6 @@ class Node {
std::atomic<bool> is_ready_;
std::atomic<bool> is_finish_;
std::atomic<bool> is_node_ready_scale_out_;
std::atomic<bool> is_node_ready_scale_in_;
std::atomic<bool> is_cluster_ready_scale_out_;
std::atomic<bool> is_cluster_ready_scale_in_;
// Determine whether to update the ip and port of the locally cached servers.
std::atomic<bool> update_local_servers_;
std::atomic<bool> is_already_stopped_;
std::atomic<bool> is_already_finished_;
std::atomic_uint64_t next_request_id_;
@ -117,6 +103,7 @@ class Node {
std::mutex message_tracker_mutex_;
std::condition_variable message_tracker_cond_;
// Worker and server receive the node state and cluster state from the scheduler.
NodeState current_node_state_;
ClusterState current_cluster_state_;
};

View File

@ -80,12 +80,9 @@ int NodeManager::NextRankId(const RegisterMessage &register_message) {
void NodeManager::UpdateHeartbeat(const std::string &node_id) {
std::lock_guard<std::mutex> lock(heartbeat_mutex_);
NodeInfo node_info = nodes_info_[node_id];
struct timeval current_time {};
(void)gettimeofday(&current_time, nullptr);
heartbeats_[node_id] = current_time;
MS_LOG(DEBUG) << "The node role: " << CommUtil::NodeRoleToString(node_info.node_role_) << ", the node id:" << node_id
<< ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec;
}
void NodeManager::UpdateNodeScaleInState(const std::string &node_id) { heartbeats_scale_in_nodes_.insert(node_id); }
@ -116,7 +113,9 @@ void NodeManager::UpdateCluster() {
for (auto it = heartbeats_.begin(); it != heartbeats_.end(); ++it) {
if (it->second.tv_sec + PSContext::instance()->cluster_config().heartbeat_timeout < current_time.tv_sec) {
MS_LOG(WARNING) << "The node id:" << it->first << " is timeout!";
timeout_nodes_info_[it->first] = nodes_info_[it->first];
if (nodes_info_.count(it->first)) {
timeout_nodes_info_[it->first] = nodes_info_[it->first];
}
}
}
if (!timeout_nodes_info_.empty()) {
@ -146,9 +145,9 @@ void NodeManager::CheckClusterTimeout() {
void NodeManager::AddFinishNode(const std::string &finish_message) { finish_nodes_id_.insert(finish_message); }
bool NodeManager::CheckRegisterNum() { return SizeToInt(nodes_info_.size()) == total_node_num_; }
bool NodeManager::IsAllNodesRegistered() { return SizeToInt(nodes_info_.size()) == total_node_num_; }
bool NodeManager::CheckFinishNum() { return SizeToInt(finish_nodes_id_.size()) == total_node_num_; }
bool NodeManager::IsAllNodesFinished() { return SizeToInt(finish_nodes_id_.size()) == total_node_num_; }
std::unordered_map<std::string, NodeInfo> &NodeManager::nodes_info() { return nodes_info_; }
@ -172,9 +171,24 @@ ClusterState NodeManager::GetClusterState() {
return cluster_state_;
}
void NodeManager::ResetMetadata() {
MS_LOG(WARNING) << "Reset metadata.";
nodes_info_.clear();
next_worker_rank_id_ = -1;
next_server_rank_id_ = -1;
}
void NodeManager::set_total_node_num(const int32_t &node_num) { total_node_num_ = node_num; }
const int32_t &NodeManager::total_node_num() { return total_node_num_; }
void NodeManager::set_worker_num(const int32_t &worker_num) { meta_data_->worker_num = worker_num; }
void NodeManager::set_server_num(const int32_t &server_num) { meta_data_->server_num = server_num; }
int32_t NodeManager::worker_num() { return UintToInt(meta_data_->worker_num); }
int32_t NodeManager::server_num() { return UintToInt(meta_data_->server_num); }
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -68,21 +68,29 @@ class NodeManager {
// When workers and servers registered to scheduler, the scheduler will collect the number of registered
// nodes and Determine whether the registered number of worker and server is equal to total_node_num_.
bool CheckRegisterNum();
bool IsAllNodesRegistered();
// When workers and servers send a finish message to the scheduler, the scheduler will collect the number of
// finish nodes and Determine whether the finished nodes are equal to total_node_num_.
bool CheckFinishNum();
bool IsAllNodesFinished();
std::unordered_map<std::string, NodeInfo> &nodes_info();
void set_total_node_num(const int32_t &node_num);
const int32_t &total_node_num();
void set_worker_num(const int32_t &worker_num);
void set_server_num(const int32_t &server_num);
int32_t worker_num();
int32_t server_num();
void UpdateNodeState(const NodeState &state);
void UpdateClusterState(const ClusterState &state);
NodeState GetNodeState();
ClusterState GetClusterState();
// When the scheduler receives the scale out or scale in message, the metadata needs to be reset, because all nodes
// will re-register.
void ResetMetadata();
private:
std::mutex node_mutex_;
std::mutex cluster_mutex_;
@ -107,6 +115,7 @@ class NodeManager {
std::unordered_map<std::string, NodeInfo> timeout_nodes_info_;
std::unordered_set<std::string> finish_nodes_id_;
// Cluster metadata information can be dynamically changed
std::unique_ptr<ClusterMetadata> meta_data_;
NodeState node_state_;

View File

@ -29,6 +29,8 @@ enum NodeCommand {
COLLECTIVE_SEND_DATA = 6;
// The scheduler actively sends metadata to the worker and server
SEND_METADATA = 7;
SCALE_OUT = 8;
SCALE_IN = 9;
}
enum NodeRole {
@ -120,3 +122,11 @@ message CommMessage {
MessageMeta pb_meta = 1;
bytes data = 2;
}
// The scheduler will broadcast the worker/server numbers after scale out/in to all nodes.
message ScaleOutMessage {
// the worker number after scale out/in
int32 worker_num = 1;
// the server number after scale out/in
int32 server_num = 2;
}

View File

@ -120,7 +120,7 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
server->SendMessage(conn, meta, Protos::PROTOBUF, register_resp_message.SerializeAsString().data(),
register_resp_message.ByteSizeLong());
if (node_manager_.CheckRegisterNum()) {
if (node_manager_.IsAllNodesRegistered()) {
is_ready_ = true;
auto node_infos = node_manager_.nodes_info();
for (const auto &kvs : node_infos) {
@ -142,7 +142,7 @@ void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared
node_manager_.AddFinishNode(*finish_message);
MS_LOG(INFO) << "Process finish message from node id:" << *finish_message;
server->SendMessage(conn, meta, Protos::PROTOBUF, data, size);
if (node_manager_.CheckFinishNum()) {
if (node_manager_.IsAllNodesFinished()) {
auto node_infos = node_manager_.nodes_info();
for (const auto &kvs : node_infos) {
auto client = GetOrCreateClient(kvs.second);
@ -177,8 +177,6 @@ void SchedulerNode::SendMetadata(const std::shared_ptr<TcpClient> &client) {
SendMetadataMessage send_metadata_message;
std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta();
MS_LOG(ERROR) << "the list size:" << servers_meta_list.size();
*send_metadata_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};
if (!SendMessageAsync(client, message_meta, Protos::PROTOBUF, send_metadata_message.SerializeAsString().data(),
@ -238,7 +236,6 @@ const std::shared_ptr<TcpClient> &SchedulerNode::GetOrCreateClient(const NodeInf
std::string ip = node_info.ip_;
uint16_t port = node_info.port_;
auto client = std::make_shared<TcpClient>(ip, port);
MS_LOG(ERROR) << "the ip:" << node_info.ip_ << ", the port:" << node_info.port_;
client->SetMessageCallback([&](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
size_t size) { NotifyMessageArrival(meta); });
client->Init();

View File

@ -89,6 +89,7 @@ class SchedulerNode : public Node {
NodeManager node_manager_;
// In this thread will start a http server.
std::unique_ptr<std::thread> restful_thread_;
std::shared_ptr<HttpServer> http_server_;

View File

@ -83,7 +83,8 @@ void ServerNode::Initialize() {
InitServerHandler();
CreateTcpServer();
is_already_stopped_ = false;
InitNode(NodeRole::SERVER);
InitNodeInfo(NodeRole::SERVER);
InitNodeNum();
InitCommandHandler();
if (!InitClientToScheduler()) {
MS_LOG(EXCEPTION) << "Server node init client timeout!";

View File

@ -41,7 +41,8 @@ void WorkerNode::Initialize() {
is_already_stopped_ = false;
InitServerHandler();
CreateTcpServer();
InitNode(NodeRole::WORKER);
InitNodeInfo(NodeRole::WORKER);
InitNodeNum();
InitCommandHandler();
if (!InitClientToScheduler()) {
MS_LOG(EXCEPTION) << "Worker node init client timeout!";

View File

@ -52,6 +52,7 @@ void PSContext::SetPSEnable(bool enabled) {
server_num_ = std::strtol(common::GetEnv(kEnvPServerNum).c_str(), nullptr, 10);
scheduler_host_ = common::GetEnv(kEnvSchedulerHost);
scheduler_port_ = std::strtol(common::GetEnv(kEnvSchedulerPort).c_str(), nullptr, 10);
scheduler_manage_port_ = std::strtol(common::GetEnv(kEnvSchedulerManagePort).c_str(), nullptr, 10);
cluster_config_ = std::make_unique<core::ClusterConfig>(worker_num_, server_num_, scheduler_host_, scheduler_port_);
} else {
MS_LOG(INFO) << "PS mode is disabled.";
@ -312,5 +313,9 @@ core::ClusterConfig &PSContext::cluster_config() {
}
return *cluster_config_;
}
void PSContext::set_scheduler_manage_port(uint16_t sched_port) { scheduler_manage_port_ = sched_port; }
uint16_t PSContext::scheduler_manage_port() const { return scheduler_manage_port_; }
} // namespace ps
} // namespace mindspore

View File

@ -145,6 +145,9 @@ class PSContext {
core::ClusterConfig &cluster_config();
void set_scheduler_manage_port(uint16_t sched_port);
uint16_t scheduler_manage_port() const;
private:
PSContext()
: ps_enabled_(false),
@ -172,7 +175,8 @@ class PSContext {
client_batch_size_(32),
client_learning_rate_(0.001),
secure_aggregation_(false),
cluster_config_(nullptr) {}
cluster_config_(nullptr),
scheduler_manage_port_(0) {}
bool ps_enabled_;
bool is_worker_;
bool is_pserver_;
@ -231,6 +235,9 @@ class PSContext {
// The cluster config read through environment variables, the value does not change.
std::unique_ptr<core::ClusterConfig> cluster_config_;
// The port used by scheduler to receive http requests for scale out or scale in.
uint16_t scheduler_manage_port_;
};
} // namespace ps
} // namespace mindspore

View File

@ -52,7 +52,8 @@ _set_ps_context_func_map = {
"client_epoch_num": ps_context().set_client_epoch_num,
"client_batch_size": ps_context().set_client_batch_size,
"client_learning_rate": ps_context().set_client_learning_rate,
"enable_ps_ssl": ps_context().set_enable_ssl
"enable_ps_ssl": ps_context().set_enable_ssl,
"scheduler_manage_port": ps_context().set_scheduler_manage_port
}
_get_ps_context_func_map = {