!16230 Optimize server reliability in multiple scenarios.

From: @zpac
Reviewed-by: @cristoval,@limingqi107
Signed-off-by: @limingqi107
This commit is contained in:
mindspore-ci-bot 2021-05-15 16:50:40 +08:00 committed by Gitee
commit 94ca479fbe
33 changed files with 435 additions and 142 deletions

View File

@ -629,14 +629,17 @@ bool StartServerAction(const ResourcePtr &res) {
uint64_t fl_server_port = ps::PSContext::instance()->fl_server_port(); uint64_t fl_server_port = ps::PSContext::instance()->fl_server_port();
// Update model threshold is a certain ratio of start_fl_job threshold. // Update model threshold is a certain ratio of start_fl_job threshold.
// update_model_threshold_ = start_fl_job_threshold_ * percent_for_update_model_. // update_model_threshold = start_fl_job_threshold * update_model_ratio.
size_t start_fl_job_threshold = ps::PSContext::instance()->start_fl_job_threshold(); size_t start_fl_job_threshold = ps::PSContext::instance()->start_fl_job_threshold();
float percent_for_update_model = 1; float update_model_ratio = ps::PSContext::instance()->update_model_ratio();
size_t update_model_threshold = static_cast<size_t>(std::ceil(start_fl_job_threshold * percent_for_update_model)); size_t update_model_threshold = static_cast<size_t>(std::ceil(start_fl_job_threshold * update_model_ratio));
uint64_t start_fl_job_time_window = ps::PSContext::instance()->start_fl_job_time_window();
uint64_t update_model_time_window = ps::PSContext::instance()->update_model_time_window();
std::vector<ps::server::RoundConfig> rounds_config = {{"startFLJob", false, 3000, true, start_fl_job_threshold}, std::vector<ps::server::RoundConfig> rounds_config = {
{"updateModel", false, 3000, true, update_model_threshold}, {"startFLJob", true, start_fl_job_time_window, true, start_fl_job_threshold},
{"getModel", false, 3000}}; {"updateModel", true, update_model_time_window, true, update_model_threshold},
{"getModel"}};
size_t executor_threshold = 0; size_t executor_threshold = 0;
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) { if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {

View File

@ -345,11 +345,20 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_scheduler_port", &PSContext::set_scheduler_port, "Set scheduler port.") .def("set_scheduler_port", &PSContext::set_scheduler_port, "Set scheduler port.")
.def("set_fl_server_port", &PSContext::set_fl_server_port, "Set federated learning server port.") .def("set_fl_server_port", &PSContext::set_fl_server_port, "Set federated learning server port.")
.def("set_fl_client_enable", &PSContext::set_fl_client_enable, "Set federated learning client.") .def("set_fl_client_enable", &PSContext::set_fl_client_enable, "Set federated learning client.")
.def("set_start_fl_job_threshold", &PSContext::set_start_fl_job_threshold, "Set threshold count for start_fl_job.") .def("set_start_fl_job_threshold", &PSContext::set_start_fl_job_threshold,
"Set threshold count for startFLJob round.")
.def("set_start_fl_job_time_window", &PSContext::set_start_fl_job_time_window,
"Set time window for startFLJob round.")
.def("set_update_model_ratio", &PSContext::set_update_model_ratio,
"Set threshold count ratio for updateModel round.")
.def("set_update_model_time_window", &PSContext::set_update_model_time_window,
"Set time window for updateModel round.")
.def("set_fl_name", &PSContext::set_fl_name, "Set federated learning name.") .def("set_fl_name", &PSContext::set_fl_name, "Set federated learning name.")
.def("set_fl_iteration_num", &PSContext::set_fl_iteration_num, "Set federated learning iteration number.") .def("set_fl_iteration_num", &PSContext::set_fl_iteration_num, "Set federated learning iteration number.")
.def("set_client_epoch_num", &PSContext::set_client_epoch_num, "Set federated learning client epoch number.") .def("set_client_epoch_num", &PSContext::set_client_epoch_num, "Set federated learning client epoch number.")
.def("set_client_batch_size", &PSContext::set_client_batch_size, "Set federated learning client batch size.") .def("set_client_batch_size", &PSContext::set_client_batch_size, "Set federated learning client batch size.")
.def("set_client_learning_rate", &PSContext::set_client_learning_rate,
"Set federated learning client learning rate.")
.def("set_secure_aggregation", &PSContext::set_secure_aggregation, .def("set_secure_aggregation", &PSContext::set_secure_aggregation,
"Set federated learning client using secure aggregation.") "Set federated learning client using secure aggregation.")
.def("set_enable_ssl", &PSContext::enable_ssl, "Set PS SSL mode enabled or disabled."); .def("set_enable_ssl", &PSContext::enable_ssl, "Set PS SSL mode enabled or disabled.");

View File

@ -196,7 +196,14 @@ void PSContext::set_ms_role(const std::string &role) {
role_ = role; role_ = role;
} }
void PSContext::set_worker_num(uint32_t worker_num) { worker_num_ = worker_num; } void PSContext::set_worker_num(uint32_t worker_num) {
// Hybrid training mode only supports one worker for now.
if (server_mode_ == kServerModeHybrid && worker_num != 1) {
MS_LOG(EXCEPTION) << "The worker number should be set to 1 in hybrid training mode.";
return;
}
worker_num_ = worker_num;
}
uint32_t PSContext::worker_num() const { return worker_num_; } uint32_t PSContext::worker_num() const { return worker_num_; }
void PSContext::set_server_num(uint32_t server_num) { void PSContext::set_server_num(uint32_t server_num) {
@ -235,7 +242,7 @@ void PSContext::GenerateResetterRound() {
} }
binary_server_context = (is_parameter_server_mode << 0) | (is_federated_learning_mode << 1) | binary_server_context = (is_parameter_server_mode << 0) | (is_federated_learning_mode << 1) |
(is_mixed_training_mode << 2) | (secure_aggregation_ << 3) | (worker_upload_weights_ << 4); (is_mixed_training_mode << 2) | (secure_aggregation_ << 3);
if (kServerContextToResetRoundMap.count(binary_server_context) == 0) { if (kServerContextToResetRoundMap.count(binary_server_context) == 0) {
resetter_round_ = ResetterRound::kNoNeedToReset; resetter_round_ = ResetterRound::kNoNeedToReset;
} else { } else {
@ -255,11 +262,27 @@ void PSContext::set_fl_client_enable(bool enabled) { fl_client_enable_ = enabled
bool PSContext::fl_client_enable() { return fl_client_enable_; } bool PSContext::fl_client_enable() { return fl_client_enable_; }
void PSContext::set_start_fl_job_threshold(size_t start_fl_job_threshold) { void PSContext::set_start_fl_job_threshold(uint64_t start_fl_job_threshold) {
start_fl_job_threshold_ = start_fl_job_threshold; start_fl_job_threshold_ = start_fl_job_threshold;
} }
size_t PSContext::start_fl_job_threshold() const { return start_fl_job_threshold_; } uint64_t PSContext::start_fl_job_threshold() const { return start_fl_job_threshold_; }
void PSContext::set_start_fl_job_time_window(uint64_t start_fl_job_time_window) {
start_fl_job_time_window_ = start_fl_job_time_window;
}
uint64_t PSContext::start_fl_job_time_window() const { return start_fl_job_time_window_; }
void PSContext::set_update_model_ratio(float update_model_ratio) { update_model_ratio_ = update_model_ratio; }
float PSContext::update_model_ratio() const { return update_model_ratio_; }
void PSContext::set_update_model_time_window(uint64_t update_model_time_window) {
update_model_time_window_ = update_model_time_window;
}
uint64_t PSContext::update_model_time_window() const { return update_model_time_window_; }
void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; } void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
@ -277,11 +300,9 @@ void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch
uint64_t PSContext::client_batch_size() const { return client_batch_size_; } uint64_t PSContext::client_batch_size() const { return client_batch_size_; }
void PSContext::set_worker_upload_weights(uint64_t worker_upload_weights) { void PSContext::set_client_learning_rate(float client_learning_rate) { client_learning_rate_ = client_learning_rate; }
worker_upload_weights_ = worker_upload_weights;
}
uint64_t PSContext::worker_upload_weights() const { return worker_upload_weights_; } float PSContext::client_learning_rate() const { return client_learning_rate_; }
void PSContext::set_secure_aggregation(bool secure_aggregation) { secure_aggregation_ = secure_aggregation; } void PSContext::set_secure_aggregation(bool secure_aggregation) { secure_aggregation_ = secure_aggregation; }

View File

@ -41,15 +41,13 @@ constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
// 1: Server is in federated learning mode. // 1: Server is in federated learning mode.
// 2: Server is in mixed training mode. // 2: Server is in mixed training mode.
// 3: Server enables sucure aggregation. // 3: Server enables sucure aggregation.
// 4: Server needs worker to overwrite weights. // For example: 1010 stands for that the server is in federated learning mode and sucure aggregation is enabled.
// For example: 01010 stands for that the server is in federated learning mode and sucure aggregation is enabled. enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kWorkerUploadWeights };
enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kWorkerOverwriteWeights }; const std::map<uint32_t, ResetterRound> kServerContextToResetRoundMap = {{0b0010, ResetterRound::kUpdateModel},
const std::map<uint32_t, ResetterRound> kServerContextToResetRoundMap = { {0b1010, ResetterRound::kReconstructSeccrets},
{0b00010, ResetterRound::kUpdateModel}, {0b1100, ResetterRound::kWorkerUploadWeights},
{0b01010, ResetterRound::kReconstructSeccrets}, {0b0100, ResetterRound::kWorkerUploadWeights},
{0b11100, ResetterRound::kWorkerOverwriteWeights}, {0b0100, ResetterRound::kUpdateModel}};
{0b10100, ResetterRound::kWorkerOverwriteWeights},
{0b00100, ResetterRound::kUpdateModel}};
class PSContext { class PSContext {
public: public:
@ -115,8 +113,17 @@ class PSContext {
void set_fl_client_enable(bool enabled); void set_fl_client_enable(bool enabled);
bool fl_client_enable(); bool fl_client_enable();
void set_start_fl_job_threshold(size_t start_fl_job_threshold); void set_start_fl_job_threshold(uint64_t start_fl_job_threshold);
size_t start_fl_job_threshold() const; uint64_t start_fl_job_threshold() const;
void set_start_fl_job_time_window(uint64_t start_fl_job_time_window);
uint64_t start_fl_job_time_window() const;
void set_update_model_ratio(float update_model_ratio);
float update_model_ratio() const;
void set_update_model_time_window(uint64_t update_model_time_window);
uint64_t update_model_time_window() const;
void set_fl_name(const std::string &fl_name); void set_fl_name(const std::string &fl_name);
const std::string &fl_name() const; const std::string &fl_name() const;
@ -133,9 +140,8 @@ class PSContext {
void set_client_batch_size(uint64_t client_batch_size); void set_client_batch_size(uint64_t client_batch_size);
uint64_t client_batch_size() const; uint64_t client_batch_size() const;
// Set true if worker will overwrite weights on server. Used in hybrid training. void set_client_learning_rate(float client_learning_rate);
void set_worker_upload_weights(uint64_t worker_upload_weights); float client_learning_rate() const;
uint64_t worker_upload_weights() const;
// Set true if using secure aggregation for federated learning. // Set true if using secure aggregation for federated learning.
void set_secure_aggregation(bool secure_aggregation); void set_secure_aggregation(bool secure_aggregation);
@ -160,11 +166,14 @@ class PSContext {
fl_client_enable_(false), fl_client_enable_(false),
fl_name_(""), fl_name_(""),
start_fl_job_threshold_(0), start_fl_job_threshold_(0),
fl_iteration_num_(0), start_fl_job_time_window_(3000),
client_epoch_num_(0), update_model_ratio_(1.0),
client_batch_size_(0), update_model_time_window_(3000),
secure_aggregation_(false), fl_iteration_num_(20),
worker_upload_weights_(false) {} client_epoch_num_(25),
client_batch_size_(32),
client_learning_rate_(0.001),
secure_aggregation_(false) {}
bool ps_enabled_; bool ps_enabled_;
bool is_worker_; bool is_worker_;
bool is_pserver_; bool is_pserver_;
@ -195,7 +204,16 @@ class PSContext {
std::string fl_name_; std::string fl_name_;
// The threshold count of startFLJob round. Used in federated learning for now. // The threshold count of startFLJob round. Used in federated learning for now.
size_t start_fl_job_threshold_; uint64_t start_fl_job_threshold_;
// The time window of startFLJob round in millisecond.
uint64_t start_fl_job_time_window_;
// Update model threshold is a certain ratio of start_fl_job threshold which is set as update_model_ratio_.
float update_model_ratio_;
// The time window of updateModel round in millisecond.
uint64_t update_model_time_window_;
// Iteration number of federeated learning, which is the number of interactions between client and server. // Iteration number of federeated learning, which is the number of interactions between client and server.
uint64_t fl_iteration_num_; uint64_t fl_iteration_num_;
@ -206,12 +224,11 @@ class PSContext {
// Client training data batch size. Used in federated learning for now. // Client training data batch size. Used in federated learning for now.
uint64_t client_batch_size_; uint64_t client_batch_size_;
// Client training learning rate. Used in federated learning for now.
float client_learning_rate_;
// Whether to use secure aggregation algorithm. Used in federated learning for now. // Whether to use secure aggregation algorithm. Used in federated learning for now.
bool secure_aggregation_; bool secure_aggregation_;
// Whether there's a federated learning worker uploading weights to federated learning server. Used in hybrid training
// mode for now.
bool worker_upload_weights_;
}; };
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

View File

@ -56,9 +56,9 @@ using mindspore::kernel::Address;
using mindspore::kernel::AddressPtr; using mindspore::kernel::AddressPtr;
using mindspore::kernel::CPUKernel; using mindspore::kernel::CPUKernel;
using FBBuilder = flatbuffers::FlatBufferBuilder; using FBBuilder = flatbuffers::FlatBufferBuilder;
using TimeOutCb = std::function<void(void)>; using TimeOutCb = std::function<void(bool)>;
using StopTimerCb = std::function<void(void)>; using StopTimerCb = std::function<void(void)>;
using FinishIterCb = std::function<void(void)>; using FinishIterCb = std::function<void(bool)>;
using FinalizeCb = std::function<void(void)>; using FinalizeCb = std::function<void(void)>;
using MessageCallback = std::function<void(const std::shared_ptr<core::MessageHandler> &)>; using MessageCallback = std::function<void(const std::shared_ptr<core::MessageHandler> &)>;
@ -148,6 +148,7 @@ constexpr size_t kExecutorMaxTaskNum = 32;
constexpr int kHttpSuccess = 200; constexpr int kHttpSuccess = 200;
constexpr auto kPBProtocol = "PB"; constexpr auto kPBProtocol = "PB";
constexpr auto kFBSProtocol = "FBS"; constexpr auto kFBSProtocol = "FBS";
constexpr auto kSuccess = "Success";
constexpr auto kFedAvg = "FedAvg"; constexpr auto kFedAvg = "FedAvg";
constexpr auto kAggregationKernelType = "Aggregation"; constexpr auto kAggregationKernelType = "Aggregation";
constexpr auto kOptimizerKernelType = "Optimizer"; constexpr auto kOptimizerKernelType = "Optimizer";
@ -155,6 +156,7 @@ constexpr auto kCtxFuncGraph = "FuncGraph";
constexpr auto kCtxIterNum = "iteration"; constexpr auto kCtxIterNum = "iteration";
constexpr auto kCtxDeviceMetas = "device_metas"; constexpr auto kCtxDeviceMetas = "device_metas";
constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration"; constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration";
constexpr auto kCtxIterationNextRequestTimestamp = "iteration_next_request_timestamp";
constexpr auto kCtxUpdateModelClientList = "update_model_client_list"; constexpr auto kCtxUpdateModelClientList = "update_model_client_list";
constexpr auto kCtxUpdateModelClientNum = "update_model_client_num"; constexpr auto kCtxUpdateModelClientNum = "update_model_client_num";
constexpr auto kCtxUpdateModelThld = "update_model_threshold"; constexpr auto kCtxUpdateModelThld = "update_model_threshold";

View File

@ -130,7 +130,7 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) {
void DistributedCountService::ResetCounter(const std::string &name) { void DistributedCountService::ResetCounter(const std::string &name) {
if (local_rank_ == counting_server_rank_) { if (local_rank_ == counting_server_rank_) {
MS_LOG(INFO) << "Leader server reset count for " << name; MS_LOG(DEBUG) << "Leader server reset count for " << name;
global_current_count_[name].clear(); global_current_count_[name].clear();
} }
return; return;
@ -233,7 +233,7 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<core::Mes
const auto &type = counter_event.type(); const auto &type = counter_event.type();
const auto &name = counter_event.name(); const auto &name = counter_event.name();
MS_LOG(INFO) << "Rank " << local_rank_ << " do counter event " << type << " for " << name; MS_LOG(DEBUG) << "Rank " << local_rank_ << " do counter event " << type << " for " << name;
if (type == CounterEventType::FIRST_CNT) { if (type == CounterEventType::FIRST_CNT) {
counter_handlers_[name].first_count_handler(message); counter_handlers_[name].first_count_handler(message);
} else if (type == CounterEventType::LAST_CNT) { } else if (type == CounterEventType::LAST_CNT) {
@ -259,7 +259,7 @@ void DistributedCountService::TriggerCounterEvent(const std::string &name) {
} }
void DistributedCountService::TriggerFirstCountEvent(const std::string &name) { void DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
MS_LOG(INFO) << "Activating first count event for " << name; MS_LOG(DEBUG) << "Activating first count event for " << name;
CounterEvent first_count_event; CounterEvent first_count_event;
first_count_event.set_type(CounterEventType::FIRST_CNT); first_count_event.set_type(CounterEventType::FIRST_CNT);
first_count_event.set_name(name); first_count_event.set_name(name);

View File

@ -79,10 +79,10 @@ void DistributedMetadataStore::ResetMetadata(const std::string &name) {
return; return;
} }
void DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta) { bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta) {
if (router_ == nullptr) { if (router_ == nullptr) {
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet."; MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
return; return false;
} }
uint32_t stored_rank = router_->Find(name); uint32_t stored_rank = router_->Find(name);
@ -90,18 +90,26 @@ void DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM
if (local_rank_ == stored_rank) { if (local_rank_ == stored_rank) {
if (!DoUpdateMetadata(name, meta)) { if (!DoUpdateMetadata(name, meta)) {
MS_LOG(ERROR) << "Updating meta data failed."; MS_LOG(ERROR) << "Updating meta data failed.";
return; return false;
} }
} else { } else {
PBMetadataWithName metadata_with_name; PBMetadataWithName metadata_with_name;
metadata_with_name.set_name(name); metadata_with_name.set_name(name);
*metadata_with_name.mutable_metadata() = meta; *metadata_with_name.mutable_metadata() = meta;
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, core::TcpUserCommand::kUpdateMetadata)) { std::shared_ptr<std::vector<unsigned char>> update_meta_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, core::TcpUserCommand::kUpdateMetadata,
&update_meta_rsp_msg)) {
MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed."; MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed.";
return; return false;
}
std::string update_meta_rsp = reinterpret_cast<const char *>(update_meta_rsp_msg->data());
if (update_meta_rsp != kSuccess) {
MS_LOG(ERROR) << "Updating metadata in server " << stored_rank << " failed.";
return false;
} }
} }
return; return true;
} }
PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) { PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
@ -166,6 +174,7 @@ void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr
std::string update_meta_rsp_msg; std::string update_meta_rsp_msg;
if (!DoUpdateMetadata(name, meta_with_name.metadata())) { if (!DoUpdateMetadata(name, meta_with_name.metadata())) {
update_meta_rsp_msg = "Updating meta data failed."; update_meta_rsp_msg = "Updating meta data failed.";
MS_LOG(ERROR) << update_meta_rsp_msg;
} else { } else {
update_meta_rsp_msg = "Success"; update_meta_rsp_msg = "Success";
} }

View File

@ -52,7 +52,7 @@ class DistributedMetadataStore {
void ResetMetadata(const std::string &name); void ResetMetadata(const std::string &name);
// Update the metadata for the name. // Update the metadata for the name.
void UpdateMetadata(const std::string &name, const PBMetadata &meta); bool UpdateMetadata(const std::string &name, const PBMetadata &meta);
// Get the metadata for the name. // Get the metadata for the name.
PBMetadata GetMetadata(const std::string &name); PBMetadata GetMetadata(const std::string &name);

View File

@ -23,8 +23,6 @@
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
namespace server { namespace server {
Iteration::Iteration() : iteration_num_(1) { LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); }
void Iteration::AddRound(const std::shared_ptr<Round> &round) { void Iteration::AddRound(const std::shared_ptr<Round> &round) {
MS_EXCEPTION_IF_NULL(round); MS_EXCEPTION_IF_NULL(round);
rounds_.push_back(round); rounds_.push_back(round);
@ -49,28 +47,48 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<core::CommunicatorB
// The time window for one iteration, which will be used in some round kernels. // The time window for one iteration, which will be used in some round kernels.
size_t iteration_time_window = size_t iteration_time_window =
std::accumulate(rounds_.begin(), rounds_.end(), 0, std::accumulate(rounds_.begin(), rounds_.end(), 0, [](size_t total, const std::shared_ptr<Round> &round) {
[](size_t total, const std::shared_ptr<Round> &round) { return total + round->time_window(); }); return round->check_timeout() ? total + round->time_window() : total;
});
LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window); LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window);
MS_LOG(INFO) << "Time window for one iteration is " << iteration_time_window;
return; return;
} }
void Iteration::ProceedToNextIter() { void Iteration::ProceedToNextIter(bool is_iteration_valid) {
iteration_num_ = LocalMetaStore::GetInstance().curr_iter_num(); iteration_num_ = LocalMetaStore::GetInstance().curr_iter_num();
// Store the model for each iteration. if (is_iteration_valid) {
const auto &model = Executor::GetInstance().GetModel(); // Store the model which is successfully aggregated for this iteration.
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); const auto &model = Executor::GetInstance().GetModel();
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
} else {
// Store last iteration's model because this iteration is considered as invalid.
const auto &model = ModelStore::GetInstance().GetModelByIterNum(iteration_num_ - 1);
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid.";
}
for (auto &round : rounds_) { for (auto &round : rounds_) {
round->Reset(); round->Reset();
} }
iteration_num_++; iteration_num_++;
// After the job is done, reset the iteration to the initial number and reset ModelStore.
if (iteration_num_ > PSContext::instance()->fl_iteration_num()) {
MS_LOG(INFO) << PSContext::instance()->fl_iteration_num() << " iterations are completed.";
iteration_num_ = 1;
ModelStore::GetInstance().Reset();
}
is_last_iteration_valid_ = is_iteration_valid;
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
MS_LOG(INFO) << "Proceed to next iteration:" << iteration_num_ << "\n"; MS_LOG(INFO) << "Proceed to next iteration:" << iteration_num_ << "\n";
} }
const std::vector<std::shared_ptr<Round>> &Iteration::rounds() { return rounds_; } const std::vector<std::shared_ptr<Round>> &Iteration::rounds() { return rounds_; }
bool Iteration::is_last_iteration_valid() const { return is_last_iteration_valid_; }
} // namespace server } // namespace server
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

View File

@ -31,8 +31,10 @@ namespace server {
// Rounds, only after all the rounds are finished, this iteration is considered as completed. // Rounds, only after all the rounds are finished, this iteration is considered as completed.
class Iteration { class Iteration {
public: public:
Iteration(); static Iteration &GetInstance() {
~Iteration() = default; static Iteration instance;
return instance;
}
// Add a round for the iteration. This method will be called multiple times for each round. // Add a round for the iteration. This method will be called multiple times for each round.
void AddRound(const std::shared_ptr<Round> &round); void AddRound(const std::shared_ptr<Round> &round);
@ -41,16 +43,29 @@ class Iteration {
void InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators, void InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators,
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb); const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb);
// The server proceeds to the next iteration only after the last iteration finishes. // The server proceeds to the next iteration only after the last round finishes or the timer expires.
void ProceedToNextIter(); // If the timer expires, we consider this iteration as invalid.
void ProceedToNextIter(bool is_iteration_valid);
const std::vector<std::shared_ptr<Round>> &rounds(); const std::vector<std::shared_ptr<Round>> &rounds();
bool is_last_iteration_valid() const;
private: private:
Iteration() : iteration_num_(1), is_last_iteration_valid_(true) {
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
}
~Iteration() = default;
Iteration(const Iteration &) = delete;
Iteration &operator=(const Iteration &) = delete;
std::vector<std::shared_ptr<Round>> rounds_; std::vector<std::shared_ptr<Round>> rounds_;
// Server's current iteration number. // Server's current iteration number.
size_t iteration_num_; size_t iteration_num_;
// Last iteration is successfully finished.
bool is_last_iteration_valid_;
}; };
} // namespace server } // namespace server
} // namespace ps } // namespace ps

