!22056 Add scheduler Restful interface.
Merge pull request !22056 from ZPaC/add-iteration-metrics
This commit is contained in:
commit
b18ffcf221
|
@ -61,11 +61,6 @@ class FusedPullWeightKernel : public CPUKernel {
|
|||
}
|
||||
|
||||
fl_iteration_++;
|
||||
if (fl_iteration_ > ps::PSContext::instance()->fl_iteration_num()) {
|
||||
MS_LOG(INFO) << ps::PSContext::instance()->fl_iteration_num() << " iterations are completed.";
|
||||
fl_iteration_ = 1;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Launching pulling weight for federated learning iteration " << fl_iteration_;
|
||||
if (!BuildPullWeightReq(fbb)) {
|
||||
MS_LOG(EXCEPTION) << "Building request for FusedPullWeight failed.";
|
||||
|
@ -76,6 +71,10 @@ class FusedPullWeightKernel : public CPUKernel {
|
|||
const schema::ResponsePullWeight *pull_weight_rsp = nullptr;
|
||||
int retcode = schema::ResponseCode_SucNotReady;
|
||||
while (retcode == schema::ResponseCode_SucNotReady) {
|
||||
if (!fl::worker::FLWorker::GetInstance().running()) {
|
||||
MS_LOG(WARNING) << "Worker has finished.";
|
||||
return true;
|
||||
}
|
||||
if (!fl::worker::FLWorker::GetInstance().SendToServer(
|
||||
0, fbb->GetBufferPointer(), fbb->GetSize(), ps::core::TcpUserCommand::kPullWeight, &pull_weight_rsp_msg)) {
|
||||
MS_LOG(WARNING) << "Sending request for FusedPullWeight to server 0 failed. Retry later.";
|
||||
|
|
|
@ -59,11 +59,6 @@ class FusedPushWeightKernel : public CPUKernel {
|
|||
}
|
||||
|
||||
fl_iteration_++;
|
||||
if (fl_iteration_ > ps::PSContext::instance()->fl_iteration_num()) {
|
||||
MS_LOG(INFO) << ps::PSContext::instance()->fl_iteration_num() << " iterations are completed.";
|
||||
fl_iteration_ = 1;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Launching pushing weight for federated learning iteration " << fl_iteration_;
|
||||
if (!BuildPushWeightReq(fbb, inputs)) {
|
||||
MS_LOG(EXCEPTION) << "Building request for FusedPushWeight failed.";
|
||||
|
@ -76,6 +71,10 @@ class FusedPushWeightKernel : public CPUKernel {
|
|||
const schema::ResponsePushWeight *push_weight_rsp = nullptr;
|
||||
int retcode = schema::ResponseCode_SucNotReady;
|
||||
while (retcode == schema::ResponseCode_SucNotReady) {
|
||||
if (!fl::worker::FLWorker::GetInstance().running()) {
|
||||
MS_LOG(WARNING) << "Worker has finished.";
|
||||
return true;
|
||||
}
|
||||
if (!fl::worker::FLWorker::GetInstance().SendToServer(i, fbb->GetBufferPointer(), fbb->GetSize(),
|
||||
ps::core::TcpUserCommand::kPushWeight,
|
||||
&push_weight_rsp_msg)) {
|
||||
|
|
|
@ -56,6 +56,10 @@ class PushMetricsKernel : public CPUKernel {
|
|||
uint32_t retry_time = 0;
|
||||
std::shared_ptr<std::vector<unsigned char>> push_metrics_rsp_msg = nullptr;
|
||||
do {
|
||||
if (!fl::worker::FLWorker::GetInstance().running()) {
|
||||
MS_LOG(WARNING) << "Worker has finished.";
|
||||
return true;
|
||||
}
|
||||
retry_time++;
|
||||
if (!fl::worker::FLWorker::GetInstance().SendToServer(fl::kLeaderServerRank, fbb_->GetBufferPointer(),
|
||||
fbb_->GetSize(), ps::core::TcpUserCommand::kPushMetrics,
|
||||
|
@ -66,7 +70,7 @@ class PushMetricsKernel : public CPUKernel {
|
|||
} else {
|
||||
break;
|
||||
}
|
||||
} while (retry_time > kMaxRetryTime);
|
||||
} while (retry_time < kMaxRetryTime);
|
||||
|
||||
flatbuffers::Verifier verifier(push_metrics_rsp_msg->data(), push_metrics_rsp_msg->size());
|
||||
if (!verifier.VerifyBuffer<schema::ResponsePushMetrics>()) {
|
||||
|
|
|
@ -66,6 +66,20 @@ void DistributedCountService::RegisterCounter(const std::string &name, size_t gl
|
|||
return;
|
||||
}
|
||||
|
||||
bool DistributedCountService::ReInitCounter(const std::string &name, size_t global_threshold_count) {
|
||||
MS_LOG(INFO) << "Rank " << local_rank_ << " reinitialize counter for " << name << " count:" << global_threshold_count;
|
||||
if (local_rank_ == counting_server_rank_) {
|
||||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||
if (global_threshold_count_.count(name) == 0) {
|
||||
MS_LOG(INFO) << "Counter for " << name << " is not set.";
|
||||
return false;
|
||||
}
|
||||
global_current_count_[name] = {};
|
||||
global_threshold_count_[name] = global_threshold_count;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DistributedCountService::Count(const std::string &name, const std::string &id, std::string *reason) {
|
||||
MS_LOG(INFO) << "Rank " << local_rank_ << " reports count for " << name << " of " << id;
|
||||
if (local_rank_ == counting_server_rank_) {
|
||||
|
|
|
@ -63,6 +63,9 @@ class DistributedCountService {
|
|||
// first/last count event callbacks.
|
||||
void RegisterCounter(const std::string &name, size_t global_threshold_count, const CounterHandlers &counter_handlers);
|
||||
|
||||
// Reinitialize counter due to the change of threshold count.
|
||||
bool ReInitCounter(const std::string &name, size_t global_threshold_count);
|
||||
|
||||
// Report a count to the counting server. Parameter 'id' is in case of repeated counting. Parameter 'reason' is the
|
||||
// reason why counting failed.
|
||||
bool Count(const std::string &name, const std::string &id, std::string *reason = nullptr);
|
||||
|
|
|
@ -50,7 +50,7 @@ void DistributedMetadataStore::RegisterMetadata(const std::string &name, const P
|
|||
uint32_t stored_rank = router_->Find(name);
|
||||
if (local_rank_ == stored_rank) {
|
||||
if (metadata_.count(name) != 0) {
|
||||
MS_LOG(ERROR) << "The metadata for " << name << " is already registered.";
|
||||
MS_LOG(WARNING) << "The metadata for " << name << " is already registered.";
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -340,10 +340,10 @@ bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
|
|||
param_aggrs_[param_name] = param_aggr;
|
||||
parameter_mutex_[param_name];
|
||||
if (!param_aggr->Init(cnode, aggregation_count_)) {
|
||||
MS_LOG(EXCEPTION) << "Initializing parameter aggregator failed for " << param_name;
|
||||
MS_LOG(EXCEPTION) << "Initializing parameter aggregator for " << param_name << " failed.";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Initializing control flow for param_name " << param_name << " success.";
|
||||
MS_LOG(DEBUG) << "Initializing parameter aggregator for param_name " << param_name << " success.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ class Server;
|
|||
|
||||
Iteration::~Iteration() {
|
||||
move_to_next_thread_running_ = false;
|
||||
next_iteration_cv_.notify_all();
|
||||
if (move_to_next_thread_.joinable()) {
|
||||
move_to_next_thread_.join();
|
||||
}
|
||||
|
@ -93,6 +94,9 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<ps::core::Communica
|
|||
while (move_to_next_thread_running_.load()) {
|
||||
std::unique_lock<std::mutex> lock(next_iteration_mutex_);
|
||||
next_iteration_cv_.wait(lock);
|
||||
if (!move_to_next_thread_running_.load()) {
|
||||
break;
|
||||
}
|
||||
MoveToNextIteration(is_last_iteration_valid_, move_to_next_reason_);
|
||||
}
|
||||
});
|
||||
|
@ -222,12 +226,12 @@ bool Iteration::EnableServerInstance(std::string *result) {
|
|||
std::unique_lock<std::mutex> lock(instance_mtx_);
|
||||
if (is_instance_being_updated_) {
|
||||
*result = "The instance is being updated. Please retry enabling server later.";
|
||||
MS_LOG(WARNING) << result;
|
||||
MS_LOG(WARNING) << *result;
|
||||
return false;
|
||||
}
|
||||
if (instance_state_.load() == InstanceState::kFinish) {
|
||||
*result = "The instance is completed. Please do not enabling server now.";
|
||||
MS_LOG(WARNING) << result;
|
||||
MS_LOG(WARNING) << *result;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -236,6 +240,7 @@ bool Iteration::EnableServerInstance(std::string *result) {
|
|||
|
||||
instance_state_ = InstanceState::kRunning;
|
||||
*result = "Enabling FL-Server succeeded.";
|
||||
MS_LOG(INFO) << *result;
|
||||
|
||||
// End enabling server instance.
|
||||
is_instance_being_updated_ = false;
|
||||
|
@ -259,7 +264,7 @@ bool Iteration::DisableServerInstance(std::string *result) {
|
|||
if (instance_state_.load() == InstanceState::kDisable) {
|
||||
*result = "Disabling FL-Server succeeded.";
|
||||
MS_LOG(INFO) << *result;
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Start disabling server instance.
|
||||
|
@ -272,6 +277,8 @@ bool Iteration::DisableServerInstance(std::string *result) {
|
|||
MS_LOG(ERROR) << result;
|
||||
return false;
|
||||
}
|
||||
*result = "Disabling FL-Server succeeded.";
|
||||
MS_LOG(INFO) << *result;
|
||||
|
||||
// End disabling server instance.
|
||||
is_instance_being_updated_ = false;
|
||||
|
@ -293,8 +300,9 @@ bool Iteration::NewInstance(const nlohmann::json &new_instance_json, std::string
|
|||
|
||||
// Reset current instance.
|
||||
instance_state_ = InstanceState::kFinish;
|
||||
MS_LOG(INFO) << "Proceed to a new instance.";
|
||||
Server::GetInstance().WaitExitSafeMode();
|
||||
WaitAllRoundsFinish();
|
||||
MS_LOG(INFO) << "Proceed to a new instance.";
|
||||
for (auto &round : rounds_) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(round, false);
|
||||
round->Reset();
|
||||
|
@ -705,7 +713,7 @@ bool Iteration::ReInitRounds() {
|
|||
std::vector<RoundConfig> new_round_config = {
|
||||
{"startFLJob", true, start_fl_job_time_window, true, start_fl_job_threshold},
|
||||
{"updateModel", true, update_model_time_window, true, update_model_threshold}};
|
||||
if (!Iteration::GetInstance().ReInitForUpdatingHyperParams(new_round_config)) {
|
||||
if (!ReInitForUpdatingHyperParams(new_round_config)) {
|
||||
MS_LOG(ERROR) << "Reinitializing for updating hyper-parameters failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -727,7 +735,6 @@ bool Iteration::ReInitRounds() {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -63,8 +63,6 @@ bool IterationMetrics::Initialize() {
|
|||
|
||||
metrics_file_.open(metrics_file_path, std::ios::ate | std::ios::out);
|
||||
}
|
||||
|
||||
initialized_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -89,8 +89,6 @@ class IterationMetrics {
|
|||
void set_iteration_time_cost(uint64_t iteration_time_cost);
|
||||
|
||||
private:
|
||||
bool initialized_;
|
||||
|
||||
// This is the main config file set by ps context.
|
||||
std::string config_file_path_;
|
||||
std::unique_ptr<ps::core::FileConfiguration> config_;
|
||||
|
@ -99,7 +97,7 @@ class IterationMetrics {
|
|||
std::fstream metrics_file_;
|
||||
|
||||
// Json object of metrics data.
|
||||
nlohmann::json js_;
|
||||
nlohmann::basic_json<std::map, std::vector, std::string, bool, int64_t, uint64_t, float> js_;
|
||||
|
||||
// The federated learning job name. Set by ps_context.
|
||||
std::string fl_name_;
|
||||
|
|
|
@ -180,7 +180,10 @@ class FedAvgKernel : public AggregationKernel {
|
|||
|
||||
bool ReInitForUpdatingHyperParams(size_t aggr_threshold) override {
|
||||
done_count_ = aggr_threshold;
|
||||
DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler_, last_cnt_handler_});
|
||||
if (!DistributedCountService::GetInstance().ReInitCounter(name_, done_count_)) {
|
||||
MS_LOG(ERROR) << "Reinitializing counter for " << name_ << " failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -99,7 +99,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons
|
|||
const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
|
||||
size_t latest_iter_num = iter_to_model.rbegin()->first;
|
||||
// If this iteration is not finished yet, return ResponseCode_SucNotReady so that clients could get model later.
|
||||
if ((current_iter == get_model_iter && latest_iter_num != current_iter) || current_iter == get_model_iter - 1) {
|
||||
if ((current_iter == get_model_iter && latest_iter_num != current_iter)) {
|
||||
std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter) +
|
||||
". Maybe this is because\n" + "1.Client doesn't send enough update model requests.\n" +
|
||||
"2. Worker has not push all the weights to servers.";
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace server {
|
|||
namespace kernel {
|
||||
class PushMetricsKernel : public RoundKernel {
|
||||
public:
|
||||
PushMetricsKernel() = default;
|
||||
PushMetricsKernel() : local_rank_(0) {}
|
||||
~PushMetricsKernel() override = default;
|
||||
|
||||
void InitKernel(size_t threshold_count);
|
||||
|
|
|
@ -106,10 +106,10 @@ bool Round::ReInitForUpdatingHyperParams(size_t updated_threshold_count, size_t
|
|||
time_window_ = updated_time_window;
|
||||
threshold_count_ = updated_threshold_count;
|
||||
if (check_count_) {
|
||||
auto first_count_handler = std::bind(&Round::OnFirstCountEvent, this, std::placeholders::_1);
|
||||
auto last_count_handler = std::bind(&Round::OnLastCountEvent, this, std::placeholders::_1);
|
||||
DistributedCountService::GetInstance().RegisterCounter(name_, threshold_count_,
|
||||
{first_count_handler, last_count_handler});
|
||||
if (!DistributedCountService::GetInstance().ReInitCounter(name_, threshold_count_)) {
|
||||
MS_LOG(ERROR) << "Reinitializing count for " << name_ << " failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(kernel_, false);
|
||||
|
|
|
@ -32,6 +32,22 @@
|
|||
namespace mindspore {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
// The handler to capture the signal of SIGTERM. Normally this signal is triggered by cloud cluster managers like K8S.
|
||||
std::shared_ptr<ps::core::CommunicatorBase> g_communicator_with_server = nullptr;
|
||||
std::vector<std::shared_ptr<ps::core::CommunicatorBase>> g_communicators_with_worker = {};
|
||||
void SignalHandler(int signal) {
|
||||
MS_LOG(WARNING) << "SIGTERM captured: " << signal;
|
||||
(void)std::for_each(g_communicators_with_worker.begin(), g_communicators_with_worker.end(),
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
|
||||
(void)communicator->Stop();
|
||||
});
|
||||
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(g_communicator_with_server);
|
||||
(void)g_communicator_with_server->Stop();
|
||||
return;
|
||||
}
|
||||
|
||||
void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
|
||||
const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -48,6 +64,7 @@ void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const s
|
|||
use_http_ = use_http;
|
||||
http_port_ = http_port;
|
||||
executor_threshold_ = executor_threshold;
|
||||
signal(SIGTERM, SignalHandler);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -109,6 +126,12 @@ void Server::CancelSafeMode() {
|
|||
|
||||
bool Server::IsSafeMode() const { return safemode_.load(); }
|
||||
|
||||
void Server::WaitExitSafeMode() const {
|
||||
while (safemode_.load()) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
|
||||
}
|
||||
}
|
||||
|
||||
void Server::InitServerContext() {
|
||||
ps::PSContext::instance()->GenerateResetterRound();
|
||||
scheduler_ip_ = ps::PSContext::instance()->scheduler_host();
|
||||
|
@ -145,6 +168,7 @@ bool Server::InitCommunicatorWithServer() {
|
|||
communicator_with_server_ =
|
||||
server_node_->GetOrCreateTcpComm(scheduler_ip_, scheduler_port_, worker_num_, server_num_, task_executor_);
|
||||
MS_EXCEPTION_IF_NULL(communicator_with_server_);
|
||||
g_communicator_with_server = communicator_with_server_;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -166,6 +190,7 @@ bool Server::InitCommunicatorWithWorker() {
|
|||
MS_EXCEPTION_IF_NULL(http_comm);
|
||||
communicators_with_worker_.push_back(http_comm);
|
||||
}
|
||||
g_communicator_with_worker = communicator_with_worker_;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -239,10 +264,9 @@ void Server::InitIteration() {
|
|||
#endif
|
||||
|
||||
// 2.Initialize all the rounds.
|
||||
TimeOutCb time_out_cb =
|
||||
std::bind(&Iteration::MoveToNextIteration, iteration_, std::placeholders::_1, std::placeholders::_2);
|
||||
TimeOutCb time_out_cb = std::bind(&Iteration::NotifyNext, iteration_, std::placeholders::_1, std::placeholders::_2);
|
||||
FinishIterCb finish_iter_cb =
|
||||
std::bind(&Iteration::MoveToNextIteration, iteration_, std::placeholders::_1, std::placeholders::_2);
|
||||
std::bind(&Iteration::NotifyNext, iteration_, std::placeholders::_1, std::placeholders::_2);
|
||||
iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb);
|
||||
return;
|
||||
}
|
||||
|
@ -487,15 +511,7 @@ void Server::ProcessAfterScalingIn() {
|
|||
std::unique_lock<std::mutex> lock(scaling_mtx_);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
|
||||
if (server_node_->rank_id() == UINT32_MAX) {
|
||||
MS_LOG(WARNING) << "This server the one to be scaled in. Server exiting.";
|
||||
(void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
|
||||
(void)communicator->Stop();
|
||||
});
|
||||
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
|
||||
(void)communicator_with_server_->Stop();
|
||||
MS_LOG(WARNING) << "This server the one to be scaled in. Server need to wait SIGTERM to exit.";
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -588,7 +604,7 @@ void Server::HandleNewInstanceRequest(const std::shared_ptr<ps::core::MessageHan
|
|||
|
||||
void Server::HandleQueryInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
||||
nlohmann::json response;
|
||||
nlohmann::basic_json<std::map, std::vector, std::string, bool, int64_t, uint64_t, float> response;
|
||||
response["start_fl_job_threshold"] = ps::PSContext::instance()->start_fl_job_threshold();
|
||||
response["start_fl_job_time_window"] = ps::PSContext::instance()->start_fl_job_time_window();
|
||||
response["update_model_ratio"] = ps::PSContext::instance()->update_model_ratio();
|
||||
|
|
|
@ -56,6 +56,7 @@ class Server {
|
|||
void SwitchToSafeMode();
|
||||
void CancelSafeMode();
|
||||
bool IsSafeMode() const;
|
||||
void WaitExitSafeMode() const;
|
||||
|
||||
// Whether the training job of the server is enabled.
|
||||
InstanceState instance_state() const;
|
||||
|
|
|
@ -25,7 +25,7 @@ namespace mindspore {
|
|||
namespace fl {
|
||||
namespace worker {
|
||||
void FLWorker::Run() {
|
||||
if (running_) {
|
||||
if (running_.load()) {
|
||||
return;
|
||||
}
|
||||
running_ = true;
|
||||
|
@ -48,6 +48,7 @@ void FLWorker::Run() {
|
|||
|
||||
worker_node_->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
|
||||
Finalize();
|
||||
running_ = false;
|
||||
try {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
|
||||
|
@ -57,6 +58,7 @@ void FLWorker::Run() {
|
|||
});
|
||||
worker_node_->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [this]() {
|
||||
Finalize();
|
||||
running_ = false;
|
||||
try {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
|
||||
|
@ -148,6 +150,8 @@ uint32_t FLWorker::rank_id() const { return rank_id_; }
|
|||
|
||||
uint64_t FLWorker::worker_step_num_per_iteration() const { return worker_step_num_per_iteration_; }
|
||||
|
||||
bool FLWorker::running() const { return running_.load(); }
|
||||
|
||||
void FLWorker::SetIterationRunning() {
|
||||
MS_LOG(INFO) << "Worker iteration starts.";
|
||||
worker_iteration_state_ = IterationState::kRunning;
|
||||
|
|
|
@ -72,6 +72,9 @@ class FLWorker {
|
|||
uint32_t rank_id() const;
|
||||
uint64_t worker_step_num_per_iteration() const;
|
||||
|
||||
// Check whether worker has exited.
|
||||
bool running() const;
|
||||
|
||||
// These methods set the worker's iteration state.
|
||||
void SetIterationRunning();
|
||||
void SetIterationCompleted();
|
||||
|
@ -116,7 +119,7 @@ class FLWorker {
|
|||
void ProcessAfterScalingOut();
|
||||
void ProcessAfterScalingIn();
|
||||
|
||||
bool running_;
|
||||
std::atomic_bool running_;
|
||||
uint32_t server_num_;
|
||||
uint32_t worker_num_;
|
||||
std::string scheduler_ip_;
|
||||
|
|
|
@ -78,10 +78,18 @@ FollowerScaler::~FollowerScaler() {
|
|||
running_ = false;
|
||||
scale_out_cv_.notify_all();
|
||||
scale_in_cv_.notify_all();
|
||||
process_before_scale_out_thread_.join();
|
||||
process_before_scale_in_thread_.join();
|
||||
process_after_scale_out_thread_.join();
|
||||
process_after_scale_in_thread_.join();
|
||||
if (process_before_scale_out_thread_.joinable()) {
|
||||
process_before_scale_out_thread_.join();
|
||||
}
|
||||
if (process_before_scale_in_thread_.joinable()) {
|
||||
process_before_scale_in_thread_.join();
|
||||
}
|
||||
if (process_after_scale_out_thread_.joinable()) {
|
||||
process_after_scale_out_thread_.join();
|
||||
}
|
||||
if (process_after_scale_in_thread_.joinable()) {
|
||||
process_after_scale_in_thread_.join();
|
||||
}
|
||||
}
|
||||
|
||||
void FollowerScaler::RegisterScaleEventCallbacks() {
|
||||
|
|
Loading…
Reference in New Issue