Fix code review.

This commit is contained in:
ZPaC 2021-07-31 16:25:06 +08:00
parent 79af42153d
commit a2b8198024
13 changed files with 170 additions and 104 deletions

View File

@ -121,7 +121,7 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) {
MS_LOG(INFO) << "Rank " << local_rank_ << " query whether count reaches threshold for " << name; MS_LOG(INFO) << "Rank " << local_rank_ << " query whether count reaches threshold for " << name;
if (local_rank_ == counting_server_rank_) { if (local_rank_ == counting_server_rank_) {
if (global_threshold_count_.count(name) == 0) { if (global_threshold_count_.count(name) == 0) {
MS_LOG(ERROR) << "Counter for " << name << " is not set."; MS_LOG(ERROR) << "Counter for " << name << " is not registered.";
return false; return false;
} }
@ -138,6 +138,7 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) {
return false; return false;
} }
MS_ERROR_IF_NULL_W_RET_VAL(query_cnt_enough_rsp_msg, false);
CountReachThresholdResponse count_reach_threshold_rsp; CountReachThresholdResponse count_reach_threshold_rsp;
(void)count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), (void)count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(),
SizeToInt(query_cnt_enough_rsp_msg->size())); SizeToInt(query_cnt_enough_rsp_msg->size()));
@ -173,11 +174,7 @@ bool DistributedCountService::ReInitForScaling() {
} }
void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core::MessageHandler> &message) { void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) { MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_LOG(ERROR) << "Message is nullptr.";
return;
}
CountRequest report_count_req; CountRequest report_count_req;
(void)report_count_req.ParseFromArray(message->data(), SizeToInt(message->len())); (void)report_count_req.ParseFromArray(message->data(), SizeToInt(message->len()));
const std::string &name = report_count_req.name(); const std::string &name = report_count_req.name();
@ -235,11 +232,7 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core:
void DistributedCountService::HandleCountReachThresholdRequest( void DistributedCountService::HandleCountReachThresholdRequest(
const std::shared_ptr<ps::core::MessageHandler> &message) { const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) { MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_LOG(ERROR) << "Message is nullptr.";
return;
}
CountReachThresholdRequest count_reach_threshold_req; CountReachThresholdRequest count_reach_threshold_req;
(void)count_reach_threshold_req.ParseFromArray(message->data(), SizeToInt(message->len())); (void)count_reach_threshold_req.ParseFromArray(message->data(), SizeToInt(message->len()));
const std::string &name = count_reach_threshold_req.name(); const std::string &name = count_reach_threshold_req.name();
@ -261,11 +254,7 @@ void DistributedCountService::HandleCountReachThresholdRequest(
} }
void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> &message) { void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) { MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_LOG(ERROR) << "Message is nullptr.";
return;
}
// Respond as soon as possible so the leader server won't wait for each follower servers to finish calling the // Respond as soon as possible so the leader server won't wait for each follower servers to finish calling the
// callbacks. // callbacks.
std::string couter_event_rsp_msg = "success"; std::string couter_event_rsp_msg = "success";
@ -279,6 +268,10 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core:
const auto &type = counter_event.type(); const auto &type = counter_event.type();
const auto &name = counter_event.name(); const auto &name = counter_event.name();
if (counter_handlers_.count(name) == 0) {
MS_LOG(ERROR) << "The counter handler of " << name << " is not registered.";
return;
}
MS_LOG(DEBUG) << "Rank " << local_rank_ << " do counter event " << type << " for " << name; MS_LOG(DEBUG) << "Rank " << local_rank_ << " do counter event " << type << " for " << name;
if (type == CounterEventType::FIRST_CNT) { if (type == CounterEventType::FIRST_CNT) {
counter_handlers_[name].first_count_handler(message); counter_handlers_[name].first_count_handler(message);
@ -292,6 +285,11 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core:
} }
bool DistributedCountService::TriggerCounterEvent(const std::string &name, std::string *reason) { bool DistributedCountService::TriggerCounterEvent(const std::string &name, std::string *reason) {
if (global_current_count_.count(name) == 0 || global_threshold_count_.count(name) == 0) {
MS_LOG(ERROR) << "The counter of " << name << " is not registered.";
return false;
}
MS_LOG(INFO) << "Current count for " << name << " is " << global_current_count_[name].size() MS_LOG(INFO) << "Current count for " << name << " is " << global_current_count_[name].size()
<< ", threshold count is " << global_threshold_count_[name]; << ", threshold count is " << global_threshold_count_[name];
// The threshold count may be 1 so the first and last count event should be both activated. // The threshold count may be 1 so the first and last count event should be both activated.
@ -324,6 +322,11 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, st
return false; return false;
} }
} }
if (counter_handlers_.count(name) == 0) {
MS_LOG(ERROR) << "The counter handler of " << name << " is not registered.";
return false;
}
// Leader server directly calls the callback. // Leader server directly calls the callback.
counter_handlers_[name].first_count_handler(nullptr); counter_handlers_[name].first_count_handler(nullptr);
return true; return true;
@ -345,6 +348,11 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std
return false; return false;
} }
} }
if (counter_handlers_.count(name) == 0) {
MS_LOG(ERROR) << "The counter handler of " << name << " is not registered.";
return false;
}
// Leader server directly calls the callback. // Leader server directly calls the callback.
counter_handlers_[name].last_count_handler(nullptr); counter_handlers_[name].last_count_handler(nullptr);
return true; return true;