View File

@ -29,7 +29,7 @@ void IterationTimer::Start(const std::chrono::milliseconds &duration) {
monitor_thread_ = std::thread([&]() { monitor_thread_ = std::thread([&]() {
while (running_.load()) { while (running_.load()) {
if (CURRENT_TIME_MILLI > end_time_) { if (CURRENT_TIME_MILLI > end_time_) {
timeout_callback_(); timeout_callback_(false);
running_ = false; running_ = false;
} }
// The time tick is 1 millisecond. // The time tick is 1 millisecond.

View File

@ -47,6 +47,7 @@ class ApplyMomentumKernel : public ApplyMomentumCPUKernel, public OptimizerKerne
} }
void GenerateReuseKernelNodeInfo() override { void GenerateReuseKernelNodeInfo() override {
MS_LOG(INFO) << "FedAvg reuse 'weight', 'accumulation', 'learning rate' and 'momentum' of the kernel node.";
reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, 0)); reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, 0));
reuse_kernel_node_inputs_info_.insert(std::make_pair(kAccumulation, 1)); reuse_kernel_node_inputs_info_.insert(std::make_pair(kAccumulation, 1));
reuse_kernel_node_inputs_info_.insert(std::make_pair(kLearningRate, 2)); reuse_kernel_node_inputs_info_.insert(std::make_pair(kLearningRate, 2));

View File

@ -92,7 +92,6 @@ class FedAvgKernel : public AggregationKernel {
weight_addr[i] /= data_size_addr[0]; weight_addr[i] /= data_size_addr[0];
} }
done_ = true; done_ = true;
DistributedCountService::GetInstance().ResetCounter(name_);
return; return;
}; };
DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler, last_cnt_handler}); DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler, last_cnt_handler});
@ -125,6 +124,7 @@ class FedAvgKernel : public AggregationKernel {
participated_ = true; participated_ = true;
DistributedCountService::GetInstance().Count( DistributedCountService::GetInstance().Count(
name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_)); name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_));
GenerateReuseKernelNodeInfo();
return true; return true;
} }
@ -149,6 +149,7 @@ class FedAvgKernel : public AggregationKernel {
private: private:
void GenerateReuseKernelNodeInfo() override { void GenerateReuseKernelNodeInfo() override {
MS_LOG(INFO) << "FedAvg reuse 'weight' of the kernel node.";
// Only the trainable parameter is reused for federated average. // Only the trainable parameter is reused for federated average.
reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, cnode_weight_idx_)); reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, cnode_weight_idx_));
return; return;

