forked from mindspore-Ecosystem/mindspore
!31466 code sync and fix bug
Merge pull request !31466 from tan-wei-cheng-3260/develop-twc-master
This commit is contained in:
commit
a05c9816c3
|
@ -23,10 +23,10 @@ namespace mindspore {
|
|||
namespace fl {
|
||||
namespace server {
|
||||
namespace {
|
||||
const char *kCollectivePhaseRing = "ring";
|
||||
const char *kCollectivePhaseGather = "gather";
|
||||
const char *kCollectivePhaseReduce = "reduce";
|
||||
const char *kCollectivePhaseBroadcast = "broadcast";
|
||||
const char kCollectivePhaseRing[] = "ring";
|
||||
const char kCollectivePhaseGather[] = "gather";
|
||||
const char kCollectivePhaseReduce[] = "reduce";
|
||||
const char kCollectivePhaseBroadcast[] = "broadcast";
|
||||
} // namespace
|
||||
|
||||
void CollectiveOpsImpl::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node) {
|
||||
|
@ -292,7 +292,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const std::string &data_name, c
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *const recvbuff, size_t send_count) {
|
||||
bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *recvbuff, size_t send_count) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(node_, false);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false);
|
||||
|
@ -411,7 +411,7 @@ bool CollectiveOpsImpl::Broadcast(const void *sendbuff, void *recvbuff, size_t c
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
bool CollectiveOpsImpl::AllReduce(const std::string &data_name, const void *sendbuff, void *recvbuff, size_t count) {
|
||||
bool CollectiveOpsImpl::AllReduce(const std::string &data_name, void *sendbuff, void *recvbuff, size_t count) {
|
||||
// The collective communication API does not support calling Send and Recv concurrently with multiple threads;
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false);
|
||||
|
@ -441,7 +441,7 @@ bool CollectiveOpsImpl::AllReduce(const std::string &data_name, const void *send
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
bool CollectiveOpsImpl::AllGather(const void *sendbuff, void *const recvbuff, size_t send_count,
|
||||
bool CollectiveOpsImpl::AllGather(const void *sendbuff, void *recvbuff, size_t send_count,
|
||||
const std::shared_ptr<ps::core::AbstractNode> &node) {
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(node, false);
|
||||
|
@ -476,7 +476,7 @@ bool CollectiveOpsImpl::AllGather(const void *sendbuff, void *const recvbuff, si
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
bool CollectiveOpsImpl::Broadcast(const void *sendbuff, void *const recvbuff, size_t count, uint32_t root,
|
||||
bool CollectiveOpsImpl::Broadcast(const void *sendbuff, void *recvbuff, size_t count, uint32_t root,
|
||||
const std::shared_ptr<ps::core::AbstractNode> &node,
|
||||
const CommunicationGroupInfo &group_info) {
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
|
@ -515,11 +515,11 @@ bool CollectiveOpsImpl::ReInitForScaling() {
|
|||
return true;
|
||||
}
|
||||
|
||||
template bool CollectiveOpsImpl::AllReduce<float>(const std::string &data_name, const void *sendbuff, void *recvbuff,
|
||||
template bool CollectiveOpsImpl::AllReduce<float>(const std::string &data_name, void *sendbuff, void *recvbuff,
|
||||
size_t count);
|
||||
template bool CollectiveOpsImpl::AllReduce<size_t>(const std::string &data_name, const void *sendbuff, void *recvbuff,
|
||||
template bool CollectiveOpsImpl::AllReduce<size_t>(const std::string &data_name, void *sendbuff, void *recvbuff,
|
||||
size_t count);
|
||||
template bool CollectiveOpsImpl::AllReduce<int>(const std::string &data_name, const void *sendbuff, void *recvbuff,
|
||||
template bool CollectiveOpsImpl::AllReduce<int>(const std::string &data_name, void *sendbuff, void *recvbuff,
|
||||
size_t count);
|
||||
|
||||
template bool CollectiveOpsImpl::AllGather<float>(const void *sendbuff, void *recvbuff, size_t send_count,
|
||||
|
|
|
@ -64,7 +64,7 @@ class CollectiveOpsImpl {
|
|||
void Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node);
|
||||
|
||||
template <typename T>
|
||||
bool AllReduce(const std::string &data_name, const void *sendbuff, void *recvbuff, size_t count);
|
||||
bool AllReduce(const std::string &data_name, void *sendbuff, void *recvbuff, size_t count);
|
||||
|
||||
template <typename T>
|
||||
bool AllGather(const void *sendbuff, void *recvbuff, size_t send_count,
|
||||
|
|
|
@ -242,6 +242,7 @@ constexpr auto kCtxGetKeysClientList = "get_keys_client_list";
|
|||
constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size";
|
||||
constexpr auto kCtxCipherPrimer = "cipher_primer";
|
||||
constexpr auto kCurrentIteration = "current_iteration";
|
||||
constexpr auto kInstanceState = "instance_state";
|
||||
const char PYTHON_MOD_SERIALIZE_MODULE[] = "mindspore.train.serialization";
|
||||
const char PYTHON_MOD_SAFE_WEIGHT[] = "_save_weight";
|
||||
|
||||
|
@ -295,6 +296,31 @@ enum class ResultCode {
|
|||
// If there's error happened, return kFail.
|
||||
kFail
|
||||
};
|
||||
|
||||
inline std::string GetInstanceStateStr(const InstanceState &instance_state) {
|
||||
switch (instance_state) {
|
||||
case InstanceState::kRunning:
|
||||
return "kRunning";
|
||||
case InstanceState::kFinish:
|
||||
return "kFinish";
|
||||
case InstanceState::kDisable:
|
||||
return "kDisable";
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "InstanceState " << instance_state << " is not supported.";
|
||||
}
|
||||
}
|
||||
|
||||
inline InstanceState GetInstanceState(const std::string &instance_state) {
|
||||
if (instance_state == "kRunning") {
|
||||
return InstanceState::kRunning;
|
||||
} else if (instance_state == "kFinish") {
|
||||
return InstanceState::kFinish;
|
||||
} else if (instance_state == "kDisable") {
|
||||
return InstanceState::kDisable;
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "InstanceState " << instance_state << " is not supported.";
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -165,7 +165,7 @@ bool DistributedCountService::CountReachThreshold(const std::string &name, const
|
|||
std::string reason = "Sending querying whether count reaches " + name +
|
||||
" threshold message to leader server failed" + (fl_id.empty() ? "" : " for fl id " + fl_id);
|
||||
MS_LOG(WARNING) << reason;
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(query_cnt_enough_rsp_msg, false);
|
||||
|
|
|
@ -151,7 +151,7 @@ void Iteration::SetIterationRunning() {
|
|||
server_node_->BroadcastEvent(static_cast<uint32_t>(ps::UserDefineEvent::kIterationRunning));
|
||||
if (server_recovery_ != nullptr) {
|
||||
// Save data to the persistent storage in case the recovery happens at the beginning.
|
||||
if (!server_recovery_->Save(iteration_num_)) {
|
||||
if (!server_recovery_->Save(iteration_num_, instance_state_)) {
|
||||
MS_LOG(WARNING) << "Save recovery data failed.";
|
||||
}
|
||||
}
|
||||
|
@ -207,7 +207,8 @@ bool Iteration::ReInitForUpdatingHyperParams(const std::vector<RoundConfig> &upd
|
|||
for (const auto &round : rounds_) {
|
||||
if (updated_round.name == round->name()) {
|
||||
MS_LOG(INFO) << "Reinitialize for round " << round->name();
|
||||
if (!round->ReInitForUpdatingHyperParams(updated_round.threshold_count, updated_round.time_window)) {
|
||||
if (!round->ReInitForUpdatingHyperParams(updated_round.threshold_count, updated_round.time_window,
|
||||
server_node_->server_num())) {
|
||||
MS_LOG(ERROR) << "Reinitializing for round " << round->name() << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -735,7 +736,7 @@ void Iteration::EndLastIter() {
|
|||
if (server_node_->rank_id() == kLeaderServerRank) {
|
||||
// Save current iteration number for recovery.
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(server_recovery_);
|
||||
if (!server_recovery_->Save(iteration_num_)) {
|
||||
if (!server_recovery_->Save(iteration_num_, instance_state_)) {
|
||||
MS_LOG(WARNING) << "Can't save current iteration number into persistent storage.";
|
||||
}
|
||||
}
|
||||
|
@ -748,7 +749,11 @@ void Iteration::EndLastIter() {
|
|||
set_loss(0.0f);
|
||||
Server::GetInstance().CancelSafeMode();
|
||||
iteration_state_cv_.notify_all();
|
||||
MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";
|
||||
if (iteration_num_ > ps::PSContext::instance()->fl_iteration_num()) {
|
||||
MS_LOG(WARNING) << "The server's training job is finished.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
bool Iteration::ForciblyMoveToNextIteration() {
|
||||
|
@ -914,6 +919,11 @@ void Iteration::UpdateRoundClientUploadLoss(const std::shared_ptr<std::vector<un
|
|||
|
||||
set_loss(loss_ + end_last_iter_rsp.upload_loss());
|
||||
}
|
||||
|
||||
void Iteration::set_instance_state(InstanceState state) {
|
||||
instance_state_ = state;
|
||||
MS_LOG(INFO) << "Server instance state is " << GetInstanceStateStr(instance_state_);
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -132,6 +132,8 @@ class Iteration {
|
|||
// The round kernels whose Launch method has not returned yet.
|
||||
std::atomic_uint32_t running_round_num_;
|
||||
|
||||
void set_instance_state(InstanceState staet);
|
||||
|
||||
private:
|
||||
Iteration()
|
||||
: running_round_num_(0),
|
||||
|
|
|
@ -28,9 +28,8 @@ bool IterationMetrics::Initialize() {
|
|||
config_ = std::make_unique<ps::core::FileConfiguration>(config_file_path_);
|
||||
MS_EXCEPTION_IF_NULL(config_);
|
||||
if (!config_->Initialize()) {
|
||||
MS_LOG(WARNING) << "Initializing for metrics failed. Config file path " << config_file_path_
|
||||
<< " may be invalid or not exist.";
|
||||
return false;
|
||||
MS_LOG(EXCEPTION) << "Initializing for metrics failed. Config file path " << config_file_path_
|
||||
<< " may be invalid or not exist.";
|
||||
}
|
||||
|
||||
// Read the metrics file path. If file is not set or not exits, create one.
|
||||
|
|
|
@ -57,17 +57,7 @@ void RoundKernel::set_stop_timer_cb(const StopTimerCb &timer_stopper) { stop_tim
|
|||
|
||||
void RoundKernel::SendResponseMsg(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data,
|
||||
size_t len) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(WARNING) << "The message handler is nullptr.";
|
||||
return;
|
||||
}
|
||||
if (data == nullptr || len == 0) {
|
||||
std::string reason = "The output of the round " + name_ + " is empty.";
|
||||
MS_LOG(WARNING) << reason;
|
||||
if (!message->SendResponse(reason.c_str(), reason.size())) {
|
||||
MS_LOG(WARNING) << "Sending response failed.";
|
||||
return;
|
||||
}
|
||||
if (!verifyResponse(message, data, len)) {
|
||||
return;
|
||||
}
|
||||
IncreaseTotalClientNum();
|
||||
|
@ -79,17 +69,7 @@ void RoundKernel::SendResponseMsg(const std::shared_ptr<ps::core::MessageHandler
|
|||
|
||||
void RoundKernel::SendResponseMsgInference(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data,
|
||||
size_t len, ps::core::RefBufferRelCallback cb) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(WARNING) << "The message handler is nullptr.";
|
||||
return;
|
||||
}
|
||||
if (data == nullptr || len == 0) {
|
||||
std::string reason = "The output of the round " + name_ + " is empty.";
|
||||
MS_LOG(WARNING) << reason;
|
||||
if (!message->SendResponse(reason.c_str(), reason.size())) {
|
||||
MS_LOG(WARNING) << "Sending response failed.";
|
||||
return;
|
||||
}
|
||||
if (!verifyResponse(message, data, len)) {
|
||||
return;
|
||||
}
|
||||
IncreaseTotalClientNum();
|
||||
|
@ -99,6 +79,23 @@ void RoundKernel::SendResponseMsgInference(const std::shared_ptr<ps::core::Messa
|
|||
}
|
||||
}
|
||||
|
||||
bool RoundKernel::verifyResponse(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data,
|
||||
size_t len) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(WARNING) << "The message handler is nullptr.";
|
||||
return false;
|
||||
}
|
||||
if (data == nullptr || len == 0) {
|
||||
std::string reason = "The output of the round " + name_ + " is empty.";
|
||||
MS_LOG(WARNING) << reason;
|
||||
if (!message->SendResponse(reason.c_str(), reason.size())) {
|
||||
MS_LOG(WARNING) << "Sending response failed.";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void RoundKernel::IncreaseTotalClientNum() { total_client_num_ += 1; }
|
||||
|
||||
void RoundKernel::IncreaseAcceptClientNum() { accept_client_num_ += 1; }
|
||||
|
|
|
@ -100,6 +100,8 @@ class RoundKernel {
|
|||
|
||||
float upload_loss() const;
|
||||
|
||||
bool verifyResponse(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data, size_t len);
|
||||
|
||||
protected:
|
||||
// Send response to client, and the data can be released after the call.
|
||||
void SendResponseMsg(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data, size_t len);
|
||||
|
|
|
@ -41,6 +41,10 @@ const size_t LocalMetaStore::curr_iter_num() {
|
|||
return curr_iter_num_;
|
||||
}
|
||||
|
||||
void LocalMetaStore::set_curr_instance_state(InstanceState instance_state) { instance_state_ = instance_state; }
|
||||
|
||||
const InstanceState LocalMetaStore::curr_instance_state() { return instance_state_; }
|
||||
|
||||
const void LocalMetaStore::put_aggregation_feature_map(const std::string &name, const Feature &feature) {
|
||||
if (aggregation_feature_map_.count(name) > 0) {
|
||||
MS_LOG(WARNING) << "Put feature " << name << " failed.";
|
||||
|
|
|
@ -79,6 +79,10 @@ class LocalMetaStore {
|
|||
void set_curr_iter_num(size_t num);
|
||||
const size_t curr_iter_num();
|
||||
|
||||
void set_curr_instance_state(InstanceState instance_state);
|
||||
|
||||
const InstanceState curr_instance_state();
|
||||
|
||||
const void put_aggregation_feature_map(const std::string &name, const Feature &feature);
|
||||
|
||||
std::unordered_map<std::string, Feature> &aggregation_feature_map();
|
||||
|
@ -96,7 +100,7 @@ class LocalMetaStore {
|
|||
// This mutex makes sure that the operations on key_to_meta_ is threadsafe.
|
||||
std::mutex mtx_;
|
||||
size_t curr_iter_num_{0};
|
||||
|
||||
InstanceState instance_state_;
|
||||
// aggregation_feature_map_ stores model meta data with weight name and size which will be Aggregated.
|
||||
std::unordered_map<std::string, Feature> aggregation_feature_map_;
|
||||
};
|
||||
|
|
|
@ -94,7 +94,8 @@ bool Round::ReInitForScaling(uint32_t server_num) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool Round::ReInitForUpdatingHyperParams(size_t updated_threshold_count, size_t updated_time_window) {
|
||||
bool Round::ReInitForUpdatingHyperParams(size_t updated_threshold_count, size_t updated_time_window,
|
||||
uint32_t server_num) {
|
||||
time_window_ = updated_time_window;
|
||||
threshold_count_ = updated_threshold_count;
|
||||
if (check_count_) {
|
||||
|
@ -105,7 +106,11 @@ bool Round::ReInitForUpdatingHyperParams(size_t updated_threshold_count, size_t
|
|||
}
|
||||
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(kernel_, false);
|
||||
kernel_->InitKernel(threshold_count_);
|
||||
if (name_ == "reconstructSecrets") {
|
||||
kernel_->InitKernel(server_num);
|
||||
} else {
|
||||
kernel_->InitKernel(threshold_count_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -189,7 +194,11 @@ bool Round::IsServerAvailable(std::string *reason) {
|
|||
// If the server state is Disable or Finish, refuse the request.
|
||||
if (Iteration::GetInstance().instance_state() == InstanceState::kDisable ||
|
||||
Iteration::GetInstance().instance_state() == InstanceState::kFinish) {
|
||||
MS_LOG(WARNING) << "The server's training job is disabled or finished, please retry " + name_ + " later.";
|
||||
if (kPrintTimes % kPrintTimesThreshold == 0) {
|
||||
MS_LOG(WARNING) << "The server's training job is disabled or finished, please retry " + name_ + " later.";
|
||||
kPrintTimes = 0;
|
||||
}
|
||||
kPrintTimes += 1;
|
||||
*reason = ps::kJobNotAvailable;
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ class Round {
|
|||
bool ReInitForScaling(uint32_t server_num);
|
||||
|
||||
// After hyper-parameters are updated, some rounds and kernels should be reinitialized.
|
||||
bool ReInitForUpdatingHyperParams(size_t updated_threshold_count, size_t updated_time_window);
|
||||
bool ReInitForUpdatingHyperParams(size_t updated_threshold_count, size_t updated_time_window, uint32_t server_num);
|
||||
|
||||
// Bind a round kernel to this Round. This method should be called after Initialize.
|
||||
void BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel);
|
||||
|
|
|
@ -515,6 +515,7 @@ void Server::Recover() {
|
|||
// Set the recovery handler to Iteration.
|
||||
MS_EXCEPTION_IF_NULL(iteration_);
|
||||
iteration_->set_recovery_handler(server_recovery_);
|
||||
iteration_->set_instance_state(LocalMetaStore::GetInstance().curr_instance_state());
|
||||
}
|
||||
|
||||
void Server::ProcessBeforeScalingOut() {
|
||||
|
|
|
@ -75,13 +75,18 @@ bool ServerRecovery::Recover() {
|
|||
return false;
|
||||
}
|
||||
uint64_t current_iter = JsonGetKeyWithException<uint64_t>(server_recovery_json, kCurrentIteration);
|
||||
std::string instance_state = JsonGetKeyWithException<std::string>(server_recovery_json, kInstanceState);
|
||||
|
||||
LocalMetaStore::GetInstance().set_curr_iter_num(current_iter);
|
||||
MS_LOG(INFO) << "Recover from persistent storage: current iteration number is " << current_iter;
|
||||
LocalMetaStore::GetInstance().set_curr_instance_state(GetInstanceState(instance_state));
|
||||
|
||||
MS_LOG(INFO) << "Recover from persistent storage: current iteration number is " << current_iter
|
||||
<< ", instance state is " << instance_state;
|
||||
server_recovery_file_.close();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ServerRecovery::Save(uint64_t current_iter) {
|
||||
bool ServerRecovery::Save(uint64_t current_iter, InstanceState instance_state) {
|
||||
std::unique_lock<std::mutex> lock(server_recovery_file_mtx_);
|
||||
server_recovery_file_.open(server_recovery_file_path_, std::ios::out | std::ios::ate);
|
||||
if (!server_recovery_file_.good() || !server_recovery_file_.is_open()) {
|
||||
|
@ -92,6 +97,7 @@ bool ServerRecovery::Save(uint64_t current_iter) {
|
|||
|
||||
nlohmann::json server_metadata_json;
|
||||
server_metadata_json[kCurrentIteration] = current_iter;
|
||||
server_metadata_json[kInstanceState] = GetInstanceStateStr(instance_state);
|
||||
server_recovery_file_ << server_metadata_json;
|
||||
server_recovery_file_.close();
|
||||
return true;
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
#include "fl/server/common.h"
|
||||
#include "ps/core/recovery_base.h"
|
||||
#include "ps/core/file_configuration.h"
|
||||
#include "ps/core/communicator/tcp_communicator.h"
|
||||
|
@ -45,7 +46,7 @@ class ServerRecovery : public ps::core::RecoveryBase {
|
|||
bool Recover() override;
|
||||
|
||||
// Save server's metadata to persistent storage.
|
||||
bool Save(uint64_t current_iter);
|
||||
bool Save(uint64_t current_iter, InstanceState instance_state);
|
||||
|
||||
// If this server recovers, need to notify cluster to reach consistency.
|
||||
bool SyncAfterRecovery(const std::shared_ptr<ps::core::TcpCommunicator> &communicator, uint32_t rank_id);
|
||||
|
|
|
@ -693,12 +693,7 @@ void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta
|
|||
HeartbeatRespMessage heartbeat_resp_message;
|
||||
CHECK_RETURN_TYPE(heartbeat_resp_message.ParseFromArray(data, SizeToInt(size)));
|
||||
|
||||
if (heartbeat_resp_message.cluster_state() != current_cluster_state_) {
|
||||
MS_LOG(INFO) << "cluster change state from:" << CommUtil::ClusterStateToString(current_cluster_state_) << " to "
|
||||
<< CommUtil::ClusterStateToString(heartbeat_resp_message.cluster_state());
|
||||
}
|
||||
|
||||
current_cluster_state_ = heartbeat_resp_message.cluster_state();
|
||||
UpdateClusterState(heartbeat_resp_message.cluster_state());
|
||||
MS_LOG(DEBUG) << "The current cluster state from heartbeat:"
|
||||
<< CommUtil::ClusterStateToString(current_cluster_state_);
|
||||
|
||||
|
@ -1338,11 +1333,10 @@ bool AbstractNode::Recover() {
|
|||
MS_LOG(INFO) << "The node is support recovery.";
|
||||
node_recovery_ = std::make_unique<NodeRecovery>(this);
|
||||
MS_EXCEPTION_IF_NULL(node_recovery_);
|
||||
if (!node_recovery_->Initialize(config_->Get(kKeyRecovery, ""))) {
|
||||
MS_LOG(ERROR) << "Initializing node recovery failed.";
|
||||
return false;
|
||||
if (node_recovery_->Initialize(config_->Get(kKeyRecovery, ""))) {
|
||||
MS_LOG(INFO) << "Initializing node recovery success.";
|
||||
return node_recovery_->Recover();
|
||||
}
|
||||
return node_recovery_->Recover();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -1419,8 +1413,16 @@ void AbstractNode::CreateTcpServer() {
|
|||
|
||||
void AbstractNode::UpdateClusterState(const ClusterState &state) {
|
||||
std::lock_guard<std::mutex> lock(cluster_state_mutex_);
|
||||
std::string state_str = CommUtil::ClusterStateToString(state);
|
||||
if (state_str.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (state == current_cluster_state_) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "[state]: Cluster state change from:" << CommUtil::ClusterStateToString(current_cluster_state_)
|
||||
<< " to " << CommUtil::ClusterStateToString(state);
|
||||
<< " to " << state_str;
|
||||
current_cluster_state_ = state;
|
||||
}
|
||||
|
||||
|
|
|
@ -218,15 +218,9 @@ void CommUtil::LogCallback(int severity, const char *msg) {
|
|||
}
|
||||
}
|
||||
|
||||
bool CommUtil::IsFileExists(const std::string &file) {
|
||||
std::ifstream f(file.c_str());
|
||||
if (!f.good()) {
|
||||
return false;
|
||||
} else {
|
||||
f.close();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
bool CommUtil::IsFileExists(const std::string &file) { return access(file.c_str(), F_OK) != -1; }
|
||||
|
||||
bool CommUtil::IsFileReadable(const std::string &file) { return access(file.c_str(), R_OK) != -1; }
|
||||
|
||||
bool CommUtil::IsFileEmpty(const std::string &file) {
|
||||
if (!IsFileExists(file)) {
|
||||
|
|
|
@ -136,6 +136,7 @@ class CommUtil {
|
|||
static bool checkCRLTime(const std::string &crlPath);
|
||||
static bool CreateDirectory(const std::string &directoryPath);
|
||||
static bool CheckHttpUrl(const std::string &http_url);
|
||||
static bool IsFileReadable(const std::string &file);
|
||||
|
||||
private:
|
||||
static std::random_device rd;
|
||||
|
|
|
@ -69,6 +69,8 @@ class Configuration {
|
|||
|
||||
// storage meta data without nodes
|
||||
virtual void PersistNodes(const core::ClusterConfig &clusterConfig) const = 0;
|
||||
|
||||
virtual std::string file_path() const = 0;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -29,9 +29,12 @@ bool FileConfiguration::Initialize() {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!CommUtil::IsFileReadable(file_path_)) {
|
||||
MS_LOG(EXCEPTION) << "The file path: " << file_path_ << " is not readable.";
|
||||
}
|
||||
|
||||
if (CommUtil::IsFileEmpty(file_path_)) {
|
||||
MS_LOG(INFO) << "The file: " << file_path_ << " is empty.";
|
||||
return true;
|
||||
MS_LOG(EXCEPTION) << "The file path: " << file_path_ << " content is empty.";
|
||||
}
|
||||
|
||||
std::ifstream json_file(file_path_);
|
||||
|
@ -142,6 +145,8 @@ void FileConfiguration::PersistFile(const core::ClusterConfig &clusterConfig) co
|
|||
output_file.close();
|
||||
MS_LOG(INFO) << "The meta data persist to " << file_path_;
|
||||
}
|
||||
|
||||
std::string FileConfiguration::file_path() const { return file_path_; }
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -89,6 +89,8 @@ class BACKEND_EXPORT FileConfiguration : public Configuration {
|
|||
|
||||
void PersistNodes(const core::ClusterConfig &clusterConfig) const override;
|
||||
|
||||
std::string file_path() const override;
|
||||
|
||||
private:
|
||||
// The path of the configuration file.
|
||||
std::string file_path_;
|
||||
|
|
|
@ -45,8 +45,30 @@ uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage ®ister_message
|
|||
<< ", the node_role:" << CommUtil::NodeRoleToString(registered_nodes_info_[node_id].node_role_);
|
||||
return rank_id;
|
||||
}
|
||||
|
||||
core::ClusterConfig &clusterConfig = PSContext::instance()->cluster_config();
|
||||
std::unordered_map<std::string, NodeInfo> recovery_node_infos = clusterConfig.initial_registered_nodes_infos;
|
||||
|
||||
// This is for scheduler recovery
|
||||
(void)ReAddNodeIfNotExists(node_id, register_message.ip(), register_message.port(), &rank_id);
|
||||
if (PSContext::instance()->server_mode() == kServerModeFL ||
|
||||
PSContext::instance()->server_mode() == kServerModeHybrid) {
|
||||
if (recovery_node_infos.find(node_id) != recovery_node_infos.end()) {
|
||||
const std::string &new_ip = register_message.ip();
|
||||
uint32_t new_port = register_message.port();
|
||||
rank_id = recovery_node_infos[node_id].rank_id_;
|
||||
recovery_node_infos[node_id].is_alive = true;
|
||||
recovery_node_infos[node_id].ip_ = new_ip;
|
||||
recovery_node_infos[node_id].port_ = static_cast<uint16_t>(new_port);
|
||||
registered_nodes_info_[node_id] = recovery_node_infos[node_id];
|
||||
MS_LOG(INFO) << "The node id: " << node_id << " is recovery successful!"
|
||||
<< ", ip: " << recovery_node_infos[node_id].ip_ << ", port: " << recovery_node_infos[node_id].port_
|
||||
<< ", rank id: " << rank_id << ", alive: " << recovery_node_infos[node_id].is_alive
|
||||
<< ", the node_role:" << CommUtil::NodeRoleToString(recovery_node_infos[node_id].node_role_);
|
||||
return rank_id;
|
||||
}
|
||||
} else {
|
||||
ReAddNodeIfNotExists(node_id, register_message.ip(), register_message.port());
|
||||
}
|
||||
return rank_id;
|
||||
}
|
||||
|
||||
|
@ -90,24 +112,12 @@ uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message, const
|
|||
const std::string &ip = register_message.ip();
|
||||
uint32_t port = register_message.port();
|
||||
|
||||
auto rank_it = std::find_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [&rank_id](auto item) {
|
||||
bool res = item.second.is_alive == false && item.second.node_role_ == NodeRole::SERVER;
|
||||
if (res) {
|
||||
MS_LOG(INFO) << "The server node id:" << item.first << " rank id:" << item.second.rank_id_ << " is not alive.";
|
||||
rank_id = item.second.rank_id_;
|
||||
}
|
||||
return res;
|
||||
});
|
||||
if (rank_it == registered_nodes_info_.end()) {
|
||||
if (meta->rank_id() != UINT32_MAX && meta->rank_id() < next_server_rank_id_) {
|
||||
rank_id = meta->rank_id();
|
||||
MS_LOG(INFO) << "Use the old rank id:" << rank_id;
|
||||
} else {
|
||||
rank_id = next_server_rank_id_;
|
||||
next_server_rank_id_ += 1;
|
||||
}
|
||||
if (meta->rank_id() != UINT32_MAX && meta->rank_id() < next_server_rank_id_) {
|
||||
rank_id = meta->rank_id();
|
||||
MS_LOG(INFO) << "Use the old rank id:" << rank_id;
|
||||
} else {
|
||||
registered_nodes_info_.erase((*rank_it).first);
|
||||
rank_id = next_server_rank_id_;
|
||||
next_server_rank_id_ += 1;
|
||||
}
|
||||
|
||||
if (rank_id >= meta_data_->server_num) {
|
||||
|
@ -131,25 +141,12 @@ uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message, const
|
|||
const std::string &ip = register_message.ip();
|
||||
uint32_t port = register_message.port();
|
||||
|
||||
auto worker_rank_it =
|
||||
std::find_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [&rank_id](auto item) {
|
||||
bool res = item.second.is_alive == false && item.second.node_role_ == NodeRole::WORKER;
|
||||
if (res) {
|
||||
MS_LOG(INFO) << "The worker node id:" << item.first << " rank id:" << rank_id << " is not alive.";
|
||||
rank_id = item.second.rank_id_;
|
||||
}
|
||||
return res;
|
||||
});
|
||||
if (worker_rank_it == registered_nodes_info_.end()) {
|
||||
if (meta->rank_id() != UINT32_MAX && meta->rank_id() < next_worker_rank_id_) {
|
||||
rank_id = meta->rank_id();
|
||||
MS_LOG(INFO) << "Use the old rank id:" << rank_id;
|
||||
} else {
|
||||
rank_id = next_worker_rank_id_;
|
||||
next_worker_rank_id_ += 1;
|
||||
}
|
||||
if (meta->rank_id() != UINT32_MAX && meta->rank_id() < next_worker_rank_id_) {
|
||||
rank_id = meta->rank_id();
|
||||
MS_LOG(INFO) << "Use the old rank id:" << rank_id;
|
||||
} else {
|
||||
registered_nodes_info_.erase((*worker_rank_it).first);
|
||||
rank_id = next_worker_rank_id_;
|
||||
next_worker_rank_id_ += 1;
|
||||
}
|
||||
|
||||
if (rank_id >= meta_data_->worker_num) {
|
||||
|
@ -265,12 +262,19 @@ void NodeManager::AddScaleOutDoneNode(const std::string &node_id) { scale_out_do
|
|||
|
||||
void NodeManager::AddScaleInDoneNode(const std::string &node_id) { scale_in_done_nodes_id_.insert(node_id); }
|
||||
|
||||
bool NodeManager::IsAllNodesRegistered() const {
|
||||
uint32_t num = std::count_if(registered_nodes_info_.begin(), registered_nodes_info_.end(),
|
||||
[](auto item) { return item.second.is_alive == true; });
|
||||
bool NodeManager::IsAllNodesAlive() const {
|
||||
uint32_t num = std::count_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [](auto item) {
|
||||
if (!item.second.is_alive) {
|
||||
MS_LOG(ERROR) << item.second.node_id_ << " is not alive.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
return num == total_node_num_;
|
||||
}
|
||||
|
||||
bool NodeManager::IsAllNodesRegistered() const { return SizeToUint(registered_nodes_info_.size()) == total_node_num_; }
|
||||
|
||||
bool NodeManager::IsAllNodesFinished() const { return SizeToUint(finish_nodes_id_.size()) == total_node_num_; }
|
||||
|
||||
bool NodeManager::IsAllNodesScaleOutDone() const {
|
||||
|
|
|
@ -137,6 +137,8 @@ class NodeManager {
|
|||
// Determine whether all nodes that need to be persisted are in persistence.
|
||||
bool IsAllNodeInPersisting();
|
||||
|
||||
bool IsAllNodesAlive() const;
|
||||
|
||||
private:
|
||||
std::mutex node_mutex_;
|
||||
std::mutex cluster_mutex_;
|
||||
|
|
|
@ -31,7 +31,7 @@ bool NodeRecovery::Recover() {
|
|||
int32_t worker_num = std::strtol(recovery_storage_->Get(kRecoveryWorkerNum, "").c_str(), nullptr, kBase);
|
||||
node_->set_worker_num(worker_num);
|
||||
} else {
|
||||
node_->set_worker_num(PSContext::instance()->cluster_config().initial_worker_num);
|
||||
MS_LOG(EXCEPTION) << kRecoveryWorkerNum << " is not contained in " << recovery_storage_->file_path();
|
||||
}
|
||||
|
||||
// 2. recover server num
|
||||
|
@ -39,14 +39,14 @@ bool NodeRecovery::Recover() {
|
|||
int32_t server_num = std::strtol(recovery_storage_->Get(kRecoveryServerNum, "").c_str(), nullptr, kBase);
|
||||
node_->set_server_num(server_num);
|
||||
} else {
|
||||
node_->set_server_num(PSContext::instance()->cluster_config().initial_server_num);
|
||||
MS_LOG(EXCEPTION) << kRecoveryServerNum << " is not contained in " << recovery_storage_->file_path();
|
||||
}
|
||||
|
||||
// 3. recover scheduler ip
|
||||
if (recovery_storage_->Exists(kRecoverySchedulerIp)) {
|
||||
node_->set_scheduler_ip(recovery_storage_->GetString(kRecoverySchedulerIp, ""));
|
||||
} else {
|
||||
node_->set_scheduler_ip(PSContext::instance()->cluster_config().scheduler_host);
|
||||
MS_LOG(EXCEPTION) << kRecoverySchedulerIp << " is not contained in " << recovery_storage_->file_path();
|
||||
}
|
||||
|
||||
// 4. recover scheduler port
|
||||
|
@ -54,7 +54,7 @@ bool NodeRecovery::Recover() {
|
|||
uint16_t scheduler_port = std::strtol(recovery_storage_->Get(kRecoverySchedulerPort, "").c_str(), nullptr, kBase);
|
||||
node_->set_scheduler_port(scheduler_port);
|
||||
} else {
|
||||
node_->set_scheduler_port(PSContext::instance()->cluster_config().scheduler_port);
|
||||
MS_LOG(EXCEPTION) << kRecoverySchedulerPort << " is not contained in " << recovery_storage_->file_path();
|
||||
}
|
||||
MS_LOG(INFO) << "The worker num:" << node_->worker_num() << ", the server num:" << node_->server_num()
|
||||
<< ", the scheduler ip:" << node_->scheduler_ip() << ", the scheduler port:" << node_->scheduler_port();
|
||||
|
|
|
@ -50,8 +50,10 @@ bool RecoveryBase::Initialize(const std::string &config_json) {
|
|||
}
|
||||
recovery_storage_ = std::make_unique<FileConfiguration>(storage_file_path);
|
||||
MS_EXCEPTION_IF_NULL(recovery_storage_);
|
||||
if (!recovery_storage_->Initialize()) {
|
||||
MS_LOG(WARNING) << "The storage file path " << storage_file_path << " is empty.";
|
||||
if (recovery_storage_->Initialize()) {
|
||||
MS_LOG(INFO) << "The storage file path " << storage_file_path << " initialize success.";
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -80,8 +82,10 @@ bool RecoveryBase::InitializeNodes(const std::string &config_json) {
|
|||
}
|
||||
scheduler_recovery_storage_ = std::make_unique<FileConfiguration>(scheduler_storage_file_path);
|
||||
MS_EXCEPTION_IF_NULL(scheduler_recovery_storage_);
|
||||
if (!scheduler_recovery_storage_->Initialize()) {
|
||||
MS_LOG(WARNING) << "The scheduler storage file path " << scheduler_storage_file_path << " is empty.";
|
||||
if (scheduler_recovery_storage_->Initialize()) {
|
||||
MS_LOG(INFO) << "The scheduler storage file path " << scheduler_storage_file_path << " initialize success.";
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "the scheduler storage file path is:" << scheduler_storage_file_path;
|
||||
|
|
|
@ -31,12 +31,12 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
|
|||
MS_LOG(INFO) << "[Scheduler start]: 1. Begin to start scheduler node!";
|
||||
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
|
||||
MS_EXCEPTION_IF_NULL(config_);
|
||||
InitNodeMetaData();
|
||||
if (!config_->Initialize()) {
|
||||
MS_LOG(INFO) << "The config file is empty, then init node by context.";
|
||||
InitNodeMetaData();
|
||||
MS_LOG(WARNING) << "The config file is empty.";
|
||||
} else {
|
||||
if (!RecoverScheduler()) {
|
||||
MS_LOG(WARNING) << "Recover the server node is failed.";
|
||||
MS_LOG(DEBUG) << "Recover the server node is failed.";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -338,6 +338,10 @@ void SchedulerNode::ProcessRegister(const std::shared_ptr<TcpServer> &server,
|
|||
SetRegisterConnectionFd(conn, node_id);
|
||||
|
||||
if (node_manager_.IsAllNodesRegistered()) {
|
||||
if (!node_manager_.IsAllNodesAlive()) {
|
||||
MS_LOG(ERROR) << "Do not broadcast nodes info because some server nodes are not alive.";
|
||||
return;
|
||||
}
|
||||
is_ready_ = true;
|
||||
MS_LOG(INFO) << "There are " << node_manager_.worker_num() << " workers and " << node_manager_.server_num()
|
||||
<< " servers registered to scheduer, so the scheduler send meta data to worker/server.";
|
||||
|
@ -1422,17 +1426,22 @@ bool SchedulerNode::RecoverScheduler() {
|
|||
MS_LOG(INFO) << "The scheduler node is support recovery.";
|
||||
scheduler_recovery_ = std::make_unique<SchedulerRecovery>();
|
||||
MS_EXCEPTION_IF_NULL(scheduler_recovery_);
|
||||
(void)scheduler_recovery_->Initialize(config_->Get(kKeyRecovery, ""));
|
||||
(void)scheduler_recovery_->InitializeNodes(config_->Get(kKeyRecovery, ""));
|
||||
|
||||
return scheduler_recovery_->Recover();
|
||||
bool ret = scheduler_recovery_->Initialize(config_->Get(kKeyRecovery, ""));
|
||||
bool ret_node = scheduler_recovery_->InitializeNodes(config_->Get(kKeyRecovery, ""));
|
||||
if (ret && ret_node) {
|
||||
MS_LOG(INFO) << "Scheduler recovery initialize successful.";
|
||||
return scheduler_recovery_->Recover();
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void SchedulerNode::PersistMetaData() {
|
||||
if (scheduler_recovery_ == nullptr) {
|
||||
MS_LOG(WARNING) << "scheduler recovery is null, so don't persist meta data";
|
||||
MS_LOG(WARNING) << "scheduler recovery is null, do not persist meta data";
|
||||
return;
|
||||
}
|
||||
if (!is_ready_) {
|
||||
return;
|
||||
}
|
||||
if (config_->Exists(kKeyRecovery)) {
|
||||
|
|
|
@ -27,9 +27,7 @@ std::string SchedulerRecovery::GetMetadata(const std::string &key) {
|
|||
|
||||
bool SchedulerRecovery::Recover() {
|
||||
std::unique_lock<std::mutex> lock(recovery_mtx_);
|
||||
if (recovery_storage_ == nullptr) {
|
||||
return false;
|
||||
}
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(recovery_storage_, false);
|
||||
core::ClusterConfig &clusterConfig = PSContext::instance()->cluster_config();
|
||||
|
||||
// 1. recover worker num
|
||||
|
@ -38,7 +36,7 @@ bool SchedulerRecovery::Recover() {
|
|||
UlongToUint(std::strtoul(recovery_storage_->Get(kRecoveryWorkerNum, "").c_str(), nullptr, kBase));
|
||||
clusterConfig.initial_worker_num = initial_worker_num;
|
||||
} else {
|
||||
clusterConfig.initial_worker_num = PSContext::instance()->initial_worker_num();
|
||||
MS_LOG(EXCEPTION) << kRecoveryWorkerNum << " is not contained in " << recovery_storage_->file_path();
|
||||
}
|
||||
|
||||
// 2. recover server num
|
||||
|
@ -47,14 +45,14 @@ bool SchedulerRecovery::Recover() {
|
|||
UlongToUint(std::strtoul(recovery_storage_->Get(kRecoveryServerNum, "").c_str(), nullptr, kBase));
|
||||
clusterConfig.initial_server_num = initial_server_num;
|
||||
} else {
|
||||
clusterConfig.initial_server_num = PSContext::instance()->initial_server_num();
|
||||
MS_LOG(EXCEPTION) << kRecoveryServerNum << " is not contained in " << recovery_storage_->file_path();
|
||||
}
|
||||
|
||||
// 3. recover scheduler ip
|
||||
if (recovery_storage_->Exists(kRecoverySchedulerIp)) {
|
||||
clusterConfig.scheduler_host = recovery_storage_->GetString(kRecoverySchedulerIp, "");
|
||||
} else {
|
||||
clusterConfig.scheduler_host = PSContext::instance()->scheduler_host();
|
||||
MS_LOG(EXCEPTION) << kRecoverySchedulerIp << " is not contained in " << recovery_storage_->file_path();
|
||||
}
|
||||
|
||||
// 4. recover scheduler port
|
||||
|
@ -62,7 +60,7 @@ bool SchedulerRecovery::Recover() {
|
|||
uint16_t scheduler_port = std::strtol(recovery_storage_->Get(kRecoverySchedulerPort, "").c_str(), nullptr, kBase);
|
||||
clusterConfig.scheduler_port = scheduler_port;
|
||||
} else {
|
||||
clusterConfig.scheduler_port = PSContext::instance()->scheduler_port();
|
||||
MS_LOG(EXCEPTION) << kRecoverySchedulerPort << " is not contained in " << recovery_storage_->file_path();
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "The worker num:" << clusterConfig.initial_worker_num
|
||||
|
@ -70,15 +68,14 @@ bool SchedulerRecovery::Recover() {
|
|||
<< ", the scheduler ip:" << clusterConfig.scheduler_host
|
||||
<< ", the scheduler port:" << clusterConfig.scheduler_port;
|
||||
|
||||
if (scheduler_recovery_storage_ == nullptr) {
|
||||
MS_LOG(WARNING) << "scheduler recovery storage is null. return false";
|
||||
return false;
|
||||
}
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(scheduler_recovery_storage_, false);
|
||||
// 5. recover total node num
|
||||
if (scheduler_recovery_storage_->Exists(kRecoveryTotalNodeNum)) {
|
||||
uint32_t initial_total_node_num =
|
||||
UlongToUint(std::strtoul(scheduler_recovery_storage_->Get(kRecoveryTotalNodeNum, "").c_str(), nullptr, kBase));
|
||||
clusterConfig.initial_total_node_num = initial_total_node_num;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << kRecoveryTotalNodeNum << " is not contained in " << recovery_storage_->file_path();
|
||||
}
|
||||
|
||||
// 6. recover next worker rank id
|
||||
|
@ -86,6 +83,8 @@ bool SchedulerRecovery::Recover() {
|
|||
uint32_t initial_next_worker_rank_id = UlongToUint(
|
||||
std::strtoul(scheduler_recovery_storage_->Get(kRecoveryNextWorkerRankId, "").c_str(), nullptr, kBase));
|
||||
clusterConfig.initial_next_worker_rank_id = initial_next_worker_rank_id;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << kRecoveryNextWorkerRankId << " is not contained in " << recovery_storage_->file_path();
|
||||
}
|
||||
|
||||
// 7. recover next server rank id
|
||||
|
@ -93,6 +92,8 @@ bool SchedulerRecovery::Recover() {
|
|||
uint32_t initial_next_server_rank_id = UlongToUint(
|
||||
std::strtoul(scheduler_recovery_storage_->Get(kRecoveryNextServerRankId, "").c_str(), nullptr, kBase));
|
||||
clusterConfig.initial_next_server_rank_id = initial_next_server_rank_id;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << kRecoveryNextServerRankId << " is not contained in " << recovery_storage_->file_path();
|
||||
}
|
||||
|
||||
// 8. recover register nodes info
|
||||
|
@ -122,13 +123,11 @@ bool SchedulerRecovery::Recover() {
|
|||
<< recovery_server_num << " initial server num is:" << clusterConfig.initial_server_num;
|
||||
}
|
||||
clusterConfig.initial_registered_nodes_infos = nodes_infos;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << kRecoveryRegisteredNodesInfos << " is not contained in " << recovery_storage_->file_path();
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "The worker num:" << clusterConfig.initial_worker_num
|
||||
<< ", the server num:" << clusterConfig.initial_server_num
|
||||
<< ", the scheduler ip:" << clusterConfig.scheduler_host
|
||||
<< ", the scheduler port:" << clusterConfig.scheduler_port
|
||||
<< ", the initial total node num:" << clusterConfig.initial_total_node_num
|
||||
MS_LOG(INFO) << ", the initial total node num:" << clusterConfig.initial_total_node_num
|
||||
<< ", the initial next worker rank id:" << clusterConfig.initial_next_worker_rank_id
|
||||
<< ", the initial next server rank id:" << clusterConfig.initial_next_server_rank_id;
|
||||
|
||||
|
|
|
@ -42,12 +42,12 @@ bool ServerNode::Start(const uint32_t &timeout) {
|
|||
void ServerNode::Initialize() {
|
||||
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
|
||||
MS_EXCEPTION_IF_NULL(config_);
|
||||
InitNodeNum();
|
||||
if (!config_->Initialize()) {
|
||||
MS_LOG(INFO) << "The config file is empty, then init node by context.";
|
||||
InitNodeNum();
|
||||
MS_LOG(WARNING) << "The config file is empty.";
|
||||
} else {
|
||||
if (!Recover()) {
|
||||
MS_LOG(WARNING) << "Recover the server node is failed.";
|
||||
MS_LOG(DEBUG) << "Recover the server node is failed.";
|
||||
}
|
||||
}
|
||||
InitServerHandler();
|
||||
|
|
|
@ -42,12 +42,12 @@ bool WorkerNode::Start(const uint32_t &timeout) {
|
|||
void WorkerNode::Initialize() {
|
||||
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
|
||||
MS_EXCEPTION_IF_NULL(config_);
|
||||
InitNodeNum();
|
||||
if (!config_->Initialize()) {
|
||||
MS_LOG(INFO) << "The config file is empty, then init node by context.";
|
||||
InitNodeNum();
|
||||
MS_LOG(WARNING) << "The config file is empty.";
|
||||
} else {
|
||||
if (!Recover()) {
|
||||
MS_LOG(WARNING) << "Recover the worker node is failed.";
|
||||
MS_LOG(DEBUG) << "Recover the worker node is failed.";
|
||||
}
|
||||
}
|
||||
InitServerHandler();
|
||||
|
|
|
@ -209,9 +209,8 @@ void PSContext::set_rank_id(uint32_t rank_id) const {
|
|||
}
|
||||
|
||||
void PSContext::set_server_mode(const std::string &server_mode) {
|
||||
if (server_mode != kServerModePS && server_mode != kServerModeFL && server_mode != kServerModeHybrid) {
|
||||
MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL
|
||||
<< " or " << kServerModeHybrid;
|
||||
if (server_mode != kServerModePS && server_mode != kServerModeFL) {
|
||||
MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL;
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Server mode: " << server_mode << " is used for Server and Worker. Scheduler will ignore it.";
|
||||
|
|
Loading…
Reference in New Issue