View File

@ -181,11 +181,7 @@ void DistributedMetadataStore::InitHashRing() {
} }
void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message) { void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) { MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_LOG(ERROR) << "Message is nullptr.";
return;
}
PBMetadataWithName meta_with_name; PBMetadataWithName meta_with_name;
(void)meta_with_name.ParseFromArray(message->data(), SizeToInt(message->len())); (void)meta_with_name.ParseFromArray(message->data(), SizeToInt(message->len()));
const std::string &name = meta_with_name.name(); const std::string &name = meta_with_name.name();
@ -206,17 +202,17 @@ void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr
} }
void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message) { void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) { MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_LOG(ERROR) << "Message is nullptr.";
return;
}
GetMetadataRequest get_metadata_req; GetMetadataRequest get_metadata_req;
(void)get_metadata_req.ParseFromArray(message->data(), message->len()); (void)get_metadata_req.ParseFromArray(message->data(), message->len());
const std::string &name = get_metadata_req.name(); const std::string &name = get_metadata_req.name();
MS_LOG(INFO) << "Getting metadata for " << name; MS_LOG(INFO) << "Getting metadata for " << name;
std::unique_lock<std::mutex> lock(mutex_[name]); std::unique_lock<std::mutex> lock(mutex_[name]);
if (metadata_.count(name) == 0) {
MS_LOG(ERROR) << "The metadata of " << name << " is not registered.";
return;
}
PBMetadata stored_meta = metadata_[name]; PBMetadata stored_meta = metadata_[name];
std::string getting_meta_rsp_msg = stored_meta.SerializeAsString(); std::string getting_meta_rsp_msg = stored_meta.SerializeAsString();
if (!communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message)) { if (!communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message)) {
@ -228,6 +224,10 @@ void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<ps
bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) { bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) {
std::unique_lock<std::mutex> lock(mutex_[name]); std::unique_lock<std::mutex> lock(mutex_[name]);
if (metadata_.count(name) == 0) {
MS_LOG(ERROR) << "The metadata of " << name << " is not registered.";
return false;
}
if (meta.has_device_meta()) { if (meta.has_device_meta()) {
auto &fl_id_to_meta_map = *metadata_[name].mutable_device_metas()->mutable_fl_id_to_meta(); auto &fl_id_to_meta_map = *metadata_[name].mutable_device_metas()->mutable_fl_id_to_meta();
auto &device_meta_fl_id = meta.device_meta().fl_id(); auto &device_meta_fl_id = meta.device_meta().fl_id();

View File

@ -158,17 +158,10 @@ bool Executor::HandlePushWeight(const std::map<std::string, Address> &feature_ma
auto &param_aggr = param_aggrs_[param_name]; auto &param_aggr = param_aggrs_[param_name];
MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false); MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
AddressPtr old_weight = param_aggr->GetWeight(); AddressPtr old_weight = param_aggr->GetWeight();
if (old_weight == nullptr) {
MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr.";
return false;
}
const Address &new_weight = trainable_param.second; const Address &new_weight = trainable_param.second;
if (new_weight.addr == nullptr) { MS_ERROR_IF_NULL_W_RET_VAL(old_weight, false);
MS_LOG(ERROR) << "The new weight is nullptr."; MS_ERROR_IF_NULL_W_RET_VAL(old_weight->addr, false);
return false; MS_ERROR_IF_NULL_W_RET_VAL(new_weight.addr, false);
}
int ret = memcpy_s(old_weight->addr, old_weight->size, new_weight.addr, new_weight.size); int ret = memcpy_s(old_weight->addr, old_weight->size, new_weight.addr, new_weight.size);
if (ret != 0) { if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
@ -295,6 +288,7 @@ std::string Executor::GetTrainableParamName(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(weight_node); MS_EXCEPTION_IF_NULL(weight_node);
if (!weight_node->isa<Parameter>()) { if (!weight_node->isa<Parameter>()) {
MS_LOG(EXCEPTION) << weight_idx << " input of " << cnode_name << " is not a Parameter."; MS_LOG(EXCEPTION) << weight_idx << " input of " << cnode_name << " is not a Parameter.";
return "";
} }
return weight_node->fullname_with_scope(); return weight_node->fullname_with_scope();
} }
@ -309,7 +303,7 @@ bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
continue; continue;
} }
if (param_aggrs_.count(param_name) != 0) { if (param_aggrs_.count(param_name) != 0) {
MS_LOG(WARNING) << param_name << " already has its control flow."; MS_LOG(WARNING) << param_name << " already has parameter aggregator registered.";
continue; continue;
} }

View File

@ -66,9 +66,7 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<ps::core::Communica
(void)std::for_each(communicators.begin(), communicators.end(), (void)std::for_each(communicators.begin(), communicators.end(),
[&](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { [&](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
for (auto &round : rounds_) { for (auto &round : rounds_) {
if (round == nullptr) { MS_EXCEPTION_IF_NULL(round);
continue;
}
round->Initialize(communicator, timeout_cb, finish_iteration_cb); round->Initialize(communicator, timeout_cb, finish_iteration_cb);
} }
}); });
@ -76,6 +74,7 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<ps::core::Communica
// The time window for one iteration, which will be used in some round kernels. // The time window for one iteration, which will be used in some round kernels.
size_t iteration_time_window = std::accumulate(rounds_.begin(), rounds_.end(), IntToSize(0), size_t iteration_time_window = std::accumulate(rounds_.begin(), rounds_.end(), IntToSize(0),
[](size_t total, const std::shared_ptr<Round> &round) { [](size_t total, const std::shared_ptr<Round> &round) {
MS_EXCEPTION_IF_NULL(round);
return round->check_timeout() ? total + round->time_window() : total; return round->check_timeout() ? total + round->time_window() : total;
}); });
LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window); LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window);
@ -172,11 +171,8 @@ bool Iteration::SyncIteration(uint32_t rank) {
MS_LOG(ERROR) << "Sending synchronizing iteration message to leader server failed."; MS_LOG(ERROR) << "Sending synchronizing iteration message to leader server failed.";
return false; return false;
} }
if (sync_iter_rsp_msg == nullptr) {
MS_LOG(ERROR) << "Response from server 0 is empty.";
return false;
}
MS_ERROR_IF_NULL_W_RET_VAL(sync_iter_rsp_msg, false);
SyncIterationResponse sync_iter_rsp; SyncIterationResponse sync_iter_rsp;
(void)sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), SizeToInt(sync_iter_rsp_msg->size())); (void)sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), SizeToInt(sync_iter_rsp_msg->size()));
iteration_num_ = sync_iter_rsp.iteration(); iteration_num_ = sync_iter_rsp.iteration();
@ -372,16 +368,17 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
if (is_iteration_valid) { if (is_iteration_valid) {
// Store the model which is successfully aggregated for this iteration. // Store the model which is successfully aggregated for this iteration.
const auto &model = Executor::GetInstance().GetModel(); const auto &model = Executor::GetInstance().GetModel();
(void)ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished."; MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
} else { } else {
// Store last iteration's model because this iteration is considered as invalid. // Store last iteration's model because this iteration is considered as invalid.
const auto &model = ModelStore::GetInstance().GetModelByIterNum(iteration_num_ - 1); const auto &model = ModelStore::GetInstance().GetModelByIterNum(iteration_num_ - 1);
(void)ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason; MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason;
} }
for (auto &round : rounds_) { for (auto &round : rounds_) {
MS_ERROR_IF_NULL_WO_RET_VAL(round);
round->Reset(); round->Reset();
} }
} }