View File

@ -19,6 +19,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "ps/server/iteration.h"
#include "ps/server/model_store.h" #include "ps/server/model_store.h"
namespace mindspore { namespace mindspore {
@ -67,27 +68,31 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons
const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model(); const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
size_t latest_iter_num = iter_to_model.rbegin()->first; size_t latest_iter_num = iter_to_model.rbegin()->first;
// If this iteration is not finished yet, return ResponseCode_SucNotReady so that clients could get model later.
if ((current_iter == get_model_iter && latest_iter_num != current_iter) || current_iter == get_model_iter - 1) { if ((current_iter == get_model_iter && latest_iter_num != current_iter) || current_iter == get_model_iter - 1) {
std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter); std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter);
BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps, BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(WARNING) << reason; MS_LOG(WARNING) << reason;
return; return;
} }
if (iter_to_model.count(get_model_iter) == 0) { if (iter_to_model.count(get_model_iter) == 0) {
std::string reason = "The iteration of GetModel request" + std::to_string(get_model_iter) + // If the model of get_model_iter is not stored, return the latest version of model and current iteration number.
" is invalid. Current iteration is " + std::to_string(current_iter); MS_LOG(WARNING) << "The iteration of GetModel request " << std::to_string(get_model_iter)
BuildGetModelRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter, feature_maps, << " is invalid. Current iteration is " << std::to_string(current_iter);
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); feature_maps = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num);
MS_LOG(ERROR) << reason; } else {
return; feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter);
} }
feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter); // If the iteration of this model is invalid, return ResponseCode_OutOfTime to the clients could startFLJob according
BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, // to next_req_time.
"Get model for iteration " + std::to_string(get_model_iter) + " success.", current_iter, auto response_code =
feature_maps, std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); Iteration::GetInstance().is_last_iteration_valid() ? schema::ResponseCode_SUCCEED : schema::ResponseCode_OutOfTime;
BuildGetModelRsp(fbb, response_code, "Get model for iteration " + std::to_string(get_model_iter), current_iter,
feature_maps,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
return; return;
} }

