Optimize round kernel's return code.

This commit is contained in:
ZPaC 2021-07-12 15:18:27 +08:00
parent d0dae7bb94
commit c46f4ac8d1
16 changed files with 155 additions and 94 deletions

View File

@ -234,6 +234,32 @@ inline AddressPtr GenerateParameterNodeAddrPtr(const CNodePtr &kernel_node, size
// Definitions for Federated Learning.
constexpr auto kNetworkError = "Cluster networking failed.";
// The result code used for round kernels.
enum class ResultCode {
// If the method is successfully called and round kernel's residual methods should be called, return kSuccess.
kSuccess = 0,
// If there's error happened in the method and residual methods should not be called but this iteration continues,
// return kSuccessAndReturn so that framework won't drop this iteration.
kSuccessAndReturn,
// If there's error happened and this iteration should be dropped, return kFail.
kFail
};
bool inline ConvertResultCode(ResultCode result_code) {
switch (result_code) {
case ResultCode::kSuccess:
return true;
case ResultCode::kSuccessAndReturn:
return true;
case ResultCode::kFail:
return false;
default:
return true;
}
}
// Definitions for Parameter Server.
} // namespace server

View File

@ -66,7 +66,7 @@ void DistributedCountService::RegisterCounter(const std::string &name, size_t gl
return;
}
bool DistributedCountService::Count(const std::string &name, const std::string &id) {
bool DistributedCountService::Count(const std::string &name, const std::string &id, std::string *reason) {
MS_LOG(INFO) << "Rank " << local_rank_ << " reports count for " << name << " of " << id;
if (local_rank_ == counting_server_rank_) {
if (global_threshold_count_.count(name) == 0) {
@ -83,7 +83,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id;
global_current_count_[name].insert(id);
if (!TriggerCounterEvent(name)) {
if (!TriggerCounterEvent(name, reason)) {
MS_LOG(ERROR) << "Leader server trigger count event failed.";
return false;
}
@ -97,6 +97,9 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, ps::core::TcpUserCommand::kCount,
&report_cnt_rsp_msg)) {
MS_LOG(ERROR) << "Sending reporting count message to leader server failed for " << name;
if (reason != nullptr) {
*reason = kNetworkError;
}
return false;
}
@ -104,6 +107,9 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size()));
if (!count_rsp.result()) {
MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason();
if (reason != nullptr && count_rsp.reason().find(kNetworkError) != std::string::npos) {
*reason = kNetworkError;
}
return false;
}
}
@ -202,13 +208,13 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core:
// Insert the id for the counter, which means the count for the name is increased.
MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id;
global_current_count_[name].insert(id);
if (!TriggerCounterEvent(name)) {
std::string reason = "Trigger count event for " + name + " of " + id + " failed.";
std::string reason = "success";
if (!TriggerCounterEvent(name, &reason)) {
count_rsp.set_result(false);
count_rsp.set_reason(reason);
} else {
count_rsp.set_result(true);
count_rsp.set_reason("success");
count_rsp.set_reason(reason);
}
communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message);
return;
@ -266,24 +272,24 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core:
return;
}
bool DistributedCountService::TriggerCounterEvent(const std::string &name) {
bool DistributedCountService::TriggerCounterEvent(const std::string &name, std::string *reason) {
MS_LOG(INFO) << "Current count for " << name << " is " << global_current_count_[name].size()
<< ", threshold count is " << global_threshold_count_[name];
// The threshold count may be 1 so the first and last count event should be both activated.
if (global_current_count_[name].size() == 1) {
if (!TriggerFirstCountEvent(name)) {
if (!TriggerFirstCountEvent(name, reason)) {
return false;
}
}
if (global_current_count_[name].size() == global_threshold_count_[name]) {
if (!TriggerLastCountEvent(name)) {
if (!TriggerLastCountEvent(name, reason)) {
return false;
}
}
return true;
}
bool DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, std::string *reason) {
MS_LOG(DEBUG) << "Activating first count event for " << name;
CounterEvent first_count_event;
first_count_event.set_type(CounterEventType::FIRST_CNT);
@ -293,6 +299,9 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
for (uint32_t i = 1; i < server_num_; i++) {
if (!communicator_->SendPbRequest(first_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
MS_LOG(ERROR) << "Activating first count event to server " << i << " failed.";
if (reason != nullptr) {
*reason = "Send to rank " + std::to_string(i) + " failed. " + kNetworkError;
}
return false;
}
}
@ -301,7 +310,7 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
return true;
}
bool DistributedCountService::TriggerLastCountEvent(const std::string &name) {
bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std::string *reason) {
MS_LOG(INFO) << "Activating last count event for " << name;
CounterEvent last_count_event;
last_count_event.set_type(CounterEventType::LAST_CNT);
@ -311,6 +320,9 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name) {
for (uint32_t i = 1; i < server_num_; i++) {
if (!communicator_->SendPbRequest(last_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
MS_LOG(ERROR) << "Activating last count event to server " << i << " failed.";
if (reason != nullptr) {
*reason = "Send to rank " + std::to_string(i) + " failed. " + kNetworkError;
}
return false;
}
}

View File

@ -63,8 +63,9 @@ class DistributedCountService {
// first/last count event callbacks.
void RegisterCounter(const std::string &name, size_t global_threshold_count, const CounterHandlers &counter_handlers);
// Report a count to the counting server. Parameter 'id' is in case of repeated counting.
bool Count(const std::string &name, const std::string &id);
// Report a count to the counting server. Parameter 'id' is in case of repeated counting. Parameter 'reason' is the
// reason why counting failed.
bool Count(const std::string &name, const std::string &id, std::string *reason = nullptr);
// Query whether the count reaches the threshold count for the name. If the count is the same as the threshold count,
// this method returns true.
@ -98,9 +99,9 @@ class DistributedCountService {
void HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
// Call the callbacks when the first/last count event is triggered.
bool TriggerCounterEvent(const std::string &name);
bool TriggerFirstCountEvent(const std::string &name);
bool TriggerLastCountEvent(const std::string &name);
bool TriggerCounterEvent(const std::string &name, std::string *reason = nullptr);
bool TriggerFirstCountEvent(const std::string &name, std::string *reason = nullptr);
bool TriggerLastCountEvent(const std::string &name, std::string *reason = nullptr);
// Members for the communication between counting server and other servers.
std::shared_ptr<ps::core::ServerNode> server_node_;

View File

@ -82,7 +82,7 @@ void DistributedMetadataStore::ResetMetadata(const std::string &name) {
return;
}
bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta) {
bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta, std::string *reason) {
if (router_ == nullptr) {
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
return false;
@ -103,6 +103,9 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, ps::core::TcpUserCommand::kUpdateMetadata,
&update_meta_rsp_msg)) {
MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed.";
if (reason != nullptr) {
*reason = "Send to rank " + std::to_string(stored_rank) + " failed. " + kNetworkError;
}
return false;
}

View File

@ -55,8 +55,8 @@ class DistributedMetadataStore {
// Reset the metadata value for the name.
void ResetMetadata(const std::string &name);
// Update the metadata for the name.
bool UpdateMetadata(const std::string &name, const PBMetadata &meta);
// Update the metadata for the name. Parameter 'reason' is the reason why updating meta data failed.
bool UpdateMetadata(const std::string &name, const PBMetadata &meta, std::string *reason = nullptr);
// Get the metadata for the name.
PBMetadata GetMetadata(const std::string &name);

View File

@ -390,7 +390,7 @@ bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
end_last_iter_req.set_last_iter_num(last_iter_num);
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
if (!communicator_->SendPbRequest(end_last_iter_req, i, ps::core::TcpUserCommand::kEndLastIter)) {
MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed.";
MS_LOG(WARNING) << "Sending ending last iteration request to server " << i << " failed.";
continue;
}
}
@ -438,10 +438,10 @@ void Iteration::EndLastIter() {
ModelStore::GetInstance().Reset();
}
Server::GetInstance().CancelSafeMode();
SetIterationCompleted();
pinned_iter_num_ = 0;
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
Server::GetInstance().CancelSafeMode();
SetIterationCompleted();
MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";
}
} // namespace server

View File

@ -48,9 +48,9 @@ bool PushWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
return false;
}
bool ret = PushWeight(fbb, push_weight_req);
ResultCode result_code = PushWeight(fbb, push_weight_req);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return ret;
return ConvertResultCode(result_code);
}
bool PushWeightKernel::Reset() {
@ -67,9 +67,10 @@ void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageH
return;
}
bool PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req) {
ResultCode PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb,
const schema::RequestPushWeight *push_weight_req) {
if (fbb == nullptr || push_weight_req == nullptr) {
return false;
return ResultCode::kSuccessAndReturn;
}
size_t iteration = static_cast<size_t>(push_weight_req->iteration());
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
@ -78,7 +79,7 @@ bool PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::
", current iteration:" + std::to_string(current_iter);
BuildPushWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter);
MS_LOG(WARNING) << reason;
return true;
return ResultCode::kSuccessAndReturn;
}
std::map<std::string, Address> upload_feature_map = ParseFeatureMap(push_weight_req);
@ -86,25 +87,26 @@ bool PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::
std::string reason = "PushWeight feature_map is empty.";
BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter);
MS_LOG(ERROR) << reason;
return false;
return ResultCode::kSuccessAndReturn;
}
if (!executor_->HandlePushWeight(upload_feature_map)) {
std::string reason = "Pushing weight failed.";
BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter);
BuildPushWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter);
MS_LOG(ERROR) << reason;
return false;
return ResultCode::kSuccessAndReturn;
}
MS_LOG(INFO) << "Pushing weight for iteration " << current_iter << " succeeds.";
if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_))) {
std::string count_reason = "";
if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_), &count_reason)) {
std::string reason = "Count for push weight request failed.";
BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter);
MS_LOG(ERROR) << reason;
return false;
return count_reason == kNetworkError ? ResultCode::kSuccessAndReturn : ResultCode::kFail;
}
BuildPushWeightRsp(fbb, schema::ResponseCode_SUCCEED, "PushWeight succeed.", current_iter);
return true;
return ResultCode::kSuccess;
}
std::map<std::string, Address> PushWeightKernel::ParseFeatureMap(const schema::RequestPushWeight *push_weight_req) {

View File

@ -42,7 +42,7 @@ class PushWeightKernel : public RoundKernel {
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
private:
bool PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req);
ResultCode PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req);
std::map<std::string, Address> ParseFeatureMap(const schema::RequestPushWeight *push_weight_req);
void BuildPushWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode, const std::string &reason,
size_t iteration);