View File

@ -36,10 +36,12 @@ void IterationTimer::Start(const std::chrono::milliseconds &duration) {
std::this_thread::sleep_for(std::chrono::milliseconds(1)); std::this_thread::sleep_for(std::chrono::milliseconds(1));
} }
}); });
monitor_thread_.detach();
} }
void IterationTimer::Stop() { running_ = false; } void IterationTimer::Stop() {
running_ = false;
monitor_thread_.join();
}
void IterationTimer::SetTimeOutCallBack(const TimeOutCb &timeout_cb) { void IterationTimer::SetTimeOutCallBack(const TimeOutCb &timeout_cb) {
timeout_callback_ = timeout_cb; timeout_callback_ = timeout_cb;

View File

@ -82,6 +82,10 @@ class FedAvgKernel : public AggregationKernel {
} }
}; };
last_cnt_handler_ = [&](std::shared_ptr<ps::core::MessageHandler>) { last_cnt_handler_ = [&](std::shared_ptr<ps::core::MessageHandler>) {
MS_ERROR_IF_NULL_WO_RET_VAL(weight_addr_);
MS_ERROR_IF_NULL_WO_RET_VAL(data_size_addr_);
MS_ERROR_IF_NULL_WO_RET_VAL(weight_addr_->addr);
MS_ERROR_IF_NULL_WO_RET_VAL(data_size_addr_->addr);
T *weight_addr = reinterpret_cast<T *>(weight_addr_->addr); T *weight_addr = reinterpret_cast<T *>(weight_addr_->addr);
size_t weight_size = weight_addr_->size; size_t weight_size = weight_addr_->size;
S *data_size_addr = reinterpret_cast<S *>(data_size_addr_->addr); S *data_size_addr = reinterpret_cast<S *>(data_size_addr_->addr);

View File

@ -35,33 +35,27 @@ void ModelStore::Initialize(uint32_t max_count) {
model_size_ = ComputeModelSize(); model_size_ = ComputeModelSize();
} }
bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &new_model) { void ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &new_model) {
std::unique_lock<std::mutex> lock(model_mtx_); std::unique_lock<std::mutex> lock(model_mtx_);
if (iteration_to_model_.count(iteration) != 0) { if (iteration_to_model_.count(iteration) != 0) {
MS_LOG(WARNING) << "Model for iteration " << iteration << " is already stored"; MS_LOG(WARNING) << "Model for iteration " << iteration << " is already stored";
return false; return;
} }
if (new_model.empty()) { if (new_model.empty()) {
MS_LOG(ERROR) << "Model feature map is empty."; MS_LOG(ERROR) << "Model feature map is empty.";
return false; return;
} }
std::shared_ptr<MemoryRegister> memory_register = nullptr; std::shared_ptr<MemoryRegister> memory_register = nullptr;
if (iteration_to_model_.size() < max_model_count_) { if (iteration_to_model_.size() < max_model_count_) {
// If iteration_to_model_.size() is not max_model_count_, need to assign new memory for the model. // If iteration_to_model_.size() is not max_model_count_, need to assign new memory for the model.
memory_register = AssignNewModelMemory(); memory_register = AssignNewModelMemory();
if (memory_register == nullptr) { MS_ERROR_IF_NULL_WO_RET_VAL(memory_register);
MS_LOG(ERROR) << "Memory for the new model is nullptr.";
return false;
}
iteration_to_model_[iteration] = memory_register; iteration_to_model_[iteration] = memory_register;
} else { } else {
// If iteration_to_model_ size is already max_model_count_, we need to replace earliest model with the newest model. // If iteration_to_model_ size is already max_model_count_, we need to replace earliest model with the newest model.
memory_register = iteration_to_model_.begin()->second; memory_register = iteration_to_model_.begin()->second;
if (memory_register == nullptr) { MS_ERROR_IF_NULL_WO_RET_VAL(memory_register);
MS_LOG(ERROR) << "Earliest model is nullptr.";
return false;
}
(void)iteration_to_model_.erase(iteration_to_model_.begin()); (void)iteration_to_model_.erase(iteration_to_model_.begin());
} }
@ -74,6 +68,10 @@ bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::strin
continue; continue;
} }
MS_ERROR_IF_NULL_WO_RET_VAL(stored_model[weight_name]);
MS_ERROR_IF_NULL_WO_RET_VAL(stored_model[weight_name]->addr);
MS_ERROR_IF_NULL_WO_RET_VAL(weight.second);
MS_ERROR_IF_NULL_WO_RET_VAL(weight.second->addr);
void *dst_addr = stored_model[weight_name]->addr; void *dst_addr = stored_model[weight_name]->addr;
size_t dst_size = stored_model[weight_name]->size; size_t dst_size = stored_model[weight_name]->size;
void *src_addr = weight.second->addr; void *src_addr = weight.second->addr;
@ -81,11 +79,11 @@ bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::strin
int ret = memcpy_s(dst_addr, dst_size, src_addr, src_size); int ret = memcpy_s(dst_addr, dst_size, src_addr, src_size);
if (ret != 0) { if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return false; return;
} }
} }
iteration_to_model_[iteration] = memory_register; iteration_to_model_[iteration] = memory_register;
return true; return;
} }
std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration) { std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration) {

View File

@ -47,7 +47,7 @@ class ModelStore {
// Store the model of the given iteration. The model is acquired from Executor. If the current model count is already // Store the model of the given iteration. The model is acquired from Executor. If the current model count is already
// max_model_count_, the earliest model will be replaced. // max_model_count_, the earliest model will be replaced.
bool StoreModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &model); void StoreModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &model);
// Get model of the given iteration. // Get model of the given iteration.
std::map<std::string, AddressPtr> GetModelByIterNum(size_t iteration); std::map<std::string, AddressPtr> GetModelByIterNum(size_t iteration);

