forked from mindspore-Ecosystem/mindspore
!17052 added scale out
From: @anancds Reviewed-by: @cristoval,@limingqi107 Signed-off-by:
This commit is contained in:
commit
40ca285ab3
|
@ -359,6 +359,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("set_client_batch_size", &PSContext::set_client_batch_size, "Set federated learning client batch size.")
|
||||
.def("set_client_learning_rate", &PSContext::set_client_learning_rate,
|
||||
"Set federated learning client learning rate.")
|
||||
.def("set_scheduler_manage_port", &PSContext::set_scheduler_manage_port,
|
||||
"Set scheduler manage port used to scale out/in.")
|
||||
.def("set_enable_ssl", &PSContext::enable_ssl, "Set PS SSL mode enabled or disabled.");
|
||||
|
||||
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
|
||||
|
|
|
@ -35,6 +35,7 @@ constexpr char kEnvPServerNum[] = "MS_SERVER_NUM";
|
|||
constexpr char kEnvWorkerNum[] = "MS_WORKER_NUM";
|
||||
constexpr char kEnvSchedulerHost[] = "MS_SCHED_HOST";
|
||||
constexpr char kEnvSchedulerPort[] = "MS_SCHED_PORT";
|
||||
constexpr char kEnvSchedulerManagePort[] = "MS_SCHED_MANAGE_PORT";
|
||||
|
||||
constexpr char kCommTypeOfIBVerbs[] = "ibverbs";
|
||||
constexpr char kRoleOfPServer[] = "server";
|
||||
|
|
|
@ -90,6 +90,16 @@ bool AbstractNode::Broadcast(const enum NodeRole &node_role, const DataPtr &mess
|
|||
|
||||
void AbstractNode::set_event_callback(const OnNodeEventMessage &event) { on_node_event_message_ = event; }
|
||||
|
||||
void AbstractNode::set_ready_for_scale_out() {
|
||||
Register(client_to_scheduler_);
|
||||
connected_nodes_.clear();
|
||||
}
|
||||
|
||||
void AbstractNode::set_ready_for_scale_in() {
|
||||
Register(client_to_scheduler_);
|
||||
connected_nodes_.clear();
|
||||
}
|
||||
|
||||
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len,
|
||||
int command, const uint32_t &timeout) {
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
|
@ -267,6 +277,10 @@ bool AbstractNode::CollectiveWait(std::pair<uint32_t, uint64_t> request_id, cons
|
|||
return res;
|
||||
}
|
||||
|
||||
int32_t AbstractNode::worker_num() const { return worker_num_; }
|
||||
|
||||
int32_t AbstractNode::server_num() const { return server_num_; }
|
||||
|
||||
void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) {
|
||||
MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_
|
||||
|
@ -375,7 +389,6 @@ void AbstractNode::ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std:
|
|||
SendMetadataMessage send_meta_message;
|
||||
send_meta_message.ParseFromArray(data, size);
|
||||
nodes_address_.clear();
|
||||
MS_LOG(ERROR) << "send metadata size:" << send_meta_message.servers_meta().size();
|
||||
for (const auto &it : send_meta_message.servers_meta()) {
|
||||
nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port());
|
||||
MS_LOG(INFO) << "The server ip is:" << it.ip() << ", the port is:" << it.port();
|
||||
|
@ -383,6 +396,16 @@ void AbstractNode::ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std:
|
|||
server_->SendMessage(conn, meta, Protos::RAW, data, size);
|
||||
is_ready_ = true;
|
||||
wait_start_cond_.notify_all();
|
||||
|
||||
if (current_cluster_state_ == ClusterState::CLUSTER_SCALE_OUT) {
|
||||
MS_LOG(WARNING) << "Trigger cluster scale out done event.";
|
||||
on_node_event_message_(ClusterEvent::CLUSTER_SCALE_OUT_DONE);
|
||||
}
|
||||
if (current_cluster_state_ == ClusterState::CLUSTER_SCALE_IN) {
|
||||
MS_LOG(WARNING) << "Trigger cluster scale in done event.";
|
||||
on_node_event_message_(ClusterEvent::CLUSTER_SCALE_IN_DONE);
|
||||
}
|
||||
current_cluster_state_ = ClusterState::CLUSTER_READY;
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessFinish(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
|
@ -395,6 +418,28 @@ void AbstractNode::ProcessFinish(std::shared_ptr<TcpConnection> conn, std::share
|
|||
wait_finish_cond_.notify_all();
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessScaleOut(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
server_->SendMessage(conn, meta, Protos::RAW, data, size);
|
||||
on_node_event_message_(ClusterEvent::READY_FOR_SCALE_OUT);
|
||||
current_cluster_state_ = ClusterState::CLUSTER_SCALE_OUT;
|
||||
is_ready_ = false;
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessScaleIn(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
server_->SendMessage(conn, meta, Protos::RAW, data, size);
|
||||
on_node_event_message_(ClusterEvent::READY_FOR_SCALE_IN);
|
||||
current_cluster_state_ = ClusterState::CLUSTER_SCALE_IN;
|
||||
is_ready_ = false;
|
||||
}
|
||||
|
||||
bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) {
|
||||
auto meta = std::make_shared<MessageMeta>();
|
||||
meta->set_cmd(NodeCommand::FINISH);
|
||||
|
@ -613,9 +658,10 @@ void AbstractNode::InitServerHandler() {
|
|||
server_handler_[NodeCommand::FINISH] = &AbstractNode::ProcessFinish;
|
||||
server_handler_[NodeCommand::SEND_DATA] = nullptr;
|
||||
server_handler_[NodeCommand::COLLECTIVE_SEND_DATA] = nullptr;
|
||||
server_handler_[NodeCommand::SCALE_OUT] = &AbstractNode::ProcessScaleOut;
|
||||
}
|
||||
|
||||
void AbstractNode::InitNode(const NodeRole &role) {
|
||||
void AbstractNode::InitNodeInfo(const NodeRole &role) {
|
||||
node_info_.node_id_ = CommUtil::GenerateUUID();
|
||||
node_info_.node_role_ = role;
|
||||
node_info_.ip_ = server_->BoundIp();
|
||||
|
@ -624,6 +670,11 @@ void AbstractNode::InitNode(const NodeRole &role) {
|
|||
<< " is generate uuid is:" << node_info_.node_id_ << ", the ip:" << server_->BoundIp()
|
||||
<< ", the port:" << server_->BoundPort();
|
||||
}
|
||||
|
||||
void AbstractNode::InitNodeNum() {
|
||||
worker_num_ = PSContext::instance()->cluster_config().initial_worker_num;
|
||||
server_num_ = PSContext::instance()->cluster_config().initial_server_num;
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,7 +38,9 @@ class AbstractNode : public Node {
|
|||
client_to_scheduler_thread_(nullptr),
|
||||
client_to_scheduler_(nullptr),
|
||||
server_(nullptr),
|
||||
server_thread_(nullptr) {}
|
||||
server_thread_(nullptr),
|
||||
worker_num_(-1),
|
||||
server_num_(-1) {}
|
||||
~AbstractNode() override = default;
|
||||
|
||||
typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
|
@ -52,6 +54,10 @@ class AbstractNode : public Node {
|
|||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
|
||||
void set_event_callback(const OnNodeEventMessage &event);
|
||||
// When the business layer finish scale out, it should call this function
|
||||
void set_ready_for_scale_out();
|
||||
// When the business layer finish scale in, it should call this function
|
||||
void set_ready_for_scale_in();
|
||||
|
||||
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
|
@ -68,6 +74,9 @@ class AbstractNode : public Node {
|
|||
VectorPtr *output);
|
||||
bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
|
||||
int32_t worker_num() const;
|
||||
int32_t server_num() const;
|
||||
|
||||
protected:
|
||||
void Register(const std::shared_ptr<TcpClient> &client);
|
||||
bool Heartbeat(const std::shared_ptr<TcpClient> &client);
|
||||
|
@ -82,6 +91,12 @@ class AbstractNode : public Node {
|
|||
void ProcessFinish(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
|
||||
const void *data, size_t size);
|
||||
|
||||
void ProcessScaleOut(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
|
||||
const void *data, size_t size);
|
||||
|
||||
void ProcessScaleIn(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
|
||||
const void *data, size_t size);
|
||||
|
||||
void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
|
||||
void UpdateSchedulerTime();
|
||||
bool CheckSchedulerTimeout() const;
|
||||
|
@ -98,7 +113,11 @@ class AbstractNode : public Node {
|
|||
uint64_t NextActualRankRequestId(const uint32_t &rank_id);
|
||||
void InitCommandHandler();
|
||||
void InitServerHandler();
|
||||
void InitNode(const NodeRole &role);
|
||||
|
||||
// when initializing the node, should initializing the node info.
|
||||
void InitNodeInfo(const NodeRole &role);
|
||||
// Initialize worker num and server num by cluster config.
|
||||
void InitNodeNum();
|
||||
|
||||
std::unique_ptr<std::thread> heart_beat_thread_;
|
||||
std::unique_ptr<std::thread> client_to_scheduler_thread_;
|
||||
|
@ -136,11 +155,12 @@ class AbstractNode : public Node {
|
|||
std::unordered_map<NodeCommand, ResponseHandler> handlers_;
|
||||
std::unordered_map<NodeCommand, ServerHandler> server_handler_;
|
||||
|
||||
std::unordered_map<ClusterEvent, bool> is_event_send_;
|
||||
std::mutex is_event_send_mutex_;
|
||||
|
||||
// Workers and servers launch the server to process command: FINISH,SCALE_OUT,SCALE_IN,SEND_METADATA
|
||||
std::shared_ptr<TcpServer> server_;
|
||||
std::unique_ptr<std::thread> server_thread_;
|
||||
|
||||
int32_t worker_num_;
|
||||
int32_t server_num_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -47,11 +47,6 @@ class Node {
|
|||
Node()
|
||||
: is_ready_(false),
|
||||
is_finish_(false),
|
||||
is_node_ready_scale_out_(false),
|
||||
is_node_ready_scale_in_(false),
|
||||
is_cluster_ready_scale_out_(false),
|
||||
is_cluster_ready_scale_in_(false),
|
||||
update_local_servers_(false),
|
||||
is_already_stopped_(true),
|
||||
is_already_finished_(false),
|
||||
next_request_id_(0),
|
||||
|
@ -93,15 +88,6 @@ class Node {
|
|||
std::atomic<bool> is_ready_;
|
||||
std::atomic<bool> is_finish_;
|
||||
|
||||
std::atomic<bool> is_node_ready_scale_out_;
|
||||
std::atomic<bool> is_node_ready_scale_in_;
|
||||
|
||||
std::atomic<bool> is_cluster_ready_scale_out_;
|
||||
std::atomic<bool> is_cluster_ready_scale_in_;
|
||||
|
||||
// Determine whether to update the ip and port of the locally cached servers.
|
||||
std::atomic<bool> update_local_servers_;
|
||||
|
||||
std::atomic<bool> is_already_stopped_;
|
||||
std::atomic<bool> is_already_finished_;
|
||||
std::atomic_uint64_t next_request_id_;
|
||||
|
@ -117,6 +103,7 @@ class Node {
|
|||
std::mutex message_tracker_mutex_;
|
||||
std::condition_variable message_tracker_cond_;
|
||||
|
||||
// Worker and server receive the node state and cluster state from the scheduler.
|
||||
NodeState current_node_state_;
|
||||
ClusterState current_cluster_state_;
|
||||
};
|
||||
|
|
|
@ -80,12 +80,9 @@ int NodeManager::NextRankId(const RegisterMessage ®ister_message) {
|
|||
|
||||
void NodeManager::UpdateHeartbeat(const std::string &node_id) {
|
||||
std::lock_guard<std::mutex> lock(heartbeat_mutex_);
|
||||
NodeInfo node_info = nodes_info_[node_id];
|
||||
struct timeval current_time {};
|
||||
(void)gettimeofday(¤t_time, nullptr);
|
||||
heartbeats_[node_id] = current_time;
|
||||
MS_LOG(DEBUG) << "The node role: " << CommUtil::NodeRoleToString(node_info.node_role_) << ", the node id:" << node_id
|
||||
<< ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec;
|
||||
}
|
||||
|
||||
void NodeManager::UpdateNodeScaleInState(const std::string &node_id) { heartbeats_scale_in_nodes_.insert(node_id); }
|
||||
|
@ -116,7 +113,9 @@ void NodeManager::UpdateCluster() {
|
|||
for (auto it = heartbeats_.begin(); it != heartbeats_.end(); ++it) {
|
||||
if (it->second.tv_sec + PSContext::instance()->cluster_config().heartbeat_timeout < current_time.tv_sec) {
|
||||
MS_LOG(WARNING) << "The node id:" << it->first << " is timeout!";
|
||||
timeout_nodes_info_[it->first] = nodes_info_[it->first];
|
||||
if (nodes_info_.count(it->first)) {
|
||||
timeout_nodes_info_[it->first] = nodes_info_[it->first];
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!timeout_nodes_info_.empty()) {
|
||||
|
@ -146,9 +145,9 @@ void NodeManager::CheckClusterTimeout() {
|
|||
|
||||
void NodeManager::AddFinishNode(const std::string &finish_message) { finish_nodes_id_.insert(finish_message); }
|
||||
|
||||
bool NodeManager::CheckRegisterNum() { return SizeToInt(nodes_info_.size()) == total_node_num_; }
|
||||
bool NodeManager::IsAllNodesRegistered() { return SizeToInt(nodes_info_.size()) == total_node_num_; }
|
||||
|
||||
bool NodeManager::CheckFinishNum() { return SizeToInt(finish_nodes_id_.size()) == total_node_num_; }
|
||||
bool NodeManager::IsAllNodesFinished() { return SizeToInt(finish_nodes_id_.size()) == total_node_num_; }
|
||||
|
||||
std::unordered_map<std::string, NodeInfo> &NodeManager::nodes_info() { return nodes_info_; }
|
||||
|
||||
|
@ -172,9 +171,24 @@ ClusterState NodeManager::GetClusterState() {
|
|||
return cluster_state_;
|
||||
}
|
||||
|
||||
void NodeManager::ResetMetadata() {
|
||||
MS_LOG(WARNING) << "Reset metadata.";
|
||||
nodes_info_.clear();
|
||||
next_worker_rank_id_ = -1;
|
||||
next_server_rank_id_ = -1;
|
||||
}
|
||||
|
||||
void NodeManager::set_total_node_num(const int32_t &node_num) { total_node_num_ = node_num; }
|
||||
|
||||
const int32_t &NodeManager::total_node_num() { return total_node_num_; }
|
||||
|
||||
void NodeManager::set_worker_num(const int32_t &worker_num) { meta_data_->worker_num = worker_num; }
|
||||
|
||||
void NodeManager::set_server_num(const int32_t &server_num) { meta_data_->server_num = server_num; }
|
||||
|
||||
int32_t NodeManager::worker_num() { return UintToInt(meta_data_->worker_num); }
|
||||
|
||||
int32_t NodeManager::server_num() { return UintToInt(meta_data_->server_num); }
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -68,21 +68,29 @@ class NodeManager {
|
|||
|
||||
// When workers and servers registered to scheduler, the scheduler will collect the number of registered
|
||||
// nodes and Determine whether the registered number of worker and server is equal to total_node_num_.
|
||||
bool CheckRegisterNum();
|
||||
bool IsAllNodesRegistered();
|
||||
// When workers and servers send a finish message to the scheduler, the scheduler will collect the number of
|
||||
// finish nodes and Determine whether the finished nodes are equal to total_node_num_.
|
||||
bool CheckFinishNum();
|
||||
bool IsAllNodesFinished();
|
||||
|
||||
std::unordered_map<std::string, NodeInfo> &nodes_info();
|
||||
|
||||
void set_total_node_num(const int32_t &node_num);
|
||||
const int32_t &total_node_num();
|
||||
void set_worker_num(const int32_t &worker_num);
|
||||
void set_server_num(const int32_t &server_num);
|
||||
int32_t worker_num();
|
||||
int32_t server_num();
|
||||
|
||||
void UpdateNodeState(const NodeState &state);
|
||||
void UpdateClusterState(const ClusterState &state);
|
||||
NodeState GetNodeState();
|
||||
ClusterState GetClusterState();
|
||||
|
||||
// When the scheduler receives the scale out or scale in message, the metadata needs to be reset, because all nodes
|
||||
// will re-register.
|
||||
void ResetMetadata();
|
||||
|
||||
private:
|
||||
std::mutex node_mutex_;
|
||||
std::mutex cluster_mutex_;
|
||||
|
@ -107,6 +115,7 @@ class NodeManager {
|
|||
std::unordered_map<std::string, NodeInfo> timeout_nodes_info_;
|
||||
std::unordered_set<std::string> finish_nodes_id_;
|
||||
|
||||
// Cluster metadata information can be dynamically changed
|
||||
std::unique_ptr<ClusterMetadata> meta_data_;
|
||||
|
||||
NodeState node_state_;
|
||||
|
|
|
@ -29,6 +29,8 @@ enum NodeCommand {
|
|||
COLLECTIVE_SEND_DATA = 6;
|
||||
// The scheduler actively sends metadata to the worker and server
|
||||
SEND_METADATA = 7;
|
||||
SCALE_OUT = 8;
|
||||
SCALE_IN = 9;
|
||||
}
|
||||
|
||||
enum NodeRole {
|
||||
|
@ -120,3 +122,11 @@ message CommMessage {
|
|||
MessageMeta pb_meta = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
// The scheduler will broadcast the worker/server numbers after scale out/in to all nodes.
|
||||
message ScaleOutMessage {
|
||||
// the worker number after scale out/in
|
||||
int32 worker_num = 1;
|
||||
// the server number after scale out/in
|
||||
int32 server_num = 2;
|
||||
}
|
||||
|
|
|
@ -120,7 +120,7 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
|
|||
server->SendMessage(conn, meta, Protos::PROTOBUF, register_resp_message.SerializeAsString().data(),
|
||||
register_resp_message.ByteSizeLong());
|
||||
|
||||
if (node_manager_.CheckRegisterNum()) {
|
||||
if (node_manager_.IsAllNodesRegistered()) {
|
||||
is_ready_ = true;
|
||||
auto node_infos = node_manager_.nodes_info();
|
||||
for (const auto &kvs : node_infos) {
|
||||
|
@ -142,7 +142,7 @@ void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared
|
|||
node_manager_.AddFinishNode(*finish_message);
|
||||
MS_LOG(INFO) << "Process finish message from node id:" << *finish_message;
|
||||
server->SendMessage(conn, meta, Protos::PROTOBUF, data, size);
|
||||
if (node_manager_.CheckFinishNum()) {
|
||||
if (node_manager_.IsAllNodesFinished()) {
|
||||
auto node_infos = node_manager_.nodes_info();
|
||||
for (const auto &kvs : node_infos) {
|
||||
auto client = GetOrCreateClient(kvs.second);
|
||||
|
@ -177,8 +177,6 @@ void SchedulerNode::SendMetadata(const std::shared_ptr<TcpClient> &client) {
|
|||
SendMetadataMessage send_metadata_message;
|
||||
std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta();
|
||||
|
||||
MS_LOG(ERROR) << "the list size:" << servers_meta_list.size();
|
||||
|
||||
*send_metadata_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};
|
||||
|
||||
if (!SendMessageAsync(client, message_meta, Protos::PROTOBUF, send_metadata_message.SerializeAsString().data(),
|
||||
|
@ -238,7 +236,6 @@ const std::shared_ptr<TcpClient> &SchedulerNode::GetOrCreateClient(const NodeInf
|
|||
std::string ip = node_info.ip_;
|
||||
uint16_t port = node_info.port_;
|
||||
auto client = std::make_shared<TcpClient>(ip, port);
|
||||
MS_LOG(ERROR) << "the ip:" << node_info.ip_ << ", the port:" << node_info.port_;
|
||||
client->SetMessageCallback([&](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
|
||||
size_t size) { NotifyMessageArrival(meta); });
|
||||
client->Init();
|
||||
|
|
|
@ -89,6 +89,7 @@ class SchedulerNode : public Node {
|
|||
|
||||
NodeManager node_manager_;
|
||||
|
||||
// In this thread will start a http server.
|
||||
std::unique_ptr<std::thread> restful_thread_;
|
||||
std::shared_ptr<HttpServer> http_server_;
|
||||
|
||||
|
|
|
@ -83,7 +83,8 @@ void ServerNode::Initialize() {
|
|||
InitServerHandler();
|
||||
CreateTcpServer();
|
||||
is_already_stopped_ = false;
|
||||
InitNode(NodeRole::SERVER);
|
||||
InitNodeInfo(NodeRole::SERVER);
|
||||
InitNodeNum();
|
||||
InitCommandHandler();
|
||||
if (!InitClientToScheduler()) {
|
||||
MS_LOG(EXCEPTION) << "Server node init client timeout!";
|
||||
|
|
|
@ -41,7 +41,8 @@ void WorkerNode::Initialize() {
|
|||
is_already_stopped_ = false;
|
||||
InitServerHandler();
|
||||
CreateTcpServer();
|
||||
InitNode(NodeRole::WORKER);
|
||||
InitNodeInfo(NodeRole::WORKER);
|
||||
InitNodeNum();
|
||||
InitCommandHandler();
|
||||
if (!InitClientToScheduler()) {
|
||||
MS_LOG(EXCEPTION) << "Worker node init client timeout!";
|
||||
|
|
|
@ -52,6 +52,7 @@ void PSContext::SetPSEnable(bool enabled) {
|
|||
server_num_ = std::strtol(common::GetEnv(kEnvPServerNum).c_str(), nullptr, 10);
|
||||
scheduler_host_ = common::GetEnv(kEnvSchedulerHost);
|
||||
scheduler_port_ = std::strtol(common::GetEnv(kEnvSchedulerPort).c_str(), nullptr, 10);
|
||||
scheduler_manage_port_ = std::strtol(common::GetEnv(kEnvSchedulerManagePort).c_str(), nullptr, 10);
|
||||
cluster_config_ = std::make_unique<core::ClusterConfig>(worker_num_, server_num_, scheduler_host_, scheduler_port_);
|
||||
} else {
|
||||
MS_LOG(INFO) << "PS mode is disabled.";
|
||||
|
@ -312,5 +313,9 @@ core::ClusterConfig &PSContext::cluster_config() {
|
|||
}
|
||||
return *cluster_config_;
|
||||
}
|
||||
|
||||
void PSContext::set_scheduler_manage_port(uint16_t sched_port) { scheduler_manage_port_ = sched_port; }
|
||||
|
||||
uint16_t PSContext::scheduler_manage_port() const { return scheduler_manage_port_; }
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -145,6 +145,9 @@ class PSContext {
|
|||
|
||||
core::ClusterConfig &cluster_config();
|
||||
|
||||
void set_scheduler_manage_port(uint16_t sched_port);
|
||||
uint16_t scheduler_manage_port() const;
|
||||
|
||||
private:
|
||||
PSContext()
|
||||
: ps_enabled_(false),
|
||||
|
@ -172,7 +175,8 @@ class PSContext {
|
|||
client_batch_size_(32),
|
||||
client_learning_rate_(0.001),
|
||||
secure_aggregation_(false),
|
||||
cluster_config_(nullptr) {}
|
||||
cluster_config_(nullptr),
|
||||
scheduler_manage_port_(0) {}
|
||||
bool ps_enabled_;
|
||||
bool is_worker_;
|
||||
bool is_pserver_;
|
||||
|
@ -231,6 +235,9 @@ class PSContext {
|
|||
|
||||
// The cluster config read through environment variables, the value does not change.
|
||||
std::unique_ptr<core::ClusterConfig> cluster_config_;
|
||||
|
||||
// The port used by scheduler to receive http requests for scale out or scale in.
|
||||
uint16_t scheduler_manage_port_;
|
||||
};
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -52,7 +52,8 @@ _set_ps_context_func_map = {
|
|||
"client_epoch_num": ps_context().set_client_epoch_num,
|
||||
"client_batch_size": ps_context().set_client_batch_size,
|
||||
"client_learning_rate": ps_context().set_client_learning_rate,
|
||||
"enable_ps_ssl": ps_context().set_enable_ssl
|
||||
"enable_ps_ssl": ps_context().set_enable_ssl,
|
||||
"scheduler_manage_port": ps_context().set_scheduler_manage_port
|
||||
}
|
||||
|
||||
_get_ps_context_func_map = {
|
||||
|
|
Loading…
Reference in New Issue