!29057 fix ISSUE I4QCJM

Merge pull request !29057 from tan-wei-cheng-3260/develop-fix
This commit is contained in:
i-robot 2022-01-18 01:12:26 +00:00 committed by Gitee
commit 6eae53eb34
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
20 changed files with 276 additions and 60 deletions

View File

@ -86,6 +86,13 @@ enum class InstanceState {
kFinish
};
enum class IterationResult {
// The iteration is timeout because of startfljob or updatemodel timeout.
kTimeout,
// The iteration is successful aggregation.
kSuccess
};
using mindspore::kernel::Address;
using mindspore::kernel::AddressPtr;
using mindspore::kernel::CPUKernel;
@ -133,6 +140,15 @@ constexpr auto kFtrlLinear = "linear";
constexpr auto kDataSize = "data_size";
constexpr auto kNewDataSize = "new_data_size";
constexpr auto kStat = "stat";
constexpr auto kStartFLJobTotalClientNum = "startFLJobTotalClientNum";
constexpr auto kStartFLJobAcceptClientNum = "startFLJobAcceptClientNum";
constexpr auto kStartFLJobRejectClientNum = "startFLJobRejectClientNum";
constexpr auto kUpdateModelTotalClientNum = "updateModelTotalClientNum";
constexpr auto kUpdateModelAcceptClientNum = "updateModelAcceptClientNum";
constexpr auto kUpdateModelRejectClientNum = "updateModelRejectClientNum";
constexpr auto kGetModelTotalClientNum = "getModelTotalClientNum";
constexpr auto kGetModelAcceptClientNum = "getModelAcceptClientNum";
constexpr auto kGetModelRejectClientNum = "getModelRejectClientNum";
// OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is
// launched.

View File