View File

@ -49,7 +49,10 @@ bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) {
bool ParameterAggregator::ReInitForScaling() { bool ParameterAggregator::ReInitForScaling() {
auto result = std::find_if(aggregation_kernel_parameters_.begin(), aggregation_kernel_parameters_.end(), auto result = std::find_if(aggregation_kernel_parameters_.begin(), aggregation_kernel_parameters_.end(),
[](auto aggregation_kernel) { return !aggregation_kernel.first->ReInitForScaling(); }); [](auto aggregation_kernel) {
MS_ERROR_IF_NULL_W_RET_VAL(aggregation_kernel.first, true);
return !aggregation_kernel.first->ReInitForScaling();
});
if (result != aggregation_kernel_parameters_.end()) { if (result != aggregation_kernel_parameters_.end()) {
MS_LOG(ERROR) << "Reinitializing aggregation kernel after scaling failed"; MS_LOG(ERROR) << "Reinitializing aggregation kernel after scaling failed";
return false; return false;
@ -65,6 +68,9 @@ bool ParameterAggregator::UpdateData(const std::map<std::string, Address> &new_d
continue; continue;
} }
MS_ERROR_IF_NULL_W_RET_VAL(name_to_addr[name], false);
MS_ERROR_IF_NULL_W_RET_VAL(name_to_addr[name]->addr, false);
MS_ERROR_IF_NULL_W_RET_VAL(data.second.addr, false);
MS_LOG(DEBUG) << "Update data for " << name << ". Destination size: " << name_to_addr[name]->size MS_LOG(DEBUG) << "Update data for " << name << ". Destination size: " << name_to_addr[name]->size
<< ". Source size: " << data.second.size; << ". Source size: " << data.second.size;
int ret = memcpy_s(name_to_addr[name]->addr, name_to_addr[name]->size, data.second.addr, data.second.size); int ret = memcpy_s(name_to_addr[name]->addr, name_to_addr[name]->size, data.second.addr, data.second.size);
@ -228,9 +234,10 @@ bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) {
template <typename K> template <typename K>
bool ParameterAggregator::AssignMemory(K server_kernel, const CNodePtr &cnode, bool ParameterAggregator::AssignMemory(K server_kernel, const CNodePtr &cnode,
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info, const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
std::shared_ptr<MemoryRegister> memory_register) { const std::shared_ptr<MemoryRegister> &memory_register) {
MS_EXCEPTION_IF_NULL(server_kernel); MS_EXCEPTION_IF_NULL(server_kernel);
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(memory_register);
const std::vector<std::string> &input_names = server_kernel->input_names(); const std::vector<std::string> &input_names = server_kernel->input_names();
const std::vector<size_t> &input_size_list = server_kernel->GetInputSizeList(); const std::vector<size_t> &input_size_list = server_kernel->GetInputSizeList();
@ -272,8 +279,8 @@ bool ParameterAggregator::AssignMemory(K server_kernel, const CNodePtr &cnode,
return true; return true;
} }
bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> aggr_kernel, bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> &aggr_kernel,
const std::shared_ptr<MemoryRegister> memory_register) { const std::shared_ptr<MemoryRegister> &memory_register) {
MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false); MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false);
MS_ERROR_IF_NULL_W_RET_VAL(memory_register, false); MS_ERROR_IF_NULL_W_RET_VAL(memory_register, false);
KernelParams aggr_params = {}; KernelParams aggr_params = {};
@ -295,8 +302,9 @@ bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr<
return true; return true;
} }
bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr<kernel::OptimizerKernel> optimizer_kernel, bool ParameterAggregator::GenerateOptimizerKernelParams(
const std::shared_ptr<MemoryRegister> memory_register) { const std::shared_ptr<kernel::OptimizerKernel> &optimizer_kernel,
const std::shared_ptr<MemoryRegister> &memory_register) {
MS_ERROR_IF_NULL_W_RET_VAL(optimizer_kernel, false); MS_ERROR_IF_NULL_W_RET_VAL(optimizer_kernel, false);
MS_ERROR_IF_NULL_W_RET_VAL(memory_register, false); MS_ERROR_IF_NULL_W_RET_VAL(memory_register, false);
KernelParams optimizer_params = {}; KernelParams optimizer_params = {};
@ -335,12 +343,12 @@ std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const C
template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::OptimizerKernel> server_kernel, template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::OptimizerKernel> server_kernel,
const CNodePtr &cnode, const CNodePtr &cnode,
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info, const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
std::shared_ptr<MemoryRegister> memory_register); const std::shared_ptr<MemoryRegister> &memory_register);
template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::AggregationKernel> server_kernel, template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::AggregationKernel> server_kernel,
const CNodePtr &cnode, const CNodePtr &cnode,
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info, const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
std::shared_ptr<MemoryRegister> memory_register); const std::shared_ptr<MemoryRegister> &memory_register);
} // namespace server } // namespace server
} // namespace fl } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -105,14 +105,14 @@ class ParameterAggregator {
// momentum, etc. // momentum, etc.
template <typename K> template <typename K>
bool AssignMemory(K server_kernel, const CNodePtr &cnode, const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info, bool AssignMemory(K server_kernel, const CNodePtr &cnode, const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
std::shared_ptr<MemoryRegister> memory_register); const std::shared_ptr<MemoryRegister> &memory_register);
// Generate kernel parameters for aggregation/optimizer kernels. All the parameters is registered and stored in // Generate kernel parameters for aggregation/optimizer kernels. All the parameters is registered and stored in
// memory_register. // memory_register.
bool GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> aggr_kernel, bool GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> &aggr_kernel,
const std::shared_ptr<MemoryRegister> memory_register); const std::shared_ptr<MemoryRegister> &memory_register);
bool GenerateOptimizerKernelParams(const std::shared_ptr<kernel::OptimizerKernel> optim_kernel, bool GenerateOptimizerKernelParams(const std::shared_ptr<kernel::OptimizerKernel> &optim_kernel,
const std::shared_ptr<MemoryRegister> memory_register); const std::shared_ptr<MemoryRegister> &memory_register);
// The selection of the aggregation algorithm depends on multiple factors. For example, server mode, user // The selection of the aggregation algorithm depends on multiple factors. For example, server mode, user
// configuration, etc. // configuration, etc.