View File

@ -72,33 +72,37 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
return true;
}
if (ReachThresholdForStartFLJob(fbb)) {
ResultCode result_code = ReachThresholdForStartFLJob(fbb);
if (result_code != ResultCode::kSuccess) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
return ConvertResultCode(result_code);
}
const schema::RequestFLJob *start_fl_job_req = flatbuffers::GetRoot<schema::RequestFLJob>(req_data);
DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req);
if (!ReadyForStartFLJob(fbb, device_meta)) {
result_code = ReadyForStartFLJob(fbb, device_meta);
if (result_code != ResultCode::kSuccess) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
return ConvertResultCode(result_code);
}
PBMetadata metadata;
*metadata.mutable_device_meta() = device_meta;
if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata)) {
std::string reason = "Updating device metadata failed.";
std::string update_reason = "";
if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata, &update_reason)) {
std::string reason = "Updating device metadata failed. " + update_reason;
BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)),
{});
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
return update_reason == kNetworkError ? true : false;
}
StartFLJob(fbb, device_meta);
// If calling ReportCount before ReadyForStartFLJob, the result will be inconsistent if the device is not selected.
if (!CountForStartFLJob(fbb, start_fl_job_req)) {
result_code = CountForStartFLJob(fbb, start_fl_job_req);
if (result_code != ResultCode::kSuccess) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
return ConvertResultCode(result_code);
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
@ -120,16 +124,16 @@ void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::Message
Iteration::GetInstance().SetIterationRunning();
}
bool StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb) {
ResultCode StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb) {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
std::string reason = "Current amount for startFLJob has reached the threshold. Please startFLJob later.";
BuildStartFLJobRsp(
fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(WARNING) << reason;
return true;
return ResultCode::kSuccessAndReturn;
}
return false;
return ResultCode::kSuccess;
}
DeviceMeta StartFLJobKernel::CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req) {
@ -146,34 +150,35 @@ DeviceMeta StartFLJobKernel::CreateDeviceMetadata(const schema::RequestFLJob *st
return device_meta;
}
bool StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
bool ret = true;
ResultCode StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
ResultCode ret = ResultCode::kSuccess;
std::string reason = "";
if (device_meta.data_size() < 1) {
reason = "FL job data size is not enough.";
ret = false;
ret = ResultCode::kSuccessAndReturn;
}
if (!ret) {
if (ret != ResultCode::kSuccess) {
BuildStartFLJobRsp(
fbb, schema::ResponseCode_RequestError, reason, false,
fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
MS_LOG(WARNING) << reason;
}
return ret;
}
bool StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestFLJob *start_fl_job_req) {
RETURN_IF_NULL(start_fl_job_req, false);
if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str())) {
ResultCode StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestFLJob *start_fl_job_req) {
RETURN_IF_NULL(start_fl_job_req, ResultCode::kSuccessAndReturn);
std::string count_reason = "";
if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str(), &count_reason)) {
std::string reason = "Counting start fl job request failed. Please retry later.";
BuildStartFLJobRsp(
fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return false;
return count_reason == kNetworkError ? ResultCode::kSuccessAndReturn : ResultCode::kFail;
}
return true;
return ResultCode::kSuccess;
}
void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {

View File

@ -44,17 +44,17 @@ class StartFLJobKernel : public RoundKernel {
private:
// Returns whether the startFLJob count of this iteration has reached the threshold.
bool ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb);
ResultCode ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb);
// The metadata of device will be stored and queried in updateModel round.
DeviceMeta CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req);
// Returns whether the request is valid for startFLJob.For now, the condition is simple. We will add more conditions
// to device in later versions.
bool ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta);
ResultCode ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta);
// Distributed count service counts for startFLJob.
bool CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req);
ResultCode CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req);
void StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta);