View File

@ -68,7 +68,7 @@ void RoundKernel::StopTimer() {
void RoundKernel::FinishIteration() { void RoundKernel::FinishIteration() {
if (finish_iteration_cb_) { if (finish_iteration_cb_) {
finish_iteration_cb_(); finish_iteration_cb_(true);
} }
return; return;
} }

View File

@ -61,15 +61,12 @@ class RoundKernel : virtual public CPUKernel {
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) = 0; const std::vector<AddressPtr> &outputs) = 0;
// The callbacks when first message and last message for this round kernel is received.
// These methods is called by class DistributedCountService and triggered by leader server(Rank 0).
// virtual void OnFirstCountEvent(std::shared_ptr<core::MessageHandler> message);
// virtual void OnLastCnt(std::shared_ptr<core::MessageHandler> message);
// Some rounds could be stateful in a iteration. Reset method resets the status of this round. // Some rounds could be stateful in a iteration. Reset method resets the status of this round.
virtual bool Reset() = 0; virtual bool Reset() = 0;
// The counter event handlers for DistributedCountService. // The counter event handlers for DistributedCountService.
// The callbacks when first message and last message for this round kernel is received.
// These methods is called by class DistributedCountService and triggered by counting server.
virtual void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message); virtual void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message);
virtual void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message); virtual void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message);

View File

