!29057 fix ISSUE I4QCJM
Merge pull request !29057 from tan-wei-cheng-3260/develop-fix
This commit is contained in:
commit
6eae53eb34
|
@ -86,6 +86,13 @@ enum class InstanceState {
|
|||
kFinish
|
||||
};
|
||||
|
||||
enum class IterationResult {
|
||||
// The iteration is timeout because of startfljob or updatemodel timeout.
|
||||
kTimeout,
|
||||
// The iteration is successful aggregation.
|
||||
kSuccess
|
||||
};
|
||||
|
||||
using mindspore::kernel::Address;
|
||||
using mindspore::kernel::AddressPtr;
|
||||
using mindspore::kernel::CPUKernel;
|
||||
|
@ -133,6 +140,15 @@ constexpr auto kFtrlLinear = "linear";
|
|||
constexpr auto kDataSize = "data_size";
|
||||
constexpr auto kNewDataSize = "new_data_size";
|
||||
constexpr auto kStat = "stat";
|
||||
constexpr auto kStartFLJobTotalClientNum = "startFLJobTotalClientNum";
|
||||
constexpr auto kStartFLJobAcceptClientNum = "startFLJobAcceptClientNum";
|
||||
constexpr auto kStartFLJobRejectClientNum = "startFLJobRejectClientNum";
|
||||
constexpr auto kUpdateModelTotalClientNum = "updateModelTotalClientNum";
|
||||
constexpr auto kUpdateModelAcceptClientNum = "updateModelAcceptClientNum";
|
||||
constexpr auto kUpdateModelRejectClientNum = "updateModelRejectClientNum";
|
||||
constexpr auto kGetModelTotalClientNum = "getModelTotalClientNum";
|
||||
constexpr auto kGetModelAcceptClientNum = "getModelAcceptClientNum";
|
||||
constexpr auto kGetModelRejectClientNum = "getModelRejectClientNum";
|
||||
|
||||
// OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is
|
||||
// launched.
|
||||
|
|
|
@ -162,6 +162,8 @@ void Iteration::SetIterationRunning() {
|
|||
std::unique_lock<std::mutex> lock(iteration_state_mtx_);
|
||||
iteration_state_ = IterationState::kRunning;
|
||||
start_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count());
|
||||
MS_LOG(INFO) << "Iteratoin " << iteration_num_ << " start global timer.";
|
||||
global_iter_timer_->Start(std::chrono::milliseconds(global_iteration_time_window_));
|
||||
}
|
||||
|
||||
void Iteration::SetIterationEnd() {
|
||||
|
@ -577,6 +579,7 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
|
|||
// Store the model which is successfully aggregated for this iteration.
|
||||
const auto &model = Executor::GetInstance().GetModel();
|
||||
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||
iteration_result_ = IterationResult::kSuccess;
|
||||
MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
|
||||
} else {
|
||||
// Store last iteration's model because this iteration is considered as invalid.
|
||||
|
@ -584,6 +587,7 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
|
|||
size_t latest_iter_num = iter_to_model.rbegin()->first;
|
||||
const auto &model = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num);
|
||||
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||
iteration_result_ = IterationResult::kTimeout;
|
||||
MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason;
|
||||
}
|
||||
|
||||
|
@ -591,6 +595,13 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
|
|||
MS_ERROR_IF_NULL_WO_RET_VAL(round);
|
||||
round->Reset();
|
||||
}
|
||||
MS_LOG(INFO) << "Iteratoin " << iteration_num_ << " stop global timer.";
|
||||
global_iter_timer_->Stop();
|
||||
|
||||
for (const auto &round : rounds_) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(round);
|
||||
round->KernelSummarize();
|
||||
}
|
||||
}
|
||||
|
||||
bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
|
||||
|
@ -598,11 +609,14 @@ bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
|
|||
MS_LOG(INFO) << "Notify all follower servers to end last iteration.";
|
||||
EndLastIterRequest end_last_iter_req;
|
||||
end_last_iter_req.set_last_iter_num(last_iter_num);
|
||||
std::shared_ptr<std::vector<unsigned char>> client_info_rsp_msg = nullptr;
|
||||
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
|
||||
if (!communicator_->SendPbRequest(end_last_iter_req, i, ps::core::TcpUserCommand::kEndLastIter)) {
|
||||
if (!communicator_->SendPbRequest(end_last_iter_req, i, ps::core::TcpUserCommand::kEndLastIter,
|
||||
&client_info_rsp_msg)) {
|
||||
MS_LOG(WARNING) << "Sending ending last iteration request to server " << i << " failed.";
|
||||
continue;
|
||||
}
|
||||
UpdateRoundClientNumMap(client_info_rsp_msg);
|
||||
}
|
||||
|
||||
EndLastIter();
|
||||
|
@ -629,10 +643,29 @@ void Iteration::HandleEndLastIterRequest(const std::shared_ptr<ps::core::Message
|
|||
return;
|
||||
}
|
||||
|
||||
EndLastIter();
|
||||
|
||||
EndLastIterResponse end_last_iter_rsp;
|
||||
end_last_iter_rsp.set_result("success");
|
||||
|
||||
for (const auto &round : rounds_) {
|
||||
if (round == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (round->name() == "startFLJob") {
|
||||
end_last_iter_rsp.set_startfljob_total_client_num(round->kernel_total_client_num());
|
||||
end_last_iter_rsp.set_startfljob_accept_client_num(round->kernel_accept_client_num());
|
||||
end_last_iter_rsp.set_startfljob_reject_client_num(round->kernel_reject_client_num());
|
||||
} else if (round->name() == "updateModel") {
|
||||
end_last_iter_rsp.set_updatemodel_total_client_num(round->kernel_total_client_num());
|
||||
end_last_iter_rsp.set_updatemodel_accept_client_num(round->kernel_accept_client_num());
|
||||
end_last_iter_rsp.set_updatemodel_reject_client_num(round->kernel_reject_client_num());
|
||||
} else if (round->name() == "getModel") {
|
||||
end_last_iter_rsp.set_getmodel_total_client_num(round->kernel_total_client_num());
|
||||
end_last_iter_rsp.set_getmodel_accept_client_num(round->kernel_accept_client_num());
|
||||
end_last_iter_rsp.set_getmodel_reject_client_num(round->kernel_reject_client_num());
|
||||
}
|
||||
}
|
||||
|
||||
EndLastIter();
|
||||
if (!communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(),
|
||||
end_last_iter_rsp.SerializeAsString().size(), message)) {
|
||||
MS_LOG(ERROR) << "Sending response failed.";
|
||||
|
@ -668,7 +701,10 @@ void Iteration::EndLastIter() {
|
|||
MS_LOG(WARNING) << "Can't save current iteration number into persistent storage.";
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &round : rounds_) {
|
||||
round->InitkernelClientVisitedNum();
|
||||
}
|
||||
round_client_num_map_.clear();
|
||||
Server::GetInstance().CancelSafeMode();
|
||||
iteration_state_cv_.notify_all();
|
||||
MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";
|
||||
|
@ -692,12 +728,8 @@ bool Iteration::SummarizeIteration() {
|
|||
metrics_->set_instance_state(instance_state_.load());
|
||||
metrics_->set_loss(loss_);
|
||||
metrics_->set_accuracy(accuracy_);
|
||||
// The joined client number is equal to the threshold of updateModel.
|
||||
size_t update_model_threshold = static_cast<size_t>(
|
||||
std::ceil(ps::PSContext::instance()->start_fl_job_threshold() * ps::PSContext::instance()->update_model_ratio()));
|
||||
metrics_->set_joined_client_num(update_model_threshold);
|
||||
// The rejected client number is equal to threshold of startFLJob minus threshold of updateModel.
|
||||
metrics_->set_rejected_client_num(ps::PSContext::instance()->start_fl_job_threshold() - update_model_threshold);
|
||||
metrics_->set_round_client_num_map(round_client_num_map_);
|
||||
metrics_->set_iteration_result(iteration_result_.load());
|
||||
|
||||
if (complete_timestamp_ < start_timestamp_) {
|
||||
MS_LOG(ERROR) << "The complete_timestamp_: " << complete_timestamp_ << ", start_timestamp_: " << start_timestamp_
|
||||
|
@ -749,7 +781,21 @@ bool Iteration::UpdateHyperParams(const nlohmann::json &json) {
|
|||
ps::PSContext::instance()->set_client_learning_rate(item.value().get<float>());
|
||||
continue;
|
||||
}
|
||||
if (key == "global_iteration_time_window") {
|
||||
ps::PSContext::instance()->set_global_iteration_time_window(item.value().get<uint64_t>());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "start_fl_job_threshold: " << ps::PSContext::instance()->start_fl_job_threshold();
|
||||
MS_LOG(INFO) << "start_fl_job_time_window: " << ps::PSContext::instance()->start_fl_job_time_window();
|
||||
MS_LOG(INFO) << "update_model_ratio: " << ps::PSContext::instance()->update_model_ratio();
|
||||
MS_LOG(INFO) << "update_model_time_window: " << ps::PSContext::instance()->update_model_time_window();
|
||||
MS_LOG(INFO) << "fl_iteration_num: " << ps::PSContext::instance()->fl_iteration_num();
|
||||
MS_LOG(INFO) << "client_epoch_num: " << ps::PSContext::instance()->client_epoch_num();
|
||||
MS_LOG(INFO) << "client_batch_size: " << ps::PSContext::instance()->client_batch_size();
|
||||
MS_LOG(INFO) << "client_learning_rate: " << ps::PSContext::instance()->client_learning_rate();
|
||||
MS_LOG(INFO) << "global_iteration_time_window: " << ps::PSContext::instance()->global_iteration_time_window();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -784,6 +830,36 @@ bool Iteration::ReInitRounds() {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void Iteration::InitGlobalIterTimer(const TimeOutCb &timeout_cb) {
|
||||
global_iteration_time_window_ = ps::PSContext::instance()->global_iteration_time_window();
|
||||
global_iter_timer_ = std::make_shared<IterationTimer>();
|
||||
|
||||
// Set the timeout callback for the timer.
|
||||
global_iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool, const std::string &) -> void {
|
||||
std::string reason = "Global Iteration " + std::to_string(iteration_num_) +
|
||||
" timeout! This iteration is invalid. Proceed to next iteration.";
|
||||
timeout_cb(false, reason);
|
||||
});
|
||||
}
|
||||
|
||||
void Iteration::UpdateRoundClientNumMap(const std::shared_ptr<std::vector<unsigned char>> &client_info_rsp_msg) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(client_info_rsp_msg);
|
||||
EndLastIterResponse end_last_iter_rsp;
|
||||
(void)end_last_iter_rsp.ParseFromArray(client_info_rsp_msg->data(), SizeToInt(client_info_rsp_msg->size()));
|
||||
|
||||
round_client_num_map_[kStartFLJobTotalClientNum] += end_last_iter_rsp.startfljob_total_client_num();
|
||||
round_client_num_map_[kStartFLJobAcceptClientNum] += end_last_iter_rsp.startfljob_accept_client_num();
|
||||
round_client_num_map_[kStartFLJobRejectClientNum] += end_last_iter_rsp.startfljob_reject_client_num();
|
||||
|
||||
round_client_num_map_[kUpdateModelTotalClientNum] += end_last_iter_rsp.updatemodel_total_client_num();
|
||||
round_client_num_map_[kUpdateModelAcceptClientNum] += end_last_iter_rsp.updatemodel_accept_client_num();
|
||||
round_client_num_map_[kUpdateModelRejectClientNum] += end_last_iter_rsp.updatemodel_reject_client_num();
|
||||
|
||||
round_client_num_map_[kGetModelTotalClientNum] += end_last_iter_rsp.getmodel_total_client_num();
|
||||
round_client_num_map_[kGetModelAcceptClientNum] += end_last_iter_rsp.getmodel_accept_client_num();
|
||||
round_client_num_map_[kGetModelRejectClientNum] += end_last_iter_rsp.getmodel_reject_client_num();
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "ps/core/communicator/communicator_base.h"
|
||||
#include "fl/server/common.h"
|
||||
#include "fl/server/round.h"
|
||||
|
@ -125,9 +126,15 @@ class Iteration {
|
|||
// Synchronize server iteration after another server's recovery is completed.
|
||||
bool SyncAfterRecovery(uint64_t iteration_num);
|
||||
|
||||
// Initialize global iteration timer.
|
||||
void InitGlobalIterTimer(const TimeOutCb &timeout_cb);
|
||||
|
||||
// The round kernels whose Launch method has not returned yet.
|
||||
std::atomic_uint32_t running_round_num_;
|
||||
|
||||
// Update count with client visited num in round
|
||||
void UpdateRoundClientNumMap(const std::string &name, const size_t num);
|
||||
|
||||
private:
|
||||
Iteration()
|
||||
: running_round_num_(0),
|
||||
|
@ -147,9 +154,18 @@ class Iteration {
|
|||
is_instance_being_updated_(false),
|
||||
loss_(0.0),
|
||||
accuracy_(0.0),
|
||||
joined_client_num_(0),
|
||||
rejected_client_num_(0),
|
||||
time_cost_(0) {
|
||||
time_cost_(0),
|
||||
global_iteration_time_window_(0),
|
||||
round_client_num_map_({{kStartFLJobTotalClientNum, 0},
|
||||
{kUpdateModelTotalClientNum, 0},
|
||||
{kGetModelTotalClientNum, 0},
|
||||
{kStartFLJobAcceptClientNum, 0},
|
||||
{kUpdateModelAcceptClientNum, 0},
|
||||
{kGetModelAcceptClientNum, 0},
|
||||
{kStartFLJobRejectClientNum, 0},
|
||||
{kUpdateModelRejectClientNum, 0},
|
||||
{kGetModelRejectClientNum, 0}}),
|
||||
iteration_result_(IterationResult::kSuccess) {
|
||||
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
|
||||
}
|
||||
~Iteration();
|
||||
|
@ -202,6 +218,8 @@ class Iteration {
|
|||
// Reinitialize rounds and round kernels.
|
||||
bool ReInitRounds();
|
||||
|
||||
void UpdateRoundClientNumMap(const std::shared_ptr<std::vector<unsigned char>> &client_info_rsp_msg);
|
||||
|
||||
std::shared_ptr<ps::core::ServerNode> server_node_;
|
||||
std::shared_ptr<ps::core::TcpCommunicator> communicator_;
|
||||
|
||||
|
@ -246,6 +264,7 @@ class Iteration {
|
|||
// Every instance is not reentrant.
|
||||
// This flag represents whether the instance is being updated.
|
||||
std::mutex instance_mtx_;
|
||||
|
||||
bool is_instance_being_updated_;
|
||||
|
||||
// The training loss after this federated learning iteration, passed by worker.
|
||||
|
@ -254,14 +273,20 @@ class Iteration {
|
|||
// The evaluation result after this federated learning iteration, passed by worker.
|
||||
float accuracy_;
|
||||
|
||||
// The number of clients which join the federated aggregation.
|
||||
size_t joined_client_num_;
|
||||
|
||||
// The number of clients which are not involved in federated aggregation.
|
||||
size_t rejected_client_num_;
|
||||
|
||||
// The time cost in millisecond for this completed iteration.
|
||||
uint64_t time_cost_;
|
||||
|
||||
// global iteration time window
|
||||
uint64_t global_iteration_time_window_;
|
||||
|
||||
// for example: "startFLJobTotalClientNum" -> startFLJob total client num
|
||||
std::map<std::string, size_t> round_client_num_map_;
|
||||
|
||||
// Iteration global timer.
|
||||
std::shared_ptr<IterationTimer> global_iter_timer_;
|
||||
|
||||
// The result for current iteration result.
|
||||
std::atomic<IterationResult> iteration_result_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
|
|
|
@ -79,11 +79,12 @@ bool IterationMetrics::Summarize() {
|
|||
js_[kInstanceStatus] = kInstanceStateName.at(instance_state_);
|
||||
js_[kFLIterationNum] = fl_iteration_num_;
|
||||
js_[kCurIteration] = cur_iteration_num_;
|
||||
js_[kJoinedClientNum] = joined_client_num_;
|
||||
js_[kRejectedClientNum] = rejected_client_num_;
|
||||
js_[kMetricsAuc] = accuracy_;
|
||||
js_[kMetricsLoss] = loss_;
|
||||
js_[kIterExecutionTime] = iteration_time_cost_;
|
||||
js_[kClientVisitedInfo] = round_client_num_map_;
|
||||
js_[kIterationResult] = kIterationResultName.at(iteration_result_);
|
||||
|
||||
metrics_file_ << js_ << "\n";
|
||||
(void)metrics_file_.flush();
|
||||
metrics_file_.close();
|
||||
|
@ -111,15 +112,15 @@ void IterationMetrics::set_loss(float loss) { loss_ = loss; }
|
|||
|
||||
void IterationMetrics::set_accuracy(float acc) { accuracy_ = acc; }
|
||||
|
||||
void IterationMetrics::set_joined_client_num(size_t joined_client_num) { joined_client_num_ = joined_client_num; }
|
||||
|
||||
void IterationMetrics::set_rejected_client_num(size_t rejected_client_num) {
|
||||
rejected_client_num_ = rejected_client_num;
|
||||
}
|
||||
|
||||
void IterationMetrics::set_iteration_time_cost(uint64_t iteration_time_cost) {
|
||||
iteration_time_cost_ = iteration_time_cost;
|
||||
}
|
||||
|
||||
void IterationMetrics::set_round_client_num_map(const std::map<std::string, size_t> round_client_num_map) {
|
||||
round_client_num_map_ = round_client_num_map;
|
||||
}
|
||||
|
||||
void IterationMetrics::set_iteration_result(IterationResult iteration_result) { iteration_result_ = iteration_result; }
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,16 +34,19 @@ constexpr auto kFLName = "flName";
|
|||
constexpr auto kInstanceStatus = "instanceStatus";
|
||||
constexpr auto kFLIterationNum = "flIterationNum";
|
||||
constexpr auto kCurIteration = "currentIteration";
|
||||
constexpr auto kJoinedClientNum = "joinedClientNum";
|
||||
constexpr auto kRejectedClientNum = "rejectedClientNum";
|
||||
constexpr auto kMetricsAuc = "metricsAuc";
|
||||
constexpr auto kMetricsLoss = "metricsLoss";
|
||||
constexpr auto kIterExecutionTime = "iterationExecutionTime";
|
||||
constexpr auto kMetrics = "metrics";
|
||||
constexpr auto kClientVisitedInfo = "clientVisitedInfo";
|
||||
constexpr auto kIterationResult = "iterationResult";
|
||||
|
||||
const std::map<InstanceState, std::string> kInstanceStateName = {
|
||||
{InstanceState::kRunning, "running"}, {InstanceState::kDisable, "disable"}, {InstanceState::kFinish, "finish"}};
|
||||
|
||||
const std::map<IterationResult, std::string> kIterationResultName = {{IterationResult::kSuccess, "success"},
|
||||
{IterationResult::kTimeout, "timeout"}};
|
||||
|
||||
class IterationMetrics {
|
||||
public:
|
||||
explicit IterationMetrics(const std::string &config_file)
|
||||
|
@ -55,9 +58,8 @@ class IterationMetrics {
|
|||
instance_state_(InstanceState::kFinish),
|
||||
loss_(0.0),
|
||||
accuracy_(0.0),
|
||||
joined_client_num_(0),
|
||||
rejected_client_num_(0),
|
||||
iteration_time_cost_(0) {}
|
||||
iteration_time_cost_(0),
|
||||
iteration_result_(IterationResult::kSuccess) {}
|
||||
~IterationMetrics() = default;
|
||||
|
||||
bool Initialize();
|
||||
|
@ -75,9 +77,9 @@ class IterationMetrics {
|
|||
void set_instance_state(InstanceState state);
|
||||
void set_loss(float loss);
|
||||
void set_accuracy(float acc);
|
||||
void set_joined_client_num(size_t joined_client_num);
|
||||
void set_rejected_client_num(size_t rejected_client_num);
|
||||
void set_iteration_time_cost(uint64_t iteration_time_cost);
|
||||
void set_round_client_num_map(const std::map<std::string, size_t> round_client_num_map);
|
||||
void set_iteration_result(IterationResult iteration_result);
|
||||
|
||||
private:
|
||||
// This is the main config file set by ps context.
|
||||
|
@ -112,14 +114,14 @@ class IterationMetrics {
|
|||
// The evaluation result after this federated learning iteration, passed by worker.
|
||||
float accuracy_;
|
||||
|
||||
// The number of clients which join the federated aggregation.
|
||||
size_t joined_client_num_;
|
||||
|
||||
// The number of clients which are not involved in federated aggregation.
|
||||
size_t rejected_client_num_;
|
||||
// for example: "startFLJobTotalClientNum" -> startFLJob total client num
|
||||
std::map<std::string, size_t> round_client_num_map_;
|
||||
|
||||
// The time cost in millisecond for this completed iteration.
|
||||
uint64_t iteration_time_cost_;
|
||||
|
||||
// Current iteration running result.
|
||||
IterationResult iteration_result_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
|
|
|
@ -30,7 +30,7 @@ void GetModelKernel::InitKernel(size_t) {
|
|||
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
}
|
||||
|
||||
InitClientVisitedNum();
|
||||
executor_ = &Executor::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
if (!executor_->initialized()) {
|
||||
|
@ -119,7 +119,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons
|
|||
} else {
|
||||
feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter);
|
||||
}
|
||||
|
||||
IncreaseAcceptClientNum();
|
||||
MS_LOG(INFO) << "GetModel last iteratin is valid or not: " << Iteration::GetInstance().is_last_iteration_valid()
|
||||
<< ", next request time is " << next_req_time << ", current iteration is " << current_iter;
|
||||
BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, "Get model for iteration " + std::to_string(get_model_iter),
|
||||
|
|
|
@ -125,8 +125,32 @@ void RoundKernel::GenerateOutput(const std::vector<AddressPtr> &outputs, const v
|
|||
|
||||
std::unique_lock<std::mutex> lock(heap_data_mtx_);
|
||||
(void)heap_data_.insert(std::make_pair(outputs[0], std::move(output_data)));
|
||||
IncreaseTotalClientNum();
|
||||
return;
|
||||
}
|
||||
|
||||
void RoundKernel::IncreaseTotalClientNum() { total_client_num_ += 1; }
|
||||
|
||||
void RoundKernel::IncreaseAcceptClientNum() { accept_client_num_ += 1; }
|
||||
|
||||
void RoundKernel::Summarize() {
|
||||
if (name_ == "startFLJob" || name_ == "updateModel" || name_ == "getModel") {
|
||||
MS_LOG(INFO) << "Round kernel " << name_ << " total client num is: " << total_client_num_
|
||||
<< ", accept client num is: " << accept_client_num_
|
||||
<< ", reject client num is: " << (total_client_num_ - accept_client_num_);
|
||||
}
|
||||
}
|
||||
|
||||
size_t RoundKernel::total_client_num() const { return total_client_num_; }
|
||||
|
||||
size_t RoundKernel::accept_client_num() const { return accept_client_num_; }
|
||||
|
||||
size_t RoundKernel::reject_client_num() const { return total_client_num_ - accept_client_num_; }
|
||||
|
||||
void RoundKernel::InitClientVisitedNum() {
|
||||
total_client_num_ = 0;
|
||||
accept_client_num_ = 0;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
|
|
|
@ -89,6 +89,20 @@ class RoundKernel : virtual public CPUKernel {
|
|||
void set_stop_timer_cb(const StopTimerCb &timer_stopper);
|
||||
void set_finish_iteration_cb(const FinishIterCb &finish_iteration_cb);
|
||||
|
||||
void Summarize();
|
||||
|
||||
void IncreaseTotalClientNum();
|
||||
|
||||
void IncreaseAcceptClientNum();
|
||||
|
||||
size_t total_client_num() const;
|
||||
|
||||
size_t accept_client_num() const;
|
||||
|
||||
size_t reject_client_num() const;
|
||||
|
||||
void InitClientVisitedNum();
|
||||
|
||||
protected:
|
||||
// Generating response data of this round. The data is allocated on the heap to ensure it's not released before sent
|
||||
// back to worker.
|
||||
|
@ -121,6 +135,9 @@ class RoundKernel : virtual public CPUKernel {
|
|||
std::queue<AddressPtr> heap_data_to_release_;
|
||||
std::mutex heap_data_mtx_;
|
||||
std::unordered_map<AddressPtr, std::unique_ptr<unsigned char[]>> heap_data_;
|
||||
|
||||
std::atomic<size_t> total_client_num_;
|
||||
std::atomic<size_t> accept_client_num_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
|
|
|
@ -36,14 +36,13 @@ void StartFLJobKernel::InitKernel(size_t) {
|
|||
}
|
||||
iter_next_req_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count()) + iteration_time_window_;
|
||||
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
|
||||
|
||||
InitClientVisitedNum();
|
||||
executor_ = &Executor::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
if (!executor_->initialized()) {
|
||||
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
||||
return;
|
||||
}
|
||||
|
||||
PBMetadata devices_metas;
|
||||
DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxDeviceMetas, devices_metas);
|
||||
|
||||
|
@ -132,6 +131,7 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return ConvertResultCode(result_code);
|
||||
}
|
||||
IncreaseAcceptClientNum();
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ void UpdateModelKernel::InitKernel(size_t threshold_count) {
|
|||
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
}
|
||||
|
||||
InitClientVisitedNum();
|
||||
executor_ = &Executor::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
if (!executor_->initialized()) {
|
||||
|
@ -121,6 +121,7 @@ bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
|
|||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return ConvertResultCode(result_code);
|
||||
}
|
||||
IncreaseAcceptClientNum();
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -220,14 +220,27 @@ bool Round::IsServerAvailable(std::string *reason) {
|
|||
return false;
|
||||
}
|
||||
|
||||
// If the server is still in the process of scaling, reject the request.
|
||||
// If the server is still in safemode, reject the request.
|
||||
if (Server::GetInstance().IsSafeMode()) {
|
||||
MS_LOG(WARNING) << "The cluster is still in process of scaling, please retry " << name_ << " later.";
|
||||
MS_LOG(WARNING) << "The cluster is still in safemode, please retry " << name_ << " later.";
|
||||
*reason = ps::kClusterSafeMode;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void Round::KernelSummarize() {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
|
||||
(void)kernel_->Summarize();
|
||||
}
|
||||
|
||||
size_t Round::kernel_total_client_num() const { return kernel_->total_client_num(); }
|
||||
|
||||
size_t Round::kernel_accept_client_num() const { return kernel_->accept_client_num(); }
|
||||
|
||||
size_t Round::kernel_reject_client_num() const { return kernel_->reject_client_num(); }
|
||||
|
||||
void Round::InitkernelClientVisitedNum() { kernel_->InitClientVisitedNum(); }
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -56,11 +56,24 @@ class Round {
|
|||
// Round needs to be reset after each iteration is finished or its timer expires.
|
||||
void Reset();
|
||||
|
||||
void KernelSummarize();
|
||||
|
||||
const std::string &name() const;
|
||||
|
||||
size_t threshold_count() const;
|
||||
|
||||
bool check_timeout() const;
|
||||
|
||||
size_t time_window() const;
|
||||
|
||||
size_t kernel_total_client_num() const;
|
||||
|
||||
size_t kernel_accept_client_num() const;
|
||||
|
||||
size_t kernel_reject_client_num() const;
|
||||
|
||||
void InitkernelClientVisitedNum();
|
||||
|
||||
private:
|
||||
// The callbacks which will be set to DistributedCounterService.
|
||||
void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
|
|
|
@ -242,6 +242,8 @@ void Server::InitIteration() {
|
|||
FinishIterCb finish_iter_cb =
|
||||
std::bind(&Iteration::NotifyNext, iteration_, std::placeholders::_1, std::placeholders::_2);
|
||||
iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb);
|
||||
|
||||
iteration_->InitGlobalIterTimer(time_out_cb);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -475,8 +475,10 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("set_encrypt_type", &PSContext::set_encrypt_type,
|
||||
"Set encrypt type for federated learning secure aggregation.")
|
||||
.def("set_http_url_prefix", &PSContext::set_http_url_prefix, "Set http url prefix for http communication.")
|
||||
.def("http_url_prefix", &PSContext::http_url_prefix, "http url prefix for http communication.");
|
||||
|
||||
.def("http_url_prefix", &PSContext::http_url_prefix, "http url prefix for http communication.")
|
||||
.def("set_global_iteration_time_window", &PSContext::set_global_iteration_time_window,
|
||||
"Set global iteration time window.")
|
||||
.def("global_iteration_time_window", &PSContext::global_iteration_time_window, "Get global iteration time window.");
|
||||
(void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data.");
|
||||
(void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data.");
|
||||
(void)m.def("_is_cipher_file", &mindspore::pipeline::PyIsCipherFile, "Determine whether the file is encrypted");
|
||||
|
|
|
@ -237,6 +237,15 @@ message EndLastIterRequest {
|
|||
|
||||
message EndLastIterResponse {
|
||||
string result = 1;
|
||||
uint64 startFLJob_total_client_num = 2;
|
||||
uint64 startFLJob_accept_client_num = 3;
|
||||
uint64 startFLJob_reject_client_num = 4;
|
||||
uint64 updateModel_total_client_num = 5;
|
||||
uint64 updateModel_accept_client_num = 6;
|
||||
uint64 updateModel_reject_client_num = 7;
|
||||
uint64 getModel_total_client_num = 8;
|
||||
uint64 getModel_accept_client_num = 9;
|
||||
uint64 getModel_reject_client_num = 10;
|
||||
}
|
||||
|
||||
message SyncAfterRecover {
|
||||
|
|
|
@ -476,5 +476,11 @@ void PSContext::set_server_password(const std::string &password) { server_passwo
|
|||
std::string PSContext::http_url_prefix() const { return http_url_prefix_; }
|
||||
|
||||
void PSContext::set_http_url_prefix(const std::string &http_url_prefix) { http_url_prefix_ = http_url_prefix; }
|
||||
|
||||
void PSContext::set_global_iteration_time_window(const uint64_t &global_iteration_time_window) {
|
||||
global_iteration_time_window_ = global_iteration_time_window;
|
||||
}
|
||||
|
||||
uint64_t PSContext::global_iteration_time_window() const { return global_iteration_time_window_; }
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -211,6 +211,9 @@ class PSContext {
|
|||
std::string http_url_prefix() const;
|
||||
void set_http_url_prefix(const std::string &http_url_prefix);
|
||||
|
||||
void set_global_iteration_time_window(const uint64_t &global_iteration_time_window);
|
||||
uint64_t global_iteration_time_window() const;
|
||||
|
||||
private:
|
||||
PSContext()
|
||||
: ps_enabled_(false),
|
||||
|
@ -229,9 +232,9 @@ class PSContext {
|
|||
fl_client_enable_(false),
|
||||
fl_name_(""),
|
||||
start_fl_job_threshold_(0),
|
||||
start_fl_job_time_window_(3000),
|
||||
start_fl_job_time_window_(300000),
|
||||
update_model_ratio_(1.0),
|
||||
update_model_time_window_(3000),
|
||||
update_model_time_window_(300000),
|
||||
share_secrets_ratio_(1.0),
|
||||
cipher_time_window_(300000),
|
||||
reconstruct_secrets_threshold_(2000),
|
||||
|
@ -257,7 +260,8 @@ class PSContext {
|
|||
enable_ssl_(false),
|
||||
client_password_(""),
|
||||
server_password_(""),
|
||||
http_url_prefix_("") {}
|
||||
http_url_prefix_(""),
|
||||
global_iteration_time_window_(21600000) {}
|
||||
bool ps_enabled_;
|
||||
bool is_worker_;
|
||||
bool is_pserver_;
|
||||
|
@ -372,6 +376,9 @@ class PSContext {
|
|||
std::string server_password_;
|
||||
// http url prefix for http communication
|
||||
std::string http_url_prefix_;
|
||||
|
||||
// The time window of startFLJob round in millisecond.
|
||||
uint64_t global_iteration_time_window_;
|
||||
};
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1101,7 +1101,10 @@ def set_fl_context(**kwargs):
|
|||
pki_verify is True. Default: "".
|
||||
replay_attack_time_diff (int): The maximum tolerable error of certificate timestamp verification (ms).
|
||||
Default: 600000.
|
||||
|
||||
http_url_prefix (string): The http url prefix for http server.
|
||||
Default: "".
|
||||
global_iteration_time_window (unsigned long): The global iteration time window for one iteration
|
||||
with rounds(ms). Default: 21600000.
|
||||
Raises:
|
||||
ValueError: If input key is not the attribute in federated learning mode context.
|
||||
|
||||
|
|
|
@ -71,7 +71,8 @@ _set_ps_context_func_map = {
|
|||
"dp_delta": ps_context().set_dp_delta,
|
||||
"dp_norm_clip": ps_context().set_dp_norm_clip,
|
||||
"encrypt_type": ps_context().set_encrypt_type,
|
||||
"http_url_prefix": ps_context().set_http_url_prefix
|
||||
"http_url_prefix": ps_context().set_http_url_prefix,
|
||||
"global_iteration_time_window": ps_context().set_global_iteration_time_window
|
||||
}
|
||||
|
||||
_get_ps_context_func_map = {
|
||||
|
@ -112,7 +113,8 @@ _get_ps_context_func_map = {
|
|||
"server_password": ps_context().server_password,
|
||||
"scheduler_manage_port": ps_context().scheduler_manage_port,
|
||||
"config_file_path": ps_context().config_file_path,
|
||||
"http_url_prefix": ps_context().http_url_prefix
|
||||
"http_url_prefix": ps_context().http_url_prefix,
|
||||
"global_iteration_time_window": ps_context().global_iteration_time_window
|
||||
}
|
||||
|
||||
_check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port",
|
||||
|
|
|
@ -178,11 +178,10 @@ def call_get_instance_detail():
|
|||
return process_self_define_json(Status.FAILED.value, "error. metrics file is not existed.")
|
||||
|
||||
ans_json_obj = {}
|
||||
joined_client_num_list = []
|
||||
rejected_client_num_list = []
|
||||
metrics_auc_list = []
|
||||
metrics_loss_list = []
|
||||
iteration_execution_time_list = []
|
||||
client_visited_info_list = []
|
||||
|
||||
with open(metrics_file_path, 'r') as f:
|
||||
metrics_list = f.readlines()
|
||||
|
@ -193,8 +192,7 @@ def call_get_instance_detail():
|
|||
for metrics in metrics_list:
|
||||
json_obj = json.loads(metrics)
|
||||
iteration_execution_time_list.append(json_obj['iterationExecutionTime'])
|
||||
joined_client_num_list.append(json_obj['joinedClientNum'])
|
||||
rejected_client_num_list.append(json_obj['rejectedClientNum'])
|
||||
client_visited_info_list.append(json_obj['clientVisitedInfo'])
|
||||
metrics_auc_list.append(json_obj['metricsAuc'])
|
||||
metrics_loss_list.append(json_obj['metricsLoss'])
|
||||
|
||||
|
@ -210,8 +208,7 @@ def call_get_instance_detail():
|
|||
ans_json_result['flName'] = last_metrics_obj['flName']
|
||||
ans_json_result['instanceStatus'] = last_metrics_obj['instanceStatus']
|
||||
ans_json_result['iterationExecutionTime'] = iteration_execution_time_list
|
||||
ans_json_result['joinedClientNum'] = joined_client_num_list
|
||||
ans_json_result['rejectedClientNum'] = rejected_client_num_list
|
||||
ans_json_result['clientVisitedInfo'] = client_visited_info_list
|
||||
ans_json_result['metricsAuc'] = metrics_auc_list
|
||||
ans_json_result['metricsLoss'] = metrics_loss_list
|
||||
|
||||
|
|
Loading…
Reference in New Issue