forked from mindspore-Ecosystem/mindspore
!16230 Optimize server reliability in multiple scenarios.
From: @zpac Reviewed-by: @cristoval,@limingqi107 Signed-off-by: @limingqi107
This commit is contained in:
commit
94ca479fbe
|
@ -629,14 +629,17 @@ bool StartServerAction(const ResourcePtr &res) {
|
|||
uint64_t fl_server_port = ps::PSContext::instance()->fl_server_port();
|
||||
|
||||
// Update model threshold is a certain ratio of start_fl_job threshold.
|
||||
// update_model_threshold_ = start_fl_job_threshold_ * percent_for_update_model_.
|
||||
// update_model_threshold = start_fl_job_threshold * update_model_ratio.
|
||||
size_t start_fl_job_threshold = ps::PSContext::instance()->start_fl_job_threshold();
|
||||
float percent_for_update_model = 1;
|
||||
size_t update_model_threshold = static_cast<size_t>(std::ceil(start_fl_job_threshold * percent_for_update_model));
|
||||
float update_model_ratio = ps::PSContext::instance()->update_model_ratio();
|
||||
size_t update_model_threshold = static_cast<size_t>(std::ceil(start_fl_job_threshold * update_model_ratio));
|
||||
uint64_t start_fl_job_time_window = ps::PSContext::instance()->start_fl_job_time_window();
|
||||
uint64_t update_model_time_window = ps::PSContext::instance()->update_model_time_window();
|
||||
|
||||
std::vector<ps::server::RoundConfig> rounds_config = {{"startFLJob", false, 3000, true, start_fl_job_threshold},
|
||||
{"updateModel", false, 3000, true, update_model_threshold},
|
||||
{"getModel", false, 3000}};
|
||||
std::vector<ps::server::RoundConfig> rounds_config = {
|
||||
{"startFLJob", true, start_fl_job_time_window, true, start_fl_job_threshold},
|
||||
{"updateModel", true, update_model_time_window, true, update_model_threshold},
|
||||
{"getModel"}};
|
||||
|
||||
size_t executor_threshold = 0;
|
||||
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
|
||||
|
|
|
@ -345,11 +345,20 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("set_scheduler_port", &PSContext::set_scheduler_port, "Set scheduler port.")
|
||||
.def("set_fl_server_port", &PSContext::set_fl_server_port, "Set federated learning server port.")
|
||||
.def("set_fl_client_enable", &PSContext::set_fl_client_enable, "Set federated learning client.")
|
||||
.def("set_start_fl_job_threshold", &PSContext::set_start_fl_job_threshold, "Set threshold count for start_fl_job.")
|
||||
.def("set_start_fl_job_threshold", &PSContext::set_start_fl_job_threshold,
|
||||
"Set threshold count for startFLJob round.")
|
||||
.def("set_start_fl_job_time_window", &PSContext::set_start_fl_job_time_window,
|
||||
"Set time window for startFLJob round.")
|
||||
.def("set_update_model_ratio", &PSContext::set_update_model_ratio,
|
||||
"Set threshold count ratio for updateModel round.")
|
||||
.def("set_update_model_time_window", &PSContext::set_update_model_time_window,
|
||||
"Set time window for updateModel round.")
|
||||
.def("set_fl_name", &PSContext::set_fl_name, "Set federated learning name.")
|
||||
.def("set_fl_iteration_num", &PSContext::set_fl_iteration_num, "Set federated learning iteration number.")
|
||||
.def("set_client_epoch_num", &PSContext::set_client_epoch_num, "Set federated learning client epoch number.")
|
||||
.def("set_client_batch_size", &PSContext::set_client_batch_size, "Set federated learning client batch size.")
|
||||
.def("set_client_learning_rate", &PSContext::set_client_learning_rate,
|
||||
"Set federated learning client learning rate.")
|
||||
.def("set_secure_aggregation", &PSContext::set_secure_aggregation,
|
||||
"Set federated learning client using secure aggregation.")
|
||||
.def("set_enable_ssl", &PSContext::enable_ssl, "Set PS SSL mode enabled or disabled.");
|
||||
|
|
|
@ -196,7 +196,14 @@ void PSContext::set_ms_role(const std::string &role) {
|
|||
role_ = role;
|
||||
}
|
||||
|
||||
void PSContext::set_worker_num(uint32_t worker_num) { worker_num_ = worker_num; }
|
||||
void PSContext::set_worker_num(uint32_t worker_num) {
|
||||
// Hybrid training mode only supports one worker for now.
|
||||
if (server_mode_ == kServerModeHybrid && worker_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "The worker number should be set to 1 in hybrid training mode.";
|
||||
return;
|
||||
}
|
||||
worker_num_ = worker_num;
|
||||
}
|
||||
uint32_t PSContext::worker_num() const { return worker_num_; }
|
||||
|
||||
void PSContext::set_server_num(uint32_t server_num) {
|
||||
|
@ -235,7 +242,7 @@ void PSContext::GenerateResetterRound() {
|
|||
}
|
||||
|
||||
binary_server_context = (is_parameter_server_mode << 0) | (is_federated_learning_mode << 1) |
|
||||
(is_mixed_training_mode << 2) | (secure_aggregation_ << 3) | (worker_upload_weights_ << 4);
|
||||
(is_mixed_training_mode << 2) | (secure_aggregation_ << 3);
|
||||
if (kServerContextToResetRoundMap.count(binary_server_context) == 0) {
|
||||
resetter_round_ = ResetterRound::kNoNeedToReset;
|
||||
} else {
|
||||
|
@ -255,11 +262,27 @@ void PSContext::set_fl_client_enable(bool enabled) { fl_client_enable_ = enabled
|
|||
|
||||
bool PSContext::fl_client_enable() { return fl_client_enable_; }
|
||||
|
||||
void PSContext::set_start_fl_job_threshold(size_t start_fl_job_threshold) {
|
||||
void PSContext::set_start_fl_job_threshold(uint64_t start_fl_job_threshold) {
|
||||
start_fl_job_threshold_ = start_fl_job_threshold;
|
||||
}
|
||||
|
||||
size_t PSContext::start_fl_job_threshold() const { return start_fl_job_threshold_; }
|
||||
uint64_t PSContext::start_fl_job_threshold() const { return start_fl_job_threshold_; }
|
||||
|
||||
void PSContext::set_start_fl_job_time_window(uint64_t start_fl_job_time_window) {
|
||||
start_fl_job_time_window_ = start_fl_job_time_window;
|
||||
}
|
||||
|
||||
uint64_t PSContext::start_fl_job_time_window() const { return start_fl_job_time_window_; }
|
||||
|
||||
void PSContext::set_update_model_ratio(float update_model_ratio) { update_model_ratio_ = update_model_ratio; }
|
||||
|
||||
float PSContext::update_model_ratio() const { return update_model_ratio_; }
|
||||
|
||||
void PSContext::set_update_model_time_window(uint64_t update_model_time_window) {
|
||||
update_model_time_window_ = update_model_time_window;
|
||||
}
|
||||
|
||||
uint64_t PSContext::update_model_time_window() const { return update_model_time_window_; }
|
||||
|
||||
void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
|
||||
|
||||
|
@ -277,11 +300,9 @@ void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch
|
|||
|
||||
uint64_t PSContext::client_batch_size() const { return client_batch_size_; }
|
||||
|
||||
void PSContext::set_worker_upload_weights(uint64_t worker_upload_weights) {
|
||||
worker_upload_weights_ = worker_upload_weights;
|
||||
}
|
||||
void PSContext::set_client_learning_rate(float client_learning_rate) { client_learning_rate_ = client_learning_rate; }
|
||||
|
||||
uint64_t PSContext::worker_upload_weights() const { return worker_upload_weights_; }
|
||||
float PSContext::client_learning_rate() const { return client_learning_rate_; }
|
||||
|
||||
void PSContext::set_secure_aggregation(bool secure_aggregation) { secure_aggregation_ = secure_aggregation; }
|
||||
|
||||
|
|
|
@ -41,15 +41,13 @@ constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
|
|||
// 1: Server is in federated learning mode.
|
||||
// 2: Server is in mixed training mode.
|
||||
// 3: Server enables sucure aggregation.
|
||||
// 4: Server needs worker to overwrite weights.
|
||||
// For example: 01010 stands for that the server is in federated learning mode and sucure aggregation is enabled.
|
||||
enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kWorkerOverwriteWeights };
|
||||
const std::map<uint32_t, ResetterRound> kServerContextToResetRoundMap = {
|
||||
{0b00010, ResetterRound::kUpdateModel},
|
||||
{0b01010, ResetterRound::kReconstructSeccrets},
|
||||
{0b11100, ResetterRound::kWorkerOverwriteWeights},
|
||||
{0b10100, ResetterRound::kWorkerOverwriteWeights},
|
||||
{0b00100, ResetterRound::kUpdateModel}};
|
||||
// For example: 1010 stands for that the server is in federated learning mode and sucure aggregation is enabled.
|
||||
enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kWorkerUploadWeights };
|
||||
const std::map<uint32_t, ResetterRound> kServerContextToResetRoundMap = {{0b0010, ResetterRound::kUpdateModel},
|
||||
{0b1010, ResetterRound::kReconstructSeccrets},
|
||||
{0b1100, ResetterRound::kWorkerUploadWeights},
|
||||
{0b0100, ResetterRound::kWorkerUploadWeights},
|
||||
{0b0100, ResetterRound::kUpdateModel}};
|
||||
|
||||
class PSContext {
|
||||
public:
|
||||
|
@ -115,8 +113,17 @@ class PSContext {
|
|||
void set_fl_client_enable(bool enabled);
|
||||
bool fl_client_enable();
|
||||
|
||||
void set_start_fl_job_threshold(size_t start_fl_job_threshold);
|
||||
size_t start_fl_job_threshold() const;
|
||||
void set_start_fl_job_threshold(uint64_t start_fl_job_threshold);
|
||||
uint64_t start_fl_job_threshold() const;
|
||||
|
||||
void set_start_fl_job_time_window(uint64_t start_fl_job_time_window);
|
||||
uint64_t start_fl_job_time_window() const;
|
||||
|
||||
void set_update_model_ratio(float update_model_ratio);
|
||||
float update_model_ratio() const;
|
||||
|
||||
void set_update_model_time_window(uint64_t update_model_time_window);
|
||||
uint64_t update_model_time_window() const;
|
||||
|
||||
void set_fl_name(const std::string &fl_name);
|
||||
const std::string &fl_name() const;
|
||||
|
@ -133,9 +140,8 @@ class PSContext {
|
|||
void set_client_batch_size(uint64_t client_batch_size);
|
||||
uint64_t client_batch_size() const;
|
||||
|
||||
// Set true if worker will overwrite weights on server. Used in hybrid training.
|
||||
void set_worker_upload_weights(uint64_t worker_upload_weights);
|
||||
uint64_t worker_upload_weights() const;
|
||||
void set_client_learning_rate(float client_learning_rate);
|
||||
float client_learning_rate() const;
|
||||
|
||||
// Set true if using secure aggregation for federated learning.
|
||||
void set_secure_aggregation(bool secure_aggregation);
|
||||
|
@ -160,11 +166,14 @@ class PSContext {
|
|||
fl_client_enable_(false),
|
||||
fl_name_(""),
|
||||
start_fl_job_threshold_(0),
|
||||
fl_iteration_num_(0),
|
||||
client_epoch_num_(0),
|
||||
client_batch_size_(0),
|
||||
secure_aggregation_(false),
|
||||
worker_upload_weights_(false) {}
|
||||
start_fl_job_time_window_(3000),
|
||||
update_model_ratio_(1.0),
|
||||
update_model_time_window_(3000),
|
||||
fl_iteration_num_(20),
|
||||
client_epoch_num_(25),
|
||||
client_batch_size_(32),
|
||||
client_learning_rate_(0.001),
|
||||
secure_aggregation_(false) {}
|
||||
bool ps_enabled_;
|
||||
bool is_worker_;
|
||||
bool is_pserver_;
|
||||
|
@ -195,7 +204,16 @@ class PSContext {
|
|||
std::string fl_name_;
|
||||
|
||||
// The threshold count of startFLJob round. Used in federated learning for now.
|
||||
size_t start_fl_job_threshold_;
|
||||
uint64_t start_fl_job_threshold_;
|
||||
|
||||
// The time window of startFLJob round in millisecond.
|
||||
uint64_t start_fl_job_time_window_;
|
||||
|
||||
// Update model threshold is a certain ratio of start_fl_job threshold which is set as update_model_ratio_.
|
||||
float update_model_ratio_;
|
||||
|
||||
// The time window of updateModel round in millisecond.
|
||||
uint64_t update_model_time_window_;
|
||||
|
||||
// Iteration number of federeated learning, which is the number of interactions between client and server.
|
||||
uint64_t fl_iteration_num_;
|
||||
|
@ -206,12 +224,11 @@ class PSContext {
|
|||
// Client training data batch size. Used in federated learning for now.
|
||||
uint64_t client_batch_size_;
|
||||
|
||||
// Client training learning rate. Used in federated learning for now.
|
||||
float client_learning_rate_;
|
||||
|
||||
// Whether to use secure aggregation algorithm. Used in federated learning for now.
|
||||
bool secure_aggregation_;
|
||||
|
||||
// Whether there's a federated learning worker uploading weights to federated learning server. Used in hybrid training
|
||||
// mode for now.
|
||||
bool worker_upload_weights_;
|
||||
};
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -56,9 +56,9 @@ using mindspore::kernel::Address;
|
|||
using mindspore::kernel::AddressPtr;
|
||||
using mindspore::kernel::CPUKernel;
|
||||
using FBBuilder = flatbuffers::FlatBufferBuilder;
|
||||
using TimeOutCb = std::function<void(void)>;
|
||||
using TimeOutCb = std::function<void(bool)>;
|
||||
using StopTimerCb = std::function<void(void)>;
|
||||
using FinishIterCb = std::function<void(void)>;
|
||||
using FinishIterCb = std::function<void(bool)>;
|
||||
using FinalizeCb = std::function<void(void)>;
|
||||
using MessageCallback = std::function<void(const std::shared_ptr<core::MessageHandler> &)>;
|
||||
|
||||
|
@ -148,6 +148,7 @@ constexpr size_t kExecutorMaxTaskNum = 32;
|
|||
constexpr int kHttpSuccess = 200;
|
||||
constexpr auto kPBProtocol = "PB";
|
||||
constexpr auto kFBSProtocol = "FBS";
|
||||
constexpr auto kSuccess = "Success";
|
||||
constexpr auto kFedAvg = "FedAvg";
|
||||
constexpr auto kAggregationKernelType = "Aggregation";
|
||||
constexpr auto kOptimizerKernelType = "Optimizer";
|
||||
|
@ -155,6 +156,7 @@ constexpr auto kCtxFuncGraph = "FuncGraph";
|
|||
constexpr auto kCtxIterNum = "iteration";
|
||||
constexpr auto kCtxDeviceMetas = "device_metas";
|
||||
constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration";
|
||||
constexpr auto kCtxIterationNextRequestTimestamp = "iteration_next_request_timestamp";
|
||||
constexpr auto kCtxUpdateModelClientList = "update_model_client_list";
|
||||
constexpr auto kCtxUpdateModelClientNum = "update_model_client_num";
|
||||
constexpr auto kCtxUpdateModelThld = "update_model_threshold";
|
||||
|
|
|
@ -130,7 +130,7 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) {
|
|||
|
||||
void DistributedCountService::ResetCounter(const std::string &name) {
|
||||
if (local_rank_ == counting_server_rank_) {
|
||||
MS_LOG(INFO) << "Leader server reset count for " << name;
|
||||
MS_LOG(DEBUG) << "Leader server reset count for " << name;
|
||||
global_current_count_[name].clear();
|
||||
}
|
||||
return;
|
||||
|
@ -233,7 +233,7 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<core::Mes
|
|||
const auto &type = counter_event.type();
|
||||
const auto &name = counter_event.name();
|
||||
|
||||
MS_LOG(INFO) << "Rank " << local_rank_ << " do counter event " << type << " for " << name;
|
||||
MS_LOG(DEBUG) << "Rank " << local_rank_ << " do counter event " << type << " for " << name;
|
||||
if (type == CounterEventType::FIRST_CNT) {
|
||||
counter_handlers_[name].first_count_handler(message);
|
||||
} else if (type == CounterEventType::LAST_CNT) {
|
||||
|
@ -259,7 +259,7 @@ void DistributedCountService::TriggerCounterEvent(const std::string &name) {
|
|||
}
|
||||
|
||||
void DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
|
||||
MS_LOG(INFO) << "Activating first count event for " << name;
|
||||
MS_LOG(DEBUG) << "Activating first count event for " << name;
|
||||
CounterEvent first_count_event;
|
||||
first_count_event.set_type(CounterEventType::FIRST_CNT);
|
||||
first_count_event.set_name(name);
|
||||
|
|
|
@ -79,10 +79,10 @@ void DistributedMetadataStore::ResetMetadata(const std::string &name) {
|
|||
return;
|
||||
}
|
||||
|
||||
void DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta) {
|
||||
bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta) {
|
||||
if (router_ == nullptr) {
|
||||
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t stored_rank = router_->Find(name);
|
||||
|
@ -90,18 +90,26 @@ void DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM
|
|||
if (local_rank_ == stored_rank) {
|
||||
if (!DoUpdateMetadata(name, meta)) {
|
||||
MS_LOG(ERROR) << "Updating meta data failed.";
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
PBMetadataWithName metadata_with_name;
|
||||
metadata_with_name.set_name(name);
|
||||
*metadata_with_name.mutable_metadata() = meta;
|
||||
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, core::TcpUserCommand::kUpdateMetadata)) {
|
||||
std::shared_ptr<std::vector<unsigned char>> update_meta_rsp_msg = nullptr;
|
||||
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, core::TcpUserCommand::kUpdateMetadata,
|
||||
&update_meta_rsp_msg)) {
|
||||
MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed.";
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string update_meta_rsp = reinterpret_cast<const char *>(update_meta_rsp_msg->data());
|
||||
if (update_meta_rsp != kSuccess) {
|
||||
MS_LOG(ERROR) << "Updating metadata in server " << stored_rank << " failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return;
|
||||
return true;
|
||||
}
|
||||
|
||||
PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
|
||||
|
@ -166,6 +174,7 @@ void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr
|
|||
std::string update_meta_rsp_msg;
|
||||
if (!DoUpdateMetadata(name, meta_with_name.metadata())) {
|
||||
update_meta_rsp_msg = "Updating meta data failed.";
|
||||
MS_LOG(ERROR) << update_meta_rsp_msg;
|
||||
} else {
|
||||
update_meta_rsp_msg = "Success";
|
||||
}
|
||||
|
|
|
@ -52,7 +52,7 @@ class DistributedMetadataStore {
|
|||
void ResetMetadata(const std::string &name);
|
||||
|
||||
// Update the metadata for the name.
|
||||
void UpdateMetadata(const std::string &name, const PBMetadata &meta);
|
||||
bool UpdateMetadata(const std::string &name, const PBMetadata &meta);
|
||||
|
||||
// Get the metadata for the name.
|
||||
PBMetadata GetMetadata(const std::string &name);
|
||||
|
|
|
@ -23,8 +23,6 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
Iteration::Iteration() : iteration_num_(1) { LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); }
|
||||
|
||||
void Iteration::AddRound(const std::shared_ptr<Round> &round) {
|
||||
MS_EXCEPTION_IF_NULL(round);
|
||||
rounds_.push_back(round);
|
||||
|
@ -49,28 +47,48 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<core::CommunicatorB
|
|||
|
||||
// The time window for one iteration, which will be used in some round kernels.
|
||||
size_t iteration_time_window =
|
||||
std::accumulate(rounds_.begin(), rounds_.end(), 0,
|
||||
[](size_t total, const std::shared_ptr<Round> &round) { return total + round->time_window(); });
|
||||
std::accumulate(rounds_.begin(), rounds_.end(), 0, [](size_t total, const std::shared_ptr<Round> &round) {
|
||||
return round->check_timeout() ? total + round->time_window() : total;
|
||||
});
|
||||
LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window);
|
||||
MS_LOG(INFO) << "Time window for one iteration is " << iteration_time_window;
|
||||
return;
|
||||
}
|
||||
|
||||
void Iteration::ProceedToNextIter() {
|
||||
void Iteration::ProceedToNextIter(bool is_iteration_valid) {
|
||||
iteration_num_ = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
// Store the model for each iteration.
|
||||
const auto &model = Executor::GetInstance().GetModel();
|
||||
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||
if (is_iteration_valid) {
|
||||
// Store the model which is successfully aggregated for this iteration.
|
||||
const auto &model = Executor::GetInstance().GetModel();
|
||||
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||
MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
|
||||
} else {
|
||||
// Store last iteration's model because this iteration is considered as invalid.
|
||||
const auto &model = ModelStore::GetInstance().GetModelByIterNum(iteration_num_ - 1);
|
||||
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||
MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid.";
|
||||
}
|
||||
|
||||
for (auto &round : rounds_) {
|
||||
round->Reset();
|
||||
}
|
||||
|
||||
iteration_num_++;
|
||||
// After the job is done, reset the iteration to the initial number and reset ModelStore.
|
||||
if (iteration_num_ > PSContext::instance()->fl_iteration_num()) {
|
||||
MS_LOG(INFO) << PSContext::instance()->fl_iteration_num() << " iterations are completed.";
|
||||
iteration_num_ = 1;
|
||||
ModelStore::GetInstance().Reset();
|
||||
}
|
||||
|
||||
is_last_iteration_valid_ = is_iteration_valid;
|
||||
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
|
||||
MS_LOG(INFO) << "Proceed to next iteration:" << iteration_num_ << "\n";
|
||||
}
|
||||
|
||||
const std::vector<std::shared_ptr<Round>> &Iteration::rounds() { return rounds_; }
|
||||
|
||||
bool Iteration::is_last_iteration_valid() const { return is_last_iteration_valid_; }
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,8 +31,10 @@ namespace server {
|
|||
// Rounds, only after all the rounds are finished, this iteration is considered as completed.
|
||||
class Iteration {
|
||||
public:
|
||||
Iteration();
|
||||
~Iteration() = default;
|
||||
static Iteration &GetInstance() {
|
||||
static Iteration instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
// Add a round for the iteration. This method will be called multiple times for each round.
|
||||
void AddRound(const std::shared_ptr<Round> &round);
|
||||
|
@ -41,16 +43,29 @@ class Iteration {
|
|||
void InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators,
|
||||
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb);
|
||||
|
||||
// The server proceeds to the next iteration only after the last iteration finishes.
|
||||
void ProceedToNextIter();
|
||||
// The server proceeds to the next iteration only after the last round finishes or the timer expires.
|
||||
// If the timer expires, we consider this iteration as invalid.
|
||||
void ProceedToNextIter(bool is_iteration_valid);
|
||||
|
||||
const std::vector<std::shared_ptr<Round>> &rounds();
|
||||
|
||||
bool is_last_iteration_valid() const;
|
||||
|
||||
private:
|
||||
Iteration() : iteration_num_(1), is_last_iteration_valid_(true) {
|
||||
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
|
||||
}
|
||||
~Iteration() = default;
|
||||
Iteration(const Iteration &) = delete;
|
||||
Iteration &operator=(const Iteration &) = delete;
|
||||
|
||||
std::vector<std::shared_ptr<Round>> rounds_;
|
||||
|
||||
// Server's current iteration number.
|
||||
size_t iteration_num_;
|
||||
|
||||
// Last iteration is successfully finished.
|
||||
bool is_last_iteration_valid_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
|
|
|
@ -29,7 +29,7 @@ void IterationTimer::Start(const std::chrono::milliseconds &duration) {
|
|||
monitor_thread_ = std::thread([&]() {
|
||||
while (running_.load()) {
|
||||
if (CURRENT_TIME_MILLI > end_time_) {
|
||||
timeout_callback_();
|
||||
timeout_callback_(false);
|
||||
running_ = false;
|
||||
}
|
||||
// The time tick is 1 millisecond.
|
||||
|
|
|
@ -47,6 +47,7 @@ class ApplyMomentumKernel : public ApplyMomentumCPUKernel, public OptimizerKerne
|
|||
}
|
||||
|
||||
void GenerateReuseKernelNodeInfo() override {
|
||||
MS_LOG(INFO) << "FedAvg reuse 'weight', 'accumulation', 'learning rate' and 'momentum' of the kernel node.";
|
||||
reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, 0));
|
||||
reuse_kernel_node_inputs_info_.insert(std::make_pair(kAccumulation, 1));
|
||||
reuse_kernel_node_inputs_info_.insert(std::make_pair(kLearningRate, 2));
|
||||
|
|
|
@ -92,7 +92,6 @@ class FedAvgKernel : public AggregationKernel {
|
|||
weight_addr[i] /= data_size_addr[0];
|
||||
}
|
||||
done_ = true;
|
||||
DistributedCountService::GetInstance().ResetCounter(name_);
|
||||
return;
|
||||
};
|
||||
DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler, last_cnt_handler});
|
||||
|
@ -125,6 +124,7 @@ class FedAvgKernel : public AggregationKernel {
|
|||
participated_ = true;
|
||||
DistributedCountService::GetInstance().Count(
|
||||
name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_));
|
||||
GenerateReuseKernelNodeInfo();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -149,6 +149,7 @@ class FedAvgKernel : public AggregationKernel {
|
|||
|
||||
private:
|
||||
void GenerateReuseKernelNodeInfo() override {
|
||||
MS_LOG(INFO) << "FedAvg reuse 'weight' of the kernel node.";
|
||||
// Only the trainable parameter is reused for federated average.
|
||||
reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, cnode_weight_idx_));
|
||||
return;
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ps/server/iteration.h"
|
||||
#include "ps/server/model_store.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -67,27 +68,31 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons
|
|||
const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
|
||||
size_t latest_iter_num = iter_to_model.rbegin()->first;
|
||||
|
||||
// If this iteration is not finished yet, return ResponseCode_SucNotReady so that clients could get model later.
|
||||
if ((current_iter == get_model_iter && latest_iter_num != current_iter) || current_iter == get_model_iter - 1) {
|
||||
std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter);
|
||||
BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps,
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(WARNING) << reason;
|
||||
return;
|
||||
}
|
||||
|
||||
if (iter_to_model.count(get_model_iter) == 0) {
|
||||
std::string reason = "The iteration of GetModel request" + std::to_string(get_model_iter) +
|
||||
" is invalid. Current iteration is " + std::to_string(current_iter);
|
||||
BuildGetModelRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter, feature_maps,
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return;
|
||||
// If the model of get_model_iter is not stored, return the latest version of model and current iteration number.
|
||||
MS_LOG(WARNING) << "The iteration of GetModel request " << std::to_string(get_model_iter)
|
||||
<< " is invalid. Current iteration is " << std::to_string(current_iter);
|
||||
feature_maps = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num);
|
||||
} else {
|
||||
feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter);
|
||||
}
|
||||
|
||||
feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter);
|
||||
BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED,
|
||||
"Get model for iteration " + std::to_string(get_model_iter) + " success.", current_iter,
|
||||
feature_maps, std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
// If the iteration of this model is invalid, return ResponseCode_OutOfTime to the clients could startFLJob according
|
||||
// to next_req_time.
|
||||
auto response_code =
|
||||
Iteration::GetInstance().is_last_iteration_valid() ? schema::ResponseCode_SUCCEED : schema::ResponseCode_OutOfTime;
|
||||
BuildGetModelRsp(fbb, response_code, "Get model for iteration " + std::to_string(get_model_iter), current_iter,
|
||||
feature_maps,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -68,7 +68,7 @@ void RoundKernel::StopTimer() {
|
|||
|
||||
void RoundKernel::FinishIteration() {
|
||||
if (finish_iteration_cb_) {
|
||||
finish_iteration_cb_();
|
||||
finish_iteration_cb_(true);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -61,15 +61,12 @@ class RoundKernel : virtual public CPUKernel {
|
|||
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) = 0;
|
||||
|
||||
// The callbacks when first message and last message for this round kernel is received.
|
||||
// These methods is called by class DistributedCountService and triggered by leader server(Rank 0).
|
||||
// virtual void OnFirstCountEvent(std::shared_ptr<core::MessageHandler> message);
|
||||
// virtual void OnLastCnt(std::shared_ptr<core::MessageHandler> message);
|
||||
|
||||
// Some rounds could be stateful in a iteration. Reset method resets the status of this round.
|
||||
virtual bool Reset() = 0;
|
||||
|
||||
// The counter event handlers for DistributedCountService.
|
||||
// The callbacks when first message and last message for this round kernel is received.
|
||||
// These methods is called by class DistributedCountService and triggered by counting server.
|
||||
virtual void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
||||
virtual void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
||||
|
||||
|
|
|
@ -25,9 +25,12 @@ namespace ps {
|
|||
namespace server {
|
||||
namespace kernel {
|
||||
void StartFLJobKernel::InitKernel(size_t) {
|
||||
// The time window of one iteration should be started at the first message of startFLJob round.
|
||||
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
}
|
||||
iter_next_req_timestamp_ = CURRENT_TIME_MILLI.count() + iteration_time_window_;
|
||||
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
|
||||
|
||||
executor_ = &Executor::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
|
@ -85,11 +88,17 @@ bool StartFLJobKernel::Reset() {
|
|||
return true;
|
||||
}
|
||||
|
||||
void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &) {
|
||||
iter_next_req_timestamp_ = CURRENT_TIME_MILLI.count() + iteration_time_window_;
|
||||
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
|
||||
}
|
||||
|
||||
bool StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb) {
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
std::string reason = "Current amount for startFLJob has reached the threshold. Please startFLJob later.";
|
||||
BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false,
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
BuildStartFLJobRsp(
|
||||
fbb, schema::ResponseCode_OutOfTime, reason, false,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return true;
|
||||
}
|
||||
|
@ -117,8 +126,9 @@ bool StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
|
|||
ret = false;
|
||||
}
|
||||
if (!ret) {
|
||||
BuildStartFLJobRsp(fbb, schema::ResponseCode_NotSelected, reason, false,
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
BuildStartFLJobRsp(
|
||||
fbb, schema::ResponseCode_NotSelected, reason, false,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(ERROR) << reason;
|
||||
}
|
||||
return ret;
|
||||
|
@ -128,8 +138,9 @@ bool StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
|
|||
const schema::RequestFLJob *start_fl_job_req) {
|
||||
if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str())) {
|
||||
std::string reason = "startFLJob counting failed.";
|
||||
BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false,
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
BuildStartFLJobRsp(
|
||||
fbb, schema::ResponseCode_OutOfTime, reason, false,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
@ -139,11 +150,18 @@ bool StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
|
|||
void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
|
||||
PBMetadata metadata;
|
||||
*metadata.mutable_device_meta() = device_meta;
|
||||
DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata);
|
||||
if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata)) {
|
||||
std::string reason = "Updating device metadata failed.";
|
||||
BuildStartFLJobRsp(fbb, schema::ResponseCode_SystemError, reason, false,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)),
|
||||
{});
|
||||
return;
|
||||
}
|
||||
|
||||
std::map<std::string, AddressPtr> feature_maps = executor_->GetModel();
|
||||
BuildStartFLJobRsp(fbb, schema::ResponseCode_SUCCEED, "success", true,
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_), feature_maps);
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)),
|
||||
feature_maps);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -153,13 +171,16 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
|||
std::map<std::string, AddressPtr> feature_maps) {
|
||||
auto fbs_reason = fbb->CreateString(reason);
|
||||
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
||||
auto fbs_server_mode = fbb->CreateString(PSContext::instance()->server_mode());
|
||||
auto fbs_fl_name = fbb->CreateString(PSContext::instance()->fl_name());
|
||||
|
||||
schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
|
||||
fl_plan_builder.add_fl_name(fbs_fl_name);
|
||||
fl_plan_builder.add_server_mode(fbs_server_mode);
|
||||
fl_plan_builder.add_iterations(PSContext::instance()->fl_iteration_num());
|
||||
fl_plan_builder.add_epochs(PSContext::instance()->client_epoch_num());
|
||||
fl_plan_builder.add_mini_batch(PSContext::instance()->client_batch_size());
|
||||
fl_plan_builder.add_lr(PSContext::instance()->client_learning_rate());
|
||||
auto fbs_fl_plan = fl_plan_builder.Finish();
|
||||
|
||||
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace server {
|
|||
namespace kernel {
|
||||
class StartFLJobKernel : public RoundKernel {
|
||||
public:
|
||||
StartFLJobKernel() = default;
|
||||
StartFLJobKernel() : executor_(nullptr), iteration_time_window_(0), iter_next_req_timestamp_(0) {}
|
||||
~StartFLJobKernel() override = default;
|
||||
|
||||
void InitKernel(size_t threshold_count) override;
|
||||
|
@ -40,6 +40,8 @@ class StartFLJobKernel : public RoundKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Reset() override;
|
||||
|
||||
void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
|
||||
|
||||
private:
|
||||
// Returns whether the startFLJob count of this iteration has reached the threshold.
|
||||
bool ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb);
|
||||
|
@ -66,6 +68,9 @@ class StartFLJobKernel : public RoundKernel {
|
|||
|
||||
// The time window of one iteration.
|
||||
size_t iteration_time_window_;
|
||||
|
||||
// Timestamp of next request time for this iteration.
|
||||
uint64_t iter_next_req_timestamp_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
|
|
|
@ -39,6 +39,7 @@ void UpdateModelKernel::InitKernel(size_t threshold_count) {
|
|||
PBMetadata client_list;
|
||||
DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxUpdateModelClientList, client_list);
|
||||
LocalMetaStore::GetInstance().put_value(kCtxUpdateModelThld, threshold_count);
|
||||
LocalMetaStore::GetInstance().put_value(kCtxFedAvgTotalDataSize, kInitialDataSizeSum);
|
||||
}
|
||||
|
||||
bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
|
@ -103,8 +104,9 @@ void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHand
|
|||
bool UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb) {
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
std::string reason = "Current amount for updateModel is enough.";
|
||||
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
BuildUpdateModelRsp(
|
||||
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
@ -117,8 +119,9 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
|
|||
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
|
||||
std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) +
|
||||
", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num());
|
||||
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
BuildUpdateModelRsp(
|
||||
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
@ -128,14 +131,24 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
|
|||
std::string update_model_fl_id = update_model_req->fl_id()->str();
|
||||
if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
|
||||
std::string reason = "devices_meta for " + update_model_fl_id + " is not set.";
|
||||
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
BuildUpdateModelRsp(
|
||||
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size();
|
||||
auto feature_map = ParseFeatureMap(update_model_req);
|
||||
if (feature_map.empty()) {
|
||||
std::string reason = "Feature map is empty.";
|
||||
BuildUpdateModelRsp(
|
||||
fbb, schema::ResponseCode_RequestError, reason,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto weight : feature_map) {
|
||||
weight.second[kNewDataSize].addr = &data_size;
|
||||
weight.second[kNewDataSize].size = sizeof(size_t);
|
||||
|
@ -146,10 +159,17 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
|
|||
fl_id.set_fl_id(update_model_fl_id);
|
||||
PBMetadata comm_value;
|
||||
*comm_value.mutable_fl_id() = fl_id;
|
||||
DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value);
|
||||
if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value)) {
|
||||
std::string reason = "Updating metadata of UpdateModelClientList failed.";
|
||||
BuildUpdateModelRsp(
|
||||
fbb, schema::ResponseCode_SystemError, reason,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
||||
BuildUpdateModelRsp(fbb, schema::ResponseCode_SucNotReady, "success not ready",
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
BuildUpdateModelRsp(fbb, schema::ResponseCode_SUCCEED, "success not ready",
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -174,8 +194,9 @@ bool UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fb
|
|||
const schema::RequestUpdateModel *update_model_req) {
|
||||
if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str())) {
|
||||
std::string reason = "UpdateModel counting failed.";
|
||||
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
BuildUpdateModelRsp(
|
||||
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -30,6 +30,9 @@ namespace mindspore {
|
|||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
// The initial data size sum of federated learning is 0, which will be accumulated in updateModel round.
|
||||
constexpr uint64_t kInitialDataSizeSum = 0;
|
||||
|
||||
class UpdateModelKernel : public RoundKernel {
|
||||
public:
|
||||
UpdateModelKernel() = default;
|
||||
|
|
|
@ -30,7 +30,8 @@ void ModelStore::Initialize(uint32_t max_count) {
|
|||
}
|
||||
|
||||
max_model_count_ = max_count;
|
||||
iteration_to_model_[kInitIterationNum] = AssignNewModelMemory();
|
||||
initial_model_ = AssignNewModelMemory();
|
||||
iteration_to_model_[kInitIterationNum] = initial_model_;
|
||||
model_size_ = ComputeModelSize();
|
||||
}
|
||||
|
||||
|
@ -52,7 +53,6 @@ bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::strin
|
|||
MS_LOG(ERROR) << "Memory for the new model is nullptr.";
|
||||
return false;
|
||||
}
|
||||
|
||||
iteration_to_model_[iteration] = memory_register;
|
||||
} else {
|
||||
// If iteration_to_model_ size is already max_model_count_, we need to replace earliest model with the newest model.
|
||||
|
@ -97,6 +97,12 @@ std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration
|
|||
return model;
|
||||
}
|
||||
|
||||
void ModelStore::Reset() {
|
||||
initial_model_ = iteration_to_model_.rbegin()->second;
|
||||
iteration_to_model_.clear();
|
||||
iteration_to_model_[kInitIterationNum] = initial_model_;
|
||||
}
|
||||
|
||||
const std::map<size_t, std::shared_ptr<MemoryRegister>> &ModelStore::iteration_to_model() const {
|
||||
return iteration_to_model_;
|
||||
}
|
||||
|
@ -121,6 +127,14 @@ std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
auto src_data_size = weight_size;
|
||||
auto dst_data_size = weight_size;
|
||||
int ret = memcpy_s(weight_data.get(), dst_data_size, weight.second->addr, src_data_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
memory_register->RegisterArray(weight_name, &weight_data, weight_size);
|
||||
}
|
||||
return memory_register;
|
||||
|
|
|
@ -49,6 +49,9 @@ class ModelStore {
|
|||
// Get model of the given iteration.
|
||||
std::map<std::string, AddressPtr> GetModelByIterNum(size_t iteration);
|
||||
|
||||
// Reset the stored models. Called when federated learning job finishes.
|
||||
void Reset();
|
||||
|
||||
// Returns all models stored in ModelStore.
|
||||
const std::map<size_t, std::shared_ptr<MemoryRegister>> &iteration_to_model() const;
|
||||
|
||||
|
@ -70,6 +73,11 @@ class ModelStore {
|
|||
|
||||
size_t max_model_count_;
|
||||
size_t model_size_;
|
||||
|
||||
// Initial model which is the model of iteration 0.
|
||||
std::shared_ptr<MemoryRegister> initial_model_;
|
||||
|
||||
// The number of all models stpred is max_model_count_.
|
||||
std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_;
|
||||
};
|
||||
} // namespace server
|
||||
|
|
|
@ -38,9 +38,9 @@ void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicat
|
|||
name_, [&](std::shared_ptr<core::MessageHandler> message) { LaunchRoundKernel(message); });
|
||||
|
||||
// Callback when the iteration is finished.
|
||||
finish_iteration_cb_ = [this, finish_iteration_cb](void) -> void {
|
||||
MS_LOG(INFO) << "Round " << name_ << " finished! Proceed to next iteration.";
|
||||
finish_iteration_cb();
|
||||
finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid) -> void {
|
||||
MS_LOG(INFO) << "Round " << name_ << " finished! This iteration is valid. Proceed to next iteration.";
|
||||
finish_iteration_cb(is_iteration_valid);
|
||||
};
|
||||
|
||||
// Callback for finalizing the server. This can only be called once.
|
||||
|
@ -50,9 +50,9 @@ void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicat
|
|||
iter_timer_ = std::make_shared<IterationTimer>();
|
||||
|
||||
// 1.Set the timeout callback for the timer.
|
||||
iter_timer_->SetTimeOutCallBack([this, timeout_cb](void) -> void {
|
||||
MS_LOG(INFO) << "Round " << name_ << " timeout! Proceed to next iteration.";
|
||||
timeout_cb();
|
||||
iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool is_iteration_valid) -> void {
|
||||
MS_LOG(INFO) << "Round " << name_ << " timeout! This iteration is invalid. Proceed to next iteration.";
|
||||
timeout_cb(is_iteration_valid);
|
||||
});
|
||||
|
||||
// 2.Stopping timer callback which will be set to the round kernel.
|
||||
|
@ -112,14 +112,19 @@ const std::string &Round::name() const { return name_; }
|
|||
|
||||
size_t Round::threshold_count() const { return threshold_count_; }
|
||||
|
||||
bool Round::check_timeout() const { return check_timeout_; }
|
||||
|
||||
size_t Round::time_window() const { return time_window_; }
|
||||
|
||||
void Round::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &) {
|
||||
void Round::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
MS_LOG(INFO) << "Round " << name_ << " first count event is triggered.";
|
||||
// The timer starts only after the first count event is triggered by DistributedCountService.
|
||||
if (check_timeout_) {
|
||||
iter_timer_->Start(std::chrono::milliseconds(time_window_));
|
||||
}
|
||||
|
||||
// Some kernels override the OnFirstCountEvent method.
|
||||
kernel_->OnFirstCountEvent(message);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -52,6 +52,7 @@ class Round {
|
|||
|
||||
const std::string &name() const;
|
||||
size_t threshold_count() const;
|
||||
bool check_timeout() const;
|
||||
size_t time_window() const;
|
||||
|
||||
private:
|
||||
|
|
|
@ -174,21 +174,22 @@ bool Server::InitCommunicatorWithWorker() {
|
|||
}
|
||||
|
||||
void Server::InitIteration() {
|
||||
iteration_ = std::make_shared<Iteration>();
|
||||
iteration_ = &Iteration::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(iteration_);
|
||||
|
||||
// 1.Add rounds to the iteration according to the server mode.
|
||||
for (const RoundConfig &config : rounds_config_) {
|
||||
std::shared_ptr<Round> round = std::make_shared<Round>(config.name, config.check_timeout, config.time_window,
|
||||
config.check_count, config.threshold_count);
|
||||
MS_LOG(INFO) << "Add round " << config.name << ", check_count: " << config.check_count
|
||||
<< ", threshold:" << config.threshold_count;
|
||||
MS_LOG(INFO) << "Add round " << config.name << ", check_timeout: " << config.check_timeout
|
||||
<< ", time window: " << config.time_window << ", check_count: " << config.check_count
|
||||
<< ", threshold: " << config.threshold_count;
|
||||
iteration_->AddRound(round);
|
||||
}
|
||||
|
||||
// 2.Initialize all the rounds.
|
||||
TimeOutCb time_out_cb = std::bind(&Iteration::ProceedToNextIter, iteration_);
|
||||
FinishIterCb finish_iter_cb = std::bind(&Iteration::ProceedToNextIter, iteration_);
|
||||
TimeOutCb time_out_cb = std::bind(&Iteration::ProceedToNextIter, iteration_, std::placeholders::_1);
|
||||
FinishIterCb finish_iter_cb = std::bind(&Iteration::ProceedToNextIter, iteration_, std::placeholders::_1);
|
||||
iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb);
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -117,7 +117,7 @@ class Server {
|
|||
std::vector<std::shared_ptr<core::CommunicatorBase>> communicators_with_worker_;
|
||||
|
||||
// Iteration consists of multiple kinds of rounds.
|
||||
std::shared_ptr<Iteration> iteration_;
|
||||
Iteration *iteration_;
|
||||
|
||||
// Variables set by ps context.
|
||||
std::string scheduler_ip_;
|
||||
|
|
|
@ -787,3 +787,59 @@ def reset_ps_context():
|
|||
- enable_ps: False.
|
||||
"""
|
||||
_reset_ps_context()
|
||||
|
||||
def set_fl_context(**kwargs):
|
||||
"""
|
||||
Set federated learning training mode context.
|
||||
|
||||
Args:
|
||||
enable_fl (bool): Whether to enable federated learning training mode.
|
||||
Default: False.
|
||||
server_mode (string): Describe the server mode, which must one of 'FEDERATED_LEARNING' and 'HYBRID_TRAINING'.
|
||||
Default: 'FEDERATED_LEARNING'.
|
||||
ms_role (string): The process's role in the federated learning mode,
|
||||
which must be one of 'MS_SERVER', 'MS_WORKER' and 'MS_SCHED'.
|
||||
Default: 'MS_NOT_PS'.
|
||||
worker_num (int): The number of workers. Default: 0.
|
||||
server_num (int): The number of federated learning servers. Default: 0.
|
||||
scheduler_ip (string): The scheduler IP. Default: ''.
|
||||
scheduler_port (int): The scheduler port. Default: 0.
|
||||
fl_server_port (int): The http port of the federated learning server.
|
||||
Normally for each server this should be set to the same value. Default: 0.
|
||||
enable_fl_client (bool): Whether this process is federated learning client. Default: False.
|
||||
start_fl_job_threshold (int): The threshold count of startFLJob. Default: 0.
|
||||
start_fl_job_time_window (int): The time window duration for startFLJob in millisecond. Default: 3000.
|
||||
update_model_ratio (float): The ratio for computing the threshold count of updateModel
|
||||
which will be multiplied by start_fl_job_threshold. Default: 1.0.
|
||||
update_model_time_window (int): The time window duration for updateModel in millisecond. Default: 3000.
|
||||
fl_name (string): The federated learning job name. Default: ''.
|
||||
fl_iteration_num (int): Iteration number of federeated learning,
|
||||
which is the number of interactions between client and server. Default: 20.
|
||||
client_epoch_num (int): Client training epoch number. Default: 25.
|
||||
client_batch_size (int): Client training data batch size. Default: 32.
|
||||
client_learning_rate (float): Client training learning rate. Default: 0.001.
|
||||
secure_aggregation (bool): Whether to use secure aggregation algorithm. Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not the attribute in federated learning mode context.
|
||||
|
||||
Examples:
|
||||
>>> context.set_fl_context(enable_fl=True, server_mode='FEDERATED_LEARNING')
|
||||
"""
|
||||
_set_ps_context(**kwargs)
|
||||
|
||||
|
||||
def get_fl_context(attr_key):
|
||||
"""
|
||||
Get federated learning mode context attribute value according to the key.
|
||||
|
||||
Args:
|
||||
attr_key (str): The key of the attribute.
|
||||
|
||||
Returns:
|
||||
Returns attribute value according to the key.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in federated learning mode context.
|
||||
"""
|
||||
return _get_ps_context(attr_key)
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
#include <unistd.h>
|
||||
#include <sys/time.h>
|
||||
#include <map>
|
||||
#include <iomanip>
|
||||
#include <thread>
|
||||
|
||||
// namespace to support utils module definition
|
||||
namespace mindspore {
|
||||
|
@ -117,8 +119,8 @@ void LogWriter::OutputLog(const std::ostringstream &msg) const {
|
|||
#define google mindspore_private
|
||||
auto submodule_name = GetSubModuleName(submodule_);
|
||||
google::LogMessage("", 0, GetGlogLevel(log_level_)).stream()
|
||||
<< "[" << GetLogLevel(log_level_) << "] " << submodule_name << "(" << getpid() << "," << GetProcName()
|
||||
<< "):" << GetTimeString() << " "
|
||||
<< "[" << GetLogLevel(log_level_) << "] " << submodule_name << "(" << getpid() << "," << std::hex
|
||||
<< std::this_thread::get_id() << std::dec << "," << GetProcName() << "):" << GetTimeString() << " "
|
||||
<< "[" << location_.file_ << ":" << location_.line_ << "] " << location_.func_ << "] " << msg.str() << std::endl;
|
||||
#undef google
|
||||
#else
|
||||
|
|
|
@ -36,6 +36,7 @@ _set_ps_context_func_map = {
|
|||
"server_mode": ps_context().set_server_mode,
|
||||
"ms_role": ps_context().set_ms_role,
|
||||
"enable_ps": ps_context().set_ps_enable,
|
||||
"enable_fl": ps_context().set_ps_enable,
|
||||
"worker_num": ps_context().set_worker_num,
|
||||
"server_num": ps_context().set_server_num,
|
||||
"scheduler_ip": ps_context().set_scheduler_ip,
|
||||
|
@ -43,10 +44,14 @@ _set_ps_context_func_map = {
|
|||
"fl_server_port": ps_context().set_fl_server_port,
|
||||
"enable_fl_client": ps_context().set_fl_client_enable,
|
||||
"start_fl_job_threshold": ps_context().set_start_fl_job_threshold,
|
||||
"start_fl_job_time_window": ps_context().set_start_fl_job_time_window,
|
||||
"update_model_ratio": ps_context().set_update_model_ratio,
|
||||
"update_model_time_window": ps_context().set_update_model_time_window,
|
||||
"fl_name": ps_context().set_fl_name,
|
||||
"fl_iteration_num": ps_context().set_fl_iteration_num,
|
||||
"client_epoch_num": ps_context().set_client_epoch_num,
|
||||
"client_batch_size": ps_context().set_client_batch_size,
|
||||
"client_learning_rate": ps_context().set_client_learning_rate,
|
||||
"secure_aggregation": ps_context().set_secure_aggregation,
|
||||
"enable_ps_ssl": ps_context().set_enable_ssl
|
||||
}
|
||||
|
|
|
@ -69,6 +69,7 @@ table ResponseFLJob {
|
|||
}
|
||||
|
||||
table FLPlan {
|
||||
server_mode:string;
|
||||
fl_name:string;
|
||||
iterations:int;
|
||||
epochs:int;
|
||||
|
|
|
@ -26,10 +26,14 @@ parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
|||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
parser.add_argument("--fl_server_port", type=int, default=6666)
|
||||
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
|
||||
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
|
||||
parser.add_argument("--update_model_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--update_model_time_window", type=int, default=3000)
|
||||
parser.add_argument("--fl_name", type=str, default="Lenet")
|
||||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||
parser.add_argument("--client_learning_rate", type=float, default=0.1)
|
||||
parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False)
|
||||
parser.add_argument("--local_server_num", type=int, default=-1)
|
||||
|
||||
|
@ -43,10 +47,14 @@ if __name__ == "__main__":
|
|||
scheduler_port = args.scheduler_port
|
||||
fl_server_port = args.fl_server_port
|
||||
start_fl_job_threshold = args.start_fl_job_threshold
|
||||
start_fl_job_time_window = args.start_fl_job_time_window
|
||||
update_model_ratio = args.update_model_ratio
|
||||
update_model_time_window = args.update_model_time_window
|
||||
fl_name = args.fl_name
|
||||
fl_iteration_num = args.fl_iteration_num
|
||||
client_epoch_num = args.client_epoch_num
|
||||
client_batch_size = args.client_batch_size
|
||||
client_learning_rate = args.client_learning_rate
|
||||
secure_aggregation = args.secure_aggregation
|
||||
local_server_num = args.local_server_num
|
||||
|
||||
|
@ -70,10 +78,14 @@ if __name__ == "__main__":
|
|||
cmd_server += " --scheduler_port=" + str(scheduler_port)
|
||||
cmd_server += " --fl_server_port=" + str(fl_server_port + i)
|
||||
cmd_server += " --start_fl_job_threshold=" + str(start_fl_job_threshold)
|
||||
cmd_server += " --start_fl_job_time_window=" + str(start_fl_job_time_window)
|
||||
cmd_server += " --update_model_ratio=" + str(update_model_ratio)
|
||||
cmd_server += " --update_model_time_window=" + str(update_model_time_window)
|
||||
cmd_server += " --fl_name=" + fl_name
|
||||
cmd_server += " --fl_iteration_num=" + str(fl_iteration_num)
|
||||
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
|
||||
cmd_server += " --client_batch_size=" + str(client_batch_size)
|
||||
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
|
||||
cmd_server += " --secure_aggregation=" + str(secure_aggregation)
|
||||
cmd_server += " > server.log 2>&1 &"
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
import argparse
|
||||
import time
|
||||
import datetime
|
||||
import random
|
||||
import sys
|
||||
import requests
|
||||
|
@ -129,7 +130,15 @@ def build_get_model(iteration):
|
|||
buf = builder_get_model.Output()
|
||||
return buf
|
||||
|
||||
weight_name_to_idx = {
|
||||
def datetime_to_timestamp(datetime_obj):
|
||||
"""将本地(local) datetime 格式的时间 (含毫秒) 转为毫秒时间戳
|
||||
:param datetime_obj: {datetime}2016-02-25 20:21:04.242000
|
||||
:return: 13 位的毫秒时间戳 1456402864242
|
||||
"""
|
||||
local_timestamp = time.mktime(datetime_obj.timetuple()) * 1000.0 + datetime_obj.microsecond // 1000.0
|
||||
return local_timestamp
|
||||
|
||||
weight_to_idx = {
|
||||
"conv1.weight": 0,
|
||||
"conv2.weight": 1,
|
||||
"fc1.weight": 2,
|
||||
|
@ -149,11 +158,12 @@ while True:
|
|||
print("start url is ", url1)
|
||||
x = requests.post(url1, data=build_start_fl_job(current_iteration))
|
||||
rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
|
||||
print("start fl job iteration:", current_iteration, ", id:", args.pid)
|
||||
while rsp_fl_job.Retcode() != ResponseCode.ResponseCode.SUCCEED:
|
||||
x = requests.post(url1, data=build_start_fl_job(current_iteration))
|
||||
rsp_fl_job = rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
|
||||
rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
|
||||
print("epoch is", rsp_fl_job.FlPlanConfig().Epochs())
|
||||
print("iteration is", rsp_fl_job.Iteration())
|
||||
current_iteration = rsp_fl_job.Iteration()
|
||||
sys.stdout.flush()
|
||||
|
||||
url2 = "http://" + http_ip + ":" + str(generate_port()) + '/updateModel'
|
||||
|
@ -170,22 +180,40 @@ while True:
|
|||
print("rsp get model iteration:", current_iteration, ", id:", args.pid, rsp_get_model.Retcode())
|
||||
sys.stdout.flush()
|
||||
|
||||
repeat_time = 0
|
||||
while rsp_get_model.Retcode() == ResponseCode.ResponseCode.SucNotReady:
|
||||
time.sleep(0.1)
|
||||
x = session.post(url3, data=build_get_model(current_iteration))
|
||||
rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
|
||||
repeat_time += 1
|
||||
if repeat_time > 1000:
|
||||
print("GetModel try timeout ", args.pid)
|
||||
sys.exit(0)
|
||||
|
||||
for i in range(0, 1):
|
||||
print(rsp_get_model.FeatureMap(i).WeightFullname())
|
||||
origin = update_model_np_data[weight_name_to_idx[rsp_get_model.FeatureMap(i).WeightFullname().decode('utf-8')]]
|
||||
after = rsp_get_model.FeatureMap(i).DataAsNumpy() * 32
|
||||
print("Before update model", args.pid, origin[0:10])
|
||||
print("After get model", args.pid, after[0:10])
|
||||
next_req_timestamp = 0
|
||||
if rsp_get_model.Retcode() == ResponseCode.ResponseCode.OutOfTime:
|
||||
next_req_timestamp = int(rsp_get_model.Timestamp().decode('utf-8'))
|
||||
print("Last iteration is invalid, next request timestamp:", next_req_timestamp)
|
||||
sys.stdout.flush()
|
||||
assert np.allclose(origin, after, rtol=1e-05, atol=1e-05)
|
||||
current_iteration += 1
|
||||
elif rsp_get_model.Retcode() == ResponseCode.ResponseCode.SucNotReady:
|
||||
repeat_time = 0
|
||||
while rsp_get_model.Retcode() == ResponseCode.ResponseCode.SucNotReady:
|
||||
time.sleep(0.2)
|
||||
x = session.post(url3, data=build_get_model(current_iteration))
|
||||
rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
|
||||
if rsp_get_model.Retcode() == ResponseCode.ResponseCode.OutOfTime:
|
||||
next_req_timestamp = int(rsp_get_model.Timestamp().decode('utf-8'))
|
||||
print("Last iteration is invalid, next request timestamp:", next_req_timestamp)
|
||||
sys.stdout.flush()
|
||||
break
|
||||
repeat_time += 1
|
||||
if repeat_time > 1000:
|
||||
print("GetModel try timeout ", args.pid)
|
||||
sys.exit(0)
|
||||
else:
|
||||
pass
|
||||
|
||||
if next_req_timestamp == 0:
|
||||
for i in range(0, 1):
|
||||
print(rsp_get_model.FeatureMap(i).WeightFullname())
|
||||
origin = update_model_np_data[weight_to_idx[rsp_get_model.FeatureMap(i).WeightFullname().decode('utf-8')]]
|
||||
after = rsp_get_model.FeatureMap(i).DataAsNumpy() * 32
|
||||
print("Before update model", args.pid, origin[0:10])
|
||||
print("After get model", args.pid, after[0:10])
|
||||
sys.stdout.flush()
|
||||
assert np.allclose(origin, after, rtol=1e-05, atol=1e-05)
|
||||
else:
|
||||
# Sleep to the next request timestamp
|
||||
current_ts = datetime_to_timestamp(datetime.datetime.now())
|
||||
duration = next_req_timestamp - current_ts
|
||||
time.sleep(duration / 1000)
|
||||
|
|
|
@ -34,10 +34,14 @@ parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
|||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
parser.add_argument("--fl_server_port", type=int, default=6666)
|
||||
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
|
||||
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
|
||||
parser.add_argument("--update_model_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--update_model_time_window", type=int, default=3000)
|
||||
parser.add_argument("--fl_name", type=str, default="Lenet")
|
||||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||
parser.add_argument("--client_learning_rate", type=float, default=0.1)
|
||||
parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
@ -50,14 +54,18 @@ scheduler_ip = args.scheduler_ip
|
|||
scheduler_port = args.scheduler_port
|
||||
fl_server_port = args.fl_server_port
|
||||
start_fl_job_threshold = args.start_fl_job_threshold
|
||||
start_fl_job_time_window = args.start_fl_job_time_window
|
||||
update_model_ratio = args.update_model_ratio
|
||||
update_model_time_window = args.update_model_time_window
|
||||
fl_name = args.fl_name
|
||||
fl_iteration_num = args.fl_iteration_num
|
||||
client_epoch_num = args.client_epoch_num
|
||||
client_batch_size = args.client_batch_size
|
||||
client_learning_rate = args.client_learning_rate
|
||||
secure_aggregation = args.secure_aggregation
|
||||
|
||||
ctx = {
|
||||
"enable_ps": False,
|
||||
"enable_fl": True,
|
||||
"server_mode": server_mode,
|
||||
"ms_role": ms_role,
|
||||
"worker_num": worker_num,
|
||||
|
@ -66,15 +74,19 @@ ctx = {
|
|||
"scheduler_port": scheduler_port,
|
||||
"fl_server_port": fl_server_port,
|
||||
"start_fl_job_threshold": start_fl_job_threshold,
|
||||
"start_fl_job_time_window": start_fl_job_time_window,
|
||||
"update_model_ratio": update_model_ratio,
|
||||
"update_model_time_window": update_model_time_window,
|
||||
"fl_name": fl_name,
|
||||
"fl_iteration_num": fl_iteration_num,
|
||||
"client_epoch_num": client_epoch_num,
|
||||
"client_batch_size": client_batch_size,
|
||||
"client_learning_rate": client_learning_rate,
|
||||
"secure_aggregation": secure_aggregation
|
||||
}
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)
|
||||
context.set_ps_context(**ctx)
|
||||
context.set_fl_context(**ctx)
|
||||
|
||||
if __name__ == "__main__":
|
||||
epoch = 5
|
||||
|
|
Loading…
Reference in New Issue