View File

@ -56,21 +56,24 @@ bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
}
MS_LOG(INFO) << "Launching UpdateModelKernel kernel.";
if (ReachThresholdForUpdateModel(fbb)) {
ResultCode result_code = ReachThresholdForUpdateModel(fbb);
if (result_code != ResultCode::kSuccess) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
return ConvertResultCode(result_code);
}
const schema::RequestUpdateModel *update_model_req = flatbuffers::GetRoot<schema::RequestUpdateModel>(req_data);
if (!UpdateModel(update_model_req, fbb)) {
result_code = UpdateModel(update_model_req, fbb);
if (result_code != ResultCode::kSuccess) {
MS_LOG(ERROR) << "Updating model failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
return ConvertResultCode(result_code);
}
if (!CountForUpdateModel(fbb, update_model_req)) {
result_code = CountForUpdateModel(fbb, update_model_req);
if (result_code != ResultCode::kSuccess) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
return ConvertResultCode(result_code);
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
@ -102,21 +105,21 @@ void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<ps::core::Message
}
}
bool UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb) {
ResultCode UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb) {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
std::string reason = "Current amount for updateModel is enough. Please retry later.";
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(WARNING) << reason;
return true;
return ResultCode::kSuccessAndReturn;
}
return false;
return ResultCode::kSuccess;
}
bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
const std::shared_ptr<FBBuilder> &fbb) {
RETURN_IF_NULL(update_model_req, false);
ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
const std::shared_ptr<FBBuilder> &fbb) {
RETURN_IF_NULL(update_model_req, ResultCode::kSuccessAndReturn);
size_t iteration = static_cast<size_t>(update_model_req->iteration());
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) +
@ -126,7 +129,7 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(WARNING) << reason;
return true;
return ResultCode::kSuccessAndReturn;
}
PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas);
@ -139,7 +142,7 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return false;
return ResultCode::kSuccessAndReturn;
}
size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size();
@ -150,7 +153,7 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
fbb, schema::ResponseCode_RequestError, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return false;
return ResultCode::kSuccessAndReturn;
}
for (auto weight : feature_map) {
@ -163,18 +166,19 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
fl_id.set_fl_id(update_model_fl_id);
PBMetadata comm_value;
*comm_value.mutable_fl_id() = fl_id;
if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value)) {
std::string reason = "Updating metadata of UpdateModelClientList failed.";
std::string update_reason = "";
if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value, &update_reason)) {
std::string reason = "Updating metadata of UpdateModelClientList failed. " + update_reason;
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return false;
return update_reason == kNetworkError ? ResultCode::kSuccessAndReturn : ResultCode::kFail;
}
BuildUpdateModelRsp(fbb, schema::ResponseCode_SUCCEED, "success not ready",
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
return true;
return ResultCode::kSuccess;
}
std::map<std::string, UploadData> UpdateModelKernel::ParseFeatureMap(
@ -195,18 +199,19 @@ std::map<std::string, UploadData> UpdateModelKernel::ParseFeatureMap(
return feature_map;
}
bool UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestUpdateModel *update_model_req) {
RETURN_IF_NULL(update_model_req, false);
if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str())) {
std::string reason = "Counting for update model request failed. Please retry later.";
ResultCode UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestUpdateModel *update_model_req) {
RETURN_IF_NULL(update_model_req, ResultCode::kSuccessAndReturn);
std::string count_reason = "";
if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str(), &count_reason)) {
std::string reason = "Counting for update model request failed. Please retry later. " + count_reason;
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return false;
return count_reason == kNetworkError ? ResultCode::kSuccessAndReturn : ResultCode::kFail;
}
return true;
return ResultCode::kSuccess;
}
void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,