View File

@ -40,8 +40,10 @@ void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &commun
communicator_ = communicator; communicator_ = communicator;
// Register callback for round kernel. // Register callback for round kernel.
communicator_->RegisterMsgCallBack( communicator_->RegisterMsgCallBack(name_, [&](std::shared_ptr<ps::core::MessageHandler> message) {
name_, [&](std::shared_ptr<ps::core::MessageHandler> message) { LaunchRoundKernel(message); }); MS_ERROR_IF_NULL_WO_RET_VAL(message);
LaunchRoundKernel(message);
});
// Callback when the iteration is finished. // Callback when the iteration is finished.
finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid, const std::string &) -> void { finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid, const std::string &) -> void {
@ -50,10 +52,14 @@ void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &commun
}; };
// Callback for finalizing the server. This can only be called once. // Callback for finalizing the server. This can only be called once.
finalize_cb_ = [&](void) -> void { (void)communicator_->Stop(); }; finalize_cb_ = [&](void) -> void {
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
(void)communicator_->Stop();
};
if (check_timeout_) { if (check_timeout_) {
iter_timer_ = std::make_shared<IterationTimer>(); iter_timer_ = std::make_shared<IterationTimer>();
MS_EXCEPTION_IF_NULL(iter_timer_);
// 1.Set the timeout callback for the timer. // 1.Set the timeout callback for the timer.
iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool is_iteration_valid, const std::string &) -> void { iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool is_iteration_valid, const std::string &) -> void {
@ -63,6 +69,7 @@ void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &commun
// 2.Stopping timer callback which will be set to the round kernel. // 2.Stopping timer callback which will be set to the round kernel.
stop_timer_cb_ = [&](void) -> void { stop_timer_cb_ = [&](void) -> void {
MS_ERROR_IF_NULL_WO_RET_VAL(iter_timer_);
MS_LOG(INFO) << "Round " << name_ << " kernel stops its timer."; MS_LOG(INFO) << "Round " << name_ << " kernel stops its timer.";
iter_timer_->Stop(); iter_timer_->Stop();
}; };
@ -90,10 +97,7 @@ bool Round::ReInitForScaling(uint32_t server_num) {
{first_count_handler, last_count_handler}); {first_count_handler, last_count_handler});
} }
if (kernel_ == nullptr) { MS_ERROR_IF_NULL_W_RET_VAL(kernel_, false);
MS_LOG(WARNING) << "Reinitializing for round " << name_ << " failed: round kernel is nullptr.";
return false;
}
kernel_->InitKernel(threshold_count_); kernel_->InitKernel(threshold_count_);
return true; return true;
} }
@ -108,6 +112,8 @@ void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel)
void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message) { void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_ERROR_IF_NULL_WO_RET_VAL(message); MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
// If the server is still in the process of scaling, refuse the request. // If the server is still in the process of scaling, refuse the request.
if (Server::GetInstance().IsSafeMode()) { if (Server::GetInstance().IsSafeMode()) {
MS_LOG(WARNING) << "The cluster is still in process of scaling, please retry " << name_ << " later."; MS_LOG(WARNING) << "The cluster is still in process of scaling, please retry " << name_ << " later.";
@ -149,7 +155,10 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m
return; return;
} }
void Round::Reset() { (void)kernel_->Reset(); } void Round::Reset() {
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
(void)kernel_->Reset();
}
const std::string &Round::name() const { return name_; } const std::string &Round::name() const { return name_; }
@ -160,6 +169,9 @@ bool Round::check_timeout() const { return check_timeout_; }
size_t Round::time_window() const { return time_window_; } size_t Round::time_window() const { return time_window_; }
void Round::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) { void Round::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
MS_ERROR_IF_NULL_WO_RET_VAL(iter_timer_);
MS_LOG(INFO) << "Round " << name_ << " first count event is triggered."; MS_LOG(INFO) << "Round " << name_ << " first count event is triggered.";
// The timer starts only after the first count event is triggered by DistributedCountService. // The timer starts only after the first count event is triggered by DistributedCountService.
if (check_timeout_) { if (check_timeout_) {
@ -172,6 +184,9 @@ void Round::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &m
} }
void Round::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) { void Round::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
MS_ERROR_IF_NULL_WO_RET_VAL(iter_timer_);
MS_LOG(INFO) << "Round " << name_ << " last count event is triggered."; MS_LOG(INFO) << "Round " << name_ << " last count event is triggered.";
// Same as the first count event, the timer must be stopped by DistributedCountService. // Same as the first count event, the timer must be stopped by DistributedCountService.
if (check_timeout_) { if (check_timeout_) {

View File

@ -86,7 +86,11 @@ void Server::Run() {
// Wait communicators to stop so the main thread is blocked. // Wait communicators to stop so the main thread is blocked.
(void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Join(); }); [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
MS_EXCEPTION_IF_NULL(communicator);
communicator->Join();
});
MS_EXCEPTION_IF_NULL(communicator_with_server_);
communicator_with_server_->Join(); communicator_with_server_->Join();
MsException::Instance().CheckException(); MsException::Instance().CheckException();
return; return;
@ -151,6 +155,7 @@ bool Server::InitCommunicatorWithWorker() {
return false; return false;
} }
if (use_tcp_) { if (use_tcp_) {
MS_EXCEPTION_IF_NULL(communicator_with_server_);
auto tcp_comm = communicator_with_server_; auto tcp_comm = communicator_with_server_;
MS_EXCEPTION_IF_NULL(tcp_comm); MS_EXCEPTION_IF_NULL(tcp_comm);
communicators_with_worker_.push_back(tcp_comm); communicators_with_worker_.push_back(tcp_comm);
@ -201,21 +206,32 @@ void Server::InitIteration() {
std::shared_ptr<Round> exchange_keys_round = std::shared_ptr<Round> exchange_keys_round =
std::make_shared<Round>("exchangeKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_); std::make_shared<Round>("exchangeKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_);
MS_EXCEPTION_IF_NULL(exchange_keys_round);
iteration_->AddRound(exchange_keys_round); iteration_->AddRound(exchange_keys_round);
std::shared_ptr<Round> get_keys_round = std::shared_ptr<Round> get_keys_round =
std::make_shared<Round>("getKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_); std::make_shared<Round>("getKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_);
MS_EXCEPTION_IF_NULL(get_keys_round);
iteration_->AddRound(get_keys_round); iteration_->AddRound(get_keys_round);
std::shared_ptr<Round> share_secrets_round = std::shared_ptr<Round> share_secrets_round =
std::make_shared<Round>("shareSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_); std::make_shared<Round>("shareSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_);
MS_EXCEPTION_IF_NULL(share_secrets_round);
iteration_->AddRound(share_secrets_round); iteration_->AddRound(share_secrets_round);
std::shared_ptr<Round> get_secrets_round = std::shared_ptr<Round> get_secrets_round =
std::make_shared<Round>("getSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_); std::make_shared<Round>("getSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_);
MS_EXCEPTION_IF_NULL(get_secrets_round);
iteration_->AddRound(get_secrets_round); iteration_->AddRound(get_secrets_round);
std::shared_ptr<Round> get_clientlist_round = std::shared_ptr<Round> get_clientlist_round =
std::make_shared<Round>("getClientList", true, cipher_time_window_, true, cipher_get_clientlist_cnt_); std::make_shared<Round>("getClientList", true, cipher_time_window_, true, cipher_get_clientlist_cnt_);
MS_EXCEPTION_IF_NULL(get_clientlist_round);
iteration_->AddRound(get_clientlist_round); iteration_->AddRound(get_clientlist_round);
std::shared_ptr<Round> reconstruct_secrets_round = std::make_shared<Round>( std::shared_ptr<Round> reconstruct_secrets_round = std::make_shared<Round>(
"reconstructSecrets", true, cipher_time_window_, true, cipher_reconstruct_secrets_up_cnt_); "reconstructSecrets", true, cipher_time_window_, true, cipher_reconstruct_secrets_up_cnt_);
MS_EXCEPTION_IF_NULL(reconstruct_secrets_round);
iteration_->AddRound(reconstruct_secrets_round); iteration_->AddRound(reconstruct_secrets_round);
MS_LOG(INFO) << "Cipher rounds has been added."; MS_LOG(INFO) << "Cipher rounds has been added.";
} }
@ -253,8 +269,16 @@ void Server::InitCipher() {
mindspore::armour::CipherPublicPara param; mindspore::armour::CipherPublicPara param;
param.g = cipher_g; param.g = cipher_g;
param.t = cipher_t; param.t = cipher_t;
memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, SECRET_MAX_LEN); int ret = memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, SECRET_MAX_LEN);
memcpy_s(param.prime, PRIME_MAX_LEN, cipher_prime, PRIME_MAX_LEN); if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
return;
}
ret = memcpy_s(param.prime, PRIME_MAX_LEN, cipher_prime, PRIME_MAX_LEN);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
return;
}
param.dp_delta = dp_delta; param.dp_delta = dp_delta;
param.dp_eps = dp_eps; param.dp_eps = dp_eps;
param.dp_norm_clip = dp_norm_clip; param.dp_norm_clip = dp_norm_clip;
@ -268,6 +292,8 @@ void Server::InitCipher() {
void Server::RegisterCommCallbacks() { void Server::RegisterCommCallbacks() {
// The message callbacks of round kernels are already set in method InitIteration, so here we don't need to register // The message callbacks of round kernels are already set in method InitIteration, so here we don't need to register
// rounds' callbacks. // rounds' callbacks.
MS_EXCEPTION_IF_NULL(server_node_);
MS_EXCEPTION_IF_NULL(iteration_);
auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_); auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
MS_EXCEPTION_IF_NULL(tcp_comm); MS_EXCEPTION_IF_NULL(tcp_comm);
@ -302,9 +328,13 @@ void Server::RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpC
communicator->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() { communicator->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() {
MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed."; MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
safemode_ = true; safemode_ = true;
(void)std::for_each( (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
communicators_with_worker_.begin(), communicators_with_worker_.end(), [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { (void)communicator->Stop(); }); MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
(void)communicator->Stop();
});
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
(void)communicator_with_server_->Stop(); (void)communicator_with_server_->Stop();
}); });
@ -313,14 +343,19 @@ void Server::RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpC
<< "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the " << "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
"network building phase."; "network building phase.";
safemode_ = true; safemode_ = true;
(void)std::for_each( (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
communicators_with_worker_.begin(), communicators_with_worker_.end(), [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { (void)communicator->Stop(); }); MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
(void)communicator->Stop();
});
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
(void)communicator_with_server_->Stop(); (void)communicator_with_server_->Stop();
}); });
} }
void Server::InitExecutor() { void Server::InitExecutor() {
MS_EXCEPTION_IF_NULL(func_graph_);
if (executor_threshold_ == 0) { if (executor_threshold_ == 0) {
MS_LOG(EXCEPTION) << "The executor's threshold should greater than 0."; MS_LOG(EXCEPTION) << "The executor's threshold should greater than 0.";
return; return;
@ -342,6 +377,7 @@ void Server::RegisterRoundKernel() {
} }
for (auto &round : rounds) { for (auto &round : rounds) {
MS_EXCEPTION_IF_NULL(round);
const std::string &name = round->name(); const std::string &name = round->name();
std::shared_ptr<kernel::RoundKernel> round_kernel = kernel::RoundKernelFactory::GetInstance().Create(name); std::shared_ptr<kernel::RoundKernel> round_kernel = kernel::RoundKernelFactory::GetInstance().Create(name);
if (round_kernel == nullptr) { if (round_kernel == nullptr) {
@ -357,12 +393,13 @@ void Server::RegisterRoundKernel() {
} }
void Server::StartCommunicator() { void Server::StartCommunicator() {
MS_EXCEPTION_IF_NULL(communicator_with_server_);
if (communicators_with_worker_.empty()) { if (communicators_with_worker_.empty()) {
MS_LOG(EXCEPTION) << "Communicators for communication with worker is empty."; MS_LOG(EXCEPTION) << "Communicators for communication with worker is empty.";
return; return;
} }
MS_EXCEPTION_IF_NULL(server_node_);
MS_EXCEPTION_IF_NULL(communicator_with_server_);
MS_LOG(INFO) << "Start communicator with server."; MS_LOG(INFO) << "Start communicator with server.";
if (!communicator_with_server_->Start()) { if (!communicator_with_server_->Start()) {
MS_LOG(EXCEPTION) << "Starting communicator with server failed."; MS_LOG(EXCEPTION) << "Starting communicator with server failed.";
@ -376,6 +413,7 @@ void Server::StartCommunicator() {
MS_LOG(INFO) << "Start communicator with worker."; MS_LOG(INFO) << "Start communicator with worker.";
(void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
if (!communicator->Start()) { if (!communicator->Start()) {
MS_LOG(EXCEPTION) << "Starting communicator with worker failed."; MS_LOG(EXCEPTION) << "Starting communicator with worker failed.";
} }
@ -383,11 +421,13 @@ void Server::StartCommunicator() {
} }
void Server::ProcessBeforeScalingOut() { void Server::ProcessBeforeScalingOut() {
MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
iteration_->ScalingBarrier(); iteration_->ScalingBarrier();
safemode_ = true; safemode_ = true;
} }
void Server::ProcessBeforeScalingIn() { void Server::ProcessBeforeScalingIn() {
MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
iteration_->ScalingBarrier(); iteration_->ScalingBarrier();
safemode_ = true; safemode_ = true;
} }
@ -419,9 +459,13 @@ void Server::ProcessAfterScalingIn() {
MS_ERROR_IF_NULL_WO_RET_VAL(server_node_); MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
if (server_node_->rank_id() == UINT32_MAX) { if (server_node_->rank_id() == UINT32_MAX) {
MS_LOG(WARNING) << "This server the one to be scaled in. Server exiting."; MS_LOG(WARNING) << "This server the one to be scaled in. Server exiting.";
(void)std::for_each( (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
communicators_with_worker_.begin(), communicators_with_worker_.end(), [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { (void)communicator->Stop(); }); MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
(void)communicator->Stop();
});
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
(void)communicator_with_server_->Stop(); (void)communicator_with_server_->Stop();
return; return;
} }

View File

@ -91,6 +91,7 @@ void FLWorker::Finalize() {
bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size, ps::core::TcpUserCommand command, bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size, ps::core::TcpUserCommand command,
std::shared_ptr<std::vector<unsigned char>> *output) { std::shared_ptr<std::vector<unsigned char>> *output) {
MS_EXCEPTION_IF_NULL(data);
// If the worker is in safemode, do not communicate with server. // If the worker is in safemode, do not communicate with server.
while (safemode_.load()) { while (safemode_.load()) {
std::this_thread::yield(); std::this_thread::yield();
@ -169,6 +170,7 @@ std::string FLWorker::fl_name() const { return ps::kServerModeFL; }
std::string FLWorker::fl_id() const { return std::to_string(rank_id_); } std::string FLWorker::fl_id() const { return std::to_string(rank_id_); }
void FLWorker::InitializeFollowerScaler() { void FLWorker::InitializeFollowerScaler() {
MS_EXCEPTION_IF_NULL(worker_node_);
if (!worker_node_->InitFollowerScaler()) { if (!worker_node_->InitFollowerScaler()) {
MS_LOG(EXCEPTION) << "Initializing follower elastic scaler failed."; MS_LOG(EXCEPTION) << "Initializing follower elastic scaler failed.";
return; return;
@ -225,10 +227,7 @@ void FLWorker::ProcessBeforeScalingIn() {
} }
void FLWorker::ProcessAfterScalingOut() { void FLWorker::ProcessAfterScalingOut() {
if (worker_node_ == nullptr) { MS_ERROR_IF_NULL_WO_RET_VAL(worker_node_);
return;
}
MS_LOG(INFO) << "Cluster scaling out completed. Reinitialize for worker."; MS_LOG(INFO) << "Cluster scaling out completed. Reinitialize for worker.";
server_num_ = IntToUint(worker_node_->server_num()); server_num_ = IntToUint(worker_node_->server_num());
worker_num_ = IntToUint(worker_node_->worker_num()); worker_num_ = IntToUint(worker_node_->worker_num());
@ -239,10 +238,7 @@ void FLWorker::ProcessAfterScalingOut() {
} }
void FLWorker::ProcessAfterScalingIn() { void FLWorker::ProcessAfterScalingIn() {
if (worker_node_ == nullptr) { MS_ERROR_IF_NULL_WO_RET_VAL(worker_node_);
return;
}
MS_LOG(INFO) << "Cluster scaling in completed. Reinitialize for worker."; MS_LOG(INFO) << "Cluster scaling in completed. Reinitialize for worker.";
server_num_ = IntToUint(worker_node_->server_num()); server_num_ = IntToUint(worker_node_->server_num());
worker_num_ = IntToUint(worker_node_->worker_num()); worker_num_ = IntToUint(worker_node_->worker_num());