add node recovery
This commit is contained in:
parent
26c7d274c9
commit
bfaab72934
|
@ -362,7 +362,9 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
"Set federated learning client learning rate.")
|
||||
.def("set_scheduler_manage_port", &PSContext::set_scheduler_manage_port,
|
||||
"Set scheduler manage port used to scale out/in.")
|
||||
.def("set_enable_ssl", &PSContext::enable_ssl, "Set PS SSL mode enabled or disabled.");
|
||||
.def("set_enable_ssl", &PSContext::set_enable_ssl, "Set PS SSL mode enabled or disabled.")
|
||||
.def("set_config_file_path", &PSContext::set_config_file_path,
|
||||
"Set configuration files required by the communication layer.");
|
||||
|
||||
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
|
||||
.def(py::init())
|
||||
|
|
|
@ -34,6 +34,9 @@ if(NOT ENABLE_CPU OR WIN32)
|
|||
list(REMOVE_ITEM _PS_SRC_FILES "core/leader_scaler.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/follower_scaler.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/file_configuration.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/recovery_base.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/node_recovery.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/scheduler_recovery.cc")
|
||||
endif()
|
||||
|
||||
if(NOT ENABLE_D)
|
||||
|
|
|
@ -95,6 +95,19 @@ constexpr uint32_t kCheckRegisteredRetryCount = 30;
|
|||
// The timeout interval for judging whether all nodes are successfully registered.
|
||||
constexpr uint32_t kCheckRegisteredIntervalInMs = 1000;
|
||||
|
||||
// The type of persistent storage, currently only supports file storage.
|
||||
constexpr char kStoreType[] = "storge_type";
|
||||
// The file used to storage metadata.
|
||||
constexpr char kStoreFilePath[] = "storge_file_path";
|
||||
// 1 indicates that the persistent storage type is file.
|
||||
constexpr char kFileStorage[] = "1";
|
||||
// The recovery key of json_config.
|
||||
constexpr char kKeyRecovery[] = "recovery";
|
||||
constexpr char kRecoveryWorkerNum[] = "worker_num";
|
||||
constexpr char kRecoveryServerNum[] = "server_num";
|
||||
constexpr char kRecoverySchedulerIp[] = "scheduler_ip";
|
||||
constexpr char kRecoverySchedulerPort[] = "scheduler_port";
|
||||
|
||||
using DataPtr = std::shared_ptr<unsigned char[]>;
|
||||
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
|
||||
using Key = uint64_t;
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "ps/core/abstract_node.h"
|
||||
#include "ps/core/node_recovery.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -54,8 +55,7 @@ void AbstractNode::ProcessRegisterResp(std::shared_ptr<MessageMeta> meta, const
|
|||
// scheduler is alive
|
||||
UpdateSchedulerTime();
|
||||
|
||||
MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_
|
||||
<< " registered scheduler success!";
|
||||
MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << " registered scheduler success!";
|
||||
}
|
||||
|
||||
bool AbstractNode::Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command,
|
||||
|
@ -143,8 +143,9 @@ void AbstractNode::BroadcastEvent(const uint32_t &event) {
|
|||
|
||||
if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF, event_message.SerializeAsString().data(),
|
||||
event_message.ByteSizeLong())) {
|
||||
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << " send event timeout!";
|
||||
MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << " send event timeout!";
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
|
@ -377,6 +378,18 @@ int32_t AbstractNode::worker_num() const { return worker_num_; }
|
|||
|
||||
int32_t AbstractNode::server_num() const { return server_num_; }
|
||||
|
||||
void AbstractNode::set_worker_num(const int32_t &worker_num) { worker_num_ = worker_num; }
|
||||
|
||||
void AbstractNode::set_server_num(const int32_t &server_num) { server_num_ = server_num; }
|
||||
|
||||
std::string AbstractNode::scheduler_ip() const { return scheduler_ip_; }
|
||||
|
||||
void AbstractNode::set_scheduler_ip(const std::string &scheduler_ip) { scheduler_ip_ = scheduler_ip; }
|
||||
|
||||
uint16_t AbstractNode::scheduler_port() const { return scheduler_port_; }
|
||||
|
||||
void AbstractNode::set_scheduler_port(const uint16_t &scheduler_port) { scheduler_port_ = scheduler_port; }
|
||||
|
||||
ClusterState AbstractNode::cluster_state() const { return current_cluster_state_; }
|
||||
|
||||
void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) {
|
||||
|
@ -450,10 +463,13 @@ void AbstractNode::ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const
|
|||
wait_start_cond_.notify_all();
|
||||
}
|
||||
|
||||
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
|
||||
is_ready_ = true;
|
||||
wait_start_cond_.notify_all();
|
||||
OnEventCallback(ClusterEvent::NODE_TIMEOUT);
|
||||
if (node_recovery_ == nullptr) {
|
||||
MS_LOG(INFO) << "The recovery is disable.";
|
||||
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
|
||||
is_ready_ = true;
|
||||
wait_start_cond_.notify_all();
|
||||
OnEventCallback(ClusterEvent::NODE_TIMEOUT);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -489,6 +505,7 @@ void AbstractNode::ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std:
|
|||
MS_EXCEPTION_IF_NULL(data);
|
||||
if (is_current_node_scale_in_) {
|
||||
MS_LOG(WARNING) << "Trigger cluster scale in done event.";
|
||||
node_info_.rank_id_ = UINT32_MAX;
|
||||
OnEventCallback(ClusterEvent::CLUSTER_SCALE_IN_DONE);
|
||||
return;
|
||||
}
|
||||
|
@ -639,9 +656,7 @@ bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) {
|
|||
}
|
||||
|
||||
bool AbstractNode::InitClientToScheduler() {
|
||||
std::string scheduler_host = PSContext::instance()->cluster_config().scheduler_host;
|
||||
uint16_t scheduler_port = PSContext::instance()->cluster_config().scheduler_port;
|
||||
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port);
|
||||
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_ip_, scheduler_port_);
|
||||
client_to_scheduler_->SetMessageCallback(
|
||||
[&](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
|
||||
try {
|
||||
|
@ -851,6 +866,7 @@ void AbstractNode::InitNodeInfo(const NodeRole &role) {
|
|||
node_info_.node_role_ = role;
|
||||
node_info_.ip_ = server_->BoundIp();
|
||||
node_info_.port_ = server_->BoundPort();
|
||||
|
||||
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " is generate uuid is:" << node_info_.node_id_ << ", the ip:" << server_->BoundIp()
|
||||
<< ", the port:" << server_->BoundPort();
|
||||
|
@ -859,6 +875,20 @@ void AbstractNode::InitNodeInfo(const NodeRole &role) {
|
|||
void AbstractNode::InitNodeNum() {
|
||||
worker_num_ = PSContext::instance()->cluster_config().initial_worker_num;
|
||||
server_num_ = PSContext::instance()->cluster_config().initial_server_num;
|
||||
scheduler_ip_ = PSContext::instance()->cluster_config().scheduler_host;
|
||||
scheduler_port_ = PSContext::instance()->cluster_config().scheduler_port;
|
||||
MS_LOG(INFO) << "The worker num:" << worker_num_ << ", the server num:" << server_num_
|
||||
<< ", the scheduler ip:" << scheduler_ip_ << ", the scheduler port:" << scheduler_port_;
|
||||
}
|
||||
|
||||
bool AbstractNode::Recover() {
|
||||
if (config_->Exists(kKeyRecovery)) {
|
||||
MS_LOG(INFO) << "The node is support recovery.";
|
||||
node_recovery_ = std::make_unique<NodeRecovery>(this);
|
||||
node_recovery_->Initialize(config_->Get(kKeyRecovery, ""));
|
||||
return node_recovery_->Recover();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void AbstractNode::OnEventCallback(const ClusterEvent &event) {
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "ps/core/follower_scaler.h"
|
||||
#include "utils/ms_exception.h"
|
||||
#include "ps/constants.h"
|
||||
#include "ps/core/recovery_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -44,7 +45,11 @@ class AbstractNode : public Node {
|
|||
server_thread_(nullptr),
|
||||
worker_num_(-1),
|
||||
server_num_(-1),
|
||||
is_current_node_scale_in_(false) {}
|
||||
is_current_node_scale_in_(false),
|
||||
follower_scaler_(nullptr),
|
||||
node_recovery_(nullptr),
|
||||
scheduler_ip_(""),
|
||||
scheduler_port_(0) {}
|
||||
~AbstractNode() override = default;
|
||||
|
||||
typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
|
@ -105,6 +110,15 @@ class AbstractNode : public Node {
|
|||
int32_t worker_num() const;
|
||||
int32_t server_num() const;
|
||||
|
||||
void set_worker_num(const int32_t &worker_num);
|
||||
void set_server_num(const int32_t &server_num);
|
||||
|
||||
std::string scheduler_ip() const;
|
||||
void set_scheduler_ip(const std::string &scheduler_ip);
|
||||
|
||||
uint16_t scheduler_port() const;
|
||||
void set_scheduler_port(const uint16_t &scheduler_port);
|
||||
|
||||
ClusterState cluster_state() const;
|
||||
|
||||
protected:
|
||||
|
@ -159,6 +173,8 @@ class AbstractNode : public Node {
|
|||
void InitNodeInfo(const NodeRole &role);
|
||||
// Initialize worker num and server num by cluster config.
|
||||
void InitNodeNum();
|
||||
// Node recover by cluster config.
|
||||
bool Recover();
|
||||
|
||||
// Trigger the callback corresponding to the event.
|
||||
void OnEventCallback(const ClusterEvent &event);
|
||||
|
@ -221,6 +237,14 @@ class AbstractNode : public Node {
|
|||
|
||||
// Scaler for worker/server node.
|
||||
std::unique_ptr<FollowerScaler> follower_scaler_;
|
||||
|
||||
// Recovery for worker/server node.
|
||||
std::unique_ptr<RecoveryBase> node_recovery_;
|
||||
|
||||
// The ip of scheduler.
|
||||
std::string scheduler_ip_;
|
||||
// The port of scheduler.
|
||||
uint16_t scheduler_port_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -47,6 +47,9 @@ class Configuration {
|
|||
|
||||
// Put configuration data to database or config file.
|
||||
virtual void Put(const std::string &key, const std::string &defaultvalue) = 0;
|
||||
|
||||
// Determine whether the configuration item exists.
|
||||
virtual bool Exists(const std::string &key) = 0;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -39,7 +39,7 @@ bool FileConfiguration::Initialize() {
|
|||
|
||||
std::string FileConfiguration::Get(const std::string &key, const std::string &defaultvalue) const {
|
||||
if (!js.contains(key)) {
|
||||
MS_LOG(WARNING) << "The key:" << key << " is not exit.";
|
||||
MS_LOG(WARNING) << "The key:" << key << " is not exist.";
|
||||
return defaultvalue;
|
||||
}
|
||||
std::string res = js.at(key);
|
||||
|
@ -54,6 +54,12 @@ void FileConfiguration::Put(const std::string &key, const std::string &value) {
|
|||
output_file.close();
|
||||
}
|
||||
|
||||
bool FileConfiguration::Exists(const std::string &key) {
|
||||
if (!js.contains(key)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -56,6 +56,8 @@ class FileConfiguration : public Configuration {
|
|||
|
||||
void Put(const std::string &key, const std::string &value) override;
|
||||
|
||||
bool Exists(const std::string &key) override;
|
||||
|
||||
private:
|
||||
// The path of the configuration file.
|
||||
std::string file_path_;
|
||||
|
|
|
@ -36,6 +36,7 @@
|
|||
#include "ps/core/node_info.h"
|
||||
#include "ps/core/communicator/tcp_client.h"
|
||||
#include "ps/core/communicator/tcp_server.h"
|
||||
#include "ps/core/file_configuration.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -67,7 +68,7 @@ class Node {
|
|||
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
|
||||
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta>, const Protos &,
|
||||
const void *, size_t size, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
const void *, size_t size, const uint32_t &timeout = kTimeoutInSeconds);
|
||||
|
||||
protected:
|
||||
bool WaitForStart(const uint32_t &timeout);
|
||||
|
@ -105,6 +106,15 @@ class Node {
|
|||
// Worker and server receive the node state and cluster state from the scheduler.
|
||||
NodeState current_node_state_;
|
||||
ClusterState current_cluster_state_;
|
||||
|
||||
// Configuration file,The format is as follows
|
||||
//{
|
||||
// "recovery": {
|
||||
// "storage_type": 1,
|
||||
// "storge_file_path": "/home/cds/config.json"
|
||||
// }
|
||||
// }
|
||||
std::unique_ptr<Configuration> config_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -36,7 +36,7 @@ enum class ClusterEvent {
|
|||
};
|
||||
|
||||
struct NodeInfo {
|
||||
NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {}
|
||||
NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0), is_alive(false) {}
|
||||
// ip
|
||||
std::string ip_;
|
||||
// the port of this node
|
||||
|
@ -47,6 +47,9 @@ struct NodeInfo {
|
|||
NodeRole node_role_;
|
||||
// the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1]
|
||||
uint32_t rank_id_;
|
||||
|
||||
// After the node registration is successful, it is alive.If the node's heartbeat times out, then it is not alive
|
||||
bool is_alive;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -42,7 +42,21 @@ uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message) {
|
|||
const std::string &ip = register_message.ip();
|
||||
uint32_t port = register_message.port();
|
||||
|
||||
rank_id = ++next_server_rank_id_;
|
||||
auto rank_it = std::find_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [&rank_id](auto item) {
|
||||
bool res = item.second.is_alive == false && item.second.node_role_ == NodeRole::SERVER;
|
||||
if (res) {
|
||||
MS_LOG(INFO) << "The server node id:" << item.first << " rank id:" << rank_id << " is not alive.";
|
||||
rank_id = item.second.rank_id_;
|
||||
}
|
||||
return res;
|
||||
});
|
||||
|
||||
if (rank_it == registered_nodes_info_.end()) {
|
||||
rank_id = ++next_server_rank_id_;
|
||||
} else {
|
||||
registered_nodes_info_.erase((*rank_it).first);
|
||||
}
|
||||
|
||||
if (rank_id >= meta_data_->server_num) {
|
||||
MS_LOG(WARNING) << "The rank id is greater than the number of servers:" << meta_data_->server_num;
|
||||
rank_id = UINT_MAX;
|
||||
|
@ -54,13 +68,29 @@ uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message) {
|
|||
node_info.rank_id_ = rank_id;
|
||||
node_info.ip_ = ip;
|
||||
node_info.port_ = port;
|
||||
node_info.is_alive = true;
|
||||
registered_nodes_info_[node_id] = node_info;
|
||||
MS_LOG(INFO) << "The server node id:" << node_id << ",node ip: " << node_info.ip_ << ",node port:" << port
|
||||
<< " assign rank id:" << rank_id;
|
||||
} else if (register_message.role() == NodeRole::WORKER) {
|
||||
const std::string &ip = register_message.ip();
|
||||
uint32_t port = register_message.port();
|
||||
rank_id = ++next_worker_rank_id_;
|
||||
|
||||
auto rank_it = std::find_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [&rank_id](auto item) {
|
||||
bool res = item.second.is_alive == false && item.second.node_role_ == NodeRole::WORKER;
|
||||
if (res) {
|
||||
MS_LOG(INFO) << "The worker node id:" << item.first << " rank id:" << rank_id << " is not alive.";
|
||||
rank_id = item.second.rank_id_;
|
||||
}
|
||||
return res;
|
||||
});
|
||||
|
||||
if (rank_it == registered_nodes_info_.end()) {
|
||||
rank_id = ++next_worker_rank_id_;
|
||||
} else {
|
||||
registered_nodes_info_.erase((*rank_it).first);
|
||||
}
|
||||
|
||||
if (rank_id >= meta_data_->worker_num) {
|
||||
MS_LOG(WARNING) << "The rank id is greater than the number of workers:" << meta_data_->worker_num;
|
||||
rank_id = UINT_MAX;
|
||||
|
@ -72,6 +102,7 @@ uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message) {
|
|||
node_info.rank_id_ = rank_id;
|
||||
node_info.ip_ = ip;
|
||||
node_info.port_ = port;
|
||||
node_info.is_alive = true;
|
||||
registered_nodes_info_[node_id] = node_info;
|
||||
MS_LOG(INFO) << "The worker node id:" << node_id << " assign rank id:" << rank_id;
|
||||
}
|
||||
|
@ -85,12 +116,6 @@ void NodeManager::UpdateHeartbeat(const std::string &node_id) {
|
|||
heartbeats_[node_id] = current_time;
|
||||
}
|
||||
|
||||
void NodeManager::UpdateNodeScaleInState(const std::string &node_id) { heartbeats_scale_in_nodes_.insert(node_id); }
|
||||
|
||||
bool NodeManager::CheckNodesScaluOutState() { return SizeToInt(heartbeats_scale_out_nodes_.size()) == total_node_num_; }
|
||||
|
||||
bool NodeManager::CheckNodesScaleInState() { return SizeToInt(heartbeats_scale_in_nodes_.size()) == total_node_num_; }
|
||||
|
||||
std::vector<ServersMeta> NodeManager::FetchServersMeta() {
|
||||
std::vector<ServersMeta> servers_meta_list;
|
||||
for (auto it = registered_nodes_info_.begin(); it != registered_nodes_info_.end(); ++it) {
|
||||
|
@ -115,12 +140,15 @@ void NodeManager::UpdateCluster() {
|
|||
if (registered_nodes_info_.count(it->first)) {
|
||||
MS_LOG(WARNING) << "The node id:" << it->first << " is timeout!";
|
||||
timeout_nodes_info_[it->first] = registered_nodes_info_[it->first];
|
||||
registered_nodes_info_[it->first].is_alive = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!timeout_nodes_info_.empty()) {
|
||||
UpdateClusterState(ClusterState::NODE_TIMEOUT);
|
||||
for (auto it = timeout_nodes_info_.begin(); it != timeout_nodes_info_.end(); ++it) {
|
||||
heartbeats_.erase(it->first);
|
||||
finish_nodes_id_.insert(it->first);
|
||||
}
|
||||
}
|
||||
|
@ -149,7 +177,11 @@ void NodeManager::AddScaleOutDoneNode(const std::string &node_id) { scale_out_do
|
|||
|
||||
void NodeManager::AddScaleInDoneNode(const std::string &node_id) { scale_in_done_nodes_id_.insert(node_id); }
|
||||
|
||||
bool NodeManager::IsAllNodesRegistered() { return SizeToInt(registered_nodes_info_.size()) == total_node_num_; }
|
||||
bool NodeManager::IsAllNodesRegistered() {
|
||||
int32_t num = std::count_if(registered_nodes_info_.begin(), registered_nodes_info_.end(),
|
||||
[](auto item) { return item.second.is_alive == true; });
|
||||
return num == total_node_num_;
|
||||
}
|
||||
|
||||
bool NodeManager::IsAllNodesFinished() { return SizeToInt(finish_nodes_id_.size()) == total_node_num_; }
|
||||
|
||||
|
@ -175,7 +207,7 @@ void NodeManager::UpdateNodeState(const NodeState &state) {
|
|||
void NodeManager::UpdateClusterState(const ClusterState &state) {
|
||||
std::lock_guard<std::mutex> lk(cluster_mutex_);
|
||||
MS_LOG(INFO) << "[state]: Scheduler change state from:" << CommUtil::ClusterStateToString(cluster_state_) << " to "
|
||||
<< state;
|
||||
<< CommUtil::ClusterStateToString(state);
|
||||
cluster_state_ = state;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,6 +30,8 @@
|
|||
#include <vector>
|
||||
#include <condition_variable>
|
||||
#include <unordered_set>
|
||||
#include <deque>
|
||||
#include <algorithm>
|
||||
|
||||
#include "ps/core/node.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
@ -57,9 +59,6 @@ class NodeManager {
|
|||
uint32_t NextRankId(const RegisterMessage ®ister_message);
|
||||
|
||||
void UpdateHeartbeat(const std::string &node_id);
|
||||
bool CheckNodesScaluOutState();
|
||||
void UpdateNodeScaleInState(const std::string &node_id);
|
||||
bool CheckNodesScaleInState();
|
||||
|
||||
std::vector<ServersMeta> FetchServersMeta();
|
||||
void UpdateCluster();
|
||||
|
@ -126,9 +125,6 @@ class NodeManager {
|
|||
std::mutex heartbeat_mutex_;
|
||||
|
||||
std::unordered_map<std::string, timeval> heartbeats_;
|
||||
std::unordered_set<std::string> heartbeats_finish_nodes_;
|
||||
std::unordered_set<std::string> heartbeats_scale_out_nodes_;
|
||||
std::unordered_set<std::string> heartbeats_scale_in_nodes_;
|
||||
// timeout nodes
|
||||
std::unordered_map<std::string, NodeInfo> timeout_nodes_info_;
|
||||
std::unordered_set<std::string> finish_nodes_id_;
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/core/node_recovery.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
|
||||
bool NodeRecovery::Recover() {
|
||||
if (recovery_storage_ == nullptr) {
|
||||
return false;
|
||||
}
|
||||
// 1. recover worker num
|
||||
if (recovery_storage_->Exists(kRecoveryWorkerNum)) {
|
||||
int32_t worker_num = std::strtol(recovery_storage_->Get(kRecoveryWorkerNum, "").c_str(), nullptr, 10);
|
||||
node_->set_worker_num(worker_num);
|
||||
} else {
|
||||
node_->set_worker_num(PSContext::instance()->cluster_config().initial_worker_num);
|
||||
}
|
||||
|
||||
// 2. recover server num
|
||||
if (recovery_storage_->Exists(kRecoveryServerNum)) {
|
||||
int32_t server_num = std::strtol(recovery_storage_->Get(kRecoveryServerNum, "").c_str(), nullptr, 10);
|
||||
node_->set_server_num(server_num);
|
||||
} else {
|
||||
node_->set_server_num(PSContext::instance()->cluster_config().initial_server_num);
|
||||
}
|
||||
|
||||
// 3. recover scheduler ip
|
||||
if (recovery_storage_->Exists(kRecoverySchedulerIp)) {
|
||||
std::string scheduler_ip = recovery_storage_->Get(kRecoverySchedulerIp, "");
|
||||
node_->set_scheduler_ip(scheduler_ip);
|
||||
} else {
|
||||
node_->set_scheduler_ip(PSContext::instance()->cluster_config().scheduler_host);
|
||||
}
|
||||
|
||||
// 4. recover scheduler port
|
||||
if (recovery_storage_->Exists(kRecoverySchedulerPort)) {
|
||||
uint16_t scheduler_port = std::strtol(recovery_storage_->Get(kRecoverySchedulerPort, "").c_str(), nullptr, 10);
|
||||
node_->set_scheduler_port(scheduler_port);
|
||||
} else {
|
||||
node_->set_scheduler_port(PSContext::instance()->cluster_config().scheduler_port);
|
||||
}
|
||||
MS_LOG(INFO) << "The worker num:" << node_->worker_num() << ", the server num:" << node_->server_num()
|
||||
<< ", the scheduler ip:" << node_->scheduler_ip() << ", the scheduler port:" << node_->scheduler_port();
|
||||
return true;
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_CORE_NODE_RECOVERY_H_
|
||||
#define MINDSPORE_CCSRC_PS_CORE_NODE_RECOVERY_H_
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ps/core/recovery_base.h"
|
||||
#include "ps/constants.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ps/core/file_configuration.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "ps/core/abstract_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
// The class helps worker/server node to do recovery operation for the cluster.
|
||||
class NodeRecovery : public RecoveryBase {
|
||||
public:
|
||||
explicit NodeRecovery(AbstractNode *const node) : node_(node) {}
|
||||
~NodeRecovery() override = default;
|
||||
|
||||
bool Recover() override;
|
||||
|
||||
private:
|
||||
// The node_ will only be instantiated with worker/server node.
|
||||
AbstractNode *const node_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_CORE_NODE_RECOVERY_H_
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/core/recovery_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
|
||||
void RecoveryBase::Initialize(const std::string &config_json) {
|
||||
nlohmann::json recovery_config;
|
||||
try {
|
||||
recovery_config = nlohmann::json::parse(config_json);
|
||||
} catch (nlohmann::json::exception &e) {
|
||||
MS_LOG(ERROR) << "Parse the json:" << config_json;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "The scheduelr is support recovery.";
|
||||
std::string storage_file_path = "";
|
||||
std::string type = recovery_config.at(kStoreType);
|
||||
if (type == kFileStorage) {
|
||||
storage_type_ = StorageType::kFileStorage;
|
||||
|
||||
storage_file_path = recovery_config.at(kStoreFilePath);
|
||||
|
||||
if (storage_file_path == "") {
|
||||
MS_LOG(EXCEPTION) << "If the scheduler support recovery, and if the persistent storage is a file, the path of "
|
||||
"the file must be configured";
|
||||
}
|
||||
|
||||
recovery_storage_ = std::make_unique<FileConfiguration>(storage_file_path);
|
||||
|
||||
if (!recovery_storage_->Initialize()) {
|
||||
MS_LOG(INFO) << "The storage file path " << storage_file_path << " is empty.";
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "The storage type is:" << storage_type_ << ", the storage file path is:" << storage_file_path;
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_CORE_RECOVERY_BASE_H_
|
||||
#define MINDSPORE_CCSRC_PS_CORE_RECOVERY_BASE_H_
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ps/constants.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ps/core/file_configuration.h"
|
||||
#include "ps/ps_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
enum class StorageType : int { kFileStorage = 1 };
|
||||
// RecoveryBase is used to parse configuration items related to recovery.
|
||||
// It is the base class of SchedulerRecovery and NodeRecovery.
|
||||
class RecoveryBase {
|
||||
public:
|
||||
RecoveryBase() : recovery_storage_(nullptr), storage_type_(StorageType::kFileStorage) {}
|
||||
|
||||
virtual ~RecoveryBase() = default;
|
||||
|
||||
// Initialize the recovery configuration item and get the storage type of recovery.
|
||||
void Initialize(const std::string &json_config);
|
||||
|
||||
// The node needs to recover metadata information when it starts.
|
||||
virtual bool Recover() = 0;
|
||||
|
||||
protected:
|
||||
// Persistent storage used to save metadata.
|
||||
std::unique_ptr<Configuration> recovery_storage_;
|
||||
|
||||
// Storage type for recovery,Currently only supports storage of file types
|
||||
StorageType storage_type_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_CORE_RECOVERY_BASE_H_
|
|
@ -133,17 +133,21 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
|
|||
|
||||
if (node_manager_.IsAllNodesRegistered()) {
|
||||
is_ready_ = true;
|
||||
MS_LOG(INFO) << "Scheduler send meta data to worker/server.";
|
||||
auto node_infos = node_manager_.nodes_info();
|
||||
if (node_manager_.GetClusterState() == ClusterState::ClUSTER_STARTING) {
|
||||
node_infos.clear();
|
||||
node_infos = node_manager_.registered_nodes_info();
|
||||
MS_LOG(INFO) << "All nodes is registered, scheduler send meta data to worker/server.";
|
||||
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_SCALE_IN) {
|
||||
auto nodes = node_manager_.nodes_info();
|
||||
for (const auto &id : scale_in_node_ids_) {
|
||||
MS_LOG(INFO) << "The scheduler send metadata to scale in node:" << id;
|
||||
auto scale_in_client = GetOrCreateClient(nodes[id]);
|
||||
SendMetadata(scale_in_client, nodes[id].rank_id_);
|
||||
}
|
||||
}
|
||||
node_manager_.UpdateNodesInfo();
|
||||
auto node_infos = node_manager_.nodes_info();
|
||||
for (const auto &kvs : node_infos) {
|
||||
auto client = GetOrCreateClient(kvs.second);
|
||||
SendMetadata(client, kvs.second.rank_id_);
|
||||
}
|
||||
node_manager_.UpdateNodesInfo();
|
||||
wait_start_cond_.notify_all();
|
||||
}
|
||||
}
|
||||
|
@ -341,8 +345,9 @@ void SchedulerNode::SendEvent(const std::shared_ptr<TcpClient> &client, const ui
|
|||
|
||||
if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, event_resp_message.SerializeAsString().data(),
|
||||
event_resp_message.ByteSizeLong())) {
|
||||
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << " send event resp timeout!";
|
||||
MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << " send event resp timeout!";
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
|
@ -463,6 +468,7 @@ void SchedulerNode::ProcessScaleOut(std::shared_ptr<HttpMessageHandler> resp) {
|
|||
node_manager_.set_worker_num(total_worker_num);
|
||||
node_manager_.set_server_num(total_server_num);
|
||||
node_manager_.set_total_node_num(total_worker_num + total_server_num);
|
||||
|
||||
node_manager_.UpdateClusterState(ClusterState::CLUSTER_SCALE_OUT);
|
||||
auto node_infos = node_manager_.nodes_info();
|
||||
node_manager_.ResetMetadata();
|
||||
|
@ -500,20 +506,20 @@ void SchedulerNode::ProcessScaleIn(std::shared_ptr<HttpMessageHandler> resp) {
|
|||
return;
|
||||
}
|
||||
|
||||
std::vector<std::string> scale_in_node_ids;
|
||||
status = resp->ParseNodeIdsFromKey(kNodesIds, &scale_in_node_ids);
|
||||
scale_in_node_ids_.clear();
|
||||
status = resp->ParseNodeIdsFromKey(kNodesIds, &scale_in_node_ids_);
|
||||
if (status != RequestProcessResultCode::kSuccess) {
|
||||
resp->ErrorResponse(HTTP_BADREQUEST, status);
|
||||
return;
|
||||
}
|
||||
|
||||
status = CheckIfNodeIdLegal(scale_in_node_ids);
|
||||
status = CheckIfNodeIdLegal(scale_in_node_ids_);
|
||||
if (status != RequestProcessResultCode::kSuccess) {
|
||||
resp->ErrorResponse(HTTP_BADREQUEST, status);
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(WARNING) << "The scale in node ids:" << scale_in_node_ids;
|
||||
MS_LOG(WARNING) << "The scale in node ids:" << scale_in_node_ids_;
|
||||
|
||||
std::unordered_map<std::string, bool> scale_in_nodes;
|
||||
|
||||
|
@ -521,7 +527,7 @@ void SchedulerNode::ProcessScaleIn(std::shared_ptr<HttpMessageHandler> resp) {
|
|||
int32_t scale_server_num = 0;
|
||||
auto node_infos = node_manager_.nodes_info();
|
||||
node_manager_.ResetMetadata();
|
||||
for (auto const &val : scale_in_node_ids) {
|
||||
for (auto const &val : scale_in_node_ids_) {
|
||||
if (node_infos.count(val)) {
|
||||
scale_in_nodes[val] = true;
|
||||
NodeInfo info = node_infos[val];
|
||||
|
|
|
@ -39,6 +39,7 @@
|
|||
#include "ps/core/cluster_metadata.h"
|
||||
#include "ps/core/communicator/http_server.h"
|
||||
#include "ps/core/leader_scaler.h"
|
||||
#include "ps/core/recovery_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -53,7 +54,8 @@ class SchedulerNode : public Node {
|
|||
http_server_(nullptr),
|
||||
client_thread_(nullptr),
|
||||
is_client_started_(false),
|
||||
leader_scaler_(nullptr) {}
|
||||
leader_scaler_(nullptr),
|
||||
scheduler_recovery_(nullptr) {}
|
||||
~SchedulerNode() override;
|
||||
|
||||
typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
|
@ -65,6 +67,7 @@ class SchedulerNode : public Node {
|
|||
|
||||
private:
|
||||
void Initialize();
|
||||
|
||||
void InitCommandHandler();
|
||||
void CreateTcpServer();
|
||||
void StartUpdateClusterStateTimer();
|
||||
|
@ -144,6 +147,12 @@ class SchedulerNode : public Node {
|
|||
std::unique_ptr<LeaderScaler> leader_scaler_;
|
||||
|
||||
std::unordered_map<std::string, OnRequestReceive> callbacks_;
|
||||
|
||||
// Used to persist and obtain metadata information for scheduler.
|
||||
std::unique_ptr<RecoveryBase> scheduler_recovery_;
|
||||
|
||||
// The node id of scale in nodes.
|
||||
std::vector<std::string> scale_in_node_ids_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/core/scheduler_recovery.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
void SchedulerRecovery::Persist(const std::string &key, const std::string &value) {
|
||||
recovery_storage_->Put(key, value);
|
||||
}
|
||||
|
||||
std::string SchedulerRecovery::GetMetadata(const std::string &key) { return recovery_storage_->Get(key, ""); }
|
||||
|
||||
bool SchedulerRecovery::Recover() { return true; }
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_CORE_SCHEDULER_RECOVERY_H_
|
||||
#define MINDSPORE_CCSRC_PS_CORE_SCHEDULER_RECOVERY_H_
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ps/constants.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ps/core/file_configuration.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "ps/core/recovery_base.h"
|
||||
#include "ps/core/scheduler_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
// The class helps scheduler node to do recovery operation for the cluster.
|
||||
class SchedulerRecovery : public RecoveryBase {
|
||||
public:
|
||||
SchedulerRecovery() = default;
|
||||
~SchedulerRecovery() override = default;
|
||||
|
||||
// Persist metadata to storage.
|
||||
void Persist(const std::string &key, const std::string &value);
|
||||
|
||||
bool Recover() override;
|
||||
|
||||
// Get metadata from storage.
|
||||
std::string GetMetadata(const std::string &key);
|
||||
|
||||
private:
|
||||
// The node_ will only be instantiated with worker/server node.
|
||||
SchedulerNode *const node_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_CORE_SCHEDULER_RECOVERY_H_
|
|
@ -84,6 +84,15 @@ void ServerNode::CreateTcpServer() {
|
|||
}
|
||||
|
||||
void ServerNode::Initialize() {
|
||||
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
|
||||
if (!config_->Initialize()) {
|
||||
MS_LOG(INFO) << "The config file is empty, then init node by context.";
|
||||
InitNodeNum();
|
||||
} else {
|
||||
if (!Recover()) {
|
||||
MS_LOG(WARNING) << "Recover the server node is failed.";
|
||||
}
|
||||
}
|
||||
InitServerHandler();
|
||||
CreateTcpServer();
|
||||
is_already_stopped_ = false;
|
||||
|
@ -91,7 +100,6 @@ void ServerNode::Initialize() {
|
|||
|
||||
MS_LOG(INFO) << "[Server start]: 2. Server node create tcp server successful!";
|
||||
|
||||
InitNodeNum();
|
||||
InitCommandHandler();
|
||||
if (!InitClientToScheduler()) {
|
||||
MS_LOG(EXCEPTION) << "Server node connect to scheduler timedout!";
|
||||
|
|
|
@ -78,5 +78,4 @@ class ServerNode : public AbstractNode {
|
|||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PS_CORE_SERVER_NODE_H_
|
||||
|
|
|
@ -42,13 +42,21 @@ bool WorkerNode::Start(const uint32_t &timeout) {
|
|||
|
||||
void WorkerNode::Initialize() {
|
||||
is_already_stopped_ = false;
|
||||
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
|
||||
if (!config_->Initialize()) {
|
||||
MS_LOG(INFO) << "The config file is empty, then init node by context.";
|
||||
InitNodeNum();
|
||||
} else {
|
||||
if (!Recover()) {
|
||||
MS_LOG(WARNING) << "Recover the worker node is failed.";
|
||||
}
|
||||
}
|
||||
InitServerHandler();
|
||||
CreateTcpServer();
|
||||
InitNodeInfo(NodeRole::WORKER);
|
||||
|
||||
MS_LOG(INFO) << "[Worker start]: 2. Worker node create tcp server successful!";
|
||||
|
||||
InitNodeNum();
|
||||
InitCommandHandler();
|
||||
if (!InitClientToScheduler()) {
|
||||
MS_LOG(EXCEPTION) << "Worker node connect to scheduler timeout!";
|
||||
|
|
|
@ -317,5 +317,9 @@ core::ClusterConfig &PSContext::cluster_config() {
|
|||
void PSContext::set_scheduler_manage_port(uint16_t sched_port) { scheduler_manage_port_ = sched_port; }
|
||||
|
||||
uint16_t PSContext::scheduler_manage_port() const { return scheduler_manage_port_; }
|
||||
|
||||
void PSContext::set_config_file_path(const std::string &path) { config_file_path_ = path; }
|
||||
|
||||
std::string PSContext::config_file_path() const { return config_file_path_; }
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -148,6 +148,9 @@ class PSContext {
|
|||
void set_scheduler_manage_port(uint16_t sched_port);
|
||||
uint16_t scheduler_manage_port() const;
|
||||
|
||||
void set_config_file_path(const std::string &path);
|
||||
std::string config_file_path() const;
|
||||
|
||||
private:
|
||||
PSContext()
|
||||
: ps_enabled_(false),
|
||||
|
@ -176,7 +179,8 @@ class PSContext {
|
|||
client_learning_rate_(0.001),
|
||||
secure_aggregation_(false),
|
||||
cluster_config_(nullptr),
|
||||
scheduler_manage_port_(0) {}
|
||||
scheduler_manage_port_(0),
|
||||
config_file_path_("") {}
|
||||
bool ps_enabled_;
|
||||
bool is_worker_;
|
||||
bool is_pserver_;
|
||||
|
@ -238,6 +242,9 @@ class PSContext {
|
|||
|
||||
// The port used by scheduler to receive http requests for scale out or scale in.
|
||||
uint16_t scheduler_manage_port_;
|
||||
|
||||
// The path of the configuration file, used to configure the certification path and persistent storage type, etc.
|
||||
std::string config_file_path_;
|
||||
};
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -53,7 +53,8 @@ _set_ps_context_func_map = {
|
|||
"client_batch_size": ps_context().set_client_batch_size,
|
||||
"client_learning_rate": ps_context().set_client_learning_rate,
|
||||
"enable_ps_ssl": ps_context().set_enable_ssl,
|
||||
"scheduler_manage_port": ps_context().set_scheduler_manage_port
|
||||
"scheduler_manage_port": ps_context().set_scheduler_manage_port,
|
||||
"config_file_path": ps_context().set_config_file_path
|
||||
}
|
||||
|
||||
_get_ps_context_func_map = {
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
{
|
||||
"recovery": {
|
||||
"storge_type": 1,
|
||||
"kStoreFilePath": "recovery.json"
|
||||
}
|
||||
}
|
|
@ -163,8 +163,6 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
list(REMOVE_ITEM MINDSPORE_SRC_LIST
|
||||
"../../../mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/util.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/internal/worker.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/internal/parameter_server.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/scheduler.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info_builder.cc")
|
||||
|
|
Loading…
Reference in New Issue