@ -25,9 +25,12 @@ namespace ps {
namespace server { namespace server {
namespace kernel { namespace kernel {
void StartFLJobKernel::InitKernel(size_t) { void StartFLJobKernel::InitKernel(size_t) {
// The time window of one iteration should be started at the first message of startFLJob round.
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
} }
iter_next_req_timestamp_ = CURRENT_TIME_MILLI.count() + iteration_time_window_;
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
executor_ = &Executor::GetInstance(); executor_ = &Executor::GetInstance();
MS_EXCEPTION_IF_NULL(executor_); MS_EXCEPTION_IF_NULL(executor_);
@ -85,11 +88,17 @@ bool StartFLJobKernel::Reset() {
return true; return true;
} }
void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &) {
iter_next_req_timestamp_ = CURRENT_TIME_MILLI.count() + iteration_time_window_;
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
}
bool StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb) { bool StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb) {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
std::string reason = "Current amount for startFLJob has reached the threshold. Please startFLJob later."; std::string reason = "Current amount for startFLJob has reached the threshold. Please startFLJob later.";
BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false, BuildStartFLJobRsp(
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason; MS_LOG(ERROR) << reason;
return true; return true;
} }
@ -117,8 +126,9 @@ bool StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
ret = false; ret = false;
} }
if (!ret) { if (!ret) {
BuildStartFLJobRsp(fbb, schema::ResponseCode_NotSelected, reason, false, BuildStartFLJobRsp(
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); fbb, schema::ResponseCode_NotSelected, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason; MS_LOG(ERROR) << reason;
} }
return ret; return ret;
@ -128,8 +138,9 @@ bool StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestFLJob *start_fl_job_req) { const schema::RequestFLJob *start_fl_job_req) {
if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str())) { if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str())) {
std::string reason = "startFLJob counting failed."; std::string reason = "startFLJob counting failed.";
BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false, BuildStartFLJobRsp(
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason; MS_LOG(ERROR) << reason;
return false; return false;
} }
@ -139,11 +150,18 @@ bool StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) { void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
PBMetadata metadata; PBMetadata metadata;
*metadata.mutable_device_meta() = device_meta; *metadata.mutable_device_meta() = device_meta;
DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata); if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata)) {
std::string reason = "Updating device metadata failed.";
BuildStartFLJobRsp(fbb, schema::ResponseCode_SystemError, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)),
{});
return;
}
std::map<std::string, AddressPtr> feature_maps = executor_->GetModel(); std::map<std::string, AddressPtr> feature_maps = executor_->GetModel();
BuildStartFLJobRsp(fbb, schema::ResponseCode_SUCCEED, "success", true, BuildStartFLJobRsp(fbb, schema::ResponseCode_SUCCEED, "success", true,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_), feature_maps); std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)),
feature_maps);
return; return;
} }
@ -153,13 +171,16 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
std::map<std::string, AddressPtr> feature_maps) { std::map<std::string, AddressPtr> feature_maps) {
auto fbs_reason = fbb->CreateString(reason); auto fbs_reason = fbb->CreateString(reason);
auto fbs_next_req_time = fbb->CreateString(next_req_time); auto fbs_next_req_time = fbb->CreateString(next_req_time);
auto fbs_server_mode = fbb->CreateString(PSContext::instance()->server_mode());
auto fbs_fl_name = fbb->CreateString(PSContext::instance()->fl_name()); auto fbs_fl_name = fbb->CreateString(PSContext::instance()->fl_name());
schema::FLPlanBuilder fl_plan_builder(*(fbb.get())); schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
fl_plan_builder.add_fl_name(fbs_fl_name); fl_plan_builder.add_fl_name(fbs_fl_name);
fl_plan_builder.add_server_mode(fbs_server_mode);
fl_plan_builder.add_iterations(PSContext::instance()->fl_iteration_num()); fl_plan_builder.add_iterations(PSContext::instance()->fl_iteration_num());
fl_plan_builder.add_epochs(PSContext::instance()->client_epoch_num()); fl_plan_builder.add_epochs(PSContext::instance()->client_epoch_num());
fl_plan_builder.add_mini_batch(PSContext::instance()->client_batch_size()); fl_plan_builder.add_mini_batch(PSContext::instance()->client_batch_size());
fl_plan_builder.add_lr(PSContext::instance()->client_learning_rate());
auto fbs_fl_plan = fl_plan_builder.Finish(); auto fbs_fl_plan = fl_plan_builder.Finish();
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps; std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;

View File

@ -32,7 +32,7 @@ namespace server {
namespace kernel { namespace kernel {
class StartFLJobKernel : public RoundKernel { class StartFLJobKernel : public RoundKernel {
public: public:
StartFLJobKernel() = default; StartFLJobKernel() : executor_(nullptr), iteration_time_window_(0), iter_next_req_timestamp_(0) {}
~StartFLJobKernel() override = default; ~StartFLJobKernel() override = default;
void InitKernel(size_t threshold_count) override; void InitKernel(size_t threshold_count) override;
@ -40,6 +40,8 @@ class StartFLJobKernel : public RoundKernel {
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
bool Reset() override; bool Reset() override;
void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
private: private:
// Returns whether the startFLJob count of this iteration has reached the threshold. // Returns whether the startFLJob count of this iteration has reached the threshold.
bool ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb); bool ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb);
@ -66,6 +68,9 @@ class StartFLJobKernel : public RoundKernel {
// The time window of one iteration. // The time window of one iteration.
size_t iteration_time_window_; size_t iteration_time_window_;
// Timestamp of next request time for this iteration.
uint64_t iter_next_req_timestamp_;
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server

View File

@ -39,6 +39,7 @@ void UpdateModelKernel::InitKernel(size_t threshold_count) {
PBMetadata client_list; PBMetadata client_list;
DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxUpdateModelClientList, client_list); DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxUpdateModelClientList, client_list);
LocalMetaStore::GetInstance().put_value(kCtxUpdateModelThld, threshold_count); LocalMetaStore::GetInstance().put_value(kCtxUpdateModelThld, threshold_count);
LocalMetaStore::GetInstance().put_value(kCtxFedAvgTotalDataSize, kInitialDataSizeSum);
} }
bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
@ -103,8 +104,9 @@ void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHand
bool UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb) { bool UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb) {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
std::string reason = "Current amount for updateModel is enough."; std::string reason = "Current amount for updateModel is enough.";
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, BuildUpdateModelRsp(
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason; MS_LOG(ERROR) << reason;
return false; return false;
} }
@ -117,8 +119,9 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) { if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) + std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) +
", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num()); ", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num());
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, BuildUpdateModelRsp(
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason; MS_LOG(ERROR) << reason;
return false; return false;
} }
@ -128,14 +131,24 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
std::string update_model_fl_id = update_model_req->fl_id()->str(); std::string update_model_fl_id = update_model_req->fl_id()->str();
if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) { if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
std::string reason = "devices_meta for " + update_model_fl_id + " is not set."; std::string reason = "devices_meta for " + update_model_fl_id + " is not set.";
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, BuildUpdateModelRsp(
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason; MS_LOG(ERROR) << reason;
return false; return false;
} }
size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size(); size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size();
auto feature_map = ParseFeatureMap(update_model_req); auto feature_map = ParseFeatureMap(update_model_req);
if (feature_map.empty()) {
std::string reason = "Feature map is empty.";
BuildUpdateModelRsp(
fbb, schema::ResponseCode_RequestError, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return false;
}
for (auto weight : feature_map) { for (auto weight : feature_map) {
weight.second[kNewDataSize].addr = &data_size; weight.second[kNewDataSize].addr = &data_size;
weight.second[kNewDataSize].size = sizeof(size_t); weight.second[kNewDataSize].size = sizeof(size_t);
@ -146,10 +159,17 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
fl_id.set_fl_id(update_model_fl_id); fl_id.set_fl_id(update_model_fl_id);
PBMetadata comm_value; PBMetadata comm_value;
*comm_value.mutable_fl_id() = fl_id; *comm_value.mutable_fl_id() = fl_id;
DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value); if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value)) {
std::string reason = "Updating metadata of UpdateModelClientList failed.";
BuildUpdateModelRsp(
fbb, schema::ResponseCode_SystemError, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return false;
}
BuildUpdateModelRsp(fbb, schema::ResponseCode_SucNotReady, "success not ready", BuildUpdateModelRsp(fbb, schema::ResponseCode_SUCCEED, "success not ready",
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
return true; return true;
} }
@ -174,8 +194,9 @@ bool UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fb
const schema::RequestUpdateModel *update_model_req) { const schema::RequestUpdateModel *update_model_req) {
if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str())) { if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str())) {
std::string reason = "UpdateModel counting failed."; std::string reason = "UpdateModel counting failed.";
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, BuildUpdateModelRsp(
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason; MS_LOG(ERROR) << reason;
return false; return false;
} }

View File

@ -30,6 +30,9 @@ namespace mindspore {
namespace ps { namespace ps {
namespace server { namespace server {
namespace kernel { namespace kernel {
// The initial data size sum of federated learning is 0, which will be accumulated in updateModel round.
constexpr uint64_t kInitialDataSizeSum = 0;
class UpdateModelKernel : public RoundKernel { class UpdateModelKernel : public RoundKernel {
public: public:
UpdateModelKernel() = default; UpdateModelKernel() = default;

View File

@ -30,7 +30,8 @@ void ModelStore::Initialize(uint32_t max_count) {
} }
max_model_count_ = max_count; max_model_count_ = max_count;
iteration_to_model_[kInitIterationNum] = AssignNewModelMemory(); initial_model_ = AssignNewModelMemory();
iteration_to_model_[kInitIterationNum] = initial_model_;
model_size_ = ComputeModelSize(); model_size_ = ComputeModelSize();
} }
@ -52,7 +53,6 @@ bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::strin
MS_LOG(ERROR) << "Memory for the new model is nullptr."; MS_LOG(ERROR) << "Memory for the new model is nullptr.";
return false; return false;
} }
iteration_to_model_[iteration] = memory_register; iteration_to_model_[iteration] = memory_register;
} else { } else {
// If iteration_to_model_ size is already max_model_count_, we need to replace earliest model with the newest model. // If iteration_to_model_ size is already max_model_count_, we need to replace earliest model with the newest model.
@ -97,6 +97,12 @@ std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration
return model; return model;
} }
void ModelStore::Reset() {
initial_model_ = iteration_to_model_.rbegin()->second;
iteration_to_model_.clear();
iteration_to_model_[kInitIterationNum] = initial_model_;
}
const std::map<size_t, std::shared_ptr<MemoryRegister>> &ModelStore::iteration_to_model() const { const std::map<size_t, std::shared_ptr<MemoryRegister>> &ModelStore::iteration_to_model() const {
return iteration_to_model_; return iteration_to_model_;
} }
@ -121,6 +127,14 @@ std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() {
return nullptr; return nullptr;
} }
auto src_data_size = weight_size;
auto dst_data_size = weight_size;
int ret = memcpy_s(weight_data.get(), dst_data_size, weight.second->addr, src_data_size);
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return nullptr;
}
memory_register->RegisterArray(weight_name, &weight_data, weight_size); memory_register->RegisterArray(weight_name, &weight_data, weight_size);
} }
return memory_register; return memory_register;