@ -162,6 +162,8 @@ void Iteration::SetIterationRunning() {
std::unique_lock<std::mutex> lock(iteration_state_mtx_);
iteration_state_ = IterationState::kRunning;
start_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count());
MS_LOG(INFO) << "Iteratoin " << iteration_num_ << " start global timer.";
global_iter_timer_->Start(std::chrono::milliseconds(global_iteration_time_window_));
}
void Iteration::SetIterationEnd() {
@ -577,6 +579,7 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
// Store the model which is successfully aggregated for this iteration.
const auto &model = Executor::GetInstance().GetModel();
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
iteration_result_ = IterationResult::kSuccess;
MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
} else {
// Store last iteration's model because this iteration is considered as invalid.
@ -584,6 +587,7 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
size_t latest_iter_num = iter_to_model.rbegin()->first;
const auto &model = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num);
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
iteration_result_ = IterationResult::kTimeout;
MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason;
}
@ -591,6 +595,13 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
MS_ERROR_IF_NULL_WO_RET_VAL(round);
round->Reset();
}
MS_LOG(INFO) << "Iteratoin " << iteration_num_ << " stop global timer.";
global_iter_timer_->Stop();
for (const auto &round : rounds_) {
MS_ERROR_IF_NULL_WO_RET_VAL(round);
round->KernelSummarize();
}
}
bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
@ -598,11 +609,14 @@ bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
MS_LOG(INFO) << "Notify all follower servers to end last iteration.";
EndLastIterRequest end_last_iter_req;
end_last_iter_req.set_last_iter_num(last_iter_num);
std::shared_ptr<std::vector<unsigned char>> client_info_rsp_msg = nullptr;
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
if (!communicator_->SendPbRequest(end_last_iter_req, i, ps::core::TcpUserCommand::kEndLastIter)) {
if (!communicator_->SendPbRequest(end_last_iter_req, i, ps::core::TcpUserCommand::kEndLastIter,
&client_info_rsp_msg)) {
MS_LOG(WARNING) << "Sending ending last iteration request to server " << i << " failed.";
continue;
}
UpdateRoundClientNumMap(client_info_rsp_msg);
}
EndLastIter();
@ -629,10 +643,29 @@ void Iteration::HandleEndLastIterRequest(const std::shared_ptr<ps::core::Message
return;
}
EndLastIter();
EndLastIterResponse end_last_iter_rsp;
end_last_iter_rsp.set_result("success");
for (const auto &round : rounds_) {
if (round == nullptr) {
continue;
}
if (round->name() == "startFLJob") {
end_last_iter_rsp.set_startfljob_total_client_num(round->kernel_total_client_num());
end_last_iter_rsp.set_startfljob_accept_client_num(round->kernel_accept_client_num());
end_last_iter_rsp.set_startfljob_reject_client_num(round->kernel_reject_client_num());
} else if (round->name() == "updateModel") {
end_last_iter_rsp.set_updatemodel_total_client_num(round->kernel_total_client_num());
end_last_iter_rsp.set_updatemodel_accept_client_num(round->kernel_accept_client_num());
end_last_iter_rsp.set_updatemodel_reject_client_num(round->kernel_reject_client_num());
} else if (round->name() == "getModel") {
end_last_iter_rsp.set_getmodel_total_client_num(round->kernel_total_client_num());
end_last_iter_rsp.set_getmodel_accept_client_num(round->kernel_accept_client_num());
end_last_iter_rsp.set_getmodel_reject_client_num(round->kernel_reject_client_num());
}
}
EndLastIter();
if (!communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(),
end_last_iter_rsp.SerializeAsString().size(), message)) {
MS_LOG(ERROR) << "Sending response failed.";
@ -668,7 +701,10 @@ void Iteration::EndLastIter() {
MS_LOG(WARNING) << "Can't save current iteration number into persistent storage.";
}
}
for (const auto &round : rounds_) {
round->InitkernelClientVisitedNum();
}
round_client_num_map_.clear();
Server::GetInstance().CancelSafeMode();
iteration_state_cv_.notify_all();
MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";
@ -692,12 +728,8 @@ bool Iteration::SummarizeIteration() {
metrics_->set_instance_state(instance_state_.load());
metrics_->set_loss(loss_);
metrics_->set_accuracy(accuracy_);
// The joined client number is equal to the threshold of updateModel.
size_t update_model_threshold = static_cast<size_t>(
std::ceil(ps::PSContext::instance()->start_fl_job_threshold() * ps::PSContext::instance()->update_model_ratio()));
metrics_->set_joined_client_num(update_model_threshold);
// The rejected client number is equal to threshold of startFLJob minus threshold of updateModel.
metrics_->set_rejected_client_num(ps::PSContext::instance()->start_fl_job_threshold() - update_model_threshold);
metrics_->set_round_client_num_map(round_client_num_map_);
metrics_->set_iteration_result(iteration_result_.load());
if (complete_timestamp_ < start_timestamp_) {
MS_LOG(ERROR) << "The complete_timestamp_: " << complete_timestamp_ << ", start_timestamp_: " << start_timestamp_
@ -749,7 +781,21 @@ bool Iteration::UpdateHyperParams(const nlohmann::json &json) {
ps::PSContext::instance()->set_client_learning_rate(item.value().get<float>());
continue;
}
if (key == "global_iteration_time_window") {
ps::PSContext::instance()->set_global_iteration_time_window(item.value().get<uint64_t>());
continue;
}
}
MS_LOG(INFO) << "start_fl_job_threshold: " << ps::PSContext::instance()->start_fl_job_threshold();
MS_LOG(INFO) << "start_fl_job_time_window: " << ps::PSContext::instance()->start_fl_job_time_window();
MS_LOG(INFO) << "update_model_ratio: " << ps::PSContext::instance()->update_model_ratio();
MS_LOG(INFO) << "update_model_time_window: " << ps::PSContext::instance()->update_model_time_window();
MS_LOG(INFO) << "fl_iteration_num: " << ps::PSContext::instance()->fl_iteration_num();
MS_LOG(INFO) << "client_epoch_num: " << ps::PSContext::instance()->client_epoch_num();
MS_LOG(INFO) << "client_batch_size: " << ps::PSContext::instance()->client_batch_size();
MS_LOG(INFO) << "client_learning_rate: " << ps::PSContext::instance()->client_learning_rate();
MS_LOG(INFO) << "global_iteration_time_window: " << ps::PSContext::instance()->global_iteration_time_window();
return true;
}
@ -784,6 +830,36 @@ bool Iteration::ReInitRounds() {
}
return true;
}
void Iteration::InitGlobalIterTimer(const TimeOutCb &timeout_cb) {
global_iteration_time_window_ = ps::PSContext::instance()->global_iteration_time_window();
global_iter_timer_ = std::make_shared<IterationTimer>();
// Set the timeout callback for the timer.
global_iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool, const std::string &) -> void {
std::string reason = "Global Iteration " + std::to_string(iteration_num_) +
" timeout! This iteration is invalid. Proceed to next iteration.";
timeout_cb(false, reason);
});
}
void Iteration::UpdateRoundClientNumMap(const std::shared_ptr<std::vector<unsigned char>> &client_info_rsp_msg) {
MS_ERROR_IF_NULL_WO_RET_VAL(client_info_rsp_msg);
EndLastIterResponse end_last_iter_rsp;
(void)end_last_iter_rsp.ParseFromArray(client_info_rsp_msg->data(), SizeToInt(client_info_rsp_msg->size()));
round_client_num_map_[kStartFLJobTotalClientNum] += end_last_iter_rsp.startfljob_total_client_num();
round_client_num_map_[kStartFLJobAcceptClientNum] += end_last_iter_rsp.startfljob_accept_client_num();
round_client_num_map_[kStartFLJobRejectClientNum] += end_last_iter_rsp.startfljob_reject_client_num();
round_client_num_map_[kUpdateModelTotalClientNum] += end_last_iter_rsp.updatemodel_total_client_num();
round_client_num_map_[kUpdateModelAcceptClientNum] += end_last_iter_rsp.updatemodel_accept_client_num();
round_client_num_map_[kUpdateModelRejectClientNum] += end_last_iter_rsp.updatemodel_reject_client_num();
round_client_num_map_[kGetModelTotalClientNum] += end_last_iter_rsp.getmodel_total_client_num();
round_client_num_map_[kGetModelAcceptClientNum] += end_last_iter_rsp.getmodel_accept_client_num();
round_client_num_map_[kGetModelRejectClientNum] += end_last_iter_rsp.getmodel_reject_client_num();
}
} // namespace server
} // namespace fl
} // namespace mindspore