View File

@ -47,10 +47,11 @@ class UpdateModelKernel : public RoundKernel {
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
private:
bool ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb);
bool UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr<FBBuilder> &fbb);
ResultCode ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb);
ResultCode UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr<FBBuilder> &fbb);
std::map<std::string, UploadData> ParseFeatureMap(const schema::RequestUpdateModel *update_model_req);
bool CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestUpdateModel *update_model_req);
ResultCode CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestUpdateModel *update_model_req);
void BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, const std::string &next_req_time);

View File

@ -36,6 +36,7 @@ void ModelStore::Initialize(uint32_t max_count) {
}
bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &new_model) {
std::unique_lock<std::mutex> lock(model_mtx_);
if (iteration_to_model_.count(iteration) != 0) {
MS_LOG(WARNING) << "Model for iteration " << iteration << " is already stored";
return false;
@ -88,6 +89,7 @@ bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::strin
}
std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration) {
std::unique_lock<std::mutex> lock(model_mtx_);
std::map<std::string, AddressPtr> model = {};
if (iteration_to_model_.count(iteration) == 0) {
MS_LOG(ERROR) << "Model for iteration " << iteration << " is not stored.";
@ -98,13 +100,15 @@ std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration
}
void ModelStore::Reset() {
std::unique_lock<std::mutex> lock(model_mtx_);
initial_model_ = iteration_to_model_.rbegin()->second;
iteration_to_model_.clear();
iteration_to_model_[kInitIterationNum] = initial_model_;
iteration_to_model_[kResetInitIterNum] = initial_model_;
}
const std::map<size_t, std::shared_ptr<MemoryRegister>> &ModelStore::iteration_to_model() const {
const std::map<size_t, std::shared_ptr<MemoryRegister>> &ModelStore::iteration_to_model() {
std::unique_lock<std::mutex> lock(model_mtx_);
return iteration_to_model_;
}
@ -142,6 +146,7 @@ std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() {
}
size_t ModelStore::ComputeModelSize() {
std::unique_lock<std::mutex> lock(model_mtx_);
if (iteration_to_model_.empty()) {
MS_LOG(EXCEPTION) << "Calculating model size failed: model for iteration 0 is not stored yet. ";
return 0;

View File

@ -56,7 +56,7 @@ class ModelStore {
void Reset();
// Returns all models stored in ModelStore.
const std::map<size_t, std::shared_ptr<MemoryRegister>> &iteration_to_model() const;
const std::map<size_t, std::shared_ptr<MemoryRegister>> &iteration_to_model();
// Returns the model size, which could be calculated at the initializing phase.
size_t model_size() const;
@ -80,7 +80,8 @@ class ModelStore {
// Initial model which is the model of iteration 0.
std::shared_ptr<MemoryRegister> initial_model_;
// The number of all models stpred is max_model_count_.
// The number of all models stored is max_model_count_.
std::mutex model_mtx_;
std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_;
};
} // namespace server

View File

@ -48,7 +48,7 @@ parser.add_argument("--sts_properties_path", type=str, default="")
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01) # usually equals 1/start_fl_job_threshold
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
parser.add_argument("--encrypt_type", type=str, default="NotEncrypt")
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
args, _ = parser.parse_known_args()

View File

@ -47,7 +47,7 @@ parser.add_argument("--sts_properties_path", type=str, default="")
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01) # usually equals 1/start_fl_job_threshold
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
parser.add_argument("--encrypt_type", type=str, default="NotEncrypt")
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)