!17454 added scale out/in

From: @anancds
Reviewed-by: @cristoval,@limingqi107
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-06-01 17:47:32 +08:00 committed by Gitee
commit 78ab4b00ec
8 changed files with 273 additions and 14 deletions

View File

@ -77,6 +77,7 @@ constexpr int64_t kInvalidID = -1;
constexpr uint32_t kMaxMessageSize = static_cast<uint32_t>(100 * (uint32_t(1) << 20));
constexpr char kServerNum[] = "server_num";
constexpr char kWorkerNum[] = "worker_num";
constexpr char kNodesIds[] = "node_ids";
constexpr int64_t kSubmitTaskIntervalInMs = 1;
constexpr int64_t kMaxTaskNum = 10240;

View File

@ -103,7 +103,7 @@ void AbstractNode::set_ready_for_scale_in() {
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len,
int command, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(data);
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
@ -126,7 +126,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
MS_LOG(EXCEPTION) << "The number of rank ids, data and lens are not equal!";
}
for (size_t it = 0; it < rank_ids.size(); ++it) {
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) {
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it), worker_num_, server_num_)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
@ -151,7 +151,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
int command, VectorPtr *output, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(message);
MS_EXCEPTION_IF_NULL(output);
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
@ -201,7 +201,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
});
for (size_t it = 0; it < size; ++it) {
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) {
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it), worker_num_, server_num_)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
@ -226,7 +226,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(data);
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
@ -242,7 +242,7 @@ uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const
std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role,
const uint32_t &rank_id, VectorPtr *output) {
MS_EXCEPTION_IF_NULL(output);
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
@ -281,6 +281,8 @@ int32_t AbstractNode::worker_num() const { return worker_num_; }
int32_t AbstractNode::server_num() const { return server_num_; }
ClusterState AbstractNode::cluster_state() const { return current_cluster_state_; }
void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) {
MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_
@ -423,6 +425,12 @@ void AbstractNode::ProcessScaleOut(std::shared_ptr<TcpConnection> conn, std::sha
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
ScaleOutMessage scale_out_message;
scale_out_message.ParseFromArray(data, size);
worker_num_ = scale_out_message.worker_num();
server_num_ = scale_out_message.server_num();
server_->SendMessage(conn, meta, Protos::RAW, data, size);
on_node_event_message_(ClusterEvent::READY_FOR_SCALE_OUT);
current_cluster_state_ = ClusterState::CLUSTER_SCALE_OUT;
@ -434,6 +442,12 @@ void AbstractNode::ProcessScaleIn(std::shared_ptr<TcpConnection> conn, std::shar
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
ScaleInMessage scale_in_message;
scale_in_message.ParseFromArray(data, size);
worker_num_ = scale_in_message.worker_num();
server_num_ = scale_in_message.server_num();
server_->SendMessage(conn, meta, Protos::RAW, data, size);
on_node_event_message_(ClusterEvent::READY_FOR_SCALE_IN);
current_cluster_state_ = ClusterState::CLUSTER_SCALE_IN;

View File

@ -77,6 +77,8 @@ class AbstractNode : public Node {
int32_t worker_num() const;
int32_t server_num() const;
ClusterState cluster_state() const;
protected:
void Register(const std::shared_ptr<TcpClient> &client);
bool Heartbeat(const std::shared_ptr<TcpClient> &client);

View File

@ -121,11 +121,11 @@ std::string CommUtil::NodeRoleToString(const NodeRole &role) {
MS_LOG(EXCEPTION) << "The node role:" << role << " is illegal!";
}
}
bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id) {
if (node_role == NodeRole::SERVER && (rank_id > PSContext::instance()->cluster_config().initial_server_num - 1)) {
bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id, const int32_t &total_worker_num,
const int32_t &total_server_num) {
if (node_role == NodeRole::SERVER && (rank_id > IntToUint(total_server_num) - 1)) {
return false;
} else if (node_role == NodeRole::WORKER &&
(rank_id > PSContext::instance()->cluster_config().initial_worker_num - 1)) {
} else if (node_role == NodeRole::WORKER && (rank_id > IntToUint(total_worker_num) - 1)) {
return false;
}
return true;

View File

@ -53,6 +53,7 @@
#include "ps/core/cluster_config.h"
#include "utils/log_adapter.h"
#include "ps/ps_context.h"
#include "utils/convert_utils_base.h"
namespace mindspore {
namespace ps {
@ -76,7 +77,8 @@ class CommUtil {
static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip);
static std::string GenerateUUID();
static std::string NodeRoleToString(const NodeRole &role);
static bool ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id);
static bool ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id, const int32_t &total_worker_num,
const int32_t &total_server_num);
static bool Retry(const std::function<bool()> &func, size_t max_attempts, size_t interval_milliseconds);
static void LogCallback(int severity, const char *msg);

View File

@ -26,6 +26,11 @@ SchedulerNode::~SchedulerNode() {
bool SchedulerNode::Start(const uint32_t &timeout) {
MS_LOG(INFO) << "Start scheduler node!";
if (PSContext::instance()->scheduler_manage_port() != 0) {
MS_LOG(WARNING) << "Start the scheduler http service, the ip:" << PSContext::instance()->scheduler_ip()
<< ", the port:" << PSContext::instance()->scheduler_manage_port();
StartRestfulServer(PSContext::instance()->scheduler_ip(), PSContext::instance()->scheduler_manage_port(), 1);
}
Initialize();
StartUpdateClusterStateTimer();
if (!WaitForStart(timeout)) {
@ -33,6 +38,7 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
return false;
}
MS_LOG(INFO) << "Start the scheduler node is successful!";
return true;
}
@ -61,6 +67,7 @@ void SchedulerNode::Initialize() {
is_already_stopped_ = false;
node_info_.node_id_ = CommUtil::GenerateUUID();
node_info_.node_role_ = NodeRole::SCHEDULER;
leader_scaler_ = std::make_unique<LeaderScaler>(this);
MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_;
}
@ -126,8 +133,9 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
for (const auto &kvs : node_infos) {
auto client = GetOrCreateClient(kvs.second);
SendMetadata(client);
MS_LOG(INFO) << "Send meta data to" << kvs.first;
}
current_cluster_state_ = ClusterState::CLUSTER_READY;
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
wait_start_cond_.notify_all();
}
}
@ -149,7 +157,7 @@ void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared
SendFinish(client);
}
is_finish_ = true;
current_cluster_state_ = ClusterState::CLUSTER_FINISH;
node_manager_.UpdateClusterState(ClusterState::CLUSTER_FINISH);
wait_finish_cond_.notify_all();
}
}
@ -266,6 +274,11 @@ bool SchedulerNode::Stop() {
client_thread_->join();
is_ready_ = true;
}
if (PSContext::instance()->scheduler_manage_port() != 0) {
MS_LOG(WARNING) << "Stop the scheduler http service, the ip:" << PSContext::instance()->scheduler_ip()
<< ", the port:" << PSContext::instance()->scheduler_manage_port();
StopRestfulServer();
}
return true;
}
@ -280,6 +293,209 @@ bool SchedulerNode::Finish(const uint32_t &timeout) {
});
return true;
}
void SchedulerNode::ProcessScaleOut(std::shared_ptr<HttpMessageHandler> resp) {
RequestProcessResult status(RequestProcessResultCode::kSuccess);
status = resp->ParsePostMessageToJson();
if (status != RequestProcessResultCode::kSuccess) {
resp->ErrorResponse(HTTP_BADREQUEST, status);
return;
}
int32_t scale_worker_num = 0;
status = resp->ParseValueFromKey(kWorkerNum, &scale_worker_num);
if (status != RequestProcessResultCode::kSuccess) {
resp->ErrorResponse(HTTP_BADREQUEST, status);
return;
}
int32_t scale_server_num = 0;
status = resp->ParseValueFromKey(kServerNum, &scale_server_num);
if (status != RequestProcessResultCode::kSuccess) {
resp->ErrorResponse(HTTP_BADREQUEST, status);
return;
}
status = CheckIfClusterReady();
if (status != RequestProcessResultCode::kSuccess) {
resp->ErrorResponse(HTTP_BADREQUEST, status);
return;
}
int32_t total_worker_num = scale_worker_num + node_manager_.worker_num();
int32_t total_server_num = scale_server_num + node_manager_.server_num();
node_manager_.set_worker_num(total_worker_num);
node_manager_.set_server_num(total_server_num);
node_manager_.set_total_node_num(total_worker_num + total_server_num);
node_manager_.UpdateClusterState(ClusterState::CLUSTER_SCALE_OUT);
auto node_infos = node_manager_.nodes_info();
for (const auto &kvs : node_infos) {
auto client = GetOrCreateClient(kvs.second);
leader_scaler_->ScaleOutAsync(client, node_manager_);
}
node_manager_.ResetMetadata();
MS_LOG(INFO) << "Scheduler send scale out successful.";
nlohmann::json js;
js["message"] = "Cluster begin to scale out.";
resp->AddRespString(js.dump());
resp->SetRespCode(HTTP_OK);
resp->SendResponse();
}
/*
* The body format is:
* {
* "node_ids": [
* {
* "node_id": "423ljjfslkj5",
* "rank_id": "0",
* "role": "SERVER"
* },
* {
* "node_id": "jklj3424kljj",
* "rank_id": "1",
* "role": "WORKER"
* }
* ]
* }
*/
void SchedulerNode::ProcessScaleIn(std::shared_ptr<HttpMessageHandler> resp) {
RequestProcessResult status(RequestProcessResultCode::kSuccess);
status = resp->ParsePostMessageToJson();
if (status != RequestProcessResultCode::kSuccess) {
resp->ErrorResponse(HTTP_BADREQUEST, status);
}
status = CheckIfClusterReady();
if (status != RequestProcessResultCode::kSuccess) {
resp->ErrorResponse(HTTP_BADREQUEST, status);
return;
}
std::vector<std::string> node_ids;
status = resp->ParseNodeIdsFromKey(kNodesIds, &node_ids);
if (status != RequestProcessResultCode::kSuccess) {
resp->ErrorResponse(HTTP_BADREQUEST, status);
return;
}
int32_t scale_worker_num = 0;
int32_t scale_server_num = 0;
auto node_infos = node_manager_.nodes_info();
for (auto const &val : node_ids) {
if (node_infos.count(val)) {
NodeInfo info = node_infos[val];
if (info.node_role_ == NodeRole::WORKER) {
scale_worker_num++;
} else if (info.node_role_ == NodeRole::SERVER) {
scale_server_num++;
}
}
}
MS_LOG(INFO) << "The scale worker num:" << scale_worker_num << ", the scale server num:" << scale_server_num;
int32_t total_worker_num = node_manager_.worker_num() - scale_worker_num;
int32_t total_server_num = node_manager_.server_num() - scale_server_num;
node_manager_.set_worker_num(total_worker_num);
node_manager_.set_server_num(total_server_num);
node_manager_.set_total_node_num(total_worker_num + total_server_num);
node_manager_.UpdateClusterState(ClusterState::CLUSTER_SCALE_IN);
for (const auto &kvs : node_infos) {
auto client = GetOrCreateClient(kvs.second);
leader_scaler_->ScaleInAsync(client, node_manager_);
}
node_manager_.ResetMetadata();
nlohmann::json js;
js["message"] = "Cluster begin to scale in.";
resp->AddRespString(js.dump());
resp->SetRespCode(HTTP_OK);
resp->SendResponse();
}
/*
* The return body format is:
* {
* "message": "Get nodes info successful.",
* "node_ids": [
* {
* "node_id": "423ljjfslkj5",
* "rank_id": "0",
* "role": "SERVER"
* },
* {
* "node_id": "jklj3424kljj",
* "rank_id": "1",
* "role": "WORKER"
* }
* ]
* }
*/
void SchedulerNode::ProcessGetNodesInfo(std::shared_ptr<HttpMessageHandler> resp) {
RequestProcessResult status(RequestProcessResultCode::kSuccess);
nlohmann::json js;
js["message"] = "Get nodes info successful.";
auto node_infos = node_manager_.nodes_info();
for (const auto &kvs : node_infos) {
std::unordered_map<std::string, std::string> res;
res["node_id"] = kvs.second.node_id_;
res["rank_id"] = kvs.second.rank_id_;
res["role"] = CommUtil::NodeRoleToString(kvs.second.node_role_);
js["node_ids"].push_back(res);
}
resp->AddRespString(js.dump());
resp->SetRespCode(HTTP_OK);
resp->SendResponse();
}
RequestProcessResult SchedulerNode::CheckIfClusterReady() {
RequestProcessResult result(RequestProcessResultCode::kSuccess);
if (node_manager_.GetClusterState() != ClusterState::CLUSTER_READY) {
std::string message = "The cluster is not ready.";
ERROR_STATUS(result, RequestProcessResultCode::kSystemError, message);
return result;
}
return result;
}
void SchedulerNode::StartRestfulServer(const std::string &address, std::uint16_t port, size_t thread_num) {
MS_LOG(INFO) << "Scheduler start https server.";
http_server_ = std::make_shared<HttpServer>(address, port, thread_num);
OnRequestReceive scale_out = std::bind(&SchedulerNode::ProcessScaleOut, this, std::placeholders::_1);
callbacks_["/scaleout"] = scale_out;
http_server_->RegisterRoute("/scaleout", &callbacks_["/scaleout"]);
OnRequestReceive scale_in = std::bind(&SchedulerNode::ProcessScaleIn, this, std::placeholders::_1);
callbacks_["/scalein"] = scale_in;
http_server_->RegisterRoute("/scalein", &callbacks_["/scalein"]);
OnRequestReceive nodes = std::bind(&SchedulerNode::ProcessGetNodesInfo, this, std::placeholders::_1);
callbacks_["/nodes"] = nodes;
http_server_->RegisterRoute("/nodes", &callbacks_["/nodes"]);
http_server_->InitServer();
http_server_->Start();
restful_thread_ = std::make_unique<std::thread>([&]() { http_server_->Wait(); });
}
void SchedulerNode::StopRestfulServer() {
MS_LOG(INFO) << "Scheduler stop https server.";
http_server_->Stop();
if (restful_thread_->joinable()) {
restful_thread_->join();
}
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -38,6 +38,7 @@
#include "ps/constants.h"
#include "ps/core/cluster_metadata.h"
#include "ps/core/communicator/http_server.h"
#include "ps/core/leader_scaler.h"
namespace mindspore {
namespace ps {
@ -51,7 +52,8 @@ class SchedulerNode : public Node {
restful_thread_(nullptr),
http_server_(nullptr),
client_thread_(nullptr),
is_client_started_(false) {}
is_client_started_(false),
leader_scaler_(nullptr) {}
~SchedulerNode() override;
typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
@ -82,6 +84,21 @@ class SchedulerNode : public Node {
// // After scheduler collects all finish message, it actively sends finish message to workers and servers.
void SendFinish(const std::shared_ptr<TcpClient> &client);
// Handle the scale out http request, then delegate to the leader scaler to process scale out asynchronously.
void ProcessScaleOut(std::shared_ptr<HttpMessageHandler> resp);
// Handle the scale in http request, then delegate to the leader scaler to process scale in asynchronously.
void ProcessScaleIn(std::shared_ptr<HttpMessageHandler> resp);
// Handle the get nodes info http request Synchronously.
void ProcessGetNodesInfo(std::shared_ptr<HttpMessageHandler> resp);
// check whether the cluster is in the ready state.
RequestProcessResult CheckIfClusterReady();
void StartRestfulServer(const std::string &address, std::uint16_t port, size_t thread_num = 10);
void StopRestfulServer();
std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> scheduler_thread_;
std::unique_ptr<std::thread> update_state_thread_;
@ -97,6 +114,10 @@ class SchedulerNode : public Node {
std::unique_ptr<std::thread> client_thread_;
std::atomic<bool> is_client_started_;
std::unique_ptr<LeaderScaler> leader_scaler_;
std::unordered_map<std::string, OnRequestReceive> callbacks_;
};
} // namespace core
} // namespace ps

View File

@ -41,6 +41,7 @@ parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -60,6 +61,7 @@ fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num
client_batch_size = args.client_batch_size
client_learning_rate = args.client_learning_rate
scheduler_manage_port = args.scheduler_manage_port
ctx = {
"enable_fl": True,
@ -79,6 +81,7 @@ ctx = {
"client_epoch_num": client_epoch_num,
"client_batch_size": client_batch_size,
"client_learning_rate": client_learning_rate,
"scheduler_manage_port": scheduler_manage_port
}
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)