View File

@ -20,6 +20,7 @@
#include <memory>
#include <vector>
#include <string>
#include <map>
#include "ps/core/communicator/communicator_base.h"
#include "fl/server/common.h"
#include "fl/server/round.h"
@ -125,9 +126,15 @@ class Iteration {
// Synchronize server iteration after another server's recovery is completed.
bool SyncAfterRecovery(uint64_t iteration_num);
// Initialize global iteration timer.
void InitGlobalIterTimer(const TimeOutCb &timeout_cb);
// The round kernels whose Launch method has not returned yet.
std::atomic_uint32_t running_round_num_;
// Update count with client visited num in round
void UpdateRoundClientNumMap(const std::string &name, const size_t num);
private:
Iteration()
: running_round_num_(0),
@ -147,9 +154,18 @@ class Iteration {
is_instance_being_updated_(false),
loss_(0.0),
accuracy_(0.0),
joined_client_num_(0),
rejected_client_num_(0),
time_cost_(0) {
time_cost_(0),
global_iteration_time_window_(0),
round_client_num_map_({{kStartFLJobTotalClientNum, 0},
{kUpdateModelTotalClientNum, 0},
{kGetModelTotalClientNum, 0},
{kStartFLJobAcceptClientNum, 0},
{kUpdateModelAcceptClientNum, 0},
{kGetModelAcceptClientNum, 0},
{kStartFLJobRejectClientNum, 0},
{kUpdateModelRejectClientNum, 0},
{kGetModelRejectClientNum, 0}}),
iteration_result_(IterationResult::kSuccess) {
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
}
~Iteration();
@ -202,6 +218,8 @@ class Iteration {
// Reinitialize rounds and round kernels.
bool ReInitRounds();
void UpdateRoundClientNumMap(const std::shared_ptr<std::vector<unsigned char>> &client_info_rsp_msg);
std::shared_ptr<ps::core::ServerNode> server_node_;
std::shared_ptr<ps::core::TcpCommunicator> communicator_;
@ -246,6 +264,7 @@ class Iteration {
// Every instance is not reentrant.
// This flag represents whether the instance is being updated.
std::mutex instance_mtx_;
bool is_instance_being_updated_;
// The training loss after this federated learning iteration, passed by worker.
@ -254,14 +273,20 @@ class Iteration {
// The evaluation result after this federated learning iteration, passed by worker.
float accuracy_;
// The number of clients which join the federated aggregation.
size_t joined_client_num_;
// The number of clients which are not involved in federated aggregation.
size_t rejected_client_num_;
// The time cost in millisecond for this completed iteration.
uint64_t time_cost_;
// global iteration time window
uint64_t global_iteration_time_window_;
// for example: "startFLJobTotalClientNum" -> startFLJob total client num
std::map<std::string, size_t> round_client_num_map_;
// Iteration global timer.
std::shared_ptr<IterationTimer> global_iter_timer_;
// The result for current iteration result.
std::atomic<IterationResult> iteration_result_;
};
} // namespace server
} // namespace fl

View File

@ -79,11 +79,12 @@ bool IterationMetrics::Summarize() {
js_[kInstanceStatus] = kInstanceStateName.at(instance_state_);
js_[kFLIterationNum] = fl_iteration_num_;
js_[kCurIteration] = cur_iteration_num_;
js_[kJoinedClientNum] = joined_client_num_;
js_[kRejectedClientNum] = rejected_client_num_;
js_[kMetricsAuc] = accuracy_;
js_[kMetricsLoss] = loss_;
js_[kIterExecutionTime] = iteration_time_cost_;
js_[kClientVisitedInfo] = round_client_num_map_;
js_[kIterationResult] = kIterationResultName.at(iteration_result_);
metrics_file_ << js_ << "\n";
(void)metrics_file_.flush();
metrics_file_.close();
@ -111,15 +112,15 @@ void IterationMetrics::set_loss(float loss) { loss_ = loss; }
void IterationMetrics::set_accuracy(float acc) { accuracy_ = acc; }
void IterationMetrics::set_joined_client_num(size_t joined_client_num) { joined_client_num_ = joined_client_num; }
void IterationMetrics::set_rejected_client_num(size_t rejected_client_num) {
rejected_client_num_ = rejected_client_num;
}
void IterationMetrics::set_iteration_time_cost(uint64_t iteration_time_cost) {
iteration_time_cost_ = iteration_time_cost;
}
void IterationMetrics::set_round_client_num_map(const std::map<std::string, size_t> round_client_num_map) {
round_client_num_map_ = round_client_num_map;
}
void IterationMetrics::set_iteration_result(IterationResult iteration_result) { iteration_result_ = iteration_result; }
} // namespace server
} // namespace fl
} // namespace mindspore

View File

@ -34,16 +34,19 @@ constexpr auto kFLName = "flName";
constexpr auto kInstanceStatus = "instanceStatus";
constexpr auto kFLIterationNum = "flIterationNum";
constexpr auto kCurIteration = "currentIteration";
constexpr auto kJoinedClientNum = "joinedClientNum";
constexpr auto kRejectedClientNum = "rejectedClientNum";
constexpr auto kMetricsAuc = "metricsAuc";
constexpr auto kMetricsLoss = "metricsLoss";
constexpr auto kIterExecutionTime = "iterationExecutionTime";
constexpr auto kMetrics = "metrics";
constexpr auto kClientVisitedInfo = "clientVisitedInfo";
constexpr auto kIterationResult = "iterationResult";
const std::map<InstanceState, std::string> kInstanceStateName = {
{InstanceState::kRunning, "running"}, {InstanceState::kDisable, "disable"}, {InstanceState::kFinish, "finish"}};
const std::map<IterationResult, std::string> kIterationResultName = {{IterationResult::kSuccess, "success"},
{IterationResult::kTimeout, "timeout"}};
class IterationMetrics {
public:
explicit IterationMetrics(const std::string &config_file)
@ -55,9 +58,8 @@ class IterationMetrics {
instance_state_(InstanceState::kFinish),
loss_(0.0),
accuracy_(0.0),
joined_client_num_(0),
rejected_client_num_(0),
iteration_time_cost_(0) {}
iteration_time_cost_(0),
iteration_result_(IterationResult::kSuccess) {}
~IterationMetrics() = default;
bool Initialize();
@ -75,9 +77,9 @@ class IterationMetrics {
void set_instance_state(InstanceState state);
void set_loss(float loss);
void set_accuracy(float acc);
void set_joined_client_num(size_t joined_client_num);
void set_rejected_client_num(size_t rejected_client_num);
void set_iteration_time_cost(uint64_t iteration_time_cost);
void set_round_client_num_map(const std::map<std::string, size_t> round_client_num_map);
void set_iteration_result(IterationResult iteration_result);
private:
// This is the main config file set by ps context.
@ -112,14 +114,14 @@ class IterationMetrics {
// The evaluation result after this federated learning iteration, passed by worker.
float accuracy_;
// The number of clients which join the federated aggregation.
size_t joined_client_num_;
// The number of clients which are not involved in federated aggregation.
size_t rejected_client_num_;
// for example: "startFLJobTotalClientNum" -> startFLJob total client num
std::map<std::string, size_t> round_client_num_map_;
// The time cost in millisecond for this completed iteration.
uint64_t iteration_time_cost_;
// Current iteration running result.
IterationResult iteration_result_;
};
} // namespace server
} // namespace fl

View File

@ -30,7 +30,7 @@ void GetModelKernel::InitKernel(size_t) {
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
}
InitClientVisitedNum();
executor_ = &Executor::GetInstance();
MS_EXCEPTION_IF_NULL(executor_);
if (!executor_->initialized()) {
@ -119,7 +119,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons
} else {
feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter);
}
IncreaseAcceptClientNum();
MS_LOG(INFO) << "GetModel last iteratin is valid or not: " << Iteration::GetInstance().is_last_iteration_valid()
<< ", next request time is " << next_req_time << ", current iteration is " << current_iter;
BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, "Get model for iteration " + std::to_string(get_model_iter),

View File

@ -125,8 +125,32 @@ void RoundKernel::GenerateOutput(const std::vector<AddressPtr> &outputs, const v
std::unique_lock<std::mutex> lock(heap_data_mtx_);
(void)heap_data_.insert(std::make_pair(outputs[0], std::move(output_data)));
IncreaseTotalClientNum();
return;
}
void RoundKernel::IncreaseTotalClientNum() { total_client_num_ += 1; }
void RoundKernel::IncreaseAcceptClientNum() { accept_client_num_ += 1; }
void RoundKernel::Summarize() {
if (name_ == "startFLJob" || name_ == "updateModel" || name_ == "getModel") {
MS_LOG(INFO) << "Round kernel " << name_ << " total client num is: " << total_client_num_
<< ", accept client num is: " << accept_client_num_
<< ", reject client num is: " << (total_client_num_ - accept_client_num_);
}
}
size_t RoundKernel::total_client_num() const { return total_client_num_; }
size_t RoundKernel::accept_client_num() const { return accept_client_num_; }
size_t RoundKernel::reject_client_num() const { return total_client_num_ - accept_client_num_; }
void RoundKernel::InitClientVisitedNum() {
total_client_num_ = 0;
accept_client_num_ = 0;
}
} // namespace kernel
} // namespace server
} // namespace fl

