diff --git a/mindspore/ccsrc/fl/server/distributed_count_service.cc b/mindspore/ccsrc/fl/server/distributed_count_service.cc index 766e5a4013f..e3ceb8ae7a9 100644 --- a/mindspore/ccsrc/fl/server/distributed_count_service.cc +++ b/mindspore/ccsrc/fl/server/distributed_count_service.cc @@ -121,7 +121,7 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) { MS_LOG(INFO) << "Rank " << local_rank_ << " query whether count reaches threshold for " << name; if (local_rank_ == counting_server_rank_) { if (global_threshold_count_.count(name) == 0) { - MS_LOG(ERROR) << "Counter for " << name << " is not set."; + MS_LOG(ERROR) << "Counter for " << name << " is not registered."; return false; } @@ -138,6 +138,7 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) { return false; } + MS_ERROR_IF_NULL_W_RET_VAL(query_cnt_enough_rsp_msg, false); CountReachThresholdResponse count_reach_threshold_rsp; (void)count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), SizeToInt(query_cnt_enough_rsp_msg->size())); @@ -173,11 +174,7 @@ bool DistributedCountService::ReInitForScaling() { } void DistributedCountService::HandleCountRequest(const std::shared_ptr &message) { - if (message == nullptr) { - MS_LOG(ERROR) << "Message is nullptr."; - return; - } - + MS_ERROR_IF_NULL_WO_RET_VAL(message); CountRequest report_count_req; (void)report_count_req.ParseFromArray(message->data(), SizeToInt(message->len())); const std::string &name = report_count_req.name(); @@ -235,11 +232,7 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr &message) { - if (message == nullptr) { - MS_LOG(ERROR) << "Message is nullptr."; - return; - } - + MS_ERROR_IF_NULL_WO_RET_VAL(message); CountReachThresholdRequest count_reach_threshold_req; (void)count_reach_threshold_req.ParseFromArray(message->data(), SizeToInt(message->len())); const std::string &name = count_reach_threshold_req.name(); @@ -261,11 +254,7 @@ void DistributedCountService::HandleCountReachThresholdRequest( } void DistributedCountService::HandleCounterEvent(const std::shared_ptr &message) { - if (message == nullptr) { - MS_LOG(ERROR) << "Message is nullptr."; - return; - } - + MS_ERROR_IF_NULL_WO_RET_VAL(message); // Respond as soon as possible so the leader server won't wait for each follower servers to finish calling the // callbacks. std::string couter_event_rsp_msg = "success"; @@ -279,6 +268,10 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr &message) { - if (message == nullptr) { - MS_LOG(ERROR) << "Message is nullptr."; - return; - } - + MS_ERROR_IF_NULL_WO_RET_VAL(message); PBMetadataWithName meta_with_name; (void)meta_with_name.ParseFromArray(message->data(), SizeToInt(message->len())); const std::string &name = meta_with_name.name(); @@ -206,17 +202,17 @@ void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr } void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr &message) { - if (message == nullptr) { - MS_LOG(ERROR) << "Message is nullptr."; - return; - } - + MS_ERROR_IF_NULL_WO_RET_VAL(message); GetMetadataRequest get_metadata_req; (void)get_metadata_req.ParseFromArray(message->data(), message->len()); const std::string &name = get_metadata_req.name(); MS_LOG(INFO) << "Getting metadata for " << name; std::unique_lock lock(mutex_[name]); + if (metadata_.count(name) == 0) { + MS_LOG(ERROR) << "The metadata of " << name << " is not registered."; + return; + } PBMetadata stored_meta = metadata_[name]; std::string getting_meta_rsp_msg = stored_meta.SerializeAsString(); if (!communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message)) { @@ -228,6 +224,10 @@ void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr lock(mutex_[name]); + if (metadata_.count(name) == 0) { + MS_LOG(ERROR) << "The metadata of " << name << " is not registered."; + return false; + } if (meta.has_device_meta()) { auto &fl_id_to_meta_map = *metadata_[name].mutable_device_metas()->mutable_fl_id_to_meta(); auto &device_meta_fl_id = meta.device_meta().fl_id(); diff --git a/mindspore/ccsrc/fl/server/executor.cc b/mindspore/ccsrc/fl/server/executor.cc index d5233a6e612..460b8dba502 100644 --- a/mindspore/ccsrc/fl/server/executor.cc +++ b/mindspore/ccsrc/fl/server/executor.cc @@ -158,17 +158,10 @@ bool Executor::HandlePushWeight(const std::map &feature_ma auto ¶m_aggr = param_aggrs_[param_name]; MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false); AddressPtr old_weight = param_aggr->GetWeight(); - if (old_weight == nullptr) { - MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr."; - return false; - } - const Address &new_weight = trainable_param.second; - if (new_weight.addr == nullptr) { - MS_LOG(ERROR) << "The new weight is nullptr."; - return false; - } - + MS_ERROR_IF_NULL_W_RET_VAL(old_weight, false); + MS_ERROR_IF_NULL_W_RET_VAL(old_weight->addr, false); + MS_ERROR_IF_NULL_W_RET_VAL(new_weight.addr, false); int ret = memcpy_s(old_weight->addr, old_weight->size, new_weight.addr, new_weight.size); if (ret != 0) { MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; @@ -295,6 +288,7 @@ std::string Executor::GetTrainableParamName(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(weight_node); if (!weight_node->isa()) { MS_LOG(EXCEPTION) << weight_idx << " input of " << cnode_name << " is not a Parameter."; + return ""; } return weight_node->fullname_with_scope(); } @@ -309,7 +303,7 @@ bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) { continue; } if (param_aggrs_.count(param_name) != 0) { - MS_LOG(WARNING) << param_name << " already has its control flow."; + MS_LOG(WARNING) << param_name << " already has parameter aggregator registered."; continue; } diff --git a/mindspore/ccsrc/fl/server/iteration.cc b/mindspore/ccsrc/fl/server/iteration.cc index 6d2b6fbb007..21e4ad014d6 100644 --- a/mindspore/ccsrc/fl/server/iteration.cc +++ b/mindspore/ccsrc/fl/server/iteration.cc @@ -66,9 +66,7 @@ void Iteration::InitRounds(const std::vector &communicator) { for (auto &round : rounds_) { - if (round == nullptr) { - continue; - } + MS_EXCEPTION_IF_NULL(round); round->Initialize(communicator, timeout_cb, finish_iteration_cb); } }); @@ -76,6 +74,7 @@ void Iteration::InitRounds(const std::vector &round) { + MS_EXCEPTION_IF_NULL(round); return round->check_timeout() ? total + round->time_window() : total; }); LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window); @@ -172,11 +171,8 @@ bool Iteration::SyncIteration(uint32_t rank) { MS_LOG(ERROR) << "Sending synchronizing iteration message to leader server failed."; return false; } - if (sync_iter_rsp_msg == nullptr) { - MS_LOG(ERROR) << "Response from server 0 is empty."; - return false; - } + MS_ERROR_IF_NULL_W_RET_VAL(sync_iter_rsp_msg, false); SyncIterationResponse sync_iter_rsp; (void)sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), SizeToInt(sync_iter_rsp_msg->size())); iteration_num_ = sync_iter_rsp.iteration(); @@ -372,16 +368,17 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) { if (is_iteration_valid) { // Store the model which is successfully aggregated for this iteration. const auto &model = Executor::GetInstance().GetModel(); - (void)ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); + ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished."; } else { // Store last iteration's model because this iteration is considered as invalid. const auto &model = ModelStore::GetInstance().GetModelByIterNum(iteration_num_ - 1); - (void)ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); + ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason; } for (auto &round : rounds_) { + MS_ERROR_IF_NULL_WO_RET_VAL(round); round->Reset(); } } diff --git a/mindspore/ccsrc/fl/server/iteration_timer.cc b/mindspore/ccsrc/fl/server/iteration_timer.cc index f805aefafaf..27a98c4191a 100644 --- a/mindspore/ccsrc/fl/server/iteration_timer.cc +++ b/mindspore/ccsrc/fl/server/iteration_timer.cc @@ -36,10 +36,12 @@ void IterationTimer::Start(const std::chrono::milliseconds &duration) { std::this_thread::sleep_for(std::chrono::milliseconds(1)); } }); - monitor_thread_.detach(); } -void IterationTimer::Stop() { running_ = false; } +void IterationTimer::Stop() { + running_ = false; + monitor_thread_.join(); +} void IterationTimer::SetTimeOutCallBack(const TimeOutCb &timeout_cb) { timeout_callback_ = timeout_cb; diff --git a/mindspore/ccsrc/fl/server/kernel/fed_avg_kernel.h b/mindspore/ccsrc/fl/server/kernel/fed_avg_kernel.h index 49fd0db563d..fa7b4abc172 100644 --- a/mindspore/ccsrc/fl/server/kernel/fed_avg_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/fed_avg_kernel.h @@ -82,6 +82,10 @@ class FedAvgKernel : public AggregationKernel { } }; last_cnt_handler_ = [&](std::shared_ptr) { + MS_ERROR_IF_NULL_WO_RET_VAL(weight_addr_); + MS_ERROR_IF_NULL_WO_RET_VAL(data_size_addr_); + MS_ERROR_IF_NULL_WO_RET_VAL(weight_addr_->addr); + MS_ERROR_IF_NULL_WO_RET_VAL(data_size_addr_->addr); T *weight_addr = reinterpret_cast(weight_addr_->addr); size_t weight_size = weight_addr_->size; S *data_size_addr = reinterpret_cast(data_size_addr_->addr); diff --git a/mindspore/ccsrc/fl/server/model_store.cc b/mindspore/ccsrc/fl/server/model_store.cc index 832468d9ddb..8cbab89a9cc 100644 --- a/mindspore/ccsrc/fl/server/model_store.cc +++ b/mindspore/ccsrc/fl/server/model_store.cc @@ -35,33 +35,27 @@ void ModelStore::Initialize(uint32_t max_count) { model_size_ = ComputeModelSize(); } -bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map &new_model) { +void ModelStore::StoreModelByIterNum(size_t iteration, const std::map &new_model) { std::unique_lock lock(model_mtx_); if (iteration_to_model_.count(iteration) != 0) { MS_LOG(WARNING) << "Model for iteration " << iteration << " is already stored"; - return false; + return; } if (new_model.empty()) { MS_LOG(ERROR) << "Model feature map is empty."; - return false; + return; } std::shared_ptr memory_register = nullptr; if (iteration_to_model_.size() < max_model_count_) { // If iteration_to_model_.size() is not max_model_count_, need to assign new memory for the model. memory_register = AssignNewModelMemory(); - if (memory_register == nullptr) { - MS_LOG(ERROR) << "Memory for the new model is nullptr."; - return false; - } + MS_ERROR_IF_NULL_WO_RET_VAL(memory_register); iteration_to_model_[iteration] = memory_register; } else { // If iteration_to_model_ size is already max_model_count_, we need to replace earliest model with the newest model. memory_register = iteration_to_model_.begin()->second; - if (memory_register == nullptr) { - MS_LOG(ERROR) << "Earliest model is nullptr."; - return false; - } + MS_ERROR_IF_NULL_WO_RET_VAL(memory_register); (void)iteration_to_model_.erase(iteration_to_model_.begin()); } @@ -74,6 +68,10 @@ bool ModelStore::StoreModelByIterNum(size_t iteration, const std::mapaddr); + MS_ERROR_IF_NULL_WO_RET_VAL(weight.second); + MS_ERROR_IF_NULL_WO_RET_VAL(weight.second->addr); void *dst_addr = stored_model[weight_name]->addr; size_t dst_size = stored_model[weight_name]->size; void *src_addr = weight.second->addr; @@ -81,11 +79,11 @@ bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map ModelStore::GetModelByIterNum(size_t iteration) { diff --git a/mindspore/ccsrc/fl/server/model_store.h b/mindspore/ccsrc/fl/server/model_store.h index ea704bd3998..b8d38829ef8 100644 --- a/mindspore/ccsrc/fl/server/model_store.h +++ b/mindspore/ccsrc/fl/server/model_store.h @@ -47,7 +47,7 @@ class ModelStore { // Store the model of the given iteration. The model is acquired from Executor. If the current model count is already // max_model_count_, the earliest model will be replaced. - bool StoreModelByIterNum(size_t iteration, const std::map &model); + void StoreModelByIterNum(size_t iteration, const std::map &model); // Get model of the given iteration. std::map GetModelByIterNum(size_t iteration); diff --git a/mindspore/ccsrc/fl/server/parameter_aggregator.cc b/mindspore/ccsrc/fl/server/parameter_aggregator.cc index 2d73d24d797..cb93808ad24 100644 --- a/mindspore/ccsrc/fl/server/parameter_aggregator.cc +++ b/mindspore/ccsrc/fl/server/parameter_aggregator.cc @@ -49,7 +49,10 @@ bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) { bool ParameterAggregator::ReInitForScaling() { auto result = std::find_if(aggregation_kernel_parameters_.begin(), aggregation_kernel_parameters_.end(), - [](auto aggregation_kernel) { return !aggregation_kernel.first->ReInitForScaling(); }); + [](auto aggregation_kernel) { + MS_ERROR_IF_NULL_W_RET_VAL(aggregation_kernel.first, true); + return !aggregation_kernel.first->ReInitForScaling(); + }); if (result != aggregation_kernel_parameters_.end()) { MS_LOG(ERROR) << "Reinitializing aggregation kernel after scaling failed"; return false; @@ -65,6 +68,9 @@ bool ParameterAggregator::UpdateData(const std::map &new_d continue; } + MS_ERROR_IF_NULL_W_RET_VAL(name_to_addr[name], false); + MS_ERROR_IF_NULL_W_RET_VAL(name_to_addr[name]->addr, false); + MS_ERROR_IF_NULL_W_RET_VAL(data.second.addr, false); MS_LOG(DEBUG) << "Update data for " << name << ". Destination size: " << name_to_addr[name]->size << ". Source size: " << data.second.size; int ret = memcpy_s(name_to_addr[name]->addr, name_to_addr[name]->size, data.second.addr, data.second.size); @@ -228,9 +234,10 @@ bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) { template bool ParameterAggregator::AssignMemory(K server_kernel, const CNodePtr &cnode, const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info, - std::shared_ptr memory_register) { + const std::shared_ptr &memory_register) { MS_EXCEPTION_IF_NULL(server_kernel); MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(memory_register); const std::vector &input_names = server_kernel->input_names(); const std::vector &input_size_list = server_kernel->GetInputSizeList(); @@ -272,8 +279,8 @@ bool ParameterAggregator::AssignMemory(K server_kernel, const CNodePtr &cnode, return true; } -bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr aggr_kernel, - const std::shared_ptr memory_register) { +bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr &aggr_kernel, + const std::shared_ptr &memory_register) { MS_ERROR_IF_NULL_W_RET_VAL(aggr_kernel, false); MS_ERROR_IF_NULL_W_RET_VAL(memory_register, false); KernelParams aggr_params = {}; @@ -295,8 +302,9 @@ bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr< return true; } -bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr optimizer_kernel, - const std::shared_ptr memory_register) { +bool ParameterAggregator::GenerateOptimizerKernelParams( + const std::shared_ptr &optimizer_kernel, + const std::shared_ptr &memory_register) { MS_ERROR_IF_NULL_W_RET_VAL(optimizer_kernel, false); MS_ERROR_IF_NULL_W_RET_VAL(memory_register, false); KernelParams optimizer_params = {}; @@ -335,12 +343,12 @@ std::vector ParameterAggregator::SelectAggregationAlgorithm(const C template bool ParameterAggregator::AssignMemory(std::shared_ptr server_kernel, const CNodePtr &cnode, const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info, - std::shared_ptr memory_register); + const std::shared_ptr &memory_register); template bool ParameterAggregator::AssignMemory(std::shared_ptr server_kernel, const CNodePtr &cnode, const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info, - std::shared_ptr memory_register); + const std::shared_ptr &memory_register); } // namespace server } // namespace fl } // namespace mindspore diff --git a/mindspore/ccsrc/fl/server/parameter_aggregator.h b/mindspore/ccsrc/fl/server/parameter_aggregator.h index 0121f0d97bc..f7f02f7ea07 100644 --- a/mindspore/ccsrc/fl/server/parameter_aggregator.h +++ b/mindspore/ccsrc/fl/server/parameter_aggregator.h @@ -105,14 +105,14 @@ class ParameterAggregator { // momentum, etc. template bool AssignMemory(K server_kernel, const CNodePtr &cnode, const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info, - std::shared_ptr memory_register); + const std::shared_ptr &memory_register); // Generate kernel parameters for aggregation/optimizer kernels. All the parameters is registered and stored in // memory_register. - bool GenerateAggregationKernelParams(const std::shared_ptr aggr_kernel, - const std::shared_ptr memory_register); - bool GenerateOptimizerKernelParams(const std::shared_ptr optim_kernel, - const std::shared_ptr memory_register); + bool GenerateAggregationKernelParams(const std::shared_ptr &aggr_kernel, + const std::shared_ptr &memory_register); + bool GenerateOptimizerKernelParams(const std::shared_ptr &optim_kernel, + const std::shared_ptr &memory_register); // The selection of the aggregation algorithm depends on multiple factors. For example, server mode, user // configuration, etc. diff --git a/mindspore/ccsrc/fl/server/round.cc b/mindspore/ccsrc/fl/server/round.cc index 3ece5ed728d..2805d27a880 100644 --- a/mindspore/ccsrc/fl/server/round.cc +++ b/mindspore/ccsrc/fl/server/round.cc @@ -40,8 +40,10 @@ void Round::Initialize(const std::shared_ptr &commun communicator_ = communicator; // Register callback for round kernel. - communicator_->RegisterMsgCallBack( - name_, [&](std::shared_ptr message) { LaunchRoundKernel(message); }); + communicator_->RegisterMsgCallBack(name_, [&](std::shared_ptr message) { + MS_ERROR_IF_NULL_WO_RET_VAL(message); + LaunchRoundKernel(message); + }); // Callback when the iteration is finished. finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid, const std::string &) -> void { @@ -50,10 +52,14 @@ void Round::Initialize(const std::shared_ptr &commun }; // Callback for finalizing the server. This can only be called once. - finalize_cb_ = [&](void) -> void { (void)communicator_->Stop(); }; + finalize_cb_ = [&](void) -> void { + MS_ERROR_IF_NULL_WO_RET_VAL(communicator_); + (void)communicator_->Stop(); + }; if (check_timeout_) { iter_timer_ = std::make_shared(); + MS_EXCEPTION_IF_NULL(iter_timer_); // 1.Set the timeout callback for the timer. iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool is_iteration_valid, const std::string &) -> void { @@ -63,6 +69,7 @@ void Round::Initialize(const std::shared_ptr &commun // 2.Stopping timer callback which will be set to the round kernel. stop_timer_cb_ = [&](void) -> void { + MS_ERROR_IF_NULL_WO_RET_VAL(iter_timer_); MS_LOG(INFO) << "Round " << name_ << " kernel stops its timer."; iter_timer_->Stop(); }; @@ -90,10 +97,7 @@ bool Round::ReInitForScaling(uint32_t server_num) { {first_count_handler, last_count_handler}); } - if (kernel_ == nullptr) { - MS_LOG(WARNING) << "Reinitializing for round " << name_ << " failed: round kernel is nullptr."; - return false; - } + MS_ERROR_IF_NULL_W_RET_VAL(kernel_, false); kernel_->InitKernel(threshold_count_); return true; } @@ -108,6 +112,8 @@ void Round::BindRoundKernel(const std::shared_ptr &kernel) void Round::LaunchRoundKernel(const std::shared_ptr &message) { MS_ERROR_IF_NULL_WO_RET_VAL(message); + MS_ERROR_IF_NULL_WO_RET_VAL(kernel_); + MS_ERROR_IF_NULL_WO_RET_VAL(communicator_); // If the server is still in the process of scaling, refuse the request. if (Server::GetInstance().IsSafeMode()) { MS_LOG(WARNING) << "The cluster is still in process of scaling, please retry " << name_ << " later."; @@ -149,7 +155,10 @@ void Round::LaunchRoundKernel(const std::shared_ptr &m return; } -void Round::Reset() { (void)kernel_->Reset(); } +void Round::Reset() { + MS_ERROR_IF_NULL_WO_RET_VAL(kernel_); + (void)kernel_->Reset(); +} const std::string &Round::name() const { return name_; } @@ -160,6 +169,9 @@ bool Round::check_timeout() const { return check_timeout_; } size_t Round::time_window() const { return time_window_; } void Round::OnFirstCountEvent(const std::shared_ptr &message) { + MS_ERROR_IF_NULL_WO_RET_VAL(message); + MS_ERROR_IF_NULL_WO_RET_VAL(kernel_); + MS_ERROR_IF_NULL_WO_RET_VAL(iter_timer_); MS_LOG(INFO) << "Round " << name_ << " first count event is triggered."; // The timer starts only after the first count event is triggered by DistributedCountService. if (check_timeout_) { @@ -172,6 +184,9 @@ void Round::OnFirstCountEvent(const std::shared_ptr &m } void Round::OnLastCountEvent(const std::shared_ptr &message) { + MS_ERROR_IF_NULL_WO_RET_VAL(message); + MS_ERROR_IF_NULL_WO_RET_VAL(kernel_); + MS_ERROR_IF_NULL_WO_RET_VAL(iter_timer_); MS_LOG(INFO) << "Round " << name_ << " last count event is triggered."; // Same as the first count event, the timer must be stopped by DistributedCountService. if (check_timeout_) { diff --git a/mindspore/ccsrc/fl/server/server.cc b/mindspore/ccsrc/fl/server/server.cc index 4d5511ef780..69ad3fe52f1 100644 --- a/mindspore/ccsrc/fl/server/server.cc +++ b/mindspore/ccsrc/fl/server/server.cc @@ -86,7 +86,11 @@ void Server::Run() { // Wait communicators to stop so the main thread is blocked. (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), - [](const std::shared_ptr &communicator) { communicator->Join(); }); + [](const std::shared_ptr &communicator) { + MS_EXCEPTION_IF_NULL(communicator); + communicator->Join(); + }); + MS_EXCEPTION_IF_NULL(communicator_with_server_); communicator_with_server_->Join(); MsException::Instance().CheckException(); return; @@ -151,6 +155,7 @@ bool Server::InitCommunicatorWithWorker() { return false; } if (use_tcp_) { + MS_EXCEPTION_IF_NULL(communicator_with_server_); auto tcp_comm = communicator_with_server_; MS_EXCEPTION_IF_NULL(tcp_comm); communicators_with_worker_.push_back(tcp_comm); @@ -201,21 +206,32 @@ void Server::InitIteration() { std::shared_ptr exchange_keys_round = std::make_shared("exchangeKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_); + MS_EXCEPTION_IF_NULL(exchange_keys_round); iteration_->AddRound(exchange_keys_round); + std::shared_ptr get_keys_round = std::make_shared("getKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_); + MS_EXCEPTION_IF_NULL(get_keys_round); iteration_->AddRound(get_keys_round); + std::shared_ptr share_secrets_round = std::make_shared("shareSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_); + MS_EXCEPTION_IF_NULL(share_secrets_round); iteration_->AddRound(share_secrets_round); + std::shared_ptr get_secrets_round = std::make_shared("getSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_); + MS_EXCEPTION_IF_NULL(get_secrets_round); iteration_->AddRound(get_secrets_round); + std::shared_ptr get_clientlist_round = std::make_shared("getClientList", true, cipher_time_window_, true, cipher_get_clientlist_cnt_); + MS_EXCEPTION_IF_NULL(get_clientlist_round); iteration_->AddRound(get_clientlist_round); + std::shared_ptr reconstruct_secrets_round = std::make_shared( "reconstructSecrets", true, cipher_time_window_, true, cipher_reconstruct_secrets_up_cnt_); + MS_EXCEPTION_IF_NULL(reconstruct_secrets_round); iteration_->AddRound(reconstruct_secrets_round); MS_LOG(INFO) << "Cipher rounds has been added."; } @@ -253,8 +269,16 @@ void Server::InitCipher() { mindspore::armour::CipherPublicPara param; param.g = cipher_g; param.t = cipher_t; - memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, SECRET_MAX_LEN); - memcpy_s(param.prime, PRIME_MAX_LEN, cipher_prime, PRIME_MAX_LEN); + int ret = memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, SECRET_MAX_LEN); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + return; + } + ret = memcpy_s(param.prime, PRIME_MAX_LEN, cipher_prime, PRIME_MAX_LEN); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + return; + } param.dp_delta = dp_delta; param.dp_eps = dp_eps; param.dp_norm_clip = dp_norm_clip; @@ -268,6 +292,8 @@ void Server::InitCipher() { void Server::RegisterCommCallbacks() { // The message callbacks of round kernels are already set in method InitIteration, so here we don't need to register // rounds' callbacks. + MS_EXCEPTION_IF_NULL(server_node_); + MS_EXCEPTION_IF_NULL(iteration_); auto tcp_comm = std::dynamic_pointer_cast(communicator_with_server_); MS_EXCEPTION_IF_NULL(tcp_comm); @@ -302,9 +328,13 @@ void Server::RegisterExceptionEventCallback(const std::shared_ptrRegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() { MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed."; safemode_ = true; - (void)std::for_each( - communicators_with_worker_.begin(), communicators_with_worker_.end(), - [](const std::shared_ptr &communicator) { (void)communicator->Stop(); }); + (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), + [](const std::shared_ptr &communicator) { + MS_ERROR_IF_NULL_WO_RET_VAL(communicator); + (void)communicator->Stop(); + }); + + MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_); (void)communicator_with_server_->Stop(); }); @@ -313,14 +343,19 @@ void Server::RegisterExceptionEventCallback(const std::shared_ptr &communicator) { (void)communicator->Stop(); }); + (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), + [](const std::shared_ptr &communicator) { + MS_ERROR_IF_NULL_WO_RET_VAL(communicator); + (void)communicator->Stop(); + }); + + MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_); (void)communicator_with_server_->Stop(); }); } void Server::InitExecutor() { + MS_EXCEPTION_IF_NULL(func_graph_); if (executor_threshold_ == 0) { MS_LOG(EXCEPTION) << "The executor's threshold should greater than 0."; return; @@ -342,6 +377,7 @@ void Server::RegisterRoundKernel() { } for (auto &round : rounds) { + MS_EXCEPTION_IF_NULL(round); const std::string &name = round->name(); std::shared_ptr round_kernel = kernel::RoundKernelFactory::GetInstance().Create(name); if (round_kernel == nullptr) { @@ -357,12 +393,13 @@ void Server::RegisterRoundKernel() { } void Server::StartCommunicator() { - MS_EXCEPTION_IF_NULL(communicator_with_server_); if (communicators_with_worker_.empty()) { MS_LOG(EXCEPTION) << "Communicators for communication with worker is empty."; return; } + MS_EXCEPTION_IF_NULL(server_node_); + MS_EXCEPTION_IF_NULL(communicator_with_server_); MS_LOG(INFO) << "Start communicator with server."; if (!communicator_with_server_->Start()) { MS_LOG(EXCEPTION) << "Starting communicator with server failed."; @@ -376,6 +413,7 @@ void Server::StartCommunicator() { MS_LOG(INFO) << "Start communicator with worker."; (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), [](const std::shared_ptr &communicator) { + MS_ERROR_IF_NULL_WO_RET_VAL(communicator); if (!communicator->Start()) { MS_LOG(EXCEPTION) << "Starting communicator with worker failed."; } @@ -383,11 +421,13 @@ void Server::StartCommunicator() { } void Server::ProcessBeforeScalingOut() { + MS_ERROR_IF_NULL_WO_RET_VAL(iteration_); iteration_->ScalingBarrier(); safemode_ = true; } void Server::ProcessBeforeScalingIn() { + MS_ERROR_IF_NULL_WO_RET_VAL(iteration_); iteration_->ScalingBarrier(); safemode_ = true; } @@ -419,9 +459,13 @@ void Server::ProcessAfterScalingIn() { MS_ERROR_IF_NULL_WO_RET_VAL(server_node_); if (server_node_->rank_id() == UINT32_MAX) { MS_LOG(WARNING) << "This server the one to be scaled in. Server exiting."; - (void)std::for_each( - communicators_with_worker_.begin(), communicators_with_worker_.end(), - [](const std::shared_ptr &communicator) { (void)communicator->Stop(); }); + (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), + [](const std::shared_ptr &communicator) { + MS_ERROR_IF_NULL_WO_RET_VAL(communicator); + (void)communicator->Stop(); + }); + + MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_); (void)communicator_with_server_->Stop(); return; } diff --git a/mindspore/ccsrc/fl/worker/fl_worker.cc b/mindspore/ccsrc/fl/worker/fl_worker.cc index 77da81ed564..a004ba74042 100644 --- a/mindspore/ccsrc/fl/worker/fl_worker.cc +++ b/mindspore/ccsrc/fl/worker/fl_worker.cc @@ -91,6 +91,7 @@ void FLWorker::Finalize() { bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size, ps::core::TcpUserCommand command, std::shared_ptr> *output) { + MS_EXCEPTION_IF_NULL(data); // If the worker is in safemode, do not communicate with server. while (safemode_.load()) { std::this_thread::yield(); @@ -169,6 +170,7 @@ std::string FLWorker::fl_name() const { return ps::kServerModeFL; } std::string FLWorker::fl_id() const { return std::to_string(rank_id_); } void FLWorker::InitializeFollowerScaler() { + MS_EXCEPTION_IF_NULL(worker_node_); if (!worker_node_->InitFollowerScaler()) { MS_LOG(EXCEPTION) << "Initializing follower elastic scaler failed."; return; @@ -225,10 +227,7 @@ void FLWorker::ProcessBeforeScalingIn() { } void FLWorker::ProcessAfterScalingOut() { - if (worker_node_ == nullptr) { - return; - } - + MS_ERROR_IF_NULL_WO_RET_VAL(worker_node_); MS_LOG(INFO) << "Cluster scaling out completed. Reinitialize for worker."; server_num_ = IntToUint(worker_node_->server_num()); worker_num_ = IntToUint(worker_node_->worker_num()); @@ -239,10 +238,7 @@ void FLWorker::ProcessAfterScalingOut() { } void FLWorker::ProcessAfterScalingIn() { - if (worker_node_ == nullptr) { - return; - } - + MS_ERROR_IF_NULL_WO_RET_VAL(worker_node_); MS_LOG(INFO) << "Cluster scaling in completed. Reinitialize for worker."; server_num_ = IntToUint(worker_node_->server_num()); worker_num_ = IntToUint(worker_node_->worker_num());