View File

@ -49,6 +49,9 @@ class ModelStore {
// Get model of the given iteration. // Get model of the given iteration.
std::map<std::string, AddressPtr> GetModelByIterNum(size_t iteration); std::map<std::string, AddressPtr> GetModelByIterNum(size_t iteration);
// Reset the stored models. Called when federated learning job finishes.
void Reset();
// Returns all models stored in ModelStore. // Returns all models stored in ModelStore.
const std::map<size_t, std::shared_ptr<MemoryRegister>> &iteration_to_model() const; const std::map<size_t, std::shared_ptr<MemoryRegister>> &iteration_to_model() const;
@ -70,6 +73,11 @@ class ModelStore {
size_t max_model_count_; size_t max_model_count_;
size_t model_size_; size_t model_size_;
// Initial model which is the model of iteration 0.
std::shared_ptr<MemoryRegister> initial_model_;
// The number of all models stpred is max_model_count_.
std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_; std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_;
}; };
} // namespace server } // namespace server

View File

@ -38,9 +38,9 @@ void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicat
name_, [&](std::shared_ptr<core::MessageHandler> message) { LaunchRoundKernel(message); }); name_, [&](std::shared_ptr<core::MessageHandler> message) { LaunchRoundKernel(message); });
// Callback when the iteration is finished. // Callback when the iteration is finished.
finish_iteration_cb_ = [this, finish_iteration_cb](void) -> void { finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid) -> void {
MS_LOG(INFO) << "Round " << name_ << " finished! Proceed to next iteration."; MS_LOG(INFO) << "Round " << name_ << " finished! This iteration is valid. Proceed to next iteration.";
finish_iteration_cb(); finish_iteration_cb(is_iteration_valid);
}; };
// Callback for finalizing the server. This can only be called once. // Callback for finalizing the server. This can only be called once.
@ -50,9 +50,9 @@ void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicat
iter_timer_ = std::make_shared<IterationTimer>(); iter_timer_ = std::make_shared<IterationTimer>();
// 1.Set the timeout callback for the timer. // 1.Set the timeout callback for the timer.
iter_timer_->SetTimeOutCallBack([this, timeout_cb](void) -> void { iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool is_iteration_valid) -> void {
MS_LOG(INFO) << "Round " << name_ << " timeout! Proceed to next iteration."; MS_LOG(INFO) << "Round " << name_ << " timeout! This iteration is invalid. Proceed to next iteration.";
timeout_cb(); timeout_cb(is_iteration_valid);
}); });
// 2.Stopping timer callback which will be set to the round kernel. // 2.Stopping timer callback which will be set to the round kernel.
@ -112,14 +112,19 @@ const std::string &Round::name() const { return name_; }
size_t Round::threshold_count() const { return threshold_count_; } size_t Round::threshold_count() const { return threshold_count_; }
bool Round::check_timeout() const { return check_timeout_; }
size_t Round::time_window() const { return time_window_; } size_t Round::time_window() const { return time_window_; }
void Round::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &) { void Round::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
MS_LOG(INFO) << "Round " << name_ << " first count event is triggered."; MS_LOG(INFO) << "Round " << name_ << " first count event is triggered.";
// The timer starts only after the first count event is triggered by DistributedCountService. // The timer starts only after the first count event is triggered by DistributedCountService.
if (check_timeout_) { if (check_timeout_) {
iter_timer_->Start(std::chrono::milliseconds(time_window_)); iter_timer_->Start(std::chrono::milliseconds(time_window_));
} }
// Some kernels override the OnFirstCountEvent method.
kernel_->OnFirstCountEvent(message);
return; return;
} }

View File

@ -52,6 +52,7 @@ class Round {
const std::string &name() const; const std::string &name() const;
size_t threshold_count() const; size_t threshold_count() const;
bool check_timeout() const;
size_t time_window() const; size_t time_window() const;
private: private:

View File

@ -174,21 +174,22 @@ bool Server::InitCommunicatorWithWorker() {
} }
void Server::InitIteration() { void Server::InitIteration() {
iteration_ = std::make_shared<Iteration>(); iteration_ = &Iteration::GetInstance();
MS_EXCEPTION_IF_NULL(iteration_); MS_EXCEPTION_IF_NULL(iteration_);
// 1.Add rounds to the iteration according to the server mode. // 1.Add rounds to the iteration according to the server mode.
for (const RoundConfig &config : rounds_config_) { for (const RoundConfig &config : rounds_config_) {
std::shared_ptr<Round> round = std::make_shared<Round>(config.name, config.check_timeout, config.time_window, std::shared_ptr<Round> round = std::make_shared<Round>(config.name, config.check_timeout, config.time_window,
config.check_count, config.threshold_count); config.check_count, config.threshold_count);
MS_LOG(INFO) << "Add round " << config.name << ", check_count: " << config.check_count MS_LOG(INFO) << "Add round " << config.name << ", check_timeout: " << config.check_timeout
<< ", threshold:" << config.threshold_count; << ", time window: " << config.time_window << ", check_count: " << config.check_count
<< ", threshold: " << config.threshold_count;
iteration_->AddRound(round); iteration_->AddRound(round);
} }
// 2.Initialize all the rounds. // 2.Initialize all the rounds.
TimeOutCb time_out_cb = std::bind(&Iteration::ProceedToNextIter, iteration_); TimeOutCb time_out_cb = std::bind(&Iteration::ProceedToNextIter, iteration_, std::placeholders::_1);
FinishIterCb finish_iter_cb = std::bind(&Iteration::ProceedToNextIter, iteration_); FinishIterCb finish_iter_cb = std::bind(&Iteration::ProceedToNextIter, iteration_, std::placeholders::_1);
iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb); iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb);
return; return;
} }

View File