View File

@ -89,6 +89,20 @@ class RoundKernel : virtual public CPUKernel {
void set_stop_timer_cb(const StopTimerCb &timer_stopper);
void set_finish_iteration_cb(const FinishIterCb &finish_iteration_cb);
void Summarize();
void IncreaseTotalClientNum();
void IncreaseAcceptClientNum();
size_t total_client_num() const;
size_t accept_client_num() const;
size_t reject_client_num() const;
void InitClientVisitedNum();
protected:
// Generating response data of this round. The data is allocated on the heap to ensure it's not released before sent
// back to worker.
@ -121,6 +135,9 @@ class RoundKernel : virtual public CPUKernel {
std::queue<AddressPtr> heap_data_to_release_;
std::mutex heap_data_mtx_;
std::unordered_map<AddressPtr, std::unique_ptr<unsigned char[]>> heap_data_;
std::atomic<size_t> total_client_num_;
std::atomic<size_t> accept_client_num_;
};
} // namespace kernel
} // namespace server

View File

@ -36,14 +36,13 @@ void StartFLJobKernel::InitKernel(size_t) {
}
iter_next_req_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count()) + iteration_time_window_;
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
InitClientVisitedNum();
executor_ = &Executor::GetInstance();
MS_EXCEPTION_IF_NULL(executor_);
if (!executor_->initialized()) {
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
return;
}
PBMetadata devices_metas;
DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxDeviceMetas, devices_metas);
@ -132,6 +131,7 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
}
IncreaseAcceptClientNum();
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}

