forked from mindspore-Ecosystem/mindspore
added scale out/in
This commit is contained in:
parent
3716d3654d
commit
00cbfdb5d9
|
@ -77,6 +77,7 @@ constexpr int64_t kInvalidID = -1;
|
|||
constexpr uint32_t kMaxMessageSize = static_cast<uint32_t>(100 * (uint32_t(1) << 20));
|
||||
constexpr char kServerNum[] = "server_num";
|
||||
constexpr char kWorkerNum[] = "worker_num";
|
||||
constexpr char kNodesIds[] = "node_ids";
|
||||
|
||||
constexpr int64_t kSubmitTaskIntervalInMs = 1;
|
||||
constexpr int64_t kMaxTaskNum = 10240;
|
||||
|
|
|
@ -103,7 +103,7 @@ void AbstractNode::set_ready_for_scale_in() {
|
|||
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len,
|
||||
int command, const uint32_t &timeout) {
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
|
@ -126,7 +126,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
|||
MS_LOG(EXCEPTION) << "The number of rank ids, data and lens are not equal!";
|
||||
}
|
||||
for (size_t it = 0; it < rank_ids.size(); ++it) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it), worker_num_, server_num_)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
|
@ -151,7 +151,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
|
|||
int command, VectorPtr *output, const uint32_t &timeout) {
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
|
@ -201,7 +201,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
|||
});
|
||||
|
||||
for (size_t it = 0; it < size; ++it) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it), worker_num_, server_num_)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
|
@ -226,7 +226,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
|||
uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const void *data,
|
||||
size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
|
@ -242,7 +242,7 @@ uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const
|
|||
std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role,
|
||||
const uint32_t &rank_id, VectorPtr *output) {
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
|
||||
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
|
||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
|
||||
}
|
||||
|
||||
|
@ -281,6 +281,8 @@ int32_t AbstractNode::worker_num() const { return worker_num_; }
|
|||
|
||||
int32_t AbstractNode::server_num() const { return server_num_; }
|
||||
|
||||
ClusterState AbstractNode::cluster_state() const { return current_cluster_state_; }
|
||||
|
||||
void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) {
|
||||
MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_
|
||||
|
@ -423,6 +425,12 @@ void AbstractNode::ProcessScaleOut(std::shared_ptr<TcpConnection> conn, std::sha
|
|||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
|
||||
ScaleOutMessage scale_out_message;
|
||||
scale_out_message.ParseFromArray(data, size);
|
||||
worker_num_ = scale_out_message.worker_num();
|
||||
server_num_ = scale_out_message.server_num();
|
||||
|
||||
server_->SendMessage(conn, meta, Protos::RAW, data, size);
|
||||
on_node_event_message_(ClusterEvent::READY_FOR_SCALE_OUT);
|
||||
current_cluster_state_ = ClusterState::CLUSTER_SCALE_OUT;
|
||||
|
@ -434,6 +442,12 @@ void AbstractNode::ProcessScaleIn(std::shared_ptr<TcpConnection> conn, std::shar
|
|||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
|
||||
ScaleInMessage scale_in_message;
|
||||
scale_in_message.ParseFromArray(data, size);
|
||||
worker_num_ = scale_in_message.worker_num();
|
||||
server_num_ = scale_in_message.server_num();
|
||||
|
||||
server_->SendMessage(conn, meta, Protos::RAW, data, size);
|
||||
on_node_event_message_(ClusterEvent::READY_FOR_SCALE_IN);
|
||||
current_cluster_state_ = ClusterState::CLUSTER_SCALE_IN;
|
||||
|
|
|
@ -77,6 +77,8 @@ class AbstractNode : public Node {
|
|||
int32_t worker_num() const;
|
||||
int32_t server_num() const;
|
||||
|
||||
ClusterState cluster_state() const;
|
||||
|
||||
protected:
|
||||
void Register(const std::shared_ptr<TcpClient> &client);
|
||||
bool Heartbeat(const std::shared_ptr<TcpClient> &client);
|
||||
|
|
|
@ -121,11 +121,11 @@ std::string CommUtil::NodeRoleToString(const NodeRole &role) {
|
|||
MS_LOG(EXCEPTION) << "The node role:" << role << " is illegal!";
|
||||
}
|
||||
}
|
||||
bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id) {
|
||||
if (node_role == NodeRole::SERVER && (rank_id > PSContext::instance()->cluster_config().initial_server_num - 1)) {
|
||||
bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id, const int32_t &total_worker_num,
|
||||
const int32_t &total_server_num) {
|
||||
if (node_role == NodeRole::SERVER && (rank_id > IntToUint(total_server_num) - 1)) {
|
||||
return false;
|
||||
} else if (node_role == NodeRole::WORKER &&
|
||||
(rank_id > PSContext::instance()->cluster_config().initial_worker_num - 1)) {
|
||||
} else if (node_role == NodeRole::WORKER && (rank_id > IntToUint(total_worker_num) - 1)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -53,6 +53,7 @@
|
|||
#include "ps/core/cluster_config.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -76,7 +77,8 @@ class CommUtil {
|
|||
static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip);
|
||||
static std::string GenerateUUID();
|
||||
static std::string NodeRoleToString(const NodeRole &role);
|
||||
static bool ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id);
|
||||
static bool ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id, const int32_t &total_worker_num,
|
||||
const int32_t &total_server_num);
|
||||
static bool Retry(const std::function<bool()> &func, size_t max_attempts, size_t interval_milliseconds);
|
||||
static void LogCallback(int severity, const char *msg);
|
||||
|
||||
|
|
|
@ -26,6 +26,11 @@ SchedulerNode::~SchedulerNode() {
|
|||
|
||||
bool SchedulerNode::Start(const uint32_t &timeout) {
|
||||
MS_LOG(INFO) << "Start scheduler node!";
|
||||
if (PSContext::instance()->scheduler_manage_port() != 0) {
|
||||
MS_LOG(WARNING) << "Start the scheduler http service, the ip:" << PSContext::instance()->scheduler_ip()
|
||||
<< ", the port:" << PSContext::instance()->scheduler_manage_port();
|
||||
StartRestfulServer(PSContext::instance()->scheduler_ip(), PSContext::instance()->scheduler_manage_port(), 1);
|
||||
}
|
||||
Initialize();
|
||||
StartUpdateClusterStateTimer();
|
||||
if (!WaitForStart(timeout)) {
|
||||
|
@ -33,6 +38,7 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
|
|||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "Start the scheduler node is successful!";
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -61,6 +67,7 @@ void SchedulerNode::Initialize() {
|
|||
is_already_stopped_ = false;
|
||||
node_info_.node_id_ = CommUtil::GenerateUUID();
|
||||
node_info_.node_role_ = NodeRole::SCHEDULER;
|
||||
leader_scaler_ = std::make_unique<LeaderScaler>(this);
|
||||
MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_;
|
||||
}
|
||||
|
@ -126,8 +133,9 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
|
|||
for (const auto &kvs : node_infos) {
|
||||
auto client = GetOrCreateClient(kvs.second);
|
||||
SendMetadata(client);
|
||||
MS_LOG(INFO) << "Send meta data to" << kvs.first;
|
||||
}
|
||||
current_cluster_state_ = ClusterState::CLUSTER_READY;
|
||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
|
||||
wait_start_cond_.notify_all();
|
||||
}
|
||||
}
|
||||
|
@ -149,7 +157,7 @@ void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared
|
|||
SendFinish(client);
|
||||
}
|
||||
is_finish_ = true;
|
||||
current_cluster_state_ = ClusterState::CLUSTER_FINISH;
|
||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_FINISH);
|
||||
wait_finish_cond_.notify_all();
|
||||
}
|
||||
}
|
||||
|
@ -266,6 +274,11 @@ bool SchedulerNode::Stop() {
|
|||
client_thread_->join();
|
||||
is_ready_ = true;
|
||||
}
|
||||
if (PSContext::instance()->scheduler_manage_port() != 0) {
|
||||
MS_LOG(WARNING) << "Stop the scheduler http service, the ip:" << PSContext::instance()->scheduler_ip()
|
||||
<< ", the port:" << PSContext::instance()->scheduler_manage_port();
|
||||
StopRestfulServer();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -280,6 +293,209 @@ bool SchedulerNode::Finish(const uint32_t &timeout) {
|
|||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessScaleOut(std::shared_ptr<HttpMessageHandler> resp) {
|
||||
RequestProcessResult status(RequestProcessResultCode::kSuccess);
|
||||
status = resp->ParsePostMessageToJson();
|
||||
if (status != RequestProcessResultCode::kSuccess) {
|
||||
resp->ErrorResponse(HTTP_BADREQUEST, status);
|
||||
return;
|
||||
}
|
||||
|
||||
int32_t scale_worker_num = 0;
|
||||
status = resp->ParseValueFromKey(kWorkerNum, &scale_worker_num);
|
||||
if (status != RequestProcessResultCode::kSuccess) {
|
||||
resp->ErrorResponse(HTTP_BADREQUEST, status);
|
||||
return;
|
||||
}
|
||||
|
||||
int32_t scale_server_num = 0;
|
||||
status = resp->ParseValueFromKey(kServerNum, &scale_server_num);
|
||||
if (status != RequestProcessResultCode::kSuccess) {
|
||||
resp->ErrorResponse(HTTP_BADREQUEST, status);
|
||||
return;
|
||||
}
|
||||
|
||||
status = CheckIfClusterReady();
|
||||
if (status != RequestProcessResultCode::kSuccess) {
|
||||
resp->ErrorResponse(HTTP_BADREQUEST, status);
|
||||
return;
|
||||
}
|
||||
|
||||
int32_t total_worker_num = scale_worker_num + node_manager_.worker_num();
|
||||
int32_t total_server_num = scale_server_num + node_manager_.server_num();
|
||||
|
||||
node_manager_.set_worker_num(total_worker_num);
|
||||
node_manager_.set_server_num(total_server_num);
|
||||
node_manager_.set_total_node_num(total_worker_num + total_server_num);
|
||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_SCALE_OUT);
|
||||
auto node_infos = node_manager_.nodes_info();
|
||||
for (const auto &kvs : node_infos) {
|
||||
auto client = GetOrCreateClient(kvs.second);
|
||||
leader_scaler_->ScaleOutAsync(client, node_manager_);
|
||||
}
|
||||
node_manager_.ResetMetadata();
|
||||
MS_LOG(INFO) << "Scheduler send scale out successful.";
|
||||
|
||||
nlohmann::json js;
|
||||
js["message"] = "Cluster begin to scale out.";
|
||||
resp->AddRespString(js.dump());
|
||||
|
||||
resp->SetRespCode(HTTP_OK);
|
||||
resp->SendResponse();
|
||||
}
|
||||
|
||||
/*
|
||||
* The body format is:
|
||||
* {
|
||||
* "node_ids": [
|
||||
* {
|
||||
* "node_id": "423ljjfslkj5",
|
||||
* "rank_id": "0",
|
||||
* "role": "SERVER"
|
||||
* },
|
||||
* {
|
||||
* "node_id": "jklj3424kljj",
|
||||
* "rank_id": "1",
|
||||
* "role": "WORKER"
|
||||
* }
|
||||
* ]
|
||||
* }
|
||||
*/
|
||||
void SchedulerNode::ProcessScaleIn(std::shared_ptr<HttpMessageHandler> resp) {
|
||||
RequestProcessResult status(RequestProcessResultCode::kSuccess);
|
||||
status = resp->ParsePostMessageToJson();
|
||||
if (status != RequestProcessResultCode::kSuccess) {
|
||||
resp->ErrorResponse(HTTP_BADREQUEST, status);
|
||||
}
|
||||
|
||||
status = CheckIfClusterReady();
|
||||
if (status != RequestProcessResultCode::kSuccess) {
|
||||
resp->ErrorResponse(HTTP_BADREQUEST, status);
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::string> node_ids;
|
||||
status = resp->ParseNodeIdsFromKey(kNodesIds, &node_ids);
|
||||
if (status != RequestProcessResultCode::kSuccess) {
|
||||
resp->ErrorResponse(HTTP_BADREQUEST, status);
|
||||
return;
|
||||
}
|
||||
|
||||
int32_t scale_worker_num = 0;
|
||||
int32_t scale_server_num = 0;
|
||||
auto node_infos = node_manager_.nodes_info();
|
||||
for (auto const &val : node_ids) {
|
||||
if (node_infos.count(val)) {
|
||||
NodeInfo info = node_infos[val];
|
||||
if (info.node_role_ == NodeRole::WORKER) {
|
||||
scale_worker_num++;
|
||||
} else if (info.node_role_ == NodeRole::SERVER) {
|
||||
scale_server_num++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "The scale worker num:" << scale_worker_num << ", the scale server num:" << scale_server_num;
|
||||
|
||||
int32_t total_worker_num = node_manager_.worker_num() - scale_worker_num;
|
||||
int32_t total_server_num = node_manager_.server_num() - scale_server_num;
|
||||
|
||||
node_manager_.set_worker_num(total_worker_num);
|
||||
node_manager_.set_server_num(total_server_num);
|
||||
node_manager_.set_total_node_num(total_worker_num + total_server_num);
|
||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_SCALE_IN);
|
||||
for (const auto &kvs : node_infos) {
|
||||
auto client = GetOrCreateClient(kvs.second);
|
||||
leader_scaler_->ScaleInAsync(client, node_manager_);
|
||||
}
|
||||
|
||||
node_manager_.ResetMetadata();
|
||||
nlohmann::json js;
|
||||
js["message"] = "Cluster begin to scale in.";
|
||||
resp->AddRespString(js.dump());
|
||||
|
||||
resp->SetRespCode(HTTP_OK);
|
||||
resp->SendResponse();
|
||||
}
|
||||
|
||||
/*
|
||||
* The return body format is:
|
||||
* {
|
||||
* "message": "Get nodes info successful.",
|
||||
* "node_ids": [
|
||||
* {
|
||||
* "node_id": "423ljjfslkj5",
|
||||
* "rank_id": "0",
|
||||
* "role": "SERVER"
|
||||
* },
|
||||
* {
|
||||
* "node_id": "jklj3424kljj",
|
||||
* "rank_id": "1",
|
||||
* "role": "WORKER"
|
||||
* }
|
||||
* ]
|
||||
* }
|
||||
*/
|
||||
void SchedulerNode::ProcessGetNodesInfo(std::shared_ptr<HttpMessageHandler> resp) {
|
||||
RequestProcessResult status(RequestProcessResultCode::kSuccess);
|
||||
|
||||
nlohmann::json js;
|
||||
js["message"] = "Get nodes info successful.";
|
||||
auto node_infos = node_manager_.nodes_info();
|
||||
for (const auto &kvs : node_infos) {
|
||||
std::unordered_map<std::string, std::string> res;
|
||||
res["node_id"] = kvs.second.node_id_;
|
||||
res["rank_id"] = kvs.second.rank_id_;
|
||||
res["role"] = CommUtil::NodeRoleToString(kvs.second.node_role_);
|
||||
js["node_ids"].push_back(res);
|
||||
}
|
||||
|
||||
resp->AddRespString(js.dump());
|
||||
|
||||
resp->SetRespCode(HTTP_OK);
|
||||
resp->SendResponse();
|
||||
}
|
||||
|
||||
RequestProcessResult SchedulerNode::CheckIfClusterReady() {
|
||||
RequestProcessResult result(RequestProcessResultCode::kSuccess);
|
||||
if (node_manager_.GetClusterState() != ClusterState::CLUSTER_READY) {
|
||||
std::string message = "The cluster is not ready.";
|
||||
ERROR_STATUS(result, RequestProcessResultCode::kSystemError, message);
|
||||
return result;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void SchedulerNode::StartRestfulServer(const std::string &address, std::uint16_t port, size_t thread_num) {
|
||||
MS_LOG(INFO) << "Scheduler start https server.";
|
||||
http_server_ = std::make_shared<HttpServer>(address, port, thread_num);
|
||||
|
||||
OnRequestReceive scale_out = std::bind(&SchedulerNode::ProcessScaleOut, this, std::placeholders::_1);
|
||||
callbacks_["/scaleout"] = scale_out;
|
||||
http_server_->RegisterRoute("/scaleout", &callbacks_["/scaleout"]);
|
||||
|
||||
OnRequestReceive scale_in = std::bind(&SchedulerNode::ProcessScaleIn, this, std::placeholders::_1);
|
||||
callbacks_["/scalein"] = scale_in;
|
||||
http_server_->RegisterRoute("/scalein", &callbacks_["/scalein"]);
|
||||
|
||||
OnRequestReceive nodes = std::bind(&SchedulerNode::ProcessGetNodesInfo, this, std::placeholders::_1);
|
||||
callbacks_["/nodes"] = nodes;
|
||||
http_server_->RegisterRoute("/nodes", &callbacks_["/nodes"]);
|
||||
|
||||
http_server_->InitServer();
|
||||
|
||||
http_server_->Start();
|
||||
restful_thread_ = std::make_unique<std::thread>([&]() { http_server_->Wait(); });
|
||||
}
|
||||
|
||||
void SchedulerNode::StopRestfulServer() {
|
||||
MS_LOG(INFO) << "Scheduler stop https server.";
|
||||
http_server_->Stop();
|
||||
if (restful_thread_->joinable()) {
|
||||
restful_thread_->join();
|
||||
}
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,6 +38,7 @@
|
|||
#include "ps/constants.h"
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
#include "ps/core/communicator/http_server.h"
|
||||
#include "ps/core/leader_scaler.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -51,7 +52,8 @@ class SchedulerNode : public Node {
|
|||
restful_thread_(nullptr),
|
||||
http_server_(nullptr),
|
||||
client_thread_(nullptr),
|
||||
is_client_started_(false) {}
|
||||
is_client_started_(false),
|
||||
leader_scaler_(nullptr) {}
|
||||
~SchedulerNode() override;
|
||||
|
||||
typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
|
@ -82,6 +84,21 @@ class SchedulerNode : public Node {
|
|||
// // After scheduler collects all finish message, it actively sends finish message to workers and servers.
|
||||
void SendFinish(const std::shared_ptr<TcpClient> &client);
|
||||
|
||||
// Handle the scale out http request, then delegate to the leader scaler to process scale out asynchronously.
|
||||
void ProcessScaleOut(std::shared_ptr<HttpMessageHandler> resp);
|
||||
|
||||
// Handle the scale in http request, then delegate to the leader scaler to process scale in asynchronously.
|
||||
void ProcessScaleIn(std::shared_ptr<HttpMessageHandler> resp);
|
||||
|
||||
// Handle the get nodes info http request Synchronously.
|
||||
void ProcessGetNodesInfo(std::shared_ptr<HttpMessageHandler> resp);
|
||||
|
||||
// check whether the cluster is in the ready state.
|
||||
RequestProcessResult CheckIfClusterReady();
|
||||
|
||||
void StartRestfulServer(const std::string &address, std::uint16_t port, size_t thread_num = 10);
|
||||
void StopRestfulServer();
|
||||
|
||||
std::shared_ptr<TcpServer> server_;
|
||||
std::unique_ptr<std::thread> scheduler_thread_;
|
||||
std::unique_ptr<std::thread> update_state_thread_;
|
||||
|
@ -97,6 +114,10 @@ class SchedulerNode : public Node {
|
|||
|
||||
std::unique_ptr<std::thread> client_thread_;
|
||||
std::atomic<bool> is_client_started_;
|
||||
|
||||
std::unique_ptr<LeaderScaler> leader_scaler_;
|
||||
|
||||
std::unordered_map<std::string, OnRequestReceive> callbacks_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -41,6 +41,7 @@ parser.add_argument("--fl_iteration_num", type=int, default=25)
|
|||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||
parser.add_argument("--client_learning_rate", type=float, default=0.1)
|
||||
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -60,6 +61,7 @@ fl_iteration_num = args.fl_iteration_num
|
|||
client_epoch_num = args.client_epoch_num
|
||||
client_batch_size = args.client_batch_size
|
||||
client_learning_rate = args.client_learning_rate
|
||||
scheduler_manage_port = args.scheduler_manage_port
|
||||
|
||||
ctx = {
|
||||
"enable_fl": True,
|
||||
|
@ -79,6 +81,7 @@ ctx = {
|
|||
"client_epoch_num": client_epoch_num,
|
||||
"client_batch_size": client_batch_size,
|
||||
"client_learning_rate": client_learning_rate,
|
||||
"scheduler_manage_port": scheduler_manage_port
|
||||
}
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)
|
||||
|
|
Loading…
Reference in New Issue