@ -117,7 +117,7 @@ class Server {
std::vector<std::shared_ptr<core::CommunicatorBase>> communicators_with_worker_; std::vector<std::shared_ptr<core::CommunicatorBase>> communicators_with_worker_;
// Iteration consists of multiple kinds of rounds. // Iteration consists of multiple kinds of rounds.
std::shared_ptr<Iteration> iteration_; Iteration *iteration_;
// Variables set by ps context. // Variables set by ps context.
std::string scheduler_ip_; std::string scheduler_ip_;

View File

@ -787,3 +787,59 @@ def reset_ps_context():
- enable_ps: False. - enable_ps: False.
""" """
_reset_ps_context() _reset_ps_context()
def set_fl_context(**kwargs):
"""
Set federated learning training mode context.
Args:
enable_fl (bool): Whether to enable federated learning training mode.
Default: False.
server_mode (string): Describe the server mode, which must one of 'FEDERATED_LEARNING' and 'HYBRID_TRAINING'.
Default: 'FEDERATED_LEARNING'.
ms_role (string): The process's role in the federated learning mode,
which must be one of 'MS_SERVER', 'MS_WORKER' and 'MS_SCHED'.
Default: 'MS_NOT_PS'.
worker_num (int): The number of workers. Default: 0.
server_num (int): The number of federated learning servers. Default: 0.
scheduler_ip (string): The scheduler IP. Default: ''.
scheduler_port (int): The scheduler port. Default: 0.
fl_server_port (int): The http port of the federated learning server.
Normally for each server this should be set to the same value. Default: 0.
enable_fl_client (bool): Whether this process is federated learning client. Default: False.
start_fl_job_threshold (int): The threshold count of startFLJob. Default: 0.
start_fl_job_time_window (int): The time window duration for startFLJob in millisecond. Default: 3000.
update_model_ratio (float): The ratio for computing the threshold count of updateModel
which will be multiplied by start_fl_job_threshold. Default: 1.0.
update_model_time_window (int): The time window duration for updateModel in millisecond. Default: 3000.
fl_name (string): The federated learning job name. Default: ''.
fl_iteration_num (int): Iteration number of federeated learning,
which is the number of interactions between client and server. Default: 20.
client_epoch_num (int): Client training epoch number. Default: 25.
client_batch_size (int): Client training data batch size. Default: 32.
client_learning_rate (float): Client training learning rate. Default: 0.001.
secure_aggregation (bool): Whether to use secure aggregation algorithm. Default: False.
Raises:
ValueError: If input key is not the attribute in federated learning mode context.
Examples:
>>> context.set_fl_context(enable_fl=True, server_mode='FEDERATED_LEARNING')
"""
_set_ps_context(**kwargs)
def get_fl_context(attr_key):
"""
Get federated learning mode context attribute value according to the key.
Args:
attr_key (str): The key of the attribute.
Returns:
Returns attribute value according to the key.
Raises:
ValueError: If input key is not attribute in federated learning mode context.
"""
return _get_ps_context(attr_key)

View File

@ -19,6 +19,8 @@
#include <unistd.h> #include <unistd.h>
#include <sys/time.h> #include <sys/time.h>
#include <map> #include <map>
#include <iomanip>
#include <thread>
// namespace to support utils module definition // namespace to support utils module definition
namespace mindspore { namespace mindspore {
@ -117,8 +119,8 @@ void LogWriter::OutputLog(const std::ostringstream &msg) const {
#define google mindspore_private #define google mindspore_private
auto submodule_name = GetSubModuleName(submodule_); auto submodule_name = GetSubModuleName(submodule_);
google::LogMessage("", 0, GetGlogLevel(log_level_)).stream() google::LogMessage("", 0, GetGlogLevel(log_level_)).stream()
<< "[" << GetLogLevel(log_level_) << "] " << submodule_name << "(" << getpid() << "," << GetProcName() << "[" << GetLogLevel(log_level_) << "] " << submodule_name << "(" << getpid() << "," << std::hex
<< "):" << GetTimeString() << " " << std::this_thread::get_id() << std::dec << "," << GetProcName() << "):" << GetTimeString() << " "
<< "[" << location_.file_ << ":" << location_.line_ << "] " << location_.func_ << "] " << msg.str() << std::endl; << "[" << location_.file_ << ":" << location_.line_ << "] " << location_.func_ << "] " << msg.str() << std::endl;
#undef google #undef google
#else #else

View File

@ -36,6 +36,7 @@ _set_ps_context_func_map = {
"server_mode": ps_context().set_server_mode, "server_mode": ps_context().set_server_mode,
"ms_role": ps_context().set_ms_role, "ms_role": ps_context().set_ms_role,
"enable_ps": ps_context().set_ps_enable, "enable_ps": ps_context().set_ps_enable,
"enable_fl": ps_context().set_ps_enable,
"worker_num": ps_context().set_worker_num, "worker_num": ps_context().set_worker_num,
"server_num": ps_context().set_server_num, "server_num": ps_context().set_server_num,
"scheduler_ip": ps_context().set_scheduler_ip, "scheduler_ip": ps_context().set_scheduler_ip,
@ -43,10 +44,14 @@ _set_ps_context_func_map = {
"fl_server_port": ps_context().set_fl_server_port, "fl_server_port": ps_context().set_fl_server_port,
"enable_fl_client": ps_context().set_fl_client_enable, "enable_fl_client": ps_context().set_fl_client_enable,
"start_fl_job_threshold": ps_context().set_start_fl_job_threshold, "start_fl_job_threshold": ps_context().set_start_fl_job_threshold,
"start_fl_job_time_window": ps_context().set_start_fl_job_time_window,
"update_model_ratio": ps_context().set_update_model_ratio,
"update_model_time_window": ps_context().set_update_model_time_window,
"fl_name": ps_context().set_fl_name, "fl_name": ps_context().set_fl_name,
"fl_iteration_num": ps_context().set_fl_iteration_num, "fl_iteration_num": ps_context().set_fl_iteration_num,
"client_epoch_num": ps_context().set_client_epoch_num, "client_epoch_num": ps_context().set_client_epoch_num,
"client_batch_size": ps_context().set_client_batch_size, "client_batch_size": ps_context().set_client_batch_size,
"client_learning_rate": ps_context().set_client_learning_rate,
"secure_aggregation": ps_context().set_secure_aggregation, "secure_aggregation": ps_context().set_secure_aggregation,
"enable_ps_ssl": ps_context().set_enable_ssl "enable_ps_ssl": ps_context().set_enable_ssl
} }

View File

@ -69,6 +69,7 @@ table ResponseFLJob {
} }
table FLPlan { table FLPlan {
server_mode:string;
fl_name:string; fl_name:string;
iterations:int; iterations:int;
epochs:int; epochs:int;

View File

@ -26,10 +26,14 @@ parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113) parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--fl_server_port", type=int, default=6666) parser.add_argument("--fl_server_port", type=int, default=6666)
parser.add_argument("--start_fl_job_threshold", type=int, default=1) parser.add_argument("--start_fl_job_threshold", type=int, default=1)
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
parser.add_argument("--update_model_ratio", type=float, default=1.0)
parser.add_argument("--update_model_time_window", type=int, default=3000)
parser.add_argument("--fl_name", type=str, default="Lenet") parser.add_argument("--fl_name", type=str, default="Lenet")
parser.add_argument("--fl_iteration_num", type=int, default=25) parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20) parser.add_argument("--client_epoch_num", type=int, default=20)
parser.add_argument("--client_batch_size", type=int, default=32) parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False) parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False)
parser.add_argument("--local_server_num", type=int, default=-1) parser.add_argument("--local_server_num", type=int, default=-1)
@ -43,10 +47,14 @@ if __name__ == "__main__":
scheduler_port = args.scheduler_port scheduler_port = args.scheduler_port
fl_server_port = args.fl_server_port fl_server_port = args.fl_server_port
start_fl_job_threshold = args.start_fl_job_threshold start_fl_job_threshold = args.start_fl_job_threshold
start_fl_job_time_window = args.start_fl_job_time_window
update_model_ratio = args.update_model_ratio
update_model_time_window = args.update_model_time_window
fl_name = args.fl_name fl_name = args.fl_name
fl_iteration_num = args.fl_iteration_num fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num client_epoch_num = args.client_epoch_num
client_batch_size = args.client_batch_size client_batch_size = args.client_batch_size
client_learning_rate = args.client_learning_rate
secure_aggregation = args.secure_aggregation secure_aggregation = args.secure_aggregation
local_server_num = args.local_server_num local_server_num = args.local_server_num
@ -70,10 +78,14 @@ if __name__ == "__main__":
cmd_server += " --scheduler_port=" + str(scheduler_port) cmd_server += " --scheduler_port=" + str(scheduler_port)
cmd_server += " --fl_server_port=" + str(fl_server_port + i) cmd_server += " --fl_server_port=" + str(fl_server_port + i)
cmd_server += " --start_fl_job_threshold=" + str(start_fl_job_threshold) cmd_server += " --start_fl_job_threshold=" + str(start_fl_job_threshold)
cmd_server += " --start_fl_job_time_window=" + str(start_fl_job_time_window)
cmd_server += " --update_model_ratio=" + str(update_model_ratio)
cmd_server += " --update_model_time_window=" + str(update_model_time_window)
cmd_server += " --fl_name=" + fl_name cmd_server += " --fl_name=" + fl_name
cmd_server += " --fl_iteration_num=" + str(fl_iteration_num) cmd_server += " --fl_iteration_num=" + str(fl_iteration_num)
cmd_server += " --client_epoch_num=" + str(client_epoch_num) cmd_server += " --client_epoch_num=" + str(client_epoch_num)
cmd_server += " --client_batch_size=" + str(client_batch_size) cmd_server += " --client_batch_size=" + str(client_batch_size)
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
cmd_server += " --secure_aggregation=" + str(secure_aggregation) cmd_server += " --secure_aggregation=" + str(secure_aggregation)
cmd_server += " > server.log 2>&1 &" cmd_server += " > server.log 2>&1 &"

View File

@ -15,6 +15,7 @@
import argparse import argparse
import time import time
import datetime
import random import random
import sys import sys
import requests import requests
@ -129,7 +130,15 @@ def build_get_model(iteration):
buf = builder_get_model.Output() buf = builder_get_model.Output()
return buf return buf
weight_name_to_idx = { def datetime_to_timestamp(datetime_obj):
"""将本地(local) datetime 格式的时间 (含毫秒) 转为毫秒时间戳
:param datetime_obj: {datetime}2016-02-25 20:21:04.242000
:return: 13 位的毫秒时间戳 1456402864242
"""
local_timestamp = time.mktime(datetime_obj.timetuple()) * 1000.0 + datetime_obj.microsecond // 1000.0
return local_timestamp
weight_to_idx = {
"conv1.weight": 0, "conv1.weight": 0,
"conv2.weight": 1, "conv2.weight": 1,
"fc1.weight": 2, "fc1.weight": 2,
@ -149,11 +158,12 @@ while True:
print("start url is ", url1) print("start url is ", url1)
x = requests.post(url1, data=build_start_fl_job(current_iteration)) x = requests.post(url1, data=build_start_fl_job(current_iteration))
rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0) rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
print("start fl job iteration:", current_iteration, ", id:", args.pid)
while rsp_fl_job.Retcode() != ResponseCode.ResponseCode.SUCCEED: while rsp_fl_job.Retcode() != ResponseCode.ResponseCode.SUCCEED:
x = requests.post(url1, data=build_start_fl_job(current_iteration)) x = requests.post(url1, data=build_start_fl_job(current_iteration))
rsp_fl_job = rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0) rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
print("epoch is", rsp_fl_job.FlPlanConfig().Epochs()) print("epoch is", rsp_fl_job.FlPlanConfig().Epochs())
print("iteration is", rsp_fl_job.Iteration())
current_iteration = rsp_fl_job.Iteration()
sys.stdout.flush() sys.stdout.flush()
url2 = "http://" + http_ip + ":" + str(generate_port()) + '/updateModel' url2 = "http://" + http_ip + ":" + str(generate_port()) + '/updateModel'
@ -170,22 +180,40 @@ while True:
print("rsp get model iteration:", current_iteration, ", id:", args.pid, rsp_get_model.Retcode()) print("rsp get model iteration:", current_iteration, ", id:", args.pid, rsp_get_model.Retcode())
sys.stdout.flush() sys.stdout.flush()
repeat_time = 0 next_req_timestamp = 0
while rsp_get_model.Retcode() == ResponseCode.ResponseCode.SucNotReady: if rsp_get_model.Retcode() == ResponseCode.ResponseCode.OutOfTime:
time.sleep(0.1) next_req_timestamp = int(rsp_get_model.Timestamp().decode('utf-8'))
x = session.post(url3, data=build_get_model(current_iteration)) print("Last iteration is invalid, next request timestamp:", next_req_timestamp)
rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
repeat_time += 1
if repeat_time > 1000:
print("GetModel try timeout ", args.pid)
sys.exit(0)
for i in range(0, 1):
print(rsp_get_model.FeatureMap(i).WeightFullname())
origin = update_model_np_data[weight_name_to_idx[rsp_get_model.FeatureMap(i).WeightFullname().decode('utf-8')]]
after = rsp_get_model.FeatureMap(i).DataAsNumpy() * 32
print("Before update model", args.pid, origin[0:10])
print("After get model", args.pid, after[0:10])
sys.stdout.flush() sys.stdout.flush()
assert np.allclose(origin, after, rtol=1e-05, atol=1e-05) elif rsp_get_model.Retcode() == ResponseCode.ResponseCode.SucNotReady:
current_iteration += 1 repeat_time = 0
while rsp_get_model.Retcode() == ResponseCode.ResponseCode.SucNotReady:
time.sleep(0.2)
x = session.post(url3, data=build_get_model(current_iteration))
rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
if rsp_get_model.Retcode() == ResponseCode.ResponseCode.OutOfTime:
next_req_timestamp = int(rsp_get_model.Timestamp().decode('utf-8'))
print("Last iteration is invalid, next request timestamp:", next_req_timestamp)
sys.stdout.flush()
break
repeat_time += 1
if repeat_time > 1000:
print("GetModel try timeout ", args.pid)
sys.exit(0)
else:
pass
if next_req_timestamp == 0:
for i in range(0, 1):
print(rsp_get_model.FeatureMap(i).WeightFullname())
origin = update_model_np_data[weight_to_idx[rsp_get_model.FeatureMap(i).WeightFullname().decode('utf-8')]]
after = rsp_get_model.FeatureMap(i).DataAsNumpy() * 32
print("Before update model", args.pid, origin[0:10])
print("After get model", args.pid, after[0:10])
sys.stdout.flush()
assert np.allclose(origin, after, rtol=1e-05, atol=1e-05)
else:
# Sleep to the next request timestamp
current_ts = datetime_to_timestamp(datetime.datetime.now())
duration = next_req_timestamp - current_ts
time.sleep(duration / 1000)

View File

@ -34,10 +34,14 @@ parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113) parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--fl_server_port", type=int, default=6666) parser.add_argument("--fl_server_port", type=int, default=6666)
parser.add_argument("--start_fl_job_threshold", type=int, default=1) parser.add_argument("--start_fl_job_threshold", type=int, default=1)
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
parser.add_argument("--update_model_ratio", type=float, default=1.0)
parser.add_argument("--update_model_time_window", type=int, default=3000)
parser.add_argument("--fl_name", type=str, default="Lenet") parser.add_argument("--fl_name", type=str, default="Lenet")
parser.add_argument("--fl_iteration_num", type=int, default=25) parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20) parser.add_argument("--client_epoch_num", type=int, default=20)
parser.add_argument("--client_batch_size", type=int, default=32) parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False) parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False)
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
@ -50,14 +54,18 @@ scheduler_ip = args.scheduler_ip
scheduler_port = args.scheduler_port scheduler_port = args.scheduler_port
fl_server_port = args.fl_server_port fl_server_port = args.fl_server_port
start_fl_job_threshold = args.start_fl_job_threshold start_fl_job_threshold = args.start_fl_job_threshold
start_fl_job_time_window = args.start_fl_job_time_window
update_model_ratio = args.update_model_ratio
update_model_time_window = args.update_model_time_window
fl_name = args.fl_name fl_name = args.fl_name
fl_iteration_num = args.fl_iteration_num fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num client_epoch_num = args.client_epoch_num
client_batch_size = args.client_batch_size client_batch_size = args.client_batch_size
client_learning_rate = args.client_learning_rate
secure_aggregation = args.secure_aggregation secure_aggregation = args.secure_aggregation
ctx = { ctx = {
"enable_ps": False, "enable_fl": True,
"server_mode": server_mode, "server_mode": server_mode,
"ms_role": ms_role, "ms_role": ms_role,
"worker_num": worker_num, "worker_num": worker_num,
@ -66,15 +74,19 @@ ctx = {
"scheduler_port": scheduler_port, "scheduler_port": scheduler_port,
"fl_server_port": fl_server_port, "fl_server_port": fl_server_port,
"start_fl_job_threshold": start_fl_job_threshold, "start_fl_job_threshold": start_fl_job_threshold,
"start_fl_job_time_window": start_fl_job_time_window,
"update_model_ratio": update_model_ratio,
"update_model_time_window": update_model_time_window,
"fl_name": fl_name, "fl_name": fl_name,
"fl_iteration_num": fl_iteration_num, "fl_iteration_num": fl_iteration_num,
"client_epoch_num": client_epoch_num, "client_epoch_num": client_epoch_num,
"client_batch_size": client_batch_size, "client_batch_size": client_batch_size,
"client_learning_rate": client_learning_rate,
"secure_aggregation": secure_aggregation "secure_aggregation": secure_aggregation
} }
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)
context.set_ps_context(**ctx) context.set_fl_context(**ctx)
if __name__ == "__main__": if __name__ == "__main__":
epoch = 5 epoch = 5