forked from mindspore-Ecosystem/mindspore
Optimize round kernel's return code.
This commit is contained in:
parent
d0dae7bb94
commit
c46f4ac8d1
|
@ -234,6 +234,32 @@ inline AddressPtr GenerateParameterNodeAddrPtr(const CNodePtr &kernel_node, size
|
|||
|
||||
// Definitions for Federated Learning.
|
||||
|
||||
constexpr auto kNetworkError = "Cluster networking failed.";
|
||||
|
||||
// The result code used for round kernels.
|
||||
enum class ResultCode {
|
||||
// If the method is successfully called and round kernel's residual methods should be called, return kSuccess.
|
||||
kSuccess = 0,
|
||||
// If there's error happened in the method and residual methods should not be called but this iteration continues,
|
||||
// return kSuccessAndReturn so that framework won't drop this iteration.
|
||||
kSuccessAndReturn,
|
||||
// If there's error happened and this iteration should be dropped, return kFail.
|
||||
kFail
|
||||
};
|
||||
|
||||
bool inline ConvertResultCode(ResultCode result_code) {
|
||||
switch (result_code) {
|
||||
case ResultCode::kSuccess:
|
||||
return true;
|
||||
case ResultCode::kSuccessAndReturn:
|
||||
return true;
|
||||
case ResultCode::kFail:
|
||||
return false;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Definitions for Parameter Server.
|
||||
|
||||
} // namespace server
|
||||
|
|
|
@ -66,7 +66,7 @@ void DistributedCountService::RegisterCounter(const std::string &name, size_t gl
|
|||
return;
|
||||
}
|
||||
|
||||
bool DistributedCountService::Count(const std::string &name, const std::string &id) {
|
||||
bool DistributedCountService::Count(const std::string &name, const std::string &id, std::string *reason) {
|
||||
MS_LOG(INFO) << "Rank " << local_rank_ << " reports count for " << name << " of " << id;
|
||||
if (local_rank_ == counting_server_rank_) {
|
||||
if (global_threshold_count_.count(name) == 0) {
|
||||
|
@ -83,7 +83,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
|
|||
|
||||
MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id;
|
||||
global_current_count_[name].insert(id);
|
||||
if (!TriggerCounterEvent(name)) {
|
||||
if (!TriggerCounterEvent(name, reason)) {
|
||||
MS_LOG(ERROR) << "Leader server trigger count event failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -97,6 +97,9 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
|
|||
if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, ps::core::TcpUserCommand::kCount,
|
||||
&report_cnt_rsp_msg)) {
|
||||
MS_LOG(ERROR) << "Sending reporting count message to leader server failed for " << name;
|
||||
if (reason != nullptr) {
|
||||
*reason = kNetworkError;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -104,6 +107,9 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
|
|||
count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size()));
|
||||
if (!count_rsp.result()) {
|
||||
MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason();
|
||||
if (reason != nullptr && count_rsp.reason().find(kNetworkError) != std::string::npos) {
|
||||
*reason = kNetworkError;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -202,13 +208,13 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core:
|
|||
// Insert the id for the counter, which means the count for the name is increased.
|
||||
MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id;
|
||||
global_current_count_[name].insert(id);
|
||||
if (!TriggerCounterEvent(name)) {
|
||||
std::string reason = "Trigger count event for " + name + " of " + id + " failed.";
|
||||
std::string reason = "success";
|
||||
if (!TriggerCounterEvent(name, &reason)) {
|
||||
count_rsp.set_result(false);
|
||||
count_rsp.set_reason(reason);
|
||||
} else {
|
||||
count_rsp.set_result(true);
|
||||
count_rsp.set_reason("success");
|
||||
count_rsp.set_reason(reason);
|
||||
}
|
||||
communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message);
|
||||
return;
|
||||
|
@ -266,24 +272,24 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core:
|
|||
return;
|
||||
}
|
||||
|
||||
bool DistributedCountService::TriggerCounterEvent(const std::string &name) {
|
||||
bool DistributedCountService::TriggerCounterEvent(const std::string &name, std::string *reason) {
|
||||
MS_LOG(INFO) << "Current count for " << name << " is " << global_current_count_[name].size()
|
||||
<< ", threshold count is " << global_threshold_count_[name];
|
||||
// The threshold count may be 1 so the first and last count event should be both activated.
|
||||
if (global_current_count_[name].size() == 1) {
|
||||
if (!TriggerFirstCountEvent(name)) {
|
||||
if (!TriggerFirstCountEvent(name, reason)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (global_current_count_[name].size() == global_threshold_count_[name]) {
|
||||
if (!TriggerLastCountEvent(name)) {
|
||||
if (!TriggerLastCountEvent(name, reason)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
|
||||
bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, std::string *reason) {
|
||||
MS_LOG(DEBUG) << "Activating first count event for " << name;
|
||||
CounterEvent first_count_event;
|
||||
first_count_event.set_type(CounterEventType::FIRST_CNT);
|
||||
|
@ -293,6 +299,9 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
|
|||
for (uint32_t i = 1; i < server_num_; i++) {
|
||||
if (!communicator_->SendPbRequest(first_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
|
||||
MS_LOG(ERROR) << "Activating first count event to server " << i << " failed.";
|
||||
if (reason != nullptr) {
|
||||
*reason = "Send to rank " + std::to_string(i) + " failed. " + kNetworkError;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -301,7 +310,7 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool DistributedCountService::TriggerLastCountEvent(const std::string &name) {
|
||||
bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std::string *reason) {
|
||||
MS_LOG(INFO) << "Activating last count event for " << name;
|
||||
CounterEvent last_count_event;
|
||||
last_count_event.set_type(CounterEventType::LAST_CNT);
|
||||
|
@ -311,6 +320,9 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name) {
|
|||
for (uint32_t i = 1; i < server_num_; i++) {
|
||||
if (!communicator_->SendPbRequest(last_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
|
||||
MS_LOG(ERROR) << "Activating last count event to server " << i << " failed.";
|
||||
if (reason != nullptr) {
|
||||
*reason = "Send to rank " + std::to_string(i) + " failed. " + kNetworkError;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -63,8 +63,9 @@ class DistributedCountService {
|
|||
// first/last count event callbacks.
|
||||
void RegisterCounter(const std::string &name, size_t global_threshold_count, const CounterHandlers &counter_handlers);
|
||||
|
||||
// Report a count to the counting server. Parameter 'id' is in case of repeated counting.
|
||||
bool Count(const std::string &name, const std::string &id);
|
||||
// Report a count to the counting server. Parameter 'id' is in case of repeated counting. Parameter 'reason' is the
|
||||
// reason why counting failed.
|
||||
bool Count(const std::string &name, const std::string &id, std::string *reason = nullptr);
|
||||
|
||||
// Query whether the count reaches the threshold count for the name. If the count is the same as the threshold count,
|
||||
// this method returns true.
|
||||
|
@ -98,9 +99,9 @@ class DistributedCountService {
|
|||
void HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
|
||||
// Call the callbacks when the first/last count event is triggered.
|
||||
bool TriggerCounterEvent(const std::string &name);
|
||||
bool TriggerFirstCountEvent(const std::string &name);
|
||||
bool TriggerLastCountEvent(const std::string &name);
|
||||
bool TriggerCounterEvent(const std::string &name, std::string *reason = nullptr);
|
||||
bool TriggerFirstCountEvent(const std::string &name, std::string *reason = nullptr);
|
||||
bool TriggerLastCountEvent(const std::string &name, std::string *reason = nullptr);
|
||||
|
||||
// Members for the communication between counting server and other servers.
|
||||
std::shared_ptr<ps::core::ServerNode> server_node_;
|
||||
|
|
|
@ -82,7 +82,7 @@ void DistributedMetadataStore::ResetMetadata(const std::string &name) {
|
|||
return;
|
||||
}
|
||||
|
||||
bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta) {
|
||||
bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta, std::string *reason) {
|
||||
if (router_ == nullptr) {
|
||||
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
|
||||
return false;
|
||||
|
@ -103,6 +103,9 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM
|
|||
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, ps::core::TcpUserCommand::kUpdateMetadata,
|
||||
&update_meta_rsp_msg)) {
|
||||
MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed.";
|
||||
if (reason != nullptr) {
|
||||
*reason = "Send to rank " + std::to_string(stored_rank) + " failed. " + kNetworkError;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -55,8 +55,8 @@ class DistributedMetadataStore {
|
|||
// Reset the metadata value for the name.
|
||||
void ResetMetadata(const std::string &name);
|
||||
|
||||
// Update the metadata for the name.
|
||||
bool UpdateMetadata(const std::string &name, const PBMetadata &meta);
|
||||
// Update the metadata for the name. Parameter 'reason' is the reason why updating meta data failed.
|
||||
bool UpdateMetadata(const std::string &name, const PBMetadata &meta, std::string *reason = nullptr);
|
||||
|
||||
// Get the metadata for the name.
|
||||
PBMetadata GetMetadata(const std::string &name);
|
||||
|
|
|
@ -390,7 +390,7 @@ bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
|
|||
end_last_iter_req.set_last_iter_num(last_iter_num);
|
||||
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
|
||||
if (!communicator_->SendPbRequest(end_last_iter_req, i, ps::core::TcpUserCommand::kEndLastIter)) {
|
||||
MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed.";
|
||||
MS_LOG(WARNING) << "Sending ending last iteration request to server " << i << " failed.";
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
@ -438,10 +438,10 @@ void Iteration::EndLastIter() {
|
|||
ModelStore::GetInstance().Reset();
|
||||
}
|
||||
|
||||
Server::GetInstance().CancelSafeMode();
|
||||
SetIterationCompleted();
|
||||
pinned_iter_num_ = 0;
|
||||
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
|
||||
Server::GetInstance().CancelSafeMode();
|
||||
SetIterationCompleted();
|
||||
MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";
|
||||
}
|
||||
} // namespace server
|
||||
|
|
|
@ -48,9 +48,9 @@ bool PushWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
return false;
|
||||
}
|
||||
|
||||
bool ret = PushWeight(fbb, push_weight_req);
|
||||
ResultCode result_code = PushWeight(fbb, push_weight_req);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return ret;
|
||||
return ConvertResultCode(result_code);
|
||||
}
|
||||
|
||||
bool PushWeightKernel::Reset() {
|
||||
|
@ -67,9 +67,10 @@ void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageH
|
|||
return;
|
||||
}
|
||||
|
||||
bool PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req) {
|
||||
ResultCode PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb,
|
||||
const schema::RequestPushWeight *push_weight_req) {
|
||||
if (fbb == nullptr || push_weight_req == nullptr) {
|
||||
return false;
|
||||
return ResultCode::kSuccessAndReturn;
|
||||
}
|
||||
size_t iteration = static_cast<size_t>(push_weight_req->iteration());
|
||||
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
|
@ -78,7 +79,7 @@ bool PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::
|
|||
", current iteration:" + std::to_string(current_iter);
|
||||
BuildPushWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter);
|
||||
MS_LOG(WARNING) << reason;
|
||||
return true;
|
||||
return ResultCode::kSuccessAndReturn;
|
||||
}
|
||||
|
||||
std::map<std::string, Address> upload_feature_map = ParseFeatureMap(push_weight_req);
|
||||
|
@ -86,25 +87,26 @@ bool PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::
|
|||
std::string reason = "PushWeight feature_map is empty.";
|
||||
BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter);
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
return ResultCode::kSuccessAndReturn;
|
||||
}
|
||||
|
||||
if (!executor_->HandlePushWeight(upload_feature_map)) {
|
||||
std::string reason = "Pushing weight failed.";
|
||||
BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter);
|
||||
BuildPushWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter);
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
return ResultCode::kSuccessAndReturn;
|
||||
}
|
||||
MS_LOG(INFO) << "Pushing weight for iteration " << current_iter << " succeeds.";
|
||||
|
||||
if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_))) {
|
||||
std::string count_reason = "";
|
||||
if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_), &count_reason)) {
|
||||
std::string reason = "Count for push weight request failed.";
|
||||
BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter);
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
return count_reason == kNetworkError ? ResultCode::kSuccessAndReturn : ResultCode::kFail;
|
||||
}
|
||||
BuildPushWeightRsp(fbb, schema::ResponseCode_SUCCEED, "PushWeight succeed.", current_iter);
|
||||
return true;
|
||||
return ResultCode::kSuccess;
|
||||
}
|
||||
|
||||
std::map<std::string, Address> PushWeightKernel::ParseFeatureMap(const schema::RequestPushWeight *push_weight_req) {
|
||||
|
|
|
@ -42,7 +42,7 @@ class PushWeightKernel : public RoundKernel {
|
|||
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
|
||||
private:
|
||||
bool PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req);
|
||||
ResultCode PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req);
|
||||
std::map<std::string, Address> ParseFeatureMap(const schema::RequestPushWeight *push_weight_req);
|
||||
void BuildPushWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode, const std::string &reason,
|
||||
size_t iteration);
|
||||
|
|
|
@ -72,33 +72,37 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
return true;
|
||||
}
|
||||
|
||||
if (ReachThresholdForStartFLJob(fbb)) {
|
||||
ResultCode result_code = ReachThresholdForStartFLJob(fbb);
|
||||
if (result_code != ResultCode::kSuccess) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
return ConvertResultCode(result_code);
|
||||
}
|
||||
|
||||
const schema::RequestFLJob *start_fl_job_req = flatbuffers::GetRoot<schema::RequestFLJob>(req_data);
|
||||
DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req);
|
||||
if (!ReadyForStartFLJob(fbb, device_meta)) {
|
||||
result_code = ReadyForStartFLJob(fbb, device_meta);
|
||||
if (result_code != ResultCode::kSuccess) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return false;
|
||||
return ConvertResultCode(result_code);
|
||||
}
|
||||
PBMetadata metadata;
|
||||
*metadata.mutable_device_meta() = device_meta;
|
||||
if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata)) {
|
||||
std::string reason = "Updating device metadata failed.";
|
||||
std::string update_reason = "";
|
||||
if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata, &update_reason)) {
|
||||
std::string reason = "Updating device metadata failed. " + update_reason;
|
||||
BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)),
|
||||
{});
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
return update_reason == kNetworkError ? true : false;
|
||||
}
|
||||
|
||||
StartFLJob(fbb, device_meta);
|
||||
// If calling ReportCount before ReadyForStartFLJob, the result will be inconsistent if the device is not selected.
|
||||
if (!CountForStartFLJob(fbb, start_fl_job_req)) {
|
||||
result_code = CountForStartFLJob(fbb, start_fl_job_req);
|
||||
if (result_code != ResultCode::kSuccess) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return false;
|
||||
return ConvertResultCode(result_code);
|
||||
}
|
||||
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
|
@ -120,16 +124,16 @@ void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::Message
|
|||
Iteration::GetInstance().SetIterationRunning();
|
||||
}
|
||||
|
||||
bool StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb) {
|
||||
ResultCode StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb) {
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
std::string reason = "Current amount for startFLJob has reached the threshold. Please startFLJob later.";
|
||||
BuildStartFLJobRsp(
|
||||
fbb, schema::ResponseCode_OutOfTime, reason, false,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(WARNING) << reason;
|
||||
return true;
|
||||
return ResultCode::kSuccessAndReturn;
|
||||
}
|
||||
return false;
|
||||
return ResultCode::kSuccess;
|
||||
}
|
||||
|
||||
DeviceMeta StartFLJobKernel::CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req) {
|
||||
|
@ -146,34 +150,35 @@ DeviceMeta StartFLJobKernel::CreateDeviceMetadata(const schema::RequestFLJob *st
|
|||
return device_meta;
|
||||
}
|
||||
|
||||
bool StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
|
||||
bool ret = true;
|
||||
ResultCode StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
|
||||
ResultCode ret = ResultCode::kSuccess;
|
||||
std::string reason = "";
|
||||
if (device_meta.data_size() < 1) {
|
||||
reason = "FL job data size is not enough.";
|
||||
ret = false;
|
||||
ret = ResultCode::kSuccessAndReturn;
|
||||
}
|
||||
if (!ret) {
|
||||
if (ret != ResultCode::kSuccess) {
|
||||
BuildStartFLJobRsp(
|
||||
fbb, schema::ResponseCode_RequestError, reason, false,
|
||||
fbb, schema::ResponseCode_OutOfTime, reason, false,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(ERROR) << reason;
|
||||
MS_LOG(WARNING) << reason;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
|
||||
const schema::RequestFLJob *start_fl_job_req) {
|
||||
RETURN_IF_NULL(start_fl_job_req, false);
|
||||
if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str())) {
|
||||
ResultCode StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
|
||||
const schema::RequestFLJob *start_fl_job_req) {
|
||||
RETURN_IF_NULL(start_fl_job_req, ResultCode::kSuccessAndReturn);
|
||||
std::string count_reason = "";
|
||||
if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str(), &count_reason)) {
|
||||
std::string reason = "Counting start fl job request failed. Please retry later.";
|
||||
BuildStartFLJobRsp(
|
||||
fbb, schema::ResponseCode_OutOfTime, reason, false,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
return count_reason == kNetworkError ? ResultCode::kSuccessAndReturn : ResultCode::kFail;
|
||||
}
|
||||
return true;
|
||||
return ResultCode::kSuccess;
|
||||
}
|
||||
|
||||
void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
|
||||
|
|
|
@ -44,17 +44,17 @@ class StartFLJobKernel : public RoundKernel {
|
|||
|
||||
private:
|
||||
// Returns whether the startFLJob count of this iteration has reached the threshold.
|
||||
bool ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb);
|
||||
ResultCode ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb);
|
||||
|
||||
// The metadata of device will be stored and queried in updateModel round.
|
||||
DeviceMeta CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req);
|
||||
|
||||
// Returns whether the request is valid for startFLJob.For now, the condition is simple. We will add more conditions
|
||||
// to device in later versions.
|
||||
bool ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta);
|
||||
ResultCode ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta);
|
||||
|
||||
// Distributed count service counts for startFLJob.
|
||||
bool CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req);
|
||||
ResultCode CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req);
|
||||
|
||||
void StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta);
|
||||
|
||||
|
|
|
@ -56,21 +56,24 @@ bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
|
|||
}
|
||||
|
||||
MS_LOG(INFO) << "Launching UpdateModelKernel kernel.";
|
||||
if (ReachThresholdForUpdateModel(fbb)) {
|
||||
ResultCode result_code = ReachThresholdForUpdateModel(fbb);
|
||||
if (result_code != ResultCode::kSuccess) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
return ConvertResultCode(result_code);
|
||||
}
|
||||
|
||||
const schema::RequestUpdateModel *update_model_req = flatbuffers::GetRoot<schema::RequestUpdateModel>(req_data);
|
||||
if (!UpdateModel(update_model_req, fbb)) {
|
||||
result_code = UpdateModel(update_model_req, fbb);
|
||||
if (result_code != ResultCode::kSuccess) {
|
||||
MS_LOG(ERROR) << "Updating model failed.";
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return false;
|
||||
return ConvertResultCode(result_code);
|
||||
}
|
||||
|
||||
if (!CountForUpdateModel(fbb, update_model_req)) {
|
||||
result_code = CountForUpdateModel(fbb, update_model_req);
|
||||
if (result_code != ResultCode::kSuccess) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return false;
|
||||
return ConvertResultCode(result_code);
|
||||
}
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
|
@ -102,21 +105,21 @@ void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<ps::core::Message
|
|||
}
|
||||
}
|
||||
|
||||
bool UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb) {
|
||||
ResultCode UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb) {
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
std::string reason = "Current amount for updateModel is enough. Please retry later.";
|
||||
BuildUpdateModelRsp(
|
||||
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(WARNING) << reason;
|
||||
return true;
|
||||
return ResultCode::kSuccessAndReturn;
|
||||
}
|
||||
return false;
|
||||
return ResultCode::kSuccess;
|
||||
}
|
||||
|
||||
bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
|
||||
const std::shared_ptr<FBBuilder> &fbb) {
|
||||
RETURN_IF_NULL(update_model_req, false);
|
||||
ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
|
||||
const std::shared_ptr<FBBuilder> &fbb) {
|
||||
RETURN_IF_NULL(update_model_req, ResultCode::kSuccessAndReturn);
|
||||
size_t iteration = static_cast<size_t>(update_model_req->iteration());
|
||||
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
|
||||
std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) +
|
||||
|
@ -126,7 +129,7 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
|
|||
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(WARNING) << reason;
|
||||
return true;
|
||||
return ResultCode::kSuccessAndReturn;
|
||||
}
|
||||
|
||||
PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas);
|
||||
|
@ -139,7 +142,7 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
|
|||
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
return ResultCode::kSuccessAndReturn;
|
||||
}
|
||||
|
||||
size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size();
|
||||
|
@ -150,7 +153,7 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
|
|||
fbb, schema::ResponseCode_RequestError, reason,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
return ResultCode::kSuccessAndReturn;
|
||||
}
|
||||
|
||||
for (auto weight : feature_map) {
|
||||
|
@ -163,18 +166,19 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
|
|||
fl_id.set_fl_id(update_model_fl_id);
|
||||
PBMetadata comm_value;
|
||||
*comm_value.mutable_fl_id() = fl_id;
|
||||
if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value)) {
|
||||
std::string reason = "Updating metadata of UpdateModelClientList failed.";
|
||||
std::string update_reason = "";
|
||||
if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value, &update_reason)) {
|
||||
std::string reason = "Updating metadata of UpdateModelClientList failed. " + update_reason;
|
||||
BuildUpdateModelRsp(
|
||||
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
return update_reason == kNetworkError ? ResultCode::kSuccessAndReturn : ResultCode::kFail;
|
||||
}
|
||||
|
||||
BuildUpdateModelRsp(fbb, schema::ResponseCode_SUCCEED, "success not ready",
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
return true;
|
||||
return ResultCode::kSuccess;
|
||||
}
|
||||
|
||||
std::map<std::string, UploadData> UpdateModelKernel::ParseFeatureMap(
|
||||
|
@ -195,18 +199,19 @@ std::map<std::string, UploadData> UpdateModelKernel::ParseFeatureMap(
|
|||
return feature_map;
|
||||
}
|
||||
|
||||
bool UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
|
||||
const schema::RequestUpdateModel *update_model_req) {
|
||||
RETURN_IF_NULL(update_model_req, false);
|
||||
if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str())) {
|
||||
std::string reason = "Counting for update model request failed. Please retry later.";
|
||||
ResultCode UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
|
||||
const schema::RequestUpdateModel *update_model_req) {
|
||||
RETURN_IF_NULL(update_model_req, ResultCode::kSuccessAndReturn);
|
||||
std::string count_reason = "";
|
||||
if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str(), &count_reason)) {
|
||||
std::string reason = "Counting for update model request failed. Please retry later. " + count_reason;
|
||||
BuildUpdateModelRsp(
|
||||
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
return count_reason == kNetworkError ? ResultCode::kSuccessAndReturn : ResultCode::kFail;
|
||||
}
|
||||
return true;
|
||||
return ResultCode::kSuccess;
|
||||
}
|
||||
|
||||
void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
|
|
|
@ -47,10 +47,11 @@ class UpdateModelKernel : public RoundKernel {
|
|||
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
|
||||
private:
|
||||
bool ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb);
|
||||
bool UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr<FBBuilder> &fbb);
|
||||
ResultCode ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb);
|
||||
ResultCode UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr<FBBuilder> &fbb);
|
||||
std::map<std::string, UploadData> ParseFeatureMap(const schema::RequestUpdateModel *update_model_req);
|
||||
bool CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestUpdateModel *update_model_req);
|
||||
ResultCode CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
|
||||
const schema::RequestUpdateModel *update_model_req);
|
||||
void BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason, const std::string &next_req_time);
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@ void ModelStore::Initialize(uint32_t max_count) {
|
|||
}
|
||||
|
||||
bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &new_model) {
|
||||
std::unique_lock<std::mutex> lock(model_mtx_);
|
||||
if (iteration_to_model_.count(iteration) != 0) {
|
||||
MS_LOG(WARNING) << "Model for iteration " << iteration << " is already stored";
|
||||
return false;
|
||||
|
@ -88,6 +89,7 @@ bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::strin
|
|||
}
|
||||
|
||||
std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration) {
|
||||
std::unique_lock<std::mutex> lock(model_mtx_);
|
||||
std::map<std::string, AddressPtr> model = {};
|
||||
if (iteration_to_model_.count(iteration) == 0) {
|
||||
MS_LOG(ERROR) << "Model for iteration " << iteration << " is not stored.";
|
||||
|
@ -98,13 +100,15 @@ std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration
|
|||
}
|
||||
|
||||
void ModelStore::Reset() {
|
||||
std::unique_lock<std::mutex> lock(model_mtx_);
|
||||
initial_model_ = iteration_to_model_.rbegin()->second;
|
||||
iteration_to_model_.clear();
|
||||
iteration_to_model_[kInitIterationNum] = initial_model_;
|
||||
iteration_to_model_[kResetInitIterNum] = initial_model_;
|
||||
}
|
||||
|
||||
const std::map<size_t, std::shared_ptr<MemoryRegister>> &ModelStore::iteration_to_model() const {
|
||||
const std::map<size_t, std::shared_ptr<MemoryRegister>> &ModelStore::iteration_to_model() {
|
||||
std::unique_lock<std::mutex> lock(model_mtx_);
|
||||
return iteration_to_model_;
|
||||
}
|
||||
|
||||
|
@ -142,6 +146,7 @@ std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() {
|
|||
}
|
||||
|
||||
size_t ModelStore::ComputeModelSize() {
|
||||
std::unique_lock<std::mutex> lock(model_mtx_);
|
||||
if (iteration_to_model_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Calculating model size failed: model for iteration 0 is not stored yet. ";
|
||||
return 0;
|
||||
|
|
|
@ -56,7 +56,7 @@ class ModelStore {
|
|||
void Reset();
|
||||
|
||||
// Returns all models stored in ModelStore.
|
||||
const std::map<size_t, std::shared_ptr<MemoryRegister>> &iteration_to_model() const;
|
||||
const std::map<size_t, std::shared_ptr<MemoryRegister>> &iteration_to_model();
|
||||
|
||||
// Returns the model size, which could be calculated at the initializing phase.
|
||||
size_t model_size() const;
|
||||
|
@ -80,7 +80,8 @@ class ModelStore {
|
|||
// Initial model which is the model of iteration 0.
|
||||
std::shared_ptr<MemoryRegister> initial_model_;
|
||||
|
||||
// The number of all models stpred is max_model_count_.
|
||||
// The number of all models stored is max_model_count_.
|
||||
std::mutex model_mtx_;
|
||||
std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_;
|
||||
};
|
||||
} // namespace server
|
||||
|
|
|
@ -48,7 +48,7 @@ parser.add_argument("--sts_properties_path", type=str, default="")
|
|||
parser.add_argument("--dp_eps", type=float, default=50.0)
|
||||
parser.add_argument("--dp_delta", type=float, default=0.01) # usually equals 1/start_fl_job_threshold
|
||||
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
|
||||
parser.add_argument("--encrypt_type", type=str, default="NotEncrypt")
|
||||
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
||||
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
|
|
@ -47,7 +47,7 @@ parser.add_argument("--sts_properties_path", type=str, default="")
|
|||
parser.add_argument("--dp_eps", type=float, default=50.0)
|
||||
parser.add_argument("--dp_delta", type=float, default=0.01) # usually equals 1/start_fl_job_threshold
|
||||
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
|
||||
parser.add_argument("--encrypt_type", type=str, default="NotEncrypt")
|
||||
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
||||
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue