!29921 FL, opt update model get client info
Merge pull request !29921 from 徐永飞/r1.6
This commit is contained in:
commit
c3893f52e3
|
@ -38,6 +38,9 @@ void DistributedMetadataStore::RegisterMessageCallback(const std::shared_ptr<ps:
|
|||
"updateMetadata", std::bind(&DistributedMetadataStore::HandleUpdateMetadataRequest, this, std::placeholders::_1));
|
||||
communicator_->RegisterMsgCallBack(
|
||||
"getMetadata", std::bind(&DistributedMetadataStore::HandleGetMetadataRequest, this, std::placeholders::_1));
|
||||
communicator_->RegisterMsgCallBack(
|
||||
"getOneDeviceMeta",
|
||||
std::bind(&DistributedMetadataStore::HandleGetOneDeviceMetaRequest, this, std::placeholders::_1));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -148,6 +151,42 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
|
|||
}
|
||||
}
|
||||
|
||||
bool DistributedMetadataStore::GetOneDeviceMeta(const std::string &fl_id, DeviceMeta *device_meta) {
|
||||
if (router_ == nullptr) {
|
||||
MS_LOG(WARNING) << "The consistent hash ring is not initialized yet.";
|
||||
return false;
|
||||
}
|
||||
const auto &name = kCtxDeviceMetas;
|
||||
uint32_t stored_rank = router_->Find(name);
|
||||
MS_LOG(DEBUG) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank;
|
||||
if (local_rank_ == stored_rank) {
|
||||
return DoGetOneDeviceMeta(fl_id, device_meta);
|
||||
} else {
|
||||
GetOneDeviceMetaRequest get_metadata_req;
|
||||
get_metadata_req.set_fl_id(fl_id);
|
||||
|
||||
std::shared_ptr<std::vector<unsigned char>> get_meta_rsp_msg = nullptr;
|
||||
if (!communicator_->SendPbRequest(get_metadata_req, stored_rank, ps::core::TcpUserCommand::kGetOneDeviceMeta,
|
||||
&get_meta_rsp_msg)) {
|
||||
MS_LOG(WARNING) << "Sending getting one client metadata message to server " << stored_rank << " failed.";
|
||||
return false;
|
||||
}
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(get_meta_rsp_msg, false);
|
||||
GetOneDeviceMetaResponse get_metadata_rsp;
|
||||
auto ret = get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), SizeToInt(get_meta_rsp_msg->size()));
|
||||
if (!ret) {
|
||||
MS_LOG(WARNING) << "Parse response of getting one client metadata message to server " << stored_rank
|
||||
<< " failed.";
|
||||
return false;
|
||||
}
|
||||
if (!get_metadata_rsp.found()) {
|
||||
return false;
|
||||
}
|
||||
*device_meta = get_metadata_rsp.device_meta();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
bool DistributedMetadataStore::ReInitForScaling() {
|
||||
// If DistributedMetadataStore is not initialized yet but the scaling event is triggered, do not throw exception.
|
||||
if (server_node_ == nullptr) {
|
||||
|
@ -221,6 +260,28 @@ void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<ps
|
|||
return;
|
||||
}
|
||||
|
||||
void DistributedMetadataStore::HandleGetOneDeviceMetaRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
||||
GetOneDeviceMetaRequest get_one_metadata_req;
|
||||
GetOneDeviceMetaResponse response;
|
||||
auto ret = get_one_metadata_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
if (!ret) {
|
||||
MS_LOG(WARNING) << "Parse GetOneDeviceMetaRequest failed.";
|
||||
response.set_found(false);
|
||||
} else {
|
||||
const std::string &fl_id = get_one_metadata_req.fl_id();
|
||||
MS_LOG(DEBUG) << "Getting one client metadata for " << fl_id;
|
||||
|
||||
auto found = DoGetOneDeviceMeta(fl_id, response.mutable_device_meta());
|
||||
response.set_found(found);
|
||||
}
|
||||
std::string getting_meta_rsp_msg = response.SerializeAsString();
|
||||
if (!communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message)) {
|
||||
MS_LOG(WARNING) << "Sending response of GetOneDeviceMetaRequest failed.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) {
|
||||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||
if (metadata_.count(name) == 0) {
|
||||
|
@ -312,6 +373,27 @@ bool DistributedMetadataStore::DoUpdateEncryptMetadata(const std::string &name,
|
|||
return true;
|
||||
}
|
||||
|
||||
bool DistributedMetadataStore::DoGetOneDeviceMeta(const std::string &fl_id, DeviceMeta *device_meta) {
|
||||
if (device_meta == nullptr) {
|
||||
return false;
|
||||
}
|
||||
const auto &name = kCtxDeviceMetas;
|
||||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||
auto meta_it = metadata_.find(name);
|
||||
if (meta_it == metadata_.end()) {
|
||||
MS_LOG(WARNING) << "The metadata of " << name << " is not registered.";
|
||||
return false;
|
||||
}
|
||||
auto &stored_meta = meta_it->second;
|
||||
const auto &fl_id_to_meta = stored_meta.device_metas().fl_id_to_meta();
|
||||
auto fl_it = fl_id_to_meta.find(fl_id);
|
||||
if (fl_it == fl_id_to_meta.end()) {
|
||||
return false;
|
||||
}
|
||||
*device_meta = fl_it->second;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DistributedMetadataStore::UpdatePairClientKeys(const std::string &name, const PBMetadata &meta) {
|
||||
auto &client_keys_map = *metadata_[name].mutable_client_keys()->mutable_client_keys();
|
||||
auto &fl_id = meta.pair_client_keys().fl_id();
|
||||
|
|
|
@ -61,6 +61,8 @@ class DistributedMetadataStore {
|
|||
// Get the metadata for the name.
|
||||
PBMetadata GetMetadata(const std::string &name);
|
||||
|
||||
bool GetOneDeviceMeta(const std::string &fl_id, DeviceMeta *device_meta);
|
||||
|
||||
// Reinitialize the consistency hash ring and clear metadata after scaling operations are done.
|
||||
bool ReInitForScaling();
|
||||
|
||||
|
@ -85,12 +87,17 @@ class DistributedMetadataStore {
|
|||
// Callback for getting metadata request sent to the server.
|
||||
void HandleGetMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
|
||||
// Callback for getting metadata item request sent to the server.
|
||||
void HandleGetOneDeviceMetaRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
|
||||
// Do updating metadata in the server where the metadata for the name is stored.
|
||||
bool DoUpdateMetadata(const std::string &name, const PBMetadata &meta);
|
||||
|
||||
// Do updating metadata about pairwise-encryption in the server where the metadata for the name is stored.
|
||||
bool DoUpdateEncryptMetadata(const std::string &name, const PBMetadata &meta);
|
||||
|
||||
bool DoGetOneDeviceMeta(const std::string &fl_id, DeviceMeta *device_meta);
|
||||
|
||||
// Update client keys stored in server
|
||||
bool UpdatePairClientKeys(const std::string &name, const PBMetadata &meta);
|
||||
|
||||
|
|
|
@ -107,9 +107,8 @@ bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
|
|||
}
|
||||
MS_LOG(INFO) << "verify signature passed!";
|
||||
}
|
||||
|
||||
PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas);
|
||||
result_code = VerifyUpdateModel(update_model_req, fbb, device_metas);
|
||||
DeviceMeta device_meta;
|
||||
result_code = VerifyUpdateModel(update_model_req, fbb, &device_meta);
|
||||
if (result_code != ResultCode::kSuccess) {
|
||||
MS_LOG(WARNING) << "Updating model failed.";
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
|
@ -122,7 +121,7 @@ bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
|
|||
return ConvertResultCode(result_code);
|
||||
}
|
||||
|
||||
result_code = UpdateModel(update_model_req, fbb, device_metas);
|
||||
result_code = UpdateModel(update_model_req, fbb, device_meta);
|
||||
if (result_code != ResultCode::kSuccess) {
|
||||
MS_LOG(WARNING) << "Updating model failed.";
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
|
@ -178,8 +177,9 @@ ResultCode UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr
|
|||
}
|
||||
|
||||
ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel *update_model_req,
|
||||
const std::shared_ptr<FBBuilder> &fbb, const PBMetadata &device_metas) {
|
||||
const std::shared_ptr<FBBuilder> &fbb, DeviceMeta *device_meta) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(device_meta, ResultCode::kSuccessAndReturn);
|
||||
size_t iteration = IntToSize(update_model_req->iteration());
|
||||
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
|
||||
auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp);
|
||||
|
@ -191,19 +191,19 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel
|
|||
return ResultCode::kSuccessAndReturn;
|
||||
}
|
||||
|
||||
FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas();
|
||||
std::string update_model_fl_id = update_model_req->fl_id()->str();
|
||||
MS_LOG(DEBUG) << "UpdateModel for fl id " << update_model_fl_id;
|
||||
if (ps::PSContext::instance()->encrypt_type() != ps::kPWEncryptType) {
|
||||
if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
|
||||
std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later.";
|
||||
BuildUpdateModelRsp(
|
||||
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(WARNING) << reason;
|
||||
return ResultCode::kSuccessAndReturn;
|
||||
}
|
||||
} else {
|
||||
|
||||
bool found = DistributedMetadataStore::GetInstance().GetOneDeviceMeta(update_model_fl_id, device_meta);
|
||||
if (!found) {
|
||||
std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later.";
|
||||
BuildUpdateModelRsp(
|
||||
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(WARNING) << reason;
|
||||
return ResultCode::kSuccessAndReturn;
|
||||
}
|
||||
if (ps::PSContext::instance()->encrypt_type() == ps::kPWEncryptType) {
|
||||
std::vector<std::string> get_secrets_clients;
|
||||
#ifdef ENABLE_ARMOUR
|
||||
mindspore::armour::CipherMetaStorage cipher_meta_storage;
|
||||
|
@ -223,13 +223,12 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel
|
|||
}
|
||||
|
||||
ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
|
||||
const std::shared_ptr<FBBuilder> &fbb, const PBMetadata &device_metas) {
|
||||
const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req->fl_id(), ResultCode::kSuccessAndReturn);
|
||||
|
||||
std::string update_model_fl_id = update_model_req->fl_id()->str();
|
||||
const auto &fl_id_to_meta = device_metas.device_metas().fl_id_to_meta();
|
||||
size_t data_size = fl_id_to_meta.at(update_model_fl_id).data_size();
|
||||
size_t data_size = device_meta.data_size();
|
||||
const auto &feature_map = ParseFeatureMap(update_model_req);
|
||||
if (feature_map.empty()) {
|
||||
std::string reason = "Feature map is empty.";
|
||||
|
|
|
@ -54,7 +54,7 @@ class UpdateModelKernel : public RoundKernel {
|
|||
private:
|
||||
ResultCode ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb);
|
||||
ResultCode UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr<FBBuilder> &fbb,
|
||||
const PBMetadata &device_metas);
|
||||
const DeviceMeta &device_meta);
|
||||
std::map<std::string, UploadData> ParseFeatureMap(const schema::RequestUpdateModel *update_model_req);
|
||||
ResultCode CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
|
||||
const schema::RequestUpdateModel *update_model_req);
|
||||
|
@ -62,7 +62,7 @@ class UpdateModelKernel : public RoundKernel {
|
|||
void BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason, const std::string &next_req_time);
|
||||
ResultCode VerifyUpdateModel(const schema::RequestUpdateModel *update_model_req,
|
||||
const std::shared_ptr<FBBuilder> &fbb, const PBMetadata &device_metas);
|
||||
const std::shared_ptr<FBBuilder> &fbb, DeviceMeta *device_meta);
|
||||
// The executor is for updating the model for updateModel request.
|
||||
Executor *executor_{nullptr};
|
||||
|
||||
|
|
|
@ -59,7 +59,8 @@ enum class TcpUserCommand {
|
|||
kDisableFLS,
|
||||
kSyncAfterRecover,
|
||||
kExchangeKeys,
|
||||
kGetKeys
|
||||
kGetKeys,
|
||||
kGetOneDeviceMeta,
|
||||
};
|
||||
|
||||
// CommunicatorBase is used to receive request and send response for server.
|
||||
|
|
|
@ -44,6 +44,7 @@ const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
|
|||
{TcpUserCommand::kResetCount, "resetCnt"},
|
||||
{TcpUserCommand::kGetMetadata, "getMetadata"},
|
||||
{TcpUserCommand::kUpdateMetadata, "updateMetadata"},
|
||||
{TcpUserCommand::kGetOneDeviceMeta, "getOneDeviceMeta"},
|
||||
{TcpUserCommand::kCounterEvent, "counterEvent"},
|
||||
{TcpUserCommand::kPullWeight, "pullWeight"},
|
||||
{TcpUserCommand::kPushWeight, "pushWeight"},
|
||||
|
|
|
@ -55,6 +55,15 @@ message GetMetadataResponse {
|
|||
bytes value = 1;
|
||||
}
|
||||
|
||||
message GetOneDeviceMetaRequest {
|
||||
string fl_id = 1;
|
||||
}
|
||||
|
||||
message GetOneDeviceMetaResponse {
|
||||
DeviceMeta device_meta = 1;
|
||||
bool found = 2;
|
||||
}
|
||||
|
||||
enum CounterEventType {
|
||||
FIRST_CNT = 0;
|
||||
LAST_CNT = 1;
|
||||
|
|
Loading…
Reference in New Issue