!29921 FL, opt update model get client info

Merge pull request !29921 from 徐永飞/r1.6
This commit is contained in:
i-robot 2022-02-11 10:20:20 +00:00 committed by Gitee
commit c3893f52e3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 121 additions and 22 deletions

View File

@ -38,6 +38,9 @@ void DistributedMetadataStore::RegisterMessageCallback(const std::shared_ptr<ps:
"updateMetadata", std::bind(&DistributedMetadataStore::HandleUpdateMetadataRequest, this, std::placeholders::_1));
communicator_->RegisterMsgCallBack(
"getMetadata", std::bind(&DistributedMetadataStore::HandleGetMetadataRequest, this, std::placeholders::_1));
communicator_->RegisterMsgCallBack(
"getOneDeviceMeta",
std::bind(&DistributedMetadataStore::HandleGetOneDeviceMetaRequest, this, std::placeholders::_1));
return;
}
@ -148,6 +151,42 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
}
}
bool DistributedMetadataStore::GetOneDeviceMeta(const std::string &fl_id, DeviceMeta *device_meta) {
if (router_ == nullptr) {
MS_LOG(WARNING) << "The consistent hash ring is not initialized yet.";
return false;
}
const auto &name = kCtxDeviceMetas;
uint32_t stored_rank = router_->Find(name);
MS_LOG(DEBUG) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank;
if (local_rank_ == stored_rank) {
return DoGetOneDeviceMeta(fl_id, device_meta);
} else {
GetOneDeviceMetaRequest get_metadata_req;
get_metadata_req.set_fl_id(fl_id);
std::shared_ptr<std::vector<unsigned char>> get_meta_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(get_metadata_req, stored_rank, ps::core::TcpUserCommand::kGetOneDeviceMeta,
&get_meta_rsp_msg)) {
MS_LOG(WARNING) << "Sending getting one client metadata message to server " << stored_rank << " failed.";
return false;
}
MS_ERROR_IF_NULL_W_RET_VAL(get_meta_rsp_msg, false);
GetOneDeviceMetaResponse get_metadata_rsp;
auto ret = get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), SizeToInt(get_meta_rsp_msg->size()));
if (!ret) {
MS_LOG(WARNING) << "Parse response of getting one client metadata message to server " << stored_rank
<< " failed.";
return false;
}
if (!get_metadata_rsp.found()) {
return false;
}
*device_meta = get_metadata_rsp.device_meta();
return true;
}
}
bool DistributedMetadataStore::ReInitForScaling() {
// If DistributedMetadataStore is not initialized yet but the scaling event is triggered, do not throw exception.
if (server_node_ == nullptr) {
@ -221,6 +260,28 @@ void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<ps
return;
}
void DistributedMetadataStore::HandleGetOneDeviceMetaRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_ERROR_IF_NULL_WO_RET_VAL(message);
GetOneDeviceMetaRequest get_one_metadata_req;
GetOneDeviceMetaResponse response;
auto ret = get_one_metadata_req.ParseFromArray(message->data(), SizeToInt(message->len()));
if (!ret) {
MS_LOG(WARNING) << "Parse GetOneDeviceMetaRequest failed.";
response.set_found(false);
} else {
const std::string &fl_id = get_one_metadata_req.fl_id();
MS_LOG(DEBUG) << "Getting one client metadata for " << fl_id;
auto found = DoGetOneDeviceMeta(fl_id, response.mutable_device_meta());
response.set_found(found);
}
std::string getting_meta_rsp_msg = response.SerializeAsString();
if (!communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message)) {
MS_LOG(WARNING) << "Sending response of GetOneDeviceMetaRequest failed.";
return;
}
}
bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) {
std::unique_lock<std::mutex> lock(mutex_[name]);
if (metadata_.count(name) == 0) {
@ -312,6 +373,27 @@ bool DistributedMetadataStore::DoUpdateEncryptMetadata(const std::string &name,
return true;
}
bool DistributedMetadataStore::DoGetOneDeviceMeta(const std::string &fl_id, DeviceMeta *device_meta) {
if (device_meta == nullptr) {
return false;
}
const auto &name = kCtxDeviceMetas;
std::unique_lock<std::mutex> lock(mutex_[name]);
auto meta_it = metadata_.find(name);
if (meta_it == metadata_.end()) {
MS_LOG(WARNING) << "The metadata of " << name << " is not registered.";
return false;
}
auto &stored_meta = meta_it->second;
const auto &fl_id_to_meta = stored_meta.device_metas().fl_id_to_meta();
auto fl_it = fl_id_to_meta.find(fl_id);
if (fl_it == fl_id_to_meta.end()) {
return false;
}
*device_meta = fl_it->second;
return true;
}
bool DistributedMetadataStore::UpdatePairClientKeys(const std::string &name, const PBMetadata &meta) {
auto &client_keys_map = *metadata_[name].mutable_client_keys()->mutable_client_keys();
auto &fl_id = meta.pair_client_keys().fl_id();

View File

@ -61,6 +61,8 @@ class DistributedMetadataStore {
// Get the metadata for the name.
PBMetadata GetMetadata(const std::string &name);
bool GetOneDeviceMeta(const std::string &fl_id, DeviceMeta *device_meta);
// Reinitialize the consistency hash ring and clear metadata after scaling operations are done.
bool ReInitForScaling();
@ -85,12 +87,17 @@ class DistributedMetadataStore {
// Callback for getting metadata request sent to the server.
void HandleGetMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// Callback for getting metadata item request sent to the server.
void HandleGetOneDeviceMetaRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// Do updating metadata in the server where the metadata for the name is stored.
bool DoUpdateMetadata(const std::string &name, const PBMetadata &meta);
// Do updating metadata about pairwise-encryption in the server where the metadata for the name is stored.
bool DoUpdateEncryptMetadata(const std::string &name, const PBMetadata &meta);
bool DoGetOneDeviceMeta(const std::string &fl_id, DeviceMeta *device_meta);
// Update client keys stored in server
bool UpdatePairClientKeys(const std::string &name, const PBMetadata &meta);

View File

@ -107,9 +107,8 @@ bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
}
MS_LOG(INFO) << "verify signature passed!";
}
PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas);
result_code = VerifyUpdateModel(update_model_req, fbb, device_metas);
DeviceMeta device_meta;
result_code = VerifyUpdateModel(update_model_req, fbb, &device_meta);
if (result_code != ResultCode::kSuccess) {
MS_LOG(WARNING) << "Updating model failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
@ -122,7 +121,7 @@ bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
return ConvertResultCode(result_code);
}
result_code = UpdateModel(update_model_req, fbb, device_metas);
result_code = UpdateModel(update_model_req, fbb, device_meta);
if (result_code != ResultCode::kSuccess) {
MS_LOG(WARNING) << "Updating model failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
@ -178,8 +177,9 @@ ResultCode UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr
}
ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel *update_model_req,
const std::shared_ptr<FBBuilder> &fbb, const PBMetadata &device_metas) {
const std::shared_ptr<FBBuilder> &fbb, DeviceMeta *device_meta) {
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(device_meta, ResultCode::kSuccessAndReturn);
size_t iteration = IntToSize(update_model_req->iteration());
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp);
@ -191,19 +191,19 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel
return ResultCode::kSuccessAndReturn;
}
FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas();
std::string update_model_fl_id = update_model_req->fl_id()->str();
MS_LOG(DEBUG) << "UpdateModel for fl id " << update_model_fl_id;
if (ps::PSContext::instance()->encrypt_type() != ps::kPWEncryptType) {
if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later.";
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(WARNING) << reason;
return ResultCode::kSuccessAndReturn;
}
} else {
bool found = DistributedMetadataStore::GetInstance().GetOneDeviceMeta(update_model_fl_id, device_meta);
if (!found) {
std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later.";
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(WARNING) << reason;
return ResultCode::kSuccessAndReturn;
}
if (ps::PSContext::instance()->encrypt_type() == ps::kPWEncryptType) {
std::vector<std::string> get_secrets_clients;
#ifdef ENABLE_ARMOUR
mindspore::armour::CipherMetaStorage cipher_meta_storage;
@ -223,13 +223,12 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel
}
ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
const std::shared_ptr<FBBuilder> &fbb, const PBMetadata &device_metas) {
const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req->fl_id(), ResultCode::kSuccessAndReturn);
std::string update_model_fl_id = update_model_req->fl_id()->str();
const auto &fl_id_to_meta = device_metas.device_metas().fl_id_to_meta();
size_t data_size = fl_id_to_meta.at(update_model_fl_id).data_size();
size_t data_size = device_meta.data_size();
const auto &feature_map = ParseFeatureMap(update_model_req);
if (feature_map.empty()) {
std::string reason = "Feature map is empty.";

View File

@ -54,7 +54,7 @@ class UpdateModelKernel : public RoundKernel {
private:
ResultCode ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb);
ResultCode UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr<FBBuilder> &fbb,
const PBMetadata &device_metas);
const DeviceMeta &device_meta);
std::map<std::string, UploadData> ParseFeatureMap(const schema::RequestUpdateModel *update_model_req);
ResultCode CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestUpdateModel *update_model_req);
@ -62,7 +62,7 @@ class UpdateModelKernel : public RoundKernel {
void BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, const std::string &next_req_time);
ResultCode VerifyUpdateModel(const schema::RequestUpdateModel *update_model_req,
const std::shared_ptr<FBBuilder> &fbb, const PBMetadata &device_metas);
const std::shared_ptr<FBBuilder> &fbb, DeviceMeta *device_meta);
// The executor is for updating the model for updateModel request.
Executor *executor_{nullptr};

View File

@ -59,7 +59,8 @@ enum class TcpUserCommand {
kDisableFLS,
kSyncAfterRecover,
kExchangeKeys,
kGetKeys
kGetKeys,
kGetOneDeviceMeta,
};
// CommunicatorBase is used to receive request and send response for server.

View File

@ -44,6 +44,7 @@ const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
{TcpUserCommand::kResetCount, "resetCnt"},
{TcpUserCommand::kGetMetadata, "getMetadata"},
{TcpUserCommand::kUpdateMetadata, "updateMetadata"},
{TcpUserCommand::kGetOneDeviceMeta, "getOneDeviceMeta"},
{TcpUserCommand::kCounterEvent, "counterEvent"},
{TcpUserCommand::kPullWeight, "pullWeight"},
{TcpUserCommand::kPushWeight, "pushWeight"},

View File

@ -55,6 +55,15 @@ message GetMetadataResponse {
bytes value = 1;
}
message GetOneDeviceMetaRequest {
string fl_id = 1;
}
message GetOneDeviceMetaResponse {
DeviceMeta device_meta = 1;
bool found = 2;
}
enum CounterEventType {
FIRST_CNT = 0;
LAST_CNT = 1;