forked from mindspore-Ecosystem/mindspore
Fix code review.
This commit is contained in:
parent
79af42153d
commit
a2b8198024
|
@ -121,7 +121,7 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) {
|
|||
MS_LOG(INFO) << "Rank " << local_rank_ << " query whether count reaches threshold for " << name;
|
||||
if (local_rank_ == counting_server_rank_) {
|
||||
if (global_threshold_count_.count(name) == 0) {
|
||||
MS_LOG(ERROR) << "Counter for " << name << " is not set.";
|
||||
MS_LOG(ERROR) << "Counter for " << name << " is not registered.";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -138,6 +138,7 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) {
|
|||
return false;
|
||||
}
|
||||
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(query_cnt_enough_rsp_msg, false);
|
||||
CountReachThresholdResponse count_reach_threshold_rsp;
|
||||
(void)count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(),
|
||||
SizeToInt(query_cnt_enough_rsp_msg->size()));
|
||||
|
@ -173,11 +174,7 @@ bool DistributedCountService::ReInitForScaling() {
|
|||
}
|
||||
|
||||
void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(ERROR) << "Message is nullptr.";
|
||||
return;
|
||||
}
|
||||
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
||||
CountRequest report_count_req;
|
||||
(void)report_count_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
const std::string &name = report_count_req.name();
|
||||
|
@ -235,11 +232,7 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core:
|
|||
|
||||
void DistributedCountService::HandleCountReachThresholdRequest(
|
||||
const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(ERROR) << "Message is nullptr.";
|
||||
return;
|
||||
}
|
||||
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
||||
CountReachThresholdRequest count_reach_threshold_req;
|
||||
(void)count_reach_threshold_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
const std::string &name = count_reach_threshold_req.name();
|
||||
|
@ -261,11 +254,7 @@ void DistributedCountService::HandleCountReachThresholdRequest(
|
|||
}
|
||||
|
||||
void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(ERROR) << "Message is nullptr.";
|
||||
return;
|
||||
}
|
||||
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
||||
// Respond as soon as possible so the leader server won't wait for each follower servers to finish calling the
|
||||
// callbacks.
|
||||
std::string couter_event_rsp_msg = "success";
|
||||
|
@ -279,6 +268,10 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core:
|
|||
const auto &type = counter_event.type();
|
||||
const auto &name = counter_event.name();
|
||||
|
||||
if (counter_handlers_.count(name) == 0) {
|
||||
MS_LOG(ERROR) << "The counter handler of " << name << " is not registered.";
|
||||
return;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Rank " << local_rank_ << " do counter event " << type << " for " << name;
|
||||
if (type == CounterEventType::FIRST_CNT) {
|
||||
counter_handlers_[name].first_count_handler(message);
|
||||
|
@ -292,6 +285,11 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core:
|
|||
}
|
||||
|
||||
bool DistributedCountService::TriggerCounterEvent(const std::string &name, std::string *reason) {
|
||||
if (global_current_count_.count(name) == 0 || global_threshold_count_.count(name) == 0) {
|
||||
MS_LOG(ERROR) << "The counter of " << name << " is not registered.";
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Current count for " << name << " is " << global_current_count_[name].size()
|
||||
<< ", threshold count is " << global_threshold_count_[name];
|
||||
// The threshold count may be 1 so the first and last count event should be both activated.
|
||||
|
@ -324,6 +322,11 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, st
|
|||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (counter_handlers_.count(name) == 0) {
|
||||
MS_LOG(ERROR) << "The counter handler of " << name << " is not registered.";
|
||||
return false;
|
||||
}
|
||||
// Leader server directly calls the callback.
|
||||
counter_handlers_[name].first_count_handler(nullptr);
|
||||
return true;
|
||||
|
@ -345,6 +348,11 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std
|
|||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (counter_handlers_.count(name) == 0) {
|
||||
MS_LOG(ERROR) << "The counter handler of " << name << " is not registered.";
|
||||
return false;
|
||||
}
|
||||
// Leader server directly calls the callback.
|
||||
counter_handlers_[name].last_count_handler(nullptr);
|
||||
return true;
|
||||
|
|
|
@ -181,11 +181,7 @@ void DistributedMetadataStore::InitHashRing() {
|
|||
}
|
||||
|
||||
void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(ERROR) << "Message is nullptr.";
|
||||
return;
|
||||
}
|
||||
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
||||
PBMetadataWithName meta_with_name;
|
||||
(void)meta_with_name.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
const std::string &name = meta_with_name.name();
|
||||
|
@ -206,17 +202,17 @@ void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr
|
|||
}
|
||||
|
||||
void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(ERROR) << "Message is nullptr.";
|
||||
return;
|
||||
}
|
||||
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
||||
GetMetadataRequest get_metadata_req;
|
||||
(void)get_metadata_req.ParseFromArray(message->data(), message->len());
|
||||
const std::string &name = get_metadata_req.name();
|
||||
MS_LOG(INFO) << "Getting metadata for " << name;
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||
if (metadata_.count(name) == 0) {
|
||||
MS_LOG(ERROR) << "The metadata of " << name << " is not registered.";
|
||||
return;
|
||||
}
|
||||
PBMetadata stored_meta = metadata_[name];
|
||||
std::string getting_meta_rsp_msg = stored_meta.SerializeAsString();
|
||||
if (!communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message)) {
|
||||
|
@ -228,6 +224,10 @@ void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<ps
|
|||
|
||||
bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) {
|
||||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||
if (metadata_.count(name) == 0) {
|
||||
MS_LOG(ERROR) << "The metadata of " << name << " is not registered.";
|
||||
return false;
|
||||
}
|
||||
if (meta.has_device_meta()) {
|
||||
auto &fl_id_to_meta_map = *metadata_[name].mutable_device_metas()->mutable_fl_id_to_meta();
|
||||
auto &device_meta_fl_id = meta.device_meta().fl_id();
|
||||
|
|
|
@ -158,17 +158,10 @@ bool Executor::HandlePushWeight(const std::map<std::string, Address> &feature_ma
|
|||
auto ¶m_aggr = param_aggrs_[param_name];
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
|
||||
AddressPtr old_weight = param_aggr->GetWeight();
|
||||
if (old_weight == nullptr) {
|
||||
MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr.";
|
||||
return false;
|
||||
}
|
||||
|
||||
const Address &new_weight = trainable_param.second;
|
||||
if (new_weight.addr == nullptr) {
|
||||
MS_LOG(ERROR) << "The new weight is nullptr.";
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(old_weight, false);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(old_weight->addr, false);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(new_weight.addr, false);
|
||||
int ret = memcpy_s(old_weight->addr, old_weight->size, new_weight.addr, new_weight.size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||
|
@ -295,6 +288,7 @@ std::string Executor::GetTrainableParamName(const CNodePtr &cnode) {
|
|||
MS_EXCEPTION_IF_NULL(weight_node);
|
||||
if (!weight_node->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << weight_idx << " input of " << cnode_name << " is not a Parameter.";
|
||||
return "";
|
||||
}
|
||||
return weight_node->fullname_with_scope();
|
||||
}
|
||||
|
@ -309,7 +303,7 @@ bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
|
|||
continue;
|
||||
}
|
||||
if (param_aggrs_.count(param_name) != 0) {
|
||||
MS_LOG(WARNING) << param_name << " already has its control flow.";
|
||||
MS_LOG(WARNING) << param_name << " already has parameter aggregator registered.";
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
@ -66,9 +66,7 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<ps::core::Communica
|
|||
(void)std::for_each(communicators.begin(), communicators.end(),
|
||||
[&](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
|
||||
for (auto &round : rounds_) {
|
||||
if (round == nullptr) {
|
||||
continue;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(round);
|
||||
round->Initialize(communicator, timeout_cb, finish_iteration_cb);
|
||||
}
|
||||
});
|
||||
|
@ -76,6 +74,7 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<ps::core::Communica
|
|||
// The time window for one iteration, which will be used in some round kernels.
|
||||
size_t iteration_time_window = std::accumulate(rounds_.begin(), rounds_.end(), IntToSize(0),
|
||||
[](size_t total, const std::shared_ptr<Round> &round) {
|
||||
MS_EXCEPTION_IF_NULL(round);
|
||||
return round->check_timeout() ? total + round->time_window() : total;
|
||||
});
|
||||
LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window);
|
||||
|
@ -172,11 +171,8 @@ bool Iteration::SyncIteration(uint32_t rank) {
|
|||
MS_LOG(ERROR) << "Sending synchronizing iteration message to leader server failed.";
|
||||
return false;
|
||||
}
|
||||
if (sync_iter_rsp_msg == nullptr) {
|
||||
MS_LOG(ERROR) << "Response from server 0 is empty.";
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(sync_iter_rsp_msg, false);
|
||||
SyncIterationResponse sync_iter_rsp;
|
||||
(void)sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), SizeToInt(sync_iter_rsp_msg->size()));
|
||||
iteration_num_ = sync_iter_rsp.iteration();
|
||||
|
@ -372,16 +368,17 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
|
|||
if (is_iteration_valid) {
|
||||
// Store the model which is successfully aggregated for this iteration.
|
||||
const auto &model = Executor::GetInstance().GetModel();
|
||||
(void)ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||
MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
|
||||
} else {
|
||||
// Store last iteration's model because this iteration is considered as invalid.
|
||||
const auto &model = ModelStore::GetInstance().GetModelByIterNum(iteration_num_ - 1);
|
||||
(void)ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||
MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason;
|
||||
}
|
||||
|
||||
for (auto &round : rounds_) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(round);
|
||||
round->Reset();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,10 +36,12 @@ void IterationTimer::Start(const std::chrono::milliseconds &duration) {
|
|||
std::this_thread::sleep_for(std::chrono::milliseconds(1));
|
||||
}
|
||||
});
|
||||
monitor_thread_.detach();
|
||||
}
|
||||
|
||||
void IterationTimer::Stop() { running_ = false; }
|
||||
void IterationTimer::Stop() {
|
||||
running_ = false;
|
||||
monitor_thread_.join();
|
||||
}
|
||||
|
||||
void IterationTimer::SetTimeOutCallBack(const TimeOutCb &timeout_cb) {
|
||||
timeout_callback_ = timeout_cb;
|
||||
|
|
|
@ -82,6 +82,10 @@ class FedAvgKernel : public AggregationKernel {
|
|||
}
|
||||
};
|
||||
last_cnt_handler_ = [&](std::shared_ptr<ps::core::MessageHandler>) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(weight_addr_);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(data_size_addr_);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(weight_addr_->addr);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(data_size_addr_->addr);
|
||||
T *weight_addr = reinterpret_cast<T *>(weight_addr_->addr);
|
||||
size_t weight_size = weight_addr_->size;
|
||||
S *data_size_addr = reinterpret_cast<S *>(data_size_addr_->addr);
|
||||
|
|
|
@ -35,33 +35,27 @@ void ModelStore::Initialize(uint32_t max_count) {
|
|||
model_size_ = ComputeModelSize();
|
||||
}
|
||||
|
||||
bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &new_model) {
|
||||
void ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &new_model) {
|
||||
std::unique_lock<std::mutex> lock(model_mtx_);
|
||||
if (iteration_to_model_.count(iteration) != 0) {
|
||||
MS_LOG(WARNING) << "Model for iteration " << iteration << " is already stored";
|
||||
return false;
|
||||
return;
|
||||
}
|
||||
if (new_model.empty()) {
|
||||
MS_LOG(ERROR) << "Model feature map is empty.";
|
||||
return false;
|
||||
return;
|
||||
}
|
||||
|
||||
std::shared_ptr<MemoryRegister> memory_register = nullptr;
|
||||
if (iteration_to_model_.size() < max_model_count_) {
|
||||
// If iteration_to_model_.size() is not max_model_count_, need to assign new memory for the model.
|
||||
memory_register = AssignNewModelMemory();
|
||||
if (memory_register == nullptr) {
|
||||
MS_LOG(ERROR) << "Memory for the new model is nullptr.";
|
||||
return false;
|
||||
}
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(memory_register);
|
||||
iteration_to_model_[iteration] = memory_register;
|
||||
} else {
|
||||
// If iteration_to_model_ size is already max_model_count_, we need to replace earliest model with the newest model.
|
||||
memory_register = iteration_to_model_.begin()->second;
|
||||
if (memory_register == nullptr) {
|
||||
MS_LOG(ERROR) << "Earliest model is nullptr.";
|
||||
return false;
|
||||
}
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(memory_register);
|
||||
(void)iteration_to_model_.erase(iteration_to_model_.begin());
|
||||
}
|
||||
|
||||
|
@ -74,6 +68,10 @@ bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::strin
|
|||
continue;
|
||||
}
|
||||
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(stored_model[weight_name]);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(stored_model[weight_name]->addr);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(weight.second);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(weight.second->addr);
|
||||
void *dst_addr = stored_model[weight_name]->addr;
|
||||
size_t dst_size = stored_model[weight_name]->size;
|
||||
void *src_addr = weight.second->addr;
|
||||
|
@ -81,11 +79,11 @@ bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::strin
|
|||
int ret = memcpy_s(dst_addr, dst_size, src_addr, src_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return false;
|
||||
return;
|
||||
}
|
||||
}
|
||||
iteration_to_model_[iteration] = memory_register;
|
||||
return true;
|
||||
return;
|
||||
}
|
||||
|
||||
std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration) {
|
||||
|
|
|
@ -47,7 +47,7 @@ class ModelStore {
|
|||
|
||||
// Store the model of the given iteration. The model is acquired from Executor. If the current model count is already
|
||||
// max_model_count_, the earliest model will be replaced.
|
||||
bool StoreModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &model);
|
||||
void StoreModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &model);
|
||||
|
||||
// Get model of the given iteration.
|
||||
std::map<std::string, AddressPtr> GetModelByIterNum(size_t iteration);
|
||||
|
|
|
@ -49,7 +49,10 @@ bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) {
|
|||
|
||||
bool ParameterAggregator::ReInitForScaling() {
|
||||
auto result = std::find_if(aggregation_kernel_parameters_.begin(), aggregation_kernel_parameters_.end(),
|
||||
[](auto aggregation_kernel) { return !aggregation_kernel.first->ReInitForScaling(); });
|
||||
[](auto aggregation_kernel) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(aggregation_kernel.first, true);
|
||||
return !aggregation_kernel.first->ReInitForScaling();
|
||||
});
|
||||
if (result != aggregation_kernel_parameters_.end()) {
|
||||
MS_LOG(ERROR) << "Reinitializing aggregation kernel after scaling failed";
|
||||
return false;
|
||||
|
@ -65,6 +68,9 @@ bool ParameterAggregator::UpdateData(const std::map<std::string, Address> &new_d
|
|||
continue;
|
||||
}
|
||||
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(name_to_addr[name], false);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(name_to_addr[name]->addr, false);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(data.second.addr, false);
|
||||
MS_LOG(DEBUG) << "Update data for " << name << ". Destination size: " << name_to_addr[name]->size
|
||||
<< ". Source size: " << data.second.size;
|
||||
int ret = memcpy_s(name_to_addr[name]->addr, name_to_addr[name]->size, data.second.addr, data.second.size);
|
||||
|
@ -228,9 +234,10 @@ bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) {
|
|||
template <typename K>
|
||||
bool ParameterAggregator::AssignMemory(K server_kernel, const CNodePtr &cnode,
|
||||
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
|
||||
std::shared_ptr<MemoryRegister> memory_register) {
|
||||
const std::shared_ptr<MemoryRegister> &memory_register) {
|
||||
MS_EXCEPTION_IF_NULL(server_kernel);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(memory_register);
|
||||
|
||||
const std::vector<std::string> &input_names = server_kernel->input_names();
|
||||
const std::vector<size_t> &input_size_list = server_kernel->GetInputSizeList();
|
||||
|
@ -272,8 +279,8 @@ bool ParameterAggregator::AssignMemory(K server_kernel, const CNodePtr &cnode,
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> aggr_kernel,
|
||||
const std::shared_ptr<MemoryRegister> memory_register) {
|
||||
bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> &aggr_kernel,
|
||||
const std::shared_ptr<MemoryRegister> &memory_register) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(memory_register, false);
|
||||
KernelParams aggr_params = {};
|
||||
|
@ -295,8 +302,9 @@ bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr<
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr<kernel::OptimizerKernel> optimizer_kernel,
|
||||
const std::shared_ptr<MemoryRegister> memory_register) {
|
||||
bool ParameterAggregator::GenerateOptimizerKernelParams(
|
||||
const std::shared_ptr<kernel::OptimizerKernel> &optimizer_kernel,
|
||||
const std::shared_ptr<MemoryRegister> &memory_register) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(optimizer_kernel, false);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(memory_register, false);
|
||||
KernelParams optimizer_params = {};
|
||||
|
@ -335,12 +343,12 @@ std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const C
|
|||
template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::OptimizerKernel> server_kernel,
|
||||
const CNodePtr &cnode,
|
||||
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
|
||||
std::shared_ptr<MemoryRegister> memory_register);
|
||||
const std::shared_ptr<MemoryRegister> &memory_register);
|
||||
|
||||
template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::AggregationKernel> server_kernel,
|
||||
const CNodePtr &cnode,
|
||||
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
|
||||
std::shared_ptr<MemoryRegister> memory_register);
|
||||
const std::shared_ptr<MemoryRegister> &memory_register);
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -105,14 +105,14 @@ class ParameterAggregator {
|
|||
// momentum, etc.
|
||||
template <typename K>
|
||||
bool AssignMemory(K server_kernel, const CNodePtr &cnode, const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
|
||||
std::shared_ptr<MemoryRegister> memory_register);
|
||||
const std::shared_ptr<MemoryRegister> &memory_register);
|
||||
|
||||
// Generate kernel parameters for aggregation/optimizer kernels. All the parameters is registered and stored in
|
||||
// memory_register.
|
||||
bool GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> aggr_kernel,
|
||||
const std::shared_ptr<MemoryRegister> memory_register);
|
||||
bool GenerateOptimizerKernelParams(const std::shared_ptr<kernel::OptimizerKernel> optim_kernel,
|
||||
const std::shared_ptr<MemoryRegister> memory_register);
|
||||
bool GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> &aggr_kernel,
|
||||
const std::shared_ptr<MemoryRegister> &memory_register);
|
||||
bool GenerateOptimizerKernelParams(const std::shared_ptr<kernel::OptimizerKernel> &optim_kernel,
|
||||
const std::shared_ptr<MemoryRegister> &memory_register);
|
||||
|
||||
// The selection of the aggregation algorithm depends on multiple factors. For example, server mode, user
|
||||
// configuration, etc.
|
||||
|
|
|
@ -40,8 +40,10 @@ void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &commun
|
|||
communicator_ = communicator;
|
||||
|
||||
// Register callback for round kernel.
|
||||
communicator_->RegisterMsgCallBack(
|
||||
name_, [&](std::shared_ptr<ps::core::MessageHandler> message) { LaunchRoundKernel(message); });
|
||||
communicator_->RegisterMsgCallBack(name_, [&](std::shared_ptr<ps::core::MessageHandler> message) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
||||
LaunchRoundKernel(message);
|
||||
});
|
||||
|
||||
// Callback when the iteration is finished.
|
||||
finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid, const std::string &) -> void {
|
||||
|
@ -50,10 +52,14 @@ void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &commun
|
|||
};
|
||||
|
||||
// Callback for finalizing the server. This can only be called once.
|
||||
finalize_cb_ = [&](void) -> void { (void)communicator_->Stop(); };
|
||||
finalize_cb_ = [&](void) -> void {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
|
||||
(void)communicator_->Stop();
|
||||
};
|
||||
|
||||
if (check_timeout_) {
|
||||
iter_timer_ = std::make_shared<IterationTimer>();
|
||||
MS_EXCEPTION_IF_NULL(iter_timer_);
|
||||
|
||||
// 1.Set the timeout callback for the timer.
|
||||
iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool is_iteration_valid, const std::string &) -> void {
|
||||
|
@ -63,6 +69,7 @@ void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &commun
|
|||
|
||||
// 2.Stopping timer callback which will be set to the round kernel.
|
||||
stop_timer_cb_ = [&](void) -> void {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(iter_timer_);
|
||||
MS_LOG(INFO) << "Round " << name_ << " kernel stops its timer.";
|
||||
iter_timer_->Stop();
|
||||
};
|
||||
|
@ -90,10 +97,7 @@ bool Round::ReInitForScaling(uint32_t server_num) {
|
|||
{first_count_handler, last_count_handler});
|
||||
}
|
||||
|
||||
if (kernel_ == nullptr) {
|
||||
MS_LOG(WARNING) << "Reinitializing for round " << name_ << " failed: round kernel is nullptr.";
|
||||
return false;
|
||||
}
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(kernel_, false);
|
||||
kernel_->InitKernel(threshold_count_);
|
||||
return true;
|
||||
}
|
||||
|
@ -108,6 +112,8 @@ void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel)
|
|||
|
||||
void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
|
||||
// If the server is still in the process of scaling, refuse the request.
|
||||
if (Server::GetInstance().IsSafeMode()) {
|
||||
MS_LOG(WARNING) << "The cluster is still in process of scaling, please retry " << name_ << " later.";
|
||||
|
@ -149,7 +155,10 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m
|
|||
return;
|
||||
}
|
||||
|
||||
void Round::Reset() { (void)kernel_->Reset(); }
|
||||
void Round::Reset() {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
|
||||
(void)kernel_->Reset();
|
||||
}
|
||||
|
||||
const std::string &Round::name() const { return name_; }
|
||||
|
||||
|
@ -160,6 +169,9 @@ bool Round::check_timeout() const { return check_timeout_; }
|
|||
size_t Round::time_window() const { return time_window_; }
|
||||
|
||||
void Round::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(iter_timer_);
|
||||
MS_LOG(INFO) << "Round " << name_ << " first count event is triggered.";
|
||||
// The timer starts only after the first count event is triggered by DistributedCountService.
|
||||
if (check_timeout_) {
|
||||
|
@ -172,6 +184,9 @@ void Round::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &m
|
|||
}
|
||||
|
||||
void Round::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(iter_timer_);
|
||||
MS_LOG(INFO) << "Round " << name_ << " last count event is triggered.";
|
||||
// Same as the first count event, the timer must be stopped by DistributedCountService.
|
||||
if (check_timeout_) {
|
||||
|
|
|
@ -86,7 +86,11 @@ void Server::Run() {
|
|||
|
||||
// Wait communicators to stop so the main thread is blocked.
|
||||
(void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Join(); });
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
|
||||
MS_EXCEPTION_IF_NULL(communicator);
|
||||
communicator->Join();
|
||||
});
|
||||
MS_EXCEPTION_IF_NULL(communicator_with_server_);
|
||||
communicator_with_server_->Join();
|
||||
MsException::Instance().CheckException();
|
||||
return;
|
||||
|
@ -151,6 +155,7 @@ bool Server::InitCommunicatorWithWorker() {
|
|||
return false;
|
||||
}
|
||||
if (use_tcp_) {
|
||||
MS_EXCEPTION_IF_NULL(communicator_with_server_);
|
||||
auto tcp_comm = communicator_with_server_;
|
||||
MS_EXCEPTION_IF_NULL(tcp_comm);
|
||||
communicators_with_worker_.push_back(tcp_comm);
|
||||
|
@ -201,21 +206,32 @@ void Server::InitIteration() {
|
|||
|
||||
std::shared_ptr<Round> exchange_keys_round =
|
||||
std::make_shared<Round>("exchangeKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_);
|
||||
MS_EXCEPTION_IF_NULL(exchange_keys_round);
|
||||
iteration_->AddRound(exchange_keys_round);
|
||||
|
||||
std::shared_ptr<Round> get_keys_round =
|
||||
std::make_shared<Round>("getKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_);
|
||||
MS_EXCEPTION_IF_NULL(get_keys_round);
|
||||
iteration_->AddRound(get_keys_round);
|
||||
|
||||
std::shared_ptr<Round> share_secrets_round =
|
||||
std::make_shared<Round>("shareSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_);
|
||||
MS_EXCEPTION_IF_NULL(share_secrets_round);
|
||||
iteration_->AddRound(share_secrets_round);
|
||||
|
||||
std::shared_ptr<Round> get_secrets_round =
|
||||
std::make_shared<Round>("getSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_);
|
||||
MS_EXCEPTION_IF_NULL(get_secrets_round);
|
||||
iteration_->AddRound(get_secrets_round);
|
||||
|
||||
std::shared_ptr<Round> get_clientlist_round =
|
||||
std::make_shared<Round>("getClientList", true, cipher_time_window_, true, cipher_get_clientlist_cnt_);
|
||||
MS_EXCEPTION_IF_NULL(get_clientlist_round);
|
||||
iteration_->AddRound(get_clientlist_round);
|
||||
|
||||
std::shared_ptr<Round> reconstruct_secrets_round = std::make_shared<Round>(
|
||||
"reconstructSecrets", true, cipher_time_window_, true, cipher_reconstruct_secrets_up_cnt_);
|
||||
MS_EXCEPTION_IF_NULL(reconstruct_secrets_round);
|
||||
iteration_->AddRound(reconstruct_secrets_round);
|
||||
MS_LOG(INFO) << "Cipher rounds has been added.";
|
||||
}
|
||||
|
@ -253,8 +269,16 @@ void Server::InitCipher() {
|
|||
mindspore::armour::CipherPublicPara param;
|
||||
param.g = cipher_g;
|
||||
param.t = cipher_t;
|
||||
memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, SECRET_MAX_LEN);
|
||||
memcpy_s(param.prime, PRIME_MAX_LEN, cipher_prime, PRIME_MAX_LEN);
|
||||
int ret = memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, SECRET_MAX_LEN);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return;
|
||||
}
|
||||
ret = memcpy_s(param.prime, PRIME_MAX_LEN, cipher_prime, PRIME_MAX_LEN);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return;
|
||||
}
|
||||
param.dp_delta = dp_delta;
|
||||
param.dp_eps = dp_eps;
|
||||
param.dp_norm_clip = dp_norm_clip;
|
||||
|
@ -268,6 +292,8 @@ void Server::InitCipher() {
|
|||
void Server::RegisterCommCallbacks() {
|
||||
// The message callbacks of round kernels are already set in method InitIteration, so here we don't need to register
|
||||
// rounds' callbacks.
|
||||
MS_EXCEPTION_IF_NULL(server_node_);
|
||||
MS_EXCEPTION_IF_NULL(iteration_);
|
||||
|
||||
auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
|
||||
MS_EXCEPTION_IF_NULL(tcp_comm);
|
||||
|
@ -302,9 +328,13 @@ void Server::RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpC
|
|||
communicator->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() {
|
||||
MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
|
||||
safemode_ = true;
|
||||
(void)std::for_each(
|
||||
communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { (void)communicator->Stop(); });
|
||||
(void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
|
||||
(void)communicator->Stop();
|
||||
});
|
||||
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
|
||||
(void)communicator_with_server_->Stop();
|
||||
});
|
||||
|
||||
|
@ -313,14 +343,19 @@ void Server::RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpC
|
|||
<< "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
|
||||
"network building phase.";
|
||||
safemode_ = true;
|
||||
(void)std::for_each(
|
||||
communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { (void)communicator->Stop(); });
|
||||
(void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
|
||||
(void)communicator->Stop();
|
||||
});
|
||||
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
|
||||
(void)communicator_with_server_->Stop();
|
||||
});
|
||||
}
|
||||
|
||||
void Server::InitExecutor() {
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
if (executor_threshold_ == 0) {
|
||||
MS_LOG(EXCEPTION) << "The executor's threshold should greater than 0.";
|
||||
return;
|
||||
|
@ -342,6 +377,7 @@ void Server::RegisterRoundKernel() {
|
|||
}
|
||||
|
||||
for (auto &round : rounds) {
|
||||
MS_EXCEPTION_IF_NULL(round);
|
||||
const std::string &name = round->name();
|
||||
std::shared_ptr<kernel::RoundKernel> round_kernel = kernel::RoundKernelFactory::GetInstance().Create(name);
|
||||
if (round_kernel == nullptr) {
|
||||
|
@ -357,12 +393,13 @@ void Server::RegisterRoundKernel() {
|
|||
}
|
||||
|
||||
void Server::StartCommunicator() {
|
||||
MS_EXCEPTION_IF_NULL(communicator_with_server_);
|
||||
if (communicators_with_worker_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Communicators for communication with worker is empty.";
|
||||
return;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(server_node_);
|
||||
MS_EXCEPTION_IF_NULL(communicator_with_server_);
|
||||
MS_LOG(INFO) << "Start communicator with server.";
|
||||
if (!communicator_with_server_->Start()) {
|
||||
MS_LOG(EXCEPTION) << "Starting communicator with server failed.";
|
||||
|
@ -376,6 +413,7 @@ void Server::StartCommunicator() {
|
|||
MS_LOG(INFO) << "Start communicator with worker.";
|
||||
(void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
|
||||
if (!communicator->Start()) {
|
||||
MS_LOG(EXCEPTION) << "Starting communicator with worker failed.";
|
||||
}
|
||||
|
@ -383,11 +421,13 @@ void Server::StartCommunicator() {
|
|||
}
|
||||
|
||||
void Server::ProcessBeforeScalingOut() {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
|
||||
iteration_->ScalingBarrier();
|
||||
safemode_ = true;
|
||||
}
|
||||
|
||||
void Server::ProcessBeforeScalingIn() {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
|
||||
iteration_->ScalingBarrier();
|
||||
safemode_ = true;
|
||||
}
|
||||
|
@ -419,9 +459,13 @@ void Server::ProcessAfterScalingIn() {
|
|||
MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
|
||||
if (server_node_->rank_id() == UINT32_MAX) {
|
||||
MS_LOG(WARNING) << "This server the one to be scaled in. Server exiting.";
|
||||
(void)std::for_each(
|
||||
communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { (void)communicator->Stop(); });
|
||||
(void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
|
||||
(void)communicator->Stop();
|
||||
});
|
||||
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
|
||||
(void)communicator_with_server_->Stop();
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -91,6 +91,7 @@ void FLWorker::Finalize() {
|
|||
|
||||
bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size, ps::core::TcpUserCommand command,
|
||||
std::shared_ptr<std::vector<unsigned char>> *output) {
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
// If the worker is in safemode, do not communicate with server.
|
||||
while (safemode_.load()) {
|
||||
std::this_thread::yield();
|
||||
|
@ -169,6 +170,7 @@ std::string FLWorker::fl_name() const { return ps::kServerModeFL; }
|
|||
std::string FLWorker::fl_id() const { return std::to_string(rank_id_); }
|
||||
|
||||
void FLWorker::InitializeFollowerScaler() {
|
||||
MS_EXCEPTION_IF_NULL(worker_node_);
|
||||
if (!worker_node_->InitFollowerScaler()) {
|
||||
MS_LOG(EXCEPTION) << "Initializing follower elastic scaler failed.";
|
||||
return;
|
||||
|
@ -225,10 +227,7 @@ void FLWorker::ProcessBeforeScalingIn() {
|
|||
}
|
||||
|
||||
void FLWorker::ProcessAfterScalingOut() {
|
||||
if (worker_node_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(worker_node_);
|
||||
MS_LOG(INFO) << "Cluster scaling out completed. Reinitialize for worker.";
|
||||
server_num_ = IntToUint(worker_node_->server_num());
|
||||
worker_num_ = IntToUint(worker_node_->worker_num());
|
||||
|
@ -239,10 +238,7 @@ void FLWorker::ProcessAfterScalingOut() {
|
|||
}
|
||||
|
||||
void FLWorker::ProcessAfterScalingIn() {
|
||||
if (worker_node_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(worker_node_);
|
||||
MS_LOG(INFO) << "Cluster scaling in completed. Reinitialize for worker.";
|
||||
server_num_ = IntToUint(worker_node_->server_num());
|
||||
worker_num_ = IntToUint(worker_node_->worker_num());
|
||||
|
|
Loading…
Reference in New Issue