diff --git a/mindspore/ccsrc/fl/server/distributed_metadata_store.cc b/mindspore/ccsrc/fl/server/distributed_metadata_store.cc index 93eb7fea0f9..f966fe0cedb 100644 --- a/mindspore/ccsrc/fl/server/distributed_metadata_store.cc +++ b/mindspore/ccsrc/fl/server/distributed_metadata_store.cc @@ -38,6 +38,9 @@ void DistributedMetadataStore::RegisterMessageCallback(const std::shared_ptrRegisterMsgCallBack( "getMetadata", std::bind(&DistributedMetadataStore::HandleGetMetadataRequest, this, std::placeholders::_1)); + communicator_->RegisterMsgCallBack( + "getOneDeviceMeta", + std::bind(&DistributedMetadataStore::HandleGetOneDeviceMetaRequest, this, std::placeholders::_1)); return; } @@ -148,6 +151,42 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) { } } +bool DistributedMetadataStore::GetOneDeviceMeta(const std::string &fl_id, DeviceMeta *device_meta) { + if (router_ == nullptr) { + MS_LOG(WARNING) << "The consistent hash ring is not initialized yet."; + return false; + } + const auto &name = kCtxDeviceMetas; + uint32_t stored_rank = router_->Find(name); + MS_LOG(DEBUG) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank; + if (local_rank_ == stored_rank) { + return DoGetOneDeviceMeta(fl_id, device_meta); + } else { + GetOneDeviceMetaRequest get_metadata_req; + get_metadata_req.set_fl_id(fl_id); + + std::shared_ptr> get_meta_rsp_msg = nullptr; + if (!communicator_->SendPbRequest(get_metadata_req, stored_rank, ps::core::TcpUserCommand::kGetOneDeviceMeta, + &get_meta_rsp_msg)) { + MS_LOG(WARNING) << "Sending getting one client metadata message to server " << stored_rank << " failed."; + return false; + } + MS_ERROR_IF_NULL_W_RET_VAL(get_meta_rsp_msg, false); + GetOneDeviceMetaResponse get_metadata_rsp; + auto ret = get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), SizeToInt(get_meta_rsp_msg->size())); + if (!ret) { + MS_LOG(WARNING) << "Parse response of getting one client metadata message to server " << stored_rank + << " failed."; + return false; + } + if (!get_metadata_rsp.found()) { + return false; + } + *device_meta = get_metadata_rsp.device_meta(); + return true; + } +} + bool DistributedMetadataStore::ReInitForScaling() { // If DistributedMetadataStore is not initialized yet but the scaling event is triggered, do not throw exception. if (server_node_ == nullptr) { @@ -221,6 +260,28 @@ void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr &message) { + MS_ERROR_IF_NULL_WO_RET_VAL(message); + GetOneDeviceMetaRequest get_one_metadata_req; + GetOneDeviceMetaResponse response; + auto ret = get_one_metadata_req.ParseFromArray(message->data(), SizeToInt(message->len())); + if (!ret) { + MS_LOG(WARNING) << "Parse GetOneDeviceMetaRequest failed."; + response.set_found(false); + } else { + const std::string &fl_id = get_one_metadata_req.fl_id(); + MS_LOG(DEBUG) << "Getting one client metadata for " << fl_id; + + auto found = DoGetOneDeviceMeta(fl_id, response.mutable_device_meta()); + response.set_found(found); + } + std::string getting_meta_rsp_msg = response.SerializeAsString(); + if (!communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message)) { + MS_LOG(WARNING) << "Sending response of GetOneDeviceMetaRequest failed."; + return; + } +} + bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) { std::unique_lock lock(mutex_[name]); if (metadata_.count(name) == 0) { @@ -312,6 +373,27 @@ bool DistributedMetadataStore::DoUpdateEncryptMetadata(const std::string &name, return true; } +bool DistributedMetadataStore::DoGetOneDeviceMeta(const std::string &fl_id, DeviceMeta *device_meta) { + if (device_meta == nullptr) { + return false; + } + const auto &name = kCtxDeviceMetas; + std::unique_lock lock(mutex_[name]); + auto meta_it = metadata_.find(name); + if (meta_it == metadata_.end()) { + MS_LOG(WARNING) << "The metadata of " << name << " is not registered."; + return false; + } + auto &stored_meta = meta_it->second; + const auto &fl_id_to_meta = stored_meta.device_metas().fl_id_to_meta(); + auto fl_it = fl_id_to_meta.find(fl_id); + if (fl_it == fl_id_to_meta.end()) { + return false; + } + *device_meta = fl_it->second; + return true; +} + bool DistributedMetadataStore::UpdatePairClientKeys(const std::string &name, const PBMetadata &meta) { auto &client_keys_map = *metadata_[name].mutable_client_keys()->mutable_client_keys(); auto &fl_id = meta.pair_client_keys().fl_id(); diff --git a/mindspore/ccsrc/fl/server/distributed_metadata_store.h b/mindspore/ccsrc/fl/server/distributed_metadata_store.h index 7b1265bbbb5..ddf6557f812 100644 --- a/mindspore/ccsrc/fl/server/distributed_metadata_store.h +++ b/mindspore/ccsrc/fl/server/distributed_metadata_store.h @@ -61,6 +61,8 @@ class DistributedMetadataStore { // Get the metadata for the name. PBMetadata GetMetadata(const std::string &name); + bool GetOneDeviceMeta(const std::string &fl_id, DeviceMeta *device_meta); + // Reinitialize the consistency hash ring and clear metadata after scaling operations are done. bool ReInitForScaling(); @@ -85,12 +87,17 @@ class DistributedMetadataStore { // Callback for getting metadata request sent to the server. void HandleGetMetadataRequest(const std::shared_ptr &message); + // Callback for getting metadata item request sent to the server. + void HandleGetOneDeviceMetaRequest(const std::shared_ptr &message); + // Do updating metadata in the server where the metadata for the name is stored. bool DoUpdateMetadata(const std::string &name, const PBMetadata &meta); // Do updating metadata about pairwise-encryption in the server where the metadata for the name is stored. bool DoUpdateEncryptMetadata(const std::string &name, const PBMetadata &meta); + bool DoGetOneDeviceMeta(const std::string &fl_id, DeviceMeta *device_meta); + // Update client keys stored in server bool UpdatePairClientKeys(const std::string &name, const PBMetadata &meta); diff --git a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc index 8e24c0203ec..0ca87a7bc8c 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc @@ -107,9 +107,8 @@ bool UpdateModelKernel::Launch(const std::vector &inputs, const std: } MS_LOG(INFO) << "verify signature passed!"; } - - PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas); - result_code = VerifyUpdateModel(update_model_req, fbb, device_metas); + DeviceMeta device_meta; + result_code = VerifyUpdateModel(update_model_req, fbb, &device_meta); if (result_code != ResultCode::kSuccess) { MS_LOG(WARNING) << "Updating model failed."; GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); @@ -122,7 +121,7 @@ bool UpdateModelKernel::Launch(const std::vector &inputs, const std: return ConvertResultCode(result_code); } - result_code = UpdateModel(update_model_req, fbb, device_metas); + result_code = UpdateModel(update_model_req, fbb, device_meta); if (result_code != ResultCode::kSuccess) { MS_LOG(WARNING) << "Updating model failed."; GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); @@ -178,8 +177,9 @@ ResultCode UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr } ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel *update_model_req, - const std::shared_ptr &fbb, const PBMetadata &device_metas) { + const std::shared_ptr &fbb, DeviceMeta *device_meta) { MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn); + MS_ERROR_IF_NULL_W_RET_VAL(device_meta, ResultCode::kSuccessAndReturn); size_t iteration = IntToSize(update_model_req->iteration()); if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) { auto next_req_time = LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp); @@ -191,19 +191,19 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel return ResultCode::kSuccessAndReturn; } - FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas(); std::string update_model_fl_id = update_model_req->fl_id()->str(); MS_LOG(DEBUG) << "UpdateModel for fl id " << update_model_fl_id; - if (ps::PSContext::instance()->encrypt_type() != ps::kPWEncryptType) { - if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) { - std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later."; - BuildUpdateModelRsp( - fbb, schema::ResponseCode_OutOfTime, reason, - std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp))); - MS_LOG(WARNING) << reason; - return ResultCode::kSuccessAndReturn; - } - } else { + + bool found = DistributedMetadataStore::GetInstance().GetOneDeviceMeta(update_model_fl_id, device_meta); + if (!found) { + std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later."; + BuildUpdateModelRsp( + fbb, schema::ResponseCode_OutOfTime, reason, + std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp))); + MS_LOG(WARNING) << reason; + return ResultCode::kSuccessAndReturn; + } + if (ps::PSContext::instance()->encrypt_type() == ps::kPWEncryptType) { std::vector get_secrets_clients; #ifdef ENABLE_ARMOUR mindspore::armour::CipherMetaStorage cipher_meta_storage; @@ -223,13 +223,12 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel } ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req, - const std::shared_ptr &fbb, const PBMetadata &device_metas) { + const std::shared_ptr &fbb, const DeviceMeta &device_meta) { MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn); MS_ERROR_IF_NULL_W_RET_VAL(update_model_req->fl_id(), ResultCode::kSuccessAndReturn); std::string update_model_fl_id = update_model_req->fl_id()->str(); - const auto &fl_id_to_meta = device_metas.device_metas().fl_id_to_meta(); - size_t data_size = fl_id_to_meta.at(update_model_fl_id).data_size(); + size_t data_size = device_meta.data_size(); const auto &feature_map = ParseFeatureMap(update_model_req); if (feature_map.empty()) { std::string reason = "Feature map is empty."; diff --git a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h index 42e64e869ae..d547ea21410 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h @@ -54,7 +54,7 @@ class UpdateModelKernel : public RoundKernel { private: ResultCode ReachThresholdForUpdateModel(const std::shared_ptr &fbb); ResultCode UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr &fbb, - const PBMetadata &device_metas); + const DeviceMeta &device_meta); std::map ParseFeatureMap(const schema::RequestUpdateModel *update_model_req); ResultCode CountForUpdateModel(const std::shared_ptr &fbb, const schema::RequestUpdateModel *update_model_req); @@ -62,7 +62,7 @@ class UpdateModelKernel : public RoundKernel { void BuildUpdateModelRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, const std::string &reason, const std::string &next_req_time); ResultCode VerifyUpdateModel(const schema::RequestUpdateModel *update_model_req, - const std::shared_ptr &fbb, const PBMetadata &device_metas); + const std::shared_ptr &fbb, DeviceMeta *device_meta); // The executor is for updating the model for updateModel request. Executor *executor_{nullptr}; diff --git a/mindspore/ccsrc/ps/core/communicator/communicator_base.h b/mindspore/ccsrc/ps/core/communicator/communicator_base.h index 41785ac2f74..adc2ba0ca08 100644 --- a/mindspore/ccsrc/ps/core/communicator/communicator_base.h +++ b/mindspore/ccsrc/ps/core/communicator/communicator_base.h @@ -59,7 +59,8 @@ enum class TcpUserCommand { kDisableFLS, kSyncAfterRecover, kExchangeKeys, - kGetKeys + kGetKeys, + kGetOneDeviceMeta, }; // CommunicatorBase is used to receive request and send response for server. diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h index df8a2f4de5e..355be6ff121 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h +++ b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h @@ -44,6 +44,7 @@ const std::unordered_map kUserCommandToMsgType = { {TcpUserCommand::kResetCount, "resetCnt"}, {TcpUserCommand::kGetMetadata, "getMetadata"}, {TcpUserCommand::kUpdateMetadata, "updateMetadata"}, + {TcpUserCommand::kGetOneDeviceMeta, "getOneDeviceMeta"}, {TcpUserCommand::kCounterEvent, "counterEvent"}, {TcpUserCommand::kPullWeight, "pullWeight"}, {TcpUserCommand::kPushWeight, "pushWeight"}, diff --git a/mindspore/ccsrc/ps/core/protos/fl.proto b/mindspore/ccsrc/ps/core/protos/fl.proto index 98338abd71d..42121e1d910 100644 --- a/mindspore/ccsrc/ps/core/protos/fl.proto +++ b/mindspore/ccsrc/ps/core/protos/fl.proto @@ -55,6 +55,15 @@ message GetMetadataResponse { bytes value = 1; } +message GetOneDeviceMetaRequest { + string fl_id = 1; +} + +message GetOneDeviceMetaResponse { + DeviceMeta device_meta = 1; + bool found = 2; +} + enum CounterEventType { FIRST_CNT = 0; LAST_CNT = 1;