View File

@ -29,7 +29,7 @@ void UpdateModelKernel::InitKernel(size_t threshold_count) {
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
}
InitClientVisitedNum();
executor_ = &Executor::GetInstance();
MS_EXCEPTION_IF_NULL(executor_);
if (!executor_->initialized()) {
@ -121,6 +121,7 @@ bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
}
IncreaseAcceptClientNum();
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}

View File

@ -220,14 +220,27 @@ bool Round::IsServerAvailable(std::string *reason) {
return false;
}
// If the server is still in the process of scaling, reject the request.
// If the server is still in safemode, reject the request.
if (Server::GetInstance().IsSafeMode()) {
MS_LOG(WARNING) << "The cluster is still in process of scaling, please retry " << name_ << " later.";
MS_LOG(WARNING) << "The cluster is still in safemode, please retry " << name_ << " later.";
*reason = ps::kClusterSafeMode;
return false;
}
return true;
}
void Round::KernelSummarize() {
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
(void)kernel_->Summarize();
}
size_t Round::kernel_total_client_num() const { return kernel_->total_client_num(); }
size_t Round::kernel_accept_client_num() const { return kernel_->accept_client_num(); }
size_t Round::kernel_reject_client_num() const { return kernel_->reject_client_num(); }
void Round::InitkernelClientVisitedNum() { kernel_->InitClientVisitedNum(); }
} // namespace server
} // namespace fl
} // namespace mindspore

