diff --git a/mindspore/ccsrc/fl/server/common.h b/mindspore/ccsrc/fl/server/common.h index a53e6697203..775e624d41b 100644 --- a/mindspore/ccsrc/fl/server/common.h +++ b/mindspore/ccsrc/fl/server/common.h @@ -234,6 +234,32 @@ inline AddressPtr GenerateParameterNodeAddrPtr(const CNodePtr &kernel_node, size // Definitions for Federated Learning. +constexpr auto kNetworkError = "Cluster networking failed."; + +// The result code used for round kernels. +enum class ResultCode { + // If the method is successfully called and round kernel's residual methods should be called, return kSuccess. + kSuccess = 0, + // If there's error happened in the method and residual methods should not be called but this iteration continues, + // return kSuccessAndReturn so that framework won't drop this iteration. + kSuccessAndReturn, + // If there's error happened and this iteration should be dropped, return kFail. + kFail +}; + +bool inline ConvertResultCode(ResultCode result_code) { + switch (result_code) { + case ResultCode::kSuccess: + return true; + case ResultCode::kSuccessAndReturn: + return true; + case ResultCode::kFail: + return false; + default: + return true; + } +} + // Definitions for Parameter Server. } // namespace server diff --git a/mindspore/ccsrc/fl/server/distributed_count_service.cc b/mindspore/ccsrc/fl/server/distributed_count_service.cc index 7a6ee018da2..02213123bf9 100644 --- a/mindspore/ccsrc/fl/server/distributed_count_service.cc +++ b/mindspore/ccsrc/fl/server/distributed_count_service.cc @@ -66,7 +66,7 @@ void DistributedCountService::RegisterCounter(const std::string &name, size_t gl return; } -bool DistributedCountService::Count(const std::string &name, const std::string &id) { +bool DistributedCountService::Count(const std::string &name, const std::string &id, std::string *reason) { MS_LOG(INFO) << "Rank " << local_rank_ << " reports count for " << name << " of " << id; if (local_rank_ == counting_server_rank_) { if (global_threshold_count_.count(name) == 0) { @@ -83,7 +83,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string & MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id; global_current_count_[name].insert(id); - if (!TriggerCounterEvent(name)) { + if (!TriggerCounterEvent(name, reason)) { MS_LOG(ERROR) << "Leader server trigger count event failed."; return false; } @@ -97,6 +97,9 @@ bool DistributedCountService::Count(const std::string &name, const std::string & if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, ps::core::TcpUserCommand::kCount, &report_cnt_rsp_msg)) { MS_LOG(ERROR) << "Sending reporting count message to leader server failed for " << name; + if (reason != nullptr) { + *reason = kNetworkError; + } return false; } @@ -104,6 +107,9 @@ bool DistributedCountService::Count(const std::string &name, const std::string & count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size())); if (!count_rsp.result()) { MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason(); + if (reason != nullptr && count_rsp.reason().find(kNetworkError) != std::string::npos) { + *reason = kNetworkError; + } return false; } } @@ -202,13 +208,13 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptrSendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message); return; @@ -266,24 +272,24 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptrSendPbRequest(first_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) { MS_LOG(ERROR) << "Activating first count event to server " << i << " failed."; + if (reason != nullptr) { + *reason = "Send to rank " + std::to_string(i) + " failed. " + kNetworkError; + } return false; } } @@ -301,7 +310,7 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name) { return true; } -bool DistributedCountService::TriggerLastCountEvent(const std::string &name) { +bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std::string *reason) { MS_LOG(INFO) << "Activating last count event for " << name; CounterEvent last_count_event; last_count_event.set_type(CounterEventType::LAST_CNT); @@ -311,6 +320,9 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name) { for (uint32_t i = 1; i < server_num_; i++) { if (!communicator_->SendPbRequest(last_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) { MS_LOG(ERROR) << "Activating last count event to server " << i << " failed."; + if (reason != nullptr) { + *reason = "Send to rank " + std::to_string(i) + " failed. " + kNetworkError; + } return false; } } diff --git a/mindspore/ccsrc/fl/server/distributed_count_service.h b/mindspore/ccsrc/fl/server/distributed_count_service.h index 2e37bbd9b8f..cdb137c4958 100644 --- a/mindspore/ccsrc/fl/server/distributed_count_service.h +++ b/mindspore/ccsrc/fl/server/distributed_count_service.h @@ -63,8 +63,9 @@ class DistributedCountService { // first/last count event callbacks. void RegisterCounter(const std::string &name, size_t global_threshold_count, const CounterHandlers &counter_handlers); - // Report a count to the counting server. Parameter 'id' is in case of repeated counting. - bool Count(const std::string &name, const std::string &id); + // Report a count to the counting server. Parameter 'id' is in case of repeated counting. Parameter 'reason' is the + // reason why counting failed. + bool Count(const std::string &name, const std::string &id, std::string *reason = nullptr); // Query whether the count reaches the threshold count for the name. If the count is the same as the threshold count, // this method returns true. @@ -98,9 +99,9 @@ class DistributedCountService { void HandleCounterEvent(const std::shared_ptr &message); // Call the callbacks when the first/last count event is triggered. - bool TriggerCounterEvent(const std::string &name); - bool TriggerFirstCountEvent(const std::string &name); - bool TriggerLastCountEvent(const std::string &name); + bool TriggerCounterEvent(const std::string &name, std::string *reason = nullptr); + bool TriggerFirstCountEvent(const std::string &name, std::string *reason = nullptr); + bool TriggerLastCountEvent(const std::string &name, std::string *reason = nullptr); // Members for the communication between counting server and other servers. std::shared_ptr server_node_; diff --git a/mindspore/ccsrc/fl/server/distributed_metadata_store.cc b/mindspore/ccsrc/fl/server/distributed_metadata_store.cc index b15af8ec604..0e5f166feb8 100644 --- a/mindspore/ccsrc/fl/server/distributed_metadata_store.cc +++ b/mindspore/ccsrc/fl/server/distributed_metadata_store.cc @@ -82,7 +82,7 @@ void DistributedMetadataStore::ResetMetadata(const std::string &name) { return; } -bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta) { +bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta, std::string *reason) { if (router_ == nullptr) { MS_LOG(ERROR) << "The consistent hash ring is not initialized yet."; return false; @@ -103,6 +103,9 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, ps::core::TcpUserCommand::kUpdateMetadata, &update_meta_rsp_msg)) { MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed."; + if (reason != nullptr) { + *reason = "Send to rank " + std::to_string(stored_rank) + " failed. " + kNetworkError; + } return false; } diff --git a/mindspore/ccsrc/fl/server/distributed_metadata_store.h b/mindspore/ccsrc/fl/server/distributed_metadata_store.h index bf13c13ea19..d74b2160a3c 100644 --- a/mindspore/ccsrc/fl/server/distributed_metadata_store.h +++ b/mindspore/ccsrc/fl/server/distributed_metadata_store.h @@ -55,8 +55,8 @@ class DistributedMetadataStore { // Reset the metadata value for the name. void ResetMetadata(const std::string &name); - // Update the metadata for the name. - bool UpdateMetadata(const std::string &name, const PBMetadata &meta); + // Update the metadata for the name. Parameter 'reason' is the reason why updating meta data failed. + bool UpdateMetadata(const std::string &name, const PBMetadata &meta, std::string *reason = nullptr); // Get the metadata for the name. PBMetadata GetMetadata(const std::string &name); diff --git a/mindspore/ccsrc/fl/server/iteration.cc b/mindspore/ccsrc/fl/server/iteration.cc index ad394ed723f..4097c22e14d 100644 --- a/mindspore/ccsrc/fl/server/iteration.cc +++ b/mindspore/ccsrc/fl/server/iteration.cc @@ -390,7 +390,7 @@ bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) { end_last_iter_req.set_last_iter_num(last_iter_num); for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) { if (!communicator_->SendPbRequest(end_last_iter_req, i, ps::core::TcpUserCommand::kEndLastIter)) { - MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed."; + MS_LOG(WARNING) << "Sending ending last iteration request to server " << i << " failed."; continue; } } @@ -438,10 +438,10 @@ void Iteration::EndLastIter() { ModelStore::GetInstance().Reset(); } - Server::GetInstance().CancelSafeMode(); - SetIterationCompleted(); pinned_iter_num_ = 0; LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); + Server::GetInstance().CancelSafeMode(); + SetIterationCompleted(); MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n"; } } // namespace server diff --git a/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc index 4aed671d753..a93335f6862 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc @@ -48,9 +48,9 @@ bool PushWeightKernel::Launch(const std::vector &inputs, const std:: return false; } - bool ret = PushWeight(fbb, push_weight_req); + ResultCode result_code = PushWeight(fbb, push_weight_req); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); - return ret; + return ConvertResultCode(result_code); } bool PushWeightKernel::Reset() { @@ -67,9 +67,10 @@ void PushWeightKernel::OnLastCountEvent(const std::shared_ptr fbb, const schema::RequestPushWeight *push_weight_req) { +ResultCode PushWeightKernel::PushWeight(std::shared_ptr fbb, + const schema::RequestPushWeight *push_weight_req) { if (fbb == nullptr || push_weight_req == nullptr) { - return false; + return ResultCode::kSuccessAndReturn; } size_t iteration = static_cast(push_weight_req->iteration()); size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num(); @@ -78,7 +79,7 @@ bool PushWeightKernel::PushWeight(std::shared_ptr fbb, const schema:: ", current iteration:" + std::to_string(current_iter); BuildPushWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter); MS_LOG(WARNING) << reason; - return true; + return ResultCode::kSuccessAndReturn; } std::map upload_feature_map = ParseFeatureMap(push_weight_req); @@ -86,25 +87,26 @@ bool PushWeightKernel::PushWeight(std::shared_ptr fbb, const schema:: std::string reason = "PushWeight feature_map is empty."; BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter); MS_LOG(ERROR) << reason; - return false; + return ResultCode::kSuccessAndReturn; } if (!executor_->HandlePushWeight(upload_feature_map)) { std::string reason = "Pushing weight failed."; - BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter); + BuildPushWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter); MS_LOG(ERROR) << reason; - return false; + return ResultCode::kSuccessAndReturn; } MS_LOG(INFO) << "Pushing weight for iteration " << current_iter << " succeeds."; - if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_))) { + std::string count_reason = ""; + if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_), &count_reason)) { std::string reason = "Count for push weight request failed."; BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter); MS_LOG(ERROR) << reason; - return false; + return count_reason == kNetworkError ? ResultCode::kSuccessAndReturn : ResultCode::kFail; } BuildPushWeightRsp(fbb, schema::ResponseCode_SUCCEED, "PushWeight succeed.", current_iter); - return true; + return ResultCode::kSuccess; } std::map PushWeightKernel::ParseFeatureMap(const schema::RequestPushWeight *push_weight_req) { diff --git a/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.h index 17d26e99d56..7b09d3d8601 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.h @@ -42,7 +42,7 @@ class PushWeightKernel : public RoundKernel { void OnLastCountEvent(const std::shared_ptr &message) override; private: - bool PushWeight(std::shared_ptr fbb, const schema::RequestPushWeight *push_weight_req); + ResultCode PushWeight(std::shared_ptr fbb, const schema::RequestPushWeight *push_weight_req); std::map ParseFeatureMap(const schema::RequestPushWeight *push_weight_req); void BuildPushWeightRsp(std::shared_ptr fbb, const schema::ResponseCode retcode, const std::string &reason, size_t iteration); diff --git a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc index 55d78471a80..324de765abc 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc @@ -72,33 +72,37 @@ bool StartFLJobKernel::Launch(const std::vector &inputs, const std:: return true; } - if (ReachThresholdForStartFLJob(fbb)) { + ResultCode result_code = ReachThresholdForStartFLJob(fbb); + if (result_code != ResultCode::kSuccess) { GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); - return true; + return ConvertResultCode(result_code); } const schema::RequestFLJob *start_fl_job_req = flatbuffers::GetRoot(req_data); DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req); - if (!ReadyForStartFLJob(fbb, device_meta)) { + result_code = ReadyForStartFLJob(fbb, device_meta); + if (result_code != ResultCode::kSuccess) { GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); - return false; + return ConvertResultCode(result_code); } PBMetadata metadata; *metadata.mutable_device_meta() = device_meta; - if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata)) { - std::string reason = "Updating device metadata failed."; + std::string update_reason = ""; + if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata, &update_reason)) { + std::string reason = "Updating device metadata failed. " + update_reason; BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false, std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp)), {}); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); - return true; + return update_reason == kNetworkError ? true : false; } StartFLJob(fbb, device_meta); // If calling ReportCount before ReadyForStartFLJob, the result will be inconsistent if the device is not selected. - if (!CountForStartFLJob(fbb, start_fl_job_req)) { + result_code = CountForStartFLJob(fbb, start_fl_job_req); + if (result_code != ResultCode::kSuccess) { GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); - return false; + return ConvertResultCode(result_code); } GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); @@ -120,16 +124,16 @@ void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr &fbb) { +ResultCode StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr &fbb) { if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { std::string reason = "Current amount for startFLJob has reached the threshold. Please startFLJob later."; BuildStartFLJobRsp( fbb, schema::ResponseCode_OutOfTime, reason, false, std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp))); MS_LOG(WARNING) << reason; - return true; + return ResultCode::kSuccessAndReturn; } - return false; + return ResultCode::kSuccess; } DeviceMeta StartFLJobKernel::CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req) { @@ -146,34 +150,35 @@ DeviceMeta StartFLJobKernel::CreateDeviceMetadata(const schema::RequestFLJob *st return device_meta; } -bool StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr &fbb, const DeviceMeta &device_meta) { - bool ret = true; +ResultCode StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr &fbb, const DeviceMeta &device_meta) { + ResultCode ret = ResultCode::kSuccess; std::string reason = ""; if (device_meta.data_size() < 1) { reason = "FL job data size is not enough."; - ret = false; + ret = ResultCode::kSuccessAndReturn; } - if (!ret) { + if (ret != ResultCode::kSuccess) { BuildStartFLJobRsp( - fbb, schema::ResponseCode_RequestError, reason, false, + fbb, schema::ResponseCode_OutOfTime, reason, false, std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp))); - MS_LOG(ERROR) << reason; + MS_LOG(WARNING) << reason; } return ret; } -bool StartFLJobKernel::CountForStartFLJob(const std::shared_ptr &fbb, - const schema::RequestFLJob *start_fl_job_req) { - RETURN_IF_NULL(start_fl_job_req, false); - if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str())) { +ResultCode StartFLJobKernel::CountForStartFLJob(const std::shared_ptr &fbb, + const schema::RequestFLJob *start_fl_job_req) { + RETURN_IF_NULL(start_fl_job_req, ResultCode::kSuccessAndReturn); + std::string count_reason = ""; + if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str(), &count_reason)) { std::string reason = "Counting start fl job request failed. Please retry later."; BuildStartFLJobRsp( fbb, schema::ResponseCode_OutOfTime, reason, false, std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp))); MS_LOG(ERROR) << reason; - return false; + return count_reason == kNetworkError ? ResultCode::kSuccessAndReturn : ResultCode::kFail; } - return true; + return ResultCode::kSuccess; } void StartFLJobKernel::StartFLJob(const std::shared_ptr &fbb, const DeviceMeta &device_meta) { diff --git a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.h index 6af55537d2d..c6f5cd0ba08 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.h @@ -44,17 +44,17 @@ class StartFLJobKernel : public RoundKernel { private: // Returns whether the startFLJob count of this iteration has reached the threshold. - bool ReachThresholdForStartFLJob(const std::shared_ptr &fbb); + ResultCode ReachThresholdForStartFLJob(const std::shared_ptr &fbb); // The metadata of device will be stored and queried in updateModel round. DeviceMeta CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req); // Returns whether the request is valid for startFLJob.For now, the condition is simple. We will add more conditions // to device in later versions. - bool ReadyForStartFLJob(const std::shared_ptr &fbb, const DeviceMeta &device_meta); + ResultCode ReadyForStartFLJob(const std::shared_ptr &fbb, const DeviceMeta &device_meta); // Distributed count service counts for startFLJob. - bool CountForStartFLJob(const std::shared_ptr &fbb, const schema::RequestFLJob *start_fl_job_req); + ResultCode CountForStartFLJob(const std::shared_ptr &fbb, const schema::RequestFLJob *start_fl_job_req); void StartFLJob(const std::shared_ptr &fbb, const DeviceMeta &device_meta); diff --git a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc index af2aacbe04f..f5698d17471 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc @@ -56,21 +56,24 @@ bool UpdateModelKernel::Launch(const std::vector &inputs, const std: } MS_LOG(INFO) << "Launching UpdateModelKernel kernel."; - if (ReachThresholdForUpdateModel(fbb)) { + ResultCode result_code = ReachThresholdForUpdateModel(fbb); + if (result_code != ResultCode::kSuccess) { GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); - return true; + return ConvertResultCode(result_code); } const schema::RequestUpdateModel *update_model_req = flatbuffers::GetRoot(req_data); - if (!UpdateModel(update_model_req, fbb)) { + result_code = UpdateModel(update_model_req, fbb); + if (result_code != ResultCode::kSuccess) { MS_LOG(ERROR) << "Updating model failed."; GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); - return false; + return ConvertResultCode(result_code); } - if (!CountForUpdateModel(fbb, update_model_req)) { + result_code = CountForUpdateModel(fbb, update_model_req); + if (result_code != ResultCode::kSuccess) { GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); - return false; + return ConvertResultCode(result_code); } GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); return true; @@ -102,21 +105,21 @@ void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr &fbb) { +ResultCode UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr &fbb) { if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { std::string reason = "Current amount for updateModel is enough. Please retry later."; BuildUpdateModelRsp( fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp))); MS_LOG(WARNING) << reason; - return true; + return ResultCode::kSuccessAndReturn; } - return false; + return ResultCode::kSuccess; } -bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req, - const std::shared_ptr &fbb) { - RETURN_IF_NULL(update_model_req, false); +ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req, + const std::shared_ptr &fbb) { + RETURN_IF_NULL(update_model_req, ResultCode::kSuccessAndReturn); size_t iteration = static_cast(update_model_req->iteration()); if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) { std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) + @@ -126,7 +129,7 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp))); MS_LOG(WARNING) << reason; - return true; + return ResultCode::kSuccessAndReturn; } PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas); @@ -139,7 +142,7 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp))); MS_LOG(ERROR) << reason; - return false; + return ResultCode::kSuccessAndReturn; } size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size(); @@ -150,7 +153,7 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod fbb, schema::ResponseCode_RequestError, reason, std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp))); MS_LOG(ERROR) << reason; - return false; + return ResultCode::kSuccessAndReturn; } for (auto weight : feature_map) { @@ -163,18 +166,19 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod fl_id.set_fl_id(update_model_fl_id); PBMetadata comm_value; *comm_value.mutable_fl_id() = fl_id; - if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value)) { - std::string reason = "Updating metadata of UpdateModelClientList failed."; + std::string update_reason = ""; + if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value, &update_reason)) { + std::string reason = "Updating metadata of UpdateModelClientList failed. " + update_reason; BuildUpdateModelRsp( fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp))); MS_LOG(ERROR) << reason; - return false; + return update_reason == kNetworkError ? ResultCode::kSuccessAndReturn : ResultCode::kFail; } BuildUpdateModelRsp(fbb, schema::ResponseCode_SUCCEED, "success not ready", std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp))); - return true; + return ResultCode::kSuccess; } std::map UpdateModelKernel::ParseFeatureMap( @@ -195,18 +199,19 @@ std::map UpdateModelKernel::ParseFeatureMap( return feature_map; } -bool UpdateModelKernel::CountForUpdateModel(const std::shared_ptr &fbb, - const schema::RequestUpdateModel *update_model_req) { - RETURN_IF_NULL(update_model_req, false); - if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str())) { - std::string reason = "Counting for update model request failed. Please retry later."; +ResultCode UpdateModelKernel::CountForUpdateModel(const std::shared_ptr &fbb, + const schema::RequestUpdateModel *update_model_req) { + RETURN_IF_NULL(update_model_req, ResultCode::kSuccessAndReturn); + std::string count_reason = ""; + if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str(), &count_reason)) { + std::string reason = "Counting for update model request failed. Please retry later. " + count_reason; BuildUpdateModelRsp( fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp))); MS_LOG(ERROR) << reason; - return false; + return count_reason == kNetworkError ? ResultCode::kSuccessAndReturn : ResultCode::kFail; } - return true; + return ResultCode::kSuccess; } void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, diff --git a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h index 258d70ebda7..341019a3736 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h @@ -47,10 +47,11 @@ class UpdateModelKernel : public RoundKernel { void OnLastCountEvent(const std::shared_ptr &message) override; private: - bool ReachThresholdForUpdateModel(const std::shared_ptr &fbb); - bool UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr &fbb); + ResultCode ReachThresholdForUpdateModel(const std::shared_ptr &fbb); + ResultCode UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr &fbb); std::map ParseFeatureMap(const schema::RequestUpdateModel *update_model_req); - bool CountForUpdateModel(const std::shared_ptr &fbb, const schema::RequestUpdateModel *update_model_req); + ResultCode CountForUpdateModel(const std::shared_ptr &fbb, + const schema::RequestUpdateModel *update_model_req); void BuildUpdateModelRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, const std::string &reason, const std::string &next_req_time); diff --git a/mindspore/ccsrc/fl/server/model_store.cc b/mindspore/ccsrc/fl/server/model_store.cc index b762fa728dd..9657dfe44be 100644 --- a/mindspore/ccsrc/fl/server/model_store.cc +++ b/mindspore/ccsrc/fl/server/model_store.cc @@ -36,6 +36,7 @@ void ModelStore::Initialize(uint32_t max_count) { } bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map &new_model) { + std::unique_lock lock(model_mtx_); if (iteration_to_model_.count(iteration) != 0) { MS_LOG(WARNING) << "Model for iteration " << iteration << " is already stored"; return false; @@ -88,6 +89,7 @@ bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map ModelStore::GetModelByIterNum(size_t iteration) { + std::unique_lock lock(model_mtx_); std::map model = {}; if (iteration_to_model_.count(iteration) == 0) { MS_LOG(ERROR) << "Model for iteration " << iteration << " is not stored."; @@ -98,13 +100,15 @@ std::map ModelStore::GetModelByIterNum(size_t iteration } void ModelStore::Reset() { + std::unique_lock lock(model_mtx_); initial_model_ = iteration_to_model_.rbegin()->second; iteration_to_model_.clear(); iteration_to_model_[kInitIterationNum] = initial_model_; iteration_to_model_[kResetInitIterNum] = initial_model_; } -const std::map> &ModelStore::iteration_to_model() const { +const std::map> &ModelStore::iteration_to_model() { + std::unique_lock lock(model_mtx_); return iteration_to_model_; } @@ -142,6 +146,7 @@ std::shared_ptr ModelStore::AssignNewModelMemory() { } size_t ModelStore::ComputeModelSize() { + std::unique_lock lock(model_mtx_); if (iteration_to_model_.empty()) { MS_LOG(EXCEPTION) << "Calculating model size failed: model for iteration 0 is not stored yet. "; return 0; diff --git a/mindspore/ccsrc/fl/server/model_store.h b/mindspore/ccsrc/fl/server/model_store.h index 97f26ada0e0..ea704bd3998 100644 --- a/mindspore/ccsrc/fl/server/model_store.h +++ b/mindspore/ccsrc/fl/server/model_store.h @@ -56,7 +56,7 @@ class ModelStore { void Reset(); // Returns all models stored in ModelStore. - const std::map> &iteration_to_model() const; + const std::map> &iteration_to_model(); // Returns the model size, which could be calculated at the initializing phase. size_t model_size() const; @@ -80,7 +80,8 @@ class ModelStore { // Initial model which is the model of iteration 0. std::shared_ptr initial_model_; - // The number of all models stpred is max_model_count_. + // The number of all models stored is max_model_count_. + std::mutex model_mtx_; std::map> iteration_to_model_; }; } // namespace server diff --git a/tests/st/fl/hybrid_lenet/run_server_disaster_recovery.py b/tests/st/fl/hybrid_lenet/run_server_disaster_recovery.py index b203dc89e97..5aecb8f49b8 100644 --- a/tests/st/fl/hybrid_lenet/run_server_disaster_recovery.py +++ b/tests/st/fl/hybrid_lenet/run_server_disaster_recovery.py @@ -48,7 +48,7 @@ parser.add_argument("--sts_properties_path", type=str, default="") parser.add_argument("--dp_eps", type=float, default=50.0) parser.add_argument("--dp_delta", type=float, default=0.01) # usually equals 1/start_fl_job_threshold parser.add_argument("--dp_norm_clip", type=float, default=1.0) -parser.add_argument("--encrypt_type", type=str, default="NotEncrypt") +parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") args, _ = parser.parse_known_args() diff --git a/tests/st/fl/mobile/run_server_disaster_recovery.py b/tests/st/fl/mobile/run_server_disaster_recovery.py index 8bc1c18d378..3f36004323f 100644 --- a/tests/st/fl/mobile/run_server_disaster_recovery.py +++ b/tests/st/fl/mobile/run_server_disaster_recovery.py @@ -47,7 +47,7 @@ parser.add_argument("--sts_properties_path", type=str, default="") parser.add_argument("--dp_eps", type=float, default=50.0) parser.add_argument("--dp_delta", type=float, default=0.01) # usually equals 1/start_fl_job_threshold parser.add_argument("--dp_norm_clip", type=float, default=1.0) -parser.add_argument("--encrypt_type", type=str, default="NotEncrypt") +parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)