View File

@ -56,11 +56,24 @@ class Round {
// Round needs to be reset after each iteration is finished or its timer expires.
void Reset();
void KernelSummarize();
const std::string &name() const;
size_t threshold_count() const;
bool check_timeout() const;
size_t time_window() const;
size_t kernel_total_client_num() const;
size_t kernel_accept_client_num() const;
size_t kernel_reject_client_num() const;
void InitkernelClientVisitedNum();
private:
// The callbacks which will be set to DistributedCounterService.
void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);

View File

@ -242,6 +242,8 @@ void Server::InitIteration() {
FinishIterCb finish_iter_cb =
std::bind(&Iteration::NotifyNext, iteration_, std::placeholders::_1, std::placeholders::_2);
iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb);
iteration_->InitGlobalIterTimer(time_out_cb);
return;
}

View File

@ -475,8 +475,10 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_encrypt_type", &PSContext::set_encrypt_type,
"Set encrypt type for federated learning secure aggregation.")
.def("set_http_url_prefix", &PSContext::set_http_url_prefix, "Set http url prefix for http communication.")
.def("http_url_prefix", &PSContext::http_url_prefix, "http url prefix for http communication.");
.def("http_url_prefix", &PSContext::http_url_prefix, "http url prefix for http communication.")
.def("set_global_iteration_time_window", &PSContext::set_global_iteration_time_window,
"Set global iteration time window.")
.def("global_iteration_time_window", &PSContext::global_iteration_time_window, "Get global iteration time window.");
(void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data.");
(void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data.");
(void)m.def("_is_cipher_file", &mindspore::pipeline::PyIsCipherFile, "Determine whether the file is encrypted");

View File

@ -237,6 +237,15 @@ message EndLastIterRequest {
message EndLastIterResponse {
string result = 1;
uint64 startFLJob_total_client_num = 2;
uint64 startFLJob_accept_client_num = 3;
uint64 startFLJob_reject_client_num = 4;
uint64 updateModel_total_client_num = 5;
uint64 updateModel_accept_client_num = 6;
uint64 updateModel_reject_client_num = 7;
uint64 getModel_total_client_num = 8;
uint64 getModel_accept_client_num = 9;
uint64 getModel_reject_client_num = 10;
}
message SyncAfterRecover {

View File

@ -476,5 +476,11 @@ void PSContext::set_server_password(const std::string &password) { server_passwo
std::string PSContext::http_url_prefix() const { return http_url_prefix_; }
void PSContext::set_http_url_prefix(const std::string &http_url_prefix) { http_url_prefix_ = http_url_prefix; }
void PSContext::set_global_iteration_time_window(const uint64_t &global_iteration_time_window) {
global_iteration_time_window_ = global_iteration_time_window;
}
uint64_t PSContext::global_iteration_time_window() const { return global_iteration_time_window_; }
} // namespace ps
} // namespace mindspore

View File

@ -211,6 +211,9 @@ class PSContext {
std::string http_url_prefix() const;
void set_http_url_prefix(const std::string &http_url_prefix);
void set_global_iteration_time_window(const uint64_t &global_iteration_time_window);
uint64_t global_iteration_time_window() const;
private:
PSContext()
: ps_enabled_(false),
@ -229,9 +232,9 @@ class PSContext {
fl_client_enable_(false),
fl_name_(""),
start_fl_job_threshold_(0),
start_fl_job_time_window_(3000),
start_fl_job_time_window_(300000),
update_model_ratio_(1.0),
update_model_time_window_(3000),
update_model_time_window_(300000),
share_secrets_ratio_(1.0),
cipher_time_window_(300000),
reconstruct_secrets_threshold_(2000),
@ -257,7 +260,8 @@ class PSContext {
enable_ssl_(false),
client_password_(""),
server_password_(""),
http_url_prefix_("") {}
http_url_prefix_(""),
global_iteration_time_window_(21600000) {}
bool ps_enabled_;
bool is_worker_;
bool is_pserver_;
@ -372,6 +376,9 @@ class PSContext {
std::string server_password_;
// http url prefix for http communication
std::string http_url_prefix_;
// The time window of startFLJob round in millisecond.
uint64_t global_iteration_time_window_;
};
} // namespace ps
} // namespace mindspore

View File

@ -1101,7 +1101,10 @@ def set_fl_context(**kwargs):
pki_verify is True. Default: "".
replay_attack_time_diff (int): The maximum tolerable error of certificate timestamp verification (ms).
Default: 600000.
http_url_prefix (string): The http url prefix for http server.
Default: "".
global_iteration_time_window (unsigned long): The global iteration time window for one iteration
with rounds(ms). Default: 21600000.
Raises:
ValueError: If input key is not the attribute in federated learning mode context.

View File

@ -71,7 +71,8 @@ _set_ps_context_func_map = {
"dp_delta": ps_context().set_dp_delta,
"dp_norm_clip": ps_context().set_dp_norm_clip,
"encrypt_type": ps_context().set_encrypt_type,
"http_url_prefix": ps_context().set_http_url_prefix
"http_url_prefix": ps_context().set_http_url_prefix,
"global_iteration_time_window": ps_context().set_global_iteration_time_window
}
_get_ps_context_func_map = {
@ -112,7 +113,8 @@ _get_ps_context_func_map = {
"server_password": ps_context().server_password,
"scheduler_manage_port": ps_context().scheduler_manage_port,
"config_file_path": ps_context().config_file_path,
"http_url_prefix": ps_context().http_url_prefix
"http_url_prefix": ps_context().http_url_prefix,
"global_iteration_time_window": ps_context().global_iteration_time_window
}
_check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port",

View File

@ -178,11 +178,10 @@ def call_get_instance_detail():
return process_self_define_json(Status.FAILED.value, "error. metrics file is not existed.")
ans_json_obj = {}
joined_client_num_list = []
rejected_client_num_list = []
metrics_auc_list = []
metrics_loss_list = []
iteration_execution_time_list = []
client_visited_info_list = []
with open(metrics_file_path, 'r') as f:
metrics_list = f.readlines()
@ -193,8 +192,7 @@ def call_get_instance_detail():
for metrics in metrics_list:
json_obj = json.loads(metrics)
iteration_execution_time_list.append(json_obj['iterationExecutionTime'])
joined_client_num_list.append(json_obj['joinedClientNum'])
rejected_client_num_list.append(json_obj['rejectedClientNum'])
client_visited_info_list.append(json_obj['clientVisitedInfo'])
metrics_auc_list.append(json_obj['metricsAuc'])
metrics_loss_list.append(json_obj['metricsLoss'])
@ -210,8 +208,7 @@ def call_get_instance_detail():
ans_json_result['flName'] = last_metrics_obj['flName']
ans_json_result['instanceStatus'] = last_metrics_obj['instanceStatus']
ans_json_result['iterationExecutionTime'] = iteration_execution_time_list
ans_json_result['joinedClientNum'] = joined_client_num_list
ans_json_result['rejectedClientNum'] = rejected_client_num_list
ans_json_result['clientVisitedInfo'] = client_visited_info_list
ans_json_result['metricsAuc'] = metrics_auc_list
ans_json_result['metricsLoss'] = metrics_loss_list