forked from mindspore-Ecosystem/mindspore
!16230 Optimize server reliability in multiple scenarios.
From: @zpac Reviewed-by: @cristoval,@limingqi107 Signed-off-by: @limingqi107
This commit is contained in:
commit
94ca479fbe
|
@ -629,14 +629,17 @@ bool StartServerAction(const ResourcePtr &res) {
|
||||||
uint64_t fl_server_port = ps::PSContext::instance()->fl_server_port();
|
uint64_t fl_server_port = ps::PSContext::instance()->fl_server_port();
|
||||||
|
|
||||||
// Update model threshold is a certain ratio of start_fl_job threshold.
|
// Update model threshold is a certain ratio of start_fl_job threshold.
|
||||||
// update_model_threshold_ = start_fl_job_threshold_ * percent_for_update_model_.
|
// update_model_threshold = start_fl_job_threshold * update_model_ratio.
|
||||||
size_t start_fl_job_threshold = ps::PSContext::instance()->start_fl_job_threshold();
|
size_t start_fl_job_threshold = ps::PSContext::instance()->start_fl_job_threshold();
|
||||||
float percent_for_update_model = 1;
|
float update_model_ratio = ps::PSContext::instance()->update_model_ratio();
|
||||||
size_t update_model_threshold = static_cast<size_t>(std::ceil(start_fl_job_threshold * percent_for_update_model));
|
size_t update_model_threshold = static_cast<size_t>(std::ceil(start_fl_job_threshold * update_model_ratio));
|
||||||
|
uint64_t start_fl_job_time_window = ps::PSContext::instance()->start_fl_job_time_window();
|
||||||
|
uint64_t update_model_time_window = ps::PSContext::instance()->update_model_time_window();
|
||||||
|
|
||||||
std::vector<ps::server::RoundConfig> rounds_config = {{"startFLJob", false, 3000, true, start_fl_job_threshold},
|
std::vector<ps::server::RoundConfig> rounds_config = {
|
||||||
{"updateModel", false, 3000, true, update_model_threshold},
|
{"startFLJob", true, start_fl_job_time_window, true, start_fl_job_threshold},
|
||||||
{"getModel", false, 3000}};
|
{"updateModel", true, update_model_time_window, true, update_model_threshold},
|
||||||
|
{"getModel"}};
|
||||||
|
|
||||||
size_t executor_threshold = 0;
|
size_t executor_threshold = 0;
|
||||||
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
|
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
|
||||||
|
|
|
@ -345,11 +345,20 @@ PYBIND11_MODULE(_c_expression, m) {
|
||||||
.def("set_scheduler_port", &PSContext::set_scheduler_port, "Set scheduler port.")
|
.def("set_scheduler_port", &PSContext::set_scheduler_port, "Set scheduler port.")
|
||||||
.def("set_fl_server_port", &PSContext::set_fl_server_port, "Set federated learning server port.")
|
.def("set_fl_server_port", &PSContext::set_fl_server_port, "Set federated learning server port.")
|
||||||
.def("set_fl_client_enable", &PSContext::set_fl_client_enable, "Set federated learning client.")
|
.def("set_fl_client_enable", &PSContext::set_fl_client_enable, "Set federated learning client.")
|
||||||
.def("set_start_fl_job_threshold", &PSContext::set_start_fl_job_threshold, "Set threshold count for start_fl_job.")
|
.def("set_start_fl_job_threshold", &PSContext::set_start_fl_job_threshold,
|
||||||
|
"Set threshold count for startFLJob round.")
|
||||||
|
.def("set_start_fl_job_time_window", &PSContext::set_start_fl_job_time_window,
|
||||||
|
"Set time window for startFLJob round.")
|
||||||
|
.def("set_update_model_ratio", &PSContext::set_update_model_ratio,
|
||||||
|
"Set threshold count ratio for updateModel round.")
|
||||||
|
.def("set_update_model_time_window", &PSContext::set_update_model_time_window,
|
||||||
|
"Set time window for updateModel round.")
|
||||||
.def("set_fl_name", &PSContext::set_fl_name, "Set federated learning name.")
|
.def("set_fl_name", &PSContext::set_fl_name, "Set federated learning name.")
|
||||||
.def("set_fl_iteration_num", &PSContext::set_fl_iteration_num, "Set federated learning iteration number.")
|
.def("set_fl_iteration_num", &PSContext::set_fl_iteration_num, "Set federated learning iteration number.")
|
||||||
.def("set_client_epoch_num", &PSContext::set_client_epoch_num, "Set federated learning client epoch number.")
|
.def("set_client_epoch_num", &PSContext::set_client_epoch_num, "Set federated learning client epoch number.")
|
||||||
.def("set_client_batch_size", &PSContext::set_client_batch_size, "Set federated learning client batch size.")
|
.def("set_client_batch_size", &PSContext::set_client_batch_size, "Set federated learning client batch size.")
|
||||||
|
.def("set_client_learning_rate", &PSContext::set_client_learning_rate,
|
||||||
|
"Set federated learning client learning rate.")
|
||||||
.def("set_secure_aggregation", &PSContext::set_secure_aggregation,
|
.def("set_secure_aggregation", &PSContext::set_secure_aggregation,
|
||||||
"Set federated learning client using secure aggregation.")
|
"Set federated learning client using secure aggregation.")
|
||||||
.def("set_enable_ssl", &PSContext::enable_ssl, "Set PS SSL mode enabled or disabled.");
|
.def("set_enable_ssl", &PSContext::enable_ssl, "Set PS SSL mode enabled or disabled.");
|
||||||
|
|
|
@ -196,7 +196,14 @@ void PSContext::set_ms_role(const std::string &role) {
|
||||||
role_ = role;
|
role_ = role;
|
||||||
}
|
}
|
||||||
|
|
||||||
void PSContext::set_worker_num(uint32_t worker_num) { worker_num_ = worker_num; }
|
void PSContext::set_worker_num(uint32_t worker_num) {
|
||||||
|
// Hybrid training mode only supports one worker for now.
|
||||||
|
if (server_mode_ == kServerModeHybrid && worker_num != 1) {
|
||||||
|
MS_LOG(EXCEPTION) << "The worker number should be set to 1 in hybrid training mode.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
worker_num_ = worker_num;
|
||||||
|
}
|
||||||
uint32_t PSContext::worker_num() const { return worker_num_; }
|
uint32_t PSContext::worker_num() const { return worker_num_; }
|
||||||
|
|
||||||
void PSContext::set_server_num(uint32_t server_num) {
|
void PSContext::set_server_num(uint32_t server_num) {
|
||||||
|
@ -235,7 +242,7 @@ void PSContext::GenerateResetterRound() {
|
||||||
}
|
}
|
||||||
|
|
||||||
binary_server_context = (is_parameter_server_mode << 0) | (is_federated_learning_mode << 1) |
|
binary_server_context = (is_parameter_server_mode << 0) | (is_federated_learning_mode << 1) |
|
||||||
(is_mixed_training_mode << 2) | (secure_aggregation_ << 3) | (worker_upload_weights_ << 4);
|
(is_mixed_training_mode << 2) | (secure_aggregation_ << 3);
|
||||||
if (kServerContextToResetRoundMap.count(binary_server_context) == 0) {
|
if (kServerContextToResetRoundMap.count(binary_server_context) == 0) {
|
||||||
resetter_round_ = ResetterRound::kNoNeedToReset;
|
resetter_round_ = ResetterRound::kNoNeedToReset;
|
||||||
} else {
|
} else {
|
||||||
|
@ -255,11 +262,27 @@ void PSContext::set_fl_client_enable(bool enabled) { fl_client_enable_ = enabled
|
||||||
|
|
||||||
bool PSContext::fl_client_enable() { return fl_client_enable_; }
|
bool PSContext::fl_client_enable() { return fl_client_enable_; }
|
||||||
|
|
||||||
void PSContext::set_start_fl_job_threshold(size_t start_fl_job_threshold) {
|
void PSContext::set_start_fl_job_threshold(uint64_t start_fl_job_threshold) {
|
||||||
start_fl_job_threshold_ = start_fl_job_threshold;
|
start_fl_job_threshold_ = start_fl_job_threshold;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t PSContext::start_fl_job_threshold() const { return start_fl_job_threshold_; }
|
uint64_t PSContext::start_fl_job_threshold() const { return start_fl_job_threshold_; }
|
||||||
|
|
||||||
|
void PSContext::set_start_fl_job_time_window(uint64_t start_fl_job_time_window) {
|
||||||
|
start_fl_job_time_window_ = start_fl_job_time_window;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t PSContext::start_fl_job_time_window() const { return start_fl_job_time_window_; }
|
||||||
|
|
||||||
|
void PSContext::set_update_model_ratio(float update_model_ratio) { update_model_ratio_ = update_model_ratio; }
|
||||||
|
|
||||||
|
float PSContext::update_model_ratio() const { return update_model_ratio_; }
|
||||||
|
|
||||||
|
void PSContext::set_update_model_time_window(uint64_t update_model_time_window) {
|
||||||
|
update_model_time_window_ = update_model_time_window;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t PSContext::update_model_time_window() const { return update_model_time_window_; }
|
||||||
|
|
||||||
void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
|
void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
|
||||||
|
|
||||||
|
@ -277,11 +300,9 @@ void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch
|
||||||
|
|
||||||
uint64_t PSContext::client_batch_size() const { return client_batch_size_; }
|
uint64_t PSContext::client_batch_size() const { return client_batch_size_; }
|
||||||
|
|
||||||
void PSContext::set_worker_upload_weights(uint64_t worker_upload_weights) {
|
void PSContext::set_client_learning_rate(float client_learning_rate) { client_learning_rate_ = client_learning_rate; }
|
||||||
worker_upload_weights_ = worker_upload_weights;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t PSContext::worker_upload_weights() const { return worker_upload_weights_; }
|
float PSContext::client_learning_rate() const { return client_learning_rate_; }
|
||||||
|
|
||||||
void PSContext::set_secure_aggregation(bool secure_aggregation) { secure_aggregation_ = secure_aggregation; }
|
void PSContext::set_secure_aggregation(bool secure_aggregation) { secure_aggregation_ = secure_aggregation; }
|
||||||
|
|
||||||
|
|
|
@ -41,15 +41,13 @@ constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
|
||||||
// 1: Server is in federated learning mode.
|
// 1: Server is in federated learning mode.
|
||||||
// 2: Server is in mixed training mode.
|
// 2: Server is in mixed training mode.
|
||||||
// 3: Server enables sucure aggregation.
|
// 3: Server enables sucure aggregation.
|
||||||
// 4: Server needs worker to overwrite weights.
|
// For example: 1010 stands for that the server is in federated learning mode and sucure aggregation is enabled.
|
||||||
// For example: 01010 stands for that the server is in federated learning mode and sucure aggregation is enabled.
|
enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kWorkerUploadWeights };
|
||||||
enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kWorkerOverwriteWeights };
|
const std::map<uint32_t, ResetterRound> kServerContextToResetRoundMap = {{0b0010, ResetterRound::kUpdateModel},
|
||||||
const std::map<uint32_t, ResetterRound> kServerContextToResetRoundMap = {
|
{0b1010, ResetterRound::kReconstructSeccrets},
|
||||||
{0b00010, ResetterRound::kUpdateModel},
|
{0b1100, ResetterRound::kWorkerUploadWeights},
|
||||||
{0b01010, ResetterRound::kReconstructSeccrets},
|
{0b0100, ResetterRound::kWorkerUploadWeights},
|
||||||
{0b11100, ResetterRound::kWorkerOverwriteWeights},
|
{0b0100, ResetterRound::kUpdateModel}};
|
||||||
{0b10100, ResetterRound::kWorkerOverwriteWeights},
|
|
||||||
{0b00100, ResetterRound::kUpdateModel}};
|
|
||||||
|
|
||||||
class PSContext {
|
class PSContext {
|
||||||
public:
|
public:
|
||||||
|
@ -115,8 +113,17 @@ class PSContext {
|
||||||
void set_fl_client_enable(bool enabled);
|
void set_fl_client_enable(bool enabled);
|
||||||
bool fl_client_enable();
|
bool fl_client_enable();
|
||||||
|
|
||||||
void set_start_fl_job_threshold(size_t start_fl_job_threshold);
|
void set_start_fl_job_threshold(uint64_t start_fl_job_threshold);
|
||||||
size_t start_fl_job_threshold() const;
|
uint64_t start_fl_job_threshold() const;
|
||||||
|
|
||||||
|
void set_start_fl_job_time_window(uint64_t start_fl_job_time_window);
|
||||||
|
uint64_t start_fl_job_time_window() const;
|
||||||
|
|
||||||
|
void set_update_model_ratio(float update_model_ratio);
|
||||||
|
float update_model_ratio() const;
|
||||||
|
|
||||||
|
void set_update_model_time_window(uint64_t update_model_time_window);
|
||||||
|
uint64_t update_model_time_window() const;
|
||||||
|
|
||||||
void set_fl_name(const std::string &fl_name);
|
void set_fl_name(const std::string &fl_name);
|
||||||
const std::string &fl_name() const;
|
const std::string &fl_name() const;
|
||||||
|
@ -133,9 +140,8 @@ class PSContext {
|
||||||
void set_client_batch_size(uint64_t client_batch_size);
|
void set_client_batch_size(uint64_t client_batch_size);
|
||||||
uint64_t client_batch_size() const;
|
uint64_t client_batch_size() const;
|
||||||
|
|
||||||
// Set true if worker will overwrite weights on server. Used in hybrid training.
|
void set_client_learning_rate(float client_learning_rate);
|
||||||
void set_worker_upload_weights(uint64_t worker_upload_weights);
|
float client_learning_rate() const;
|
||||||
uint64_t worker_upload_weights() const;
|
|
||||||
|
|
||||||
// Set true if using secure aggregation for federated learning.
|
// Set true if using secure aggregation for federated learning.
|
||||||
void set_secure_aggregation(bool secure_aggregation);
|
void set_secure_aggregation(bool secure_aggregation);
|
||||||
|
@ -160,11 +166,14 @@ class PSContext {
|
||||||
fl_client_enable_(false),
|
fl_client_enable_(false),
|
||||||
fl_name_(""),
|
fl_name_(""),
|
||||||
start_fl_job_threshold_(0),
|
start_fl_job_threshold_(0),
|
||||||
fl_iteration_num_(0),
|
start_fl_job_time_window_(3000),
|
||||||
client_epoch_num_(0),
|
update_model_ratio_(1.0),
|
||||||
client_batch_size_(0),
|
update_model_time_window_(3000),
|
||||||
secure_aggregation_(false),
|
fl_iteration_num_(20),
|
||||||
worker_upload_weights_(false) {}
|
client_epoch_num_(25),
|
||||||
|
client_batch_size_(32),
|
||||||
|
client_learning_rate_(0.001),
|
||||||
|
secure_aggregation_(false) {}
|
||||||
bool ps_enabled_;
|
bool ps_enabled_;
|
||||||
bool is_worker_;
|
bool is_worker_;
|
||||||
bool is_pserver_;
|
bool is_pserver_;
|
||||||
|
@ -195,7 +204,16 @@ class PSContext {
|
||||||
std::string fl_name_;
|
std::string fl_name_;
|
||||||
|
|
||||||
// The threshold count of startFLJob round. Used in federated learning for now.
|
// The threshold count of startFLJob round. Used in federated learning for now.
|
||||||
size_t start_fl_job_threshold_;
|
uint64_t start_fl_job_threshold_;
|
||||||
|
|
||||||
|
// The time window of startFLJob round in millisecond.
|
||||||
|
uint64_t start_fl_job_time_window_;
|
||||||
|
|
||||||
|
// Update model threshold is a certain ratio of start_fl_job threshold which is set as update_model_ratio_.
|
||||||
|
float update_model_ratio_;
|
||||||
|
|
||||||
|
// The time window of updateModel round in millisecond.
|
||||||
|
uint64_t update_model_time_window_;
|
||||||
|
|
||||||
// Iteration number of federeated learning, which is the number of interactions between client and server.
|
// Iteration number of federeated learning, which is the number of interactions between client and server.
|
||||||
uint64_t fl_iteration_num_;
|
uint64_t fl_iteration_num_;
|
||||||
|
@ -206,12 +224,11 @@ class PSContext {
|
||||||
// Client training data batch size. Used in federated learning for now.
|
// Client training data batch size. Used in federated learning for now.
|
||||||
uint64_t client_batch_size_;
|
uint64_t client_batch_size_;
|
||||||
|
|
||||||
|
// Client training learning rate. Used in federated learning for now.
|
||||||
|
float client_learning_rate_;
|
||||||
|
|
||||||
// Whether to use secure aggregation algorithm. Used in federated learning for now.
|
// Whether to use secure aggregation algorithm. Used in federated learning for now.
|
||||||
bool secure_aggregation_;
|
bool secure_aggregation_;
|
||||||
|
|
||||||
// Whether there's a federated learning worker uploading weights to federated learning server. Used in hybrid training
|
|
||||||
// mode for now.
|
|
||||||
bool worker_upload_weights_;
|
|
||||||
};
|
};
|
||||||
} // namespace ps
|
} // namespace ps
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -56,9 +56,9 @@ using mindspore::kernel::Address;
|
||||||
using mindspore::kernel::AddressPtr;
|
using mindspore::kernel::AddressPtr;
|
||||||
using mindspore::kernel::CPUKernel;
|
using mindspore::kernel::CPUKernel;
|
||||||
using FBBuilder = flatbuffers::FlatBufferBuilder;
|
using FBBuilder = flatbuffers::FlatBufferBuilder;
|
||||||
using TimeOutCb = std::function<void(void)>;
|
using TimeOutCb = std::function<void(bool)>;
|
||||||
using StopTimerCb = std::function<void(void)>;
|
using StopTimerCb = std::function<void(void)>;
|
||||||
using FinishIterCb = std::function<void(void)>;
|
using FinishIterCb = std::function<void(bool)>;
|
||||||
using FinalizeCb = std::function<void(void)>;
|
using FinalizeCb = std::function<void(void)>;
|
||||||
using MessageCallback = std::function<void(const std::shared_ptr<core::MessageHandler> &)>;
|
using MessageCallback = std::function<void(const std::shared_ptr<core::MessageHandler> &)>;
|
||||||
|
|
||||||
|
@ -148,6 +148,7 @@ constexpr size_t kExecutorMaxTaskNum = 32;
|
||||||
constexpr int kHttpSuccess = 200;
|
constexpr int kHttpSuccess = 200;
|
||||||
constexpr auto kPBProtocol = "PB";
|
constexpr auto kPBProtocol = "PB";
|
||||||
constexpr auto kFBSProtocol = "FBS";
|
constexpr auto kFBSProtocol = "FBS";
|
||||||
|
constexpr auto kSuccess = "Success";
|
||||||
constexpr auto kFedAvg = "FedAvg";
|
constexpr auto kFedAvg = "FedAvg";
|
||||||
constexpr auto kAggregationKernelType = "Aggregation";
|
constexpr auto kAggregationKernelType = "Aggregation";
|
||||||
constexpr auto kOptimizerKernelType = "Optimizer";
|
constexpr auto kOptimizerKernelType = "Optimizer";
|
||||||
|
@ -155,6 +156,7 @@ constexpr auto kCtxFuncGraph = "FuncGraph";
|
||||||
constexpr auto kCtxIterNum = "iteration";
|
constexpr auto kCtxIterNum = "iteration";
|
||||||
constexpr auto kCtxDeviceMetas = "device_metas";
|
constexpr auto kCtxDeviceMetas = "device_metas";
|
||||||
constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration";
|
constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration";
|
||||||
|
constexpr auto kCtxIterationNextRequestTimestamp = "iteration_next_request_timestamp";
|
||||||
constexpr auto kCtxUpdateModelClientList = "update_model_client_list";
|
constexpr auto kCtxUpdateModelClientList = "update_model_client_list";
|
||||||
constexpr auto kCtxUpdateModelClientNum = "update_model_client_num";
|
constexpr auto kCtxUpdateModelClientNum = "update_model_client_num";
|
||||||
constexpr auto kCtxUpdateModelThld = "update_model_threshold";
|
constexpr auto kCtxUpdateModelThld = "update_model_threshold";
|
||||||
|
|
|
@ -130,7 +130,7 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) {
|
||||||
|
|
||||||
void DistributedCountService::ResetCounter(const std::string &name) {
|
void DistributedCountService::ResetCounter(const std::string &name) {
|
||||||
if (local_rank_ == counting_server_rank_) {
|
if (local_rank_ == counting_server_rank_) {
|
||||||
MS_LOG(INFO) << "Leader server reset count for " << name;
|
MS_LOG(DEBUG) << "Leader server reset count for " << name;
|
||||||
global_current_count_[name].clear();
|
global_current_count_[name].clear();
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
|
@ -233,7 +233,7 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<core::Mes
|
||||||
const auto &type = counter_event.type();
|
const auto &type = counter_event.type();
|
||||||
const auto &name = counter_event.name();
|
const auto &name = counter_event.name();
|
||||||
|
|
||||||
MS_LOG(INFO) << "Rank " << local_rank_ << " do counter event " << type << " for " << name;
|
MS_LOG(DEBUG) << "Rank " << local_rank_ << " do counter event " << type << " for " << name;
|
||||||
if (type == CounterEventType::FIRST_CNT) {
|
if (type == CounterEventType::FIRST_CNT) {
|
||||||
counter_handlers_[name].first_count_handler(message);
|
counter_handlers_[name].first_count_handler(message);
|
||||||
} else if (type == CounterEventType::LAST_CNT) {
|
} else if (type == CounterEventType::LAST_CNT) {
|
||||||
|
@ -259,7 +259,7 @@ void DistributedCountService::TriggerCounterEvent(const std::string &name) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
|
void DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
|
||||||
MS_LOG(INFO) << "Activating first count event for " << name;
|
MS_LOG(DEBUG) << "Activating first count event for " << name;
|
||||||
CounterEvent first_count_event;
|
CounterEvent first_count_event;
|
||||||
first_count_event.set_type(CounterEventType::FIRST_CNT);
|
first_count_event.set_type(CounterEventType::FIRST_CNT);
|
||||||
first_count_event.set_name(name);
|
first_count_event.set_name(name);
|
||||||
|
|
|
@ -79,10 +79,10 @@ void DistributedMetadataStore::ResetMetadata(const std::string &name) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta) {
|
bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta) {
|
||||||
if (router_ == nullptr) {
|
if (router_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
|
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t stored_rank = router_->Find(name);
|
uint32_t stored_rank = router_->Find(name);
|
||||||
|
@ -90,18 +90,26 @@ void DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM
|
||||||
if (local_rank_ == stored_rank) {
|
if (local_rank_ == stored_rank) {
|
||||||
if (!DoUpdateMetadata(name, meta)) {
|
if (!DoUpdateMetadata(name, meta)) {
|
||||||
MS_LOG(ERROR) << "Updating meta data failed.";
|
MS_LOG(ERROR) << "Updating meta data failed.";
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
PBMetadataWithName metadata_with_name;
|
PBMetadataWithName metadata_with_name;
|
||||||
metadata_with_name.set_name(name);
|
metadata_with_name.set_name(name);
|
||||||
*metadata_with_name.mutable_metadata() = meta;
|
*metadata_with_name.mutable_metadata() = meta;
|
||||||
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, core::TcpUserCommand::kUpdateMetadata)) {
|
std::shared_ptr<std::vector<unsigned char>> update_meta_rsp_msg = nullptr;
|
||||||
|
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, core::TcpUserCommand::kUpdateMetadata,
|
||||||
|
&update_meta_rsp_msg)) {
|
||||||
MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed.";
|
MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed.";
|
||||||
return;
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string update_meta_rsp = reinterpret_cast<const char *>(update_meta_rsp_msg->data());
|
||||||
|
if (update_meta_rsp != kSuccess) {
|
||||||
|
MS_LOG(ERROR) << "Updating metadata in server " << stored_rank << " failed.";
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
|
PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
|
||||||
|
@ -166,6 +174,7 @@ void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr
|
||||||
std::string update_meta_rsp_msg;
|
std::string update_meta_rsp_msg;
|
||||||
if (!DoUpdateMetadata(name, meta_with_name.metadata())) {
|
if (!DoUpdateMetadata(name, meta_with_name.metadata())) {
|
||||||
update_meta_rsp_msg = "Updating meta data failed.";
|
update_meta_rsp_msg = "Updating meta data failed.";
|
||||||
|
MS_LOG(ERROR) << update_meta_rsp_msg;
|
||||||
} else {
|
} else {
|
||||||
update_meta_rsp_msg = "Success";
|
update_meta_rsp_msg = "Success";
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,7 +52,7 @@ class DistributedMetadataStore {
|
||||||
void ResetMetadata(const std::string &name);
|
void ResetMetadata(const std::string &name);
|
||||||
|
|
||||||
// Update the metadata for the name.
|
// Update the metadata for the name.
|
||||||
void UpdateMetadata(const std::string &name, const PBMetadata &meta);
|
bool UpdateMetadata(const std::string &name, const PBMetadata &meta);
|
||||||
|
|
||||||
// Get the metadata for the name.
|
// Get the metadata for the name.
|
||||||
PBMetadata GetMetadata(const std::string &name);
|
PBMetadata GetMetadata(const std::string &name);
|
||||||
|
|
|
@ -23,8 +23,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace ps {
|
||||||
namespace server {
|
namespace server {
|
||||||
Iteration::Iteration() : iteration_num_(1) { LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); }
|
|
||||||
|
|
||||||
void Iteration::AddRound(const std::shared_ptr<Round> &round) {
|
void Iteration::AddRound(const std::shared_ptr<Round> &round) {
|
||||||
MS_EXCEPTION_IF_NULL(round);
|
MS_EXCEPTION_IF_NULL(round);
|
||||||
rounds_.push_back(round);
|
rounds_.push_back(round);
|
||||||
|
@ -49,28 +47,48 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<core::CommunicatorB
|
||||||
|
|
||||||
// The time window for one iteration, which will be used in some round kernels.
|
// The time window for one iteration, which will be used in some round kernels.
|
||||||
size_t iteration_time_window =
|
size_t iteration_time_window =
|
||||||
std::accumulate(rounds_.begin(), rounds_.end(), 0,
|
std::accumulate(rounds_.begin(), rounds_.end(), 0, [](size_t total, const std::shared_ptr<Round> &round) {
|
||||||
[](size_t total, const std::shared_ptr<Round> &round) { return total + round->time_window(); });
|
return round->check_timeout() ? total + round->time_window() : total;
|
||||||
|
});
|
||||||
LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window);
|
LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window);
|
||||||
|
MS_LOG(INFO) << "Time window for one iteration is " << iteration_time_window;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Iteration::ProceedToNextIter() {
|
void Iteration::ProceedToNextIter(bool is_iteration_valid) {
|
||||||
iteration_num_ = LocalMetaStore::GetInstance().curr_iter_num();
|
iteration_num_ = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
// Store the model for each iteration.
|
if (is_iteration_valid) {
|
||||||
const auto &model = Executor::GetInstance().GetModel();
|
// Store the model which is successfully aggregated for this iteration.
|
||||||
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
const auto &model = Executor::GetInstance().GetModel();
|
||||||
|
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||||
|
MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
|
||||||
|
} else {
|
||||||
|
// Store last iteration's model because this iteration is considered as invalid.
|
||||||
|
const auto &model = ModelStore::GetInstance().GetModelByIterNum(iteration_num_ - 1);
|
||||||
|
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||||
|
MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid.";
|
||||||
|
}
|
||||||
|
|
||||||
for (auto &round : rounds_) {
|
for (auto &round : rounds_) {
|
||||||
round->Reset();
|
round->Reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
iteration_num_++;
|
iteration_num_++;
|
||||||
|
// After the job is done, reset the iteration to the initial number and reset ModelStore.
|
||||||
|
if (iteration_num_ > PSContext::instance()->fl_iteration_num()) {
|
||||||
|
MS_LOG(INFO) << PSContext::instance()->fl_iteration_num() << " iterations are completed.";
|
||||||
|
iteration_num_ = 1;
|
||||||
|
ModelStore::GetInstance().Reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
is_last_iteration_valid_ = is_iteration_valid;
|
||||||
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
|
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
|
||||||
MS_LOG(INFO) << "Proceed to next iteration:" << iteration_num_ << "\n";
|
MS_LOG(INFO) << "Proceed to next iteration:" << iteration_num_ << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<std::shared_ptr<Round>> &Iteration::rounds() { return rounds_; }
|
const std::vector<std::shared_ptr<Round>> &Iteration::rounds() { return rounds_; }
|
||||||
|
|
||||||
|
bool Iteration::is_last_iteration_valid() const { return is_last_iteration_valid_; }
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace ps
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -31,8 +31,10 @@ namespace server {
|
||||||
// Rounds, only after all the rounds are finished, this iteration is considered as completed.
|
// Rounds, only after all the rounds are finished, this iteration is considered as completed.
|
||||||
class Iteration {
|
class Iteration {
|
||||||
public:
|
public:
|
||||||
Iteration();
|
static Iteration &GetInstance() {
|
||||||
~Iteration() = default;
|
static Iteration instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
// Add a round for the iteration. This method will be called multiple times for each round.
|
// Add a round for the iteration. This method will be called multiple times for each round.
|
||||||
void AddRound(const std::shared_ptr<Round> &round);
|
void AddRound(const std::shared_ptr<Round> &round);
|
||||||
|
@ -41,16 +43,29 @@ class Iteration {
|
||||||
void InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators,
|
void InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators,
|
||||||
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb);
|
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb);
|
||||||
|
|
||||||
// The server proceeds to the next iteration only after the last iteration finishes.
|
// The server proceeds to the next iteration only after the last round finishes or the timer expires.
|
||||||
void ProceedToNextIter();
|
// If the timer expires, we consider this iteration as invalid.
|
||||||
|
void ProceedToNextIter(bool is_iteration_valid);
|
||||||
|
|
||||||
const std::vector<std::shared_ptr<Round>> &rounds();
|
const std::vector<std::shared_ptr<Round>> &rounds();
|
||||||
|
|
||||||
|
bool is_last_iteration_valid() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
Iteration() : iteration_num_(1), is_last_iteration_valid_(true) {
|
||||||
|
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
|
||||||
|
}
|
||||||
|
~Iteration() = default;
|
||||||
|
Iteration(const Iteration &) = delete;
|
||||||
|
Iteration &operator=(const Iteration &) = delete;
|
||||||
|
|
||||||
std::vector<std::shared_ptr<Round>> rounds_;
|
std::vector<std::shared_ptr<Round>> rounds_;
|
||||||
|
|
||||||
// Server's current iteration number.
|
// Server's current iteration number.
|
||||||
size_t iteration_num_;
|
size_t iteration_num_;
|
||||||
|
|
||||||
|
// Last iteration is successfully finished.
|
||||||
|
bool is_last_iteration_valid_;
|
||||||
};
|
};
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace ps
|
||||||
|
|
|
@ -29,7 +29,7 @@ void IterationTimer::Start(const std::chrono::milliseconds &duration) {
|
||||||
monitor_thread_ = std::thread([&]() {
|
monitor_thread_ = std::thread([&]() {
|
||||||
while (running_.load()) {
|
while (running_.load()) {
|
||||||
if (CURRENT_TIME_MILLI > end_time_) {
|
if (CURRENT_TIME_MILLI > end_time_) {
|
||||||
timeout_callback_();
|
timeout_callback_(false);
|
||||||
running_ = false;
|
running_ = false;
|
||||||
}
|
}
|
||||||
// The time tick is 1 millisecond.
|
// The time tick is 1 millisecond.
|
||||||
|
|
|
@ -47,6 +47,7 @@ class ApplyMomentumKernel : public ApplyMomentumCPUKernel, public OptimizerKerne
|
||||||
}
|
}
|
||||||
|
|
||||||
void GenerateReuseKernelNodeInfo() override {
|
void GenerateReuseKernelNodeInfo() override {
|
||||||
|
MS_LOG(INFO) << "FedAvg reuse 'weight', 'accumulation', 'learning rate' and 'momentum' of the kernel node.";
|
||||||
reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, 0));
|
reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, 0));
|
||||||
reuse_kernel_node_inputs_info_.insert(std::make_pair(kAccumulation, 1));
|
reuse_kernel_node_inputs_info_.insert(std::make_pair(kAccumulation, 1));
|
||||||
reuse_kernel_node_inputs_info_.insert(std::make_pair(kLearningRate, 2));
|
reuse_kernel_node_inputs_info_.insert(std::make_pair(kLearningRate, 2));
|
||||||
|
|
|
@ -92,7 +92,6 @@ class FedAvgKernel : public AggregationKernel {
|
||||||
weight_addr[i] /= data_size_addr[0];
|
weight_addr[i] /= data_size_addr[0];
|
||||||
}
|
}
|
||||||
done_ = true;
|
done_ = true;
|
||||||
DistributedCountService::GetInstance().ResetCounter(name_);
|
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler, last_cnt_handler});
|
DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler, last_cnt_handler});
|
||||||
|
@ -125,6 +124,7 @@ class FedAvgKernel : public AggregationKernel {
|
||||||
participated_ = true;
|
participated_ = true;
|
||||||
DistributedCountService::GetInstance().Count(
|
DistributedCountService::GetInstance().Count(
|
||||||
name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_));
|
name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_));
|
||||||
|
GenerateReuseKernelNodeInfo();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,6 +149,7 @@ class FedAvgKernel : public AggregationKernel {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void GenerateReuseKernelNodeInfo() override {
|
void GenerateReuseKernelNodeInfo() override {
|
||||||
|
MS_LOG(INFO) << "FedAvg reuse 'weight' of the kernel node.";
|
||||||
// Only the trainable parameter is reused for federated average.
|
// Only the trainable parameter is reused for federated average.
|
||||||
reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, cnode_weight_idx_));
|
reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, cnode_weight_idx_));
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include "ps/server/iteration.h"
|
||||||
#include "ps/server/model_store.h"
|
#include "ps/server/model_store.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -67,27 +68,31 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons
|
||||||
const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
|
const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
|
||||||
size_t latest_iter_num = iter_to_model.rbegin()->first;
|
size_t latest_iter_num = iter_to_model.rbegin()->first;
|
||||||
|
|
||||||
|
// If this iteration is not finished yet, return ResponseCode_SucNotReady so that clients could get model later.
|
||||||
if ((current_iter == get_model_iter && latest_iter_num != current_iter) || current_iter == get_model_iter - 1) {
|
if ((current_iter == get_model_iter && latest_iter_num != current_iter) || current_iter == get_model_iter - 1) {
|
||||||
std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter);
|
std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter);
|
||||||
BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps,
|
BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps,
|
||||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||||
MS_LOG(WARNING) << reason;
|
MS_LOG(WARNING) << reason;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (iter_to_model.count(get_model_iter) == 0) {
|
if (iter_to_model.count(get_model_iter) == 0) {
|
||||||
std::string reason = "The iteration of GetModel request" + std::to_string(get_model_iter) +
|
// If the model of get_model_iter is not stored, return the latest version of model and current iteration number.
|
||||||
" is invalid. Current iteration is " + std::to_string(current_iter);
|
MS_LOG(WARNING) << "The iteration of GetModel request " << std::to_string(get_model_iter)
|
||||||
BuildGetModelRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter, feature_maps,
|
<< " is invalid. Current iteration is " << std::to_string(current_iter);
|
||||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
feature_maps = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num);
|
||||||
MS_LOG(ERROR) << reason;
|
} else {
|
||||||
return;
|
feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter);
|
// If the iteration of this model is invalid, return ResponseCode_OutOfTime to the clients could startFLJob according
|
||||||
BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED,
|
// to next_req_time.
|
||||||
"Get model for iteration " + std::to_string(get_model_iter) + " success.", current_iter,
|
auto response_code =
|
||||||
feature_maps, std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
Iteration::GetInstance().is_last_iteration_valid() ? schema::ResponseCode_SUCCEED : schema::ResponseCode_OutOfTime;
|
||||||
|
BuildGetModelRsp(fbb, response_code, "Get model for iteration " + std::to_string(get_model_iter), current_iter,
|
||||||
|
feature_maps,
|
||||||
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -68,7 +68,7 @@ void RoundKernel::StopTimer() {
|
||||||
|
|
||||||
void RoundKernel::FinishIteration() {
|
void RoundKernel::FinishIteration() {
|
||||||
if (finish_iteration_cb_) {
|
if (finish_iteration_cb_) {
|
||||||
finish_iteration_cb_();
|
finish_iteration_cb_(true);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,15 +61,12 @@ class RoundKernel : virtual public CPUKernel {
|
||||||
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs) = 0;
|
const std::vector<AddressPtr> &outputs) = 0;
|
||||||
|
|
||||||
// The callbacks when first message and last message for this round kernel is received.
|
|
||||||
// These methods is called by class DistributedCountService and triggered by leader server(Rank 0).
|
|
||||||
// virtual void OnFirstCountEvent(std::shared_ptr<core::MessageHandler> message);
|
|
||||||
// virtual void OnLastCnt(std::shared_ptr<core::MessageHandler> message);
|
|
||||||
|
|
||||||
// Some rounds could be stateful in a iteration. Reset method resets the status of this round.
|
// Some rounds could be stateful in a iteration. Reset method resets the status of this round.
|
||||||
virtual bool Reset() = 0;
|
virtual bool Reset() = 0;
|
||||||
|
|
||||||
// The counter event handlers for DistributedCountService.
|
// The counter event handlers for DistributedCountService.
|
||||||
|
// The callbacks when first message and last message for this round kernel is received.
|
||||||
|
// These methods is called by class DistributedCountService and triggered by counting server.
|
||||||
virtual void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
virtual void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
||||||
virtual void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
virtual void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
||||||
|
|
||||||
|
|
|
@ -25,9 +25,12 @@ namespace ps {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
void StartFLJobKernel::InitKernel(size_t) {
|
void StartFLJobKernel::InitKernel(size_t) {
|
||||||
|
// The time window of one iteration should be started at the first message of startFLJob round.
|
||||||
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||||
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||||
}
|
}
|
||||||
|
iter_next_req_timestamp_ = CURRENT_TIME_MILLI.count() + iteration_time_window_;
|
||||||
|
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
|
||||||
|
|
||||||
executor_ = &Executor::GetInstance();
|
executor_ = &Executor::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(executor_);
|
MS_EXCEPTION_IF_NULL(executor_);
|
||||||
|
@ -85,11 +88,17 @@ bool StartFLJobKernel::Reset() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &) {
|
||||||
|
iter_next_req_timestamp_ = CURRENT_TIME_MILLI.count() + iteration_time_window_;
|
||||||
|
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
|
||||||
|
}
|
||||||
|
|
||||||
bool StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb) {
|
bool StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb) {
|
||||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||||
std::string reason = "Current amount for startFLJob has reached the threshold. Please startFLJob later.";
|
std::string reason = "Current amount for startFLJob has reached the threshold. Please startFLJob later.";
|
||||||
BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false,
|
BuildStartFLJobRsp(
|
||||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
fbb, schema::ResponseCode_OutOfTime, reason, false,
|
||||||
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||||
MS_LOG(ERROR) << reason;
|
MS_LOG(ERROR) << reason;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -117,8 +126,9 @@ bool StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
ret = false;
|
ret = false;
|
||||||
}
|
}
|
||||||
if (!ret) {
|
if (!ret) {
|
||||||
BuildStartFLJobRsp(fbb, schema::ResponseCode_NotSelected, reason, false,
|
BuildStartFLJobRsp(
|
||||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
fbb, schema::ResponseCode_NotSelected, reason, false,
|
||||||
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||||
MS_LOG(ERROR) << reason;
|
MS_LOG(ERROR) << reason;
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -128,8 +138,9 @@ bool StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
const schema::RequestFLJob *start_fl_job_req) {
|
const schema::RequestFLJob *start_fl_job_req) {
|
||||||
if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str())) {
|
if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str())) {
|
||||||
std::string reason = "startFLJob counting failed.";
|
std::string reason = "startFLJob counting failed.";
|
||||||
BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false,
|
BuildStartFLJobRsp(
|
||||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
fbb, schema::ResponseCode_OutOfTime, reason, false,
|
||||||
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||||
MS_LOG(ERROR) << reason;
|
MS_LOG(ERROR) << reason;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -139,11 +150,18 @@ bool StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
|
void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
|
||||||
PBMetadata metadata;
|
PBMetadata metadata;
|
||||||
*metadata.mutable_device_meta() = device_meta;
|
*metadata.mutable_device_meta() = device_meta;
|
||||||
DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata);
|
if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata)) {
|
||||||
|
std::string reason = "Updating device metadata failed.";
|
||||||
|
BuildStartFLJobRsp(fbb, schema::ResponseCode_SystemError, reason, false,
|
||||||
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)),
|
||||||
|
{});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
std::map<std::string, AddressPtr> feature_maps = executor_->GetModel();
|
std::map<std::string, AddressPtr> feature_maps = executor_->GetModel();
|
||||||
BuildStartFLJobRsp(fbb, schema::ResponseCode_SUCCEED, "success", true,
|
BuildStartFLJobRsp(fbb, schema::ResponseCode_SUCCEED, "success", true,
|
||||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_), feature_maps);
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)),
|
||||||
|
feature_maps);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -153,13 +171,16 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
std::map<std::string, AddressPtr> feature_maps) {
|
std::map<std::string, AddressPtr> feature_maps) {
|
||||||
auto fbs_reason = fbb->CreateString(reason);
|
auto fbs_reason = fbb->CreateString(reason);
|
||||||
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
||||||
|
auto fbs_server_mode = fbb->CreateString(PSContext::instance()->server_mode());
|
||||||
auto fbs_fl_name = fbb->CreateString(PSContext::instance()->fl_name());
|
auto fbs_fl_name = fbb->CreateString(PSContext::instance()->fl_name());
|
||||||
|
|
||||||
schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
|
schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
|
||||||
fl_plan_builder.add_fl_name(fbs_fl_name);
|
fl_plan_builder.add_fl_name(fbs_fl_name);
|
||||||
|
fl_plan_builder.add_server_mode(fbs_server_mode);
|
||||||
fl_plan_builder.add_iterations(PSContext::instance()->fl_iteration_num());
|
fl_plan_builder.add_iterations(PSContext::instance()->fl_iteration_num());
|
||||||
fl_plan_builder.add_epochs(PSContext::instance()->client_epoch_num());
|
fl_plan_builder.add_epochs(PSContext::instance()->client_epoch_num());
|
||||||
fl_plan_builder.add_mini_batch(PSContext::instance()->client_batch_size());
|
fl_plan_builder.add_mini_batch(PSContext::instance()->client_batch_size());
|
||||||
|
fl_plan_builder.add_lr(PSContext::instance()->client_learning_rate());
|
||||||
auto fbs_fl_plan = fl_plan_builder.Finish();
|
auto fbs_fl_plan = fl_plan_builder.Finish();
|
||||||
|
|
||||||
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
|
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
|
||||||
|
|
|
@ -32,7 +32,7 @@ namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
class StartFLJobKernel : public RoundKernel {
|
class StartFLJobKernel : public RoundKernel {
|
||||||
public:
|
public:
|
||||||
StartFLJobKernel() = default;
|
StartFLJobKernel() : executor_(nullptr), iteration_time_window_(0), iter_next_req_timestamp_(0) {}
|
||||||
~StartFLJobKernel() override = default;
|
~StartFLJobKernel() override = default;
|
||||||
|
|
||||||
void InitKernel(size_t threshold_count) override;
|
void InitKernel(size_t threshold_count) override;
|
||||||
|
@ -40,6 +40,8 @@ class StartFLJobKernel : public RoundKernel {
|
||||||
const std::vector<AddressPtr> &outputs) override;
|
const std::vector<AddressPtr> &outputs) override;
|
||||||
bool Reset() override;
|
bool Reset() override;
|
||||||
|
|
||||||
|
void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Returns whether the startFLJob count of this iteration has reached the threshold.
|
// Returns whether the startFLJob count of this iteration has reached the threshold.
|
||||||
bool ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb);
|
bool ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb);
|
||||||
|
@ -66,6 +68,9 @@ class StartFLJobKernel : public RoundKernel {
|
||||||
|
|
||||||
// The time window of one iteration.
|
// The time window of one iteration.
|
||||||
size_t iteration_time_window_;
|
size_t iteration_time_window_;
|
||||||
|
|
||||||
|
// Timestamp of next request time for this iteration.
|
||||||
|
uint64_t iter_next_req_timestamp_;
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
|
|
|
@ -39,6 +39,7 @@ void UpdateModelKernel::InitKernel(size_t threshold_count) {
|
||||||
PBMetadata client_list;
|
PBMetadata client_list;
|
||||||
DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxUpdateModelClientList, client_list);
|
DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxUpdateModelClientList, client_list);
|
||||||
LocalMetaStore::GetInstance().put_value(kCtxUpdateModelThld, threshold_count);
|
LocalMetaStore::GetInstance().put_value(kCtxUpdateModelThld, threshold_count);
|
||||||
|
LocalMetaStore::GetInstance().put_value(kCtxFedAvgTotalDataSize, kInitialDataSizeSum);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
@ -103,8 +104,9 @@ void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHand
|
||||||
bool UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb) {
|
bool UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb) {
|
||||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||||
std::string reason = "Current amount for updateModel is enough.";
|
std::string reason = "Current amount for updateModel is enough.";
|
||||||
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
BuildUpdateModelRsp(
|
||||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||||
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||||
MS_LOG(ERROR) << reason;
|
MS_LOG(ERROR) << reason;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -117,8 +119,9 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
|
||||||
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
|
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
|
||||||
std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) +
|
std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) +
|
||||||
", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num());
|
", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num());
|
||||||
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
BuildUpdateModelRsp(
|
||||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||||
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||||
MS_LOG(ERROR) << reason;
|
MS_LOG(ERROR) << reason;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -128,14 +131,24 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
|
||||||
std::string update_model_fl_id = update_model_req->fl_id()->str();
|
std::string update_model_fl_id = update_model_req->fl_id()->str();
|
||||||
if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
|
if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
|
||||||
std::string reason = "devices_meta for " + update_model_fl_id + " is not set.";
|
std::string reason = "devices_meta for " + update_model_fl_id + " is not set.";
|
||||||
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
BuildUpdateModelRsp(
|
||||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||||
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||||
MS_LOG(ERROR) << reason;
|
MS_LOG(ERROR) << reason;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size();
|
size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size();
|
||||||
auto feature_map = ParseFeatureMap(update_model_req);
|
auto feature_map = ParseFeatureMap(update_model_req);
|
||||||
|
if (feature_map.empty()) {
|
||||||
|
std::string reason = "Feature map is empty.";
|
||||||
|
BuildUpdateModelRsp(
|
||||||
|
fbb, schema::ResponseCode_RequestError, reason,
|
||||||
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
for (auto weight : feature_map) {
|
for (auto weight : feature_map) {
|
||||||
weight.second[kNewDataSize].addr = &data_size;
|
weight.second[kNewDataSize].addr = &data_size;
|
||||||
weight.second[kNewDataSize].size = sizeof(size_t);
|
weight.second[kNewDataSize].size = sizeof(size_t);
|
||||||
|
@ -146,10 +159,17 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
|
||||||
fl_id.set_fl_id(update_model_fl_id);
|
fl_id.set_fl_id(update_model_fl_id);
|
||||||
PBMetadata comm_value;
|
PBMetadata comm_value;
|
||||||
*comm_value.mutable_fl_id() = fl_id;
|
*comm_value.mutable_fl_id() = fl_id;
|
||||||
DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value);
|
if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value)) {
|
||||||
|
std::string reason = "Updating metadata of UpdateModelClientList failed.";
|
||||||
|
BuildUpdateModelRsp(
|
||||||
|
fbb, schema::ResponseCode_SystemError, reason,
|
||||||
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
BuildUpdateModelRsp(fbb, schema::ResponseCode_SucNotReady, "success not ready",
|
BuildUpdateModelRsp(fbb, schema::ResponseCode_SUCCEED, "success not ready",
|
||||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -174,8 +194,9 @@ bool UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fb
|
||||||
const schema::RequestUpdateModel *update_model_req) {
|
const schema::RequestUpdateModel *update_model_req) {
|
||||||
if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str())) {
|
if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str())) {
|
||||||
std::string reason = "UpdateModel counting failed.";
|
std::string reason = "UpdateModel counting failed.";
|
||||||
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
BuildUpdateModelRsp(
|
||||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||||
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||||
MS_LOG(ERROR) << reason;
|
MS_LOG(ERROR) << reason;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,6 +30,9 @@ namespace mindspore {
|
||||||
namespace ps {
|
namespace ps {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
// The initial data size sum of federated learning is 0, which will be accumulated in updateModel round.
|
||||||
|
constexpr uint64_t kInitialDataSizeSum = 0;
|
||||||
|
|
||||||
class UpdateModelKernel : public RoundKernel {
|
class UpdateModelKernel : public RoundKernel {
|
||||||
public:
|
public:
|
||||||
UpdateModelKernel() = default;
|
UpdateModelKernel() = default;
|
||||||
|
|
|
@ -30,7 +30,8 @@ void ModelStore::Initialize(uint32_t max_count) {
|
||||||
}
|
}
|
||||||
|
|
||||||
max_model_count_ = max_count;
|
max_model_count_ = max_count;
|
||||||
iteration_to_model_[kInitIterationNum] = AssignNewModelMemory();
|
initial_model_ = AssignNewModelMemory();
|
||||||
|
iteration_to_model_[kInitIterationNum] = initial_model_;
|
||||||
model_size_ = ComputeModelSize();
|
model_size_ = ComputeModelSize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,7 +53,6 @@ bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::strin
|
||||||
MS_LOG(ERROR) << "Memory for the new model is nullptr.";
|
MS_LOG(ERROR) << "Memory for the new model is nullptr.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
iteration_to_model_[iteration] = memory_register;
|
iteration_to_model_[iteration] = memory_register;
|
||||||
} else {
|
} else {
|
||||||
// If iteration_to_model_ size is already max_model_count_, we need to replace earliest model with the newest model.
|
// If iteration_to_model_ size is already max_model_count_, we need to replace earliest model with the newest model.
|
||||||
|
@ -97,6 +97,12 @@ std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration
|
||||||
return model;
|
return model;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ModelStore::Reset() {
|
||||||
|
initial_model_ = iteration_to_model_.rbegin()->second;
|
||||||
|
iteration_to_model_.clear();
|
||||||
|
iteration_to_model_[kInitIterationNum] = initial_model_;
|
||||||
|
}
|
||||||
|
|
||||||
const std::map<size_t, std::shared_ptr<MemoryRegister>> &ModelStore::iteration_to_model() const {
|
const std::map<size_t, std::shared_ptr<MemoryRegister>> &ModelStore::iteration_to_model() const {
|
||||||
return iteration_to_model_;
|
return iteration_to_model_;
|
||||||
}
|
}
|
||||||
|
@ -121,6 +127,14 @@ std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto src_data_size = weight_size;
|
||||||
|
auto dst_data_size = weight_size;
|
||||||
|
int ret = memcpy_s(weight_data.get(), dst_data_size, weight.second->addr, src_data_size);
|
||||||
|
if (ret != 0) {
|
||||||
|
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
memory_register->RegisterArray(weight_name, &weight_data, weight_size);
|
memory_register->RegisterArray(weight_name, &weight_data, weight_size);
|
||||||
}
|
}
|
||||||
return memory_register;
|
return memory_register;
|
||||||
|
|
|
@ -49,6 +49,9 @@ class ModelStore {
|
||||||
// Get model of the given iteration.
|
// Get model of the given iteration.
|
||||||
std::map<std::string, AddressPtr> GetModelByIterNum(size_t iteration);
|
std::map<std::string, AddressPtr> GetModelByIterNum(size_t iteration);
|
||||||
|
|
||||||
|
// Reset the stored models. Called when federated learning job finishes.
|
||||||
|
void Reset();
|
||||||
|
|
||||||
// Returns all models stored in ModelStore.
|
// Returns all models stored in ModelStore.
|
||||||
const std::map<size_t, std::shared_ptr<MemoryRegister>> &iteration_to_model() const;
|
const std::map<size_t, std::shared_ptr<MemoryRegister>> &iteration_to_model() const;
|
||||||
|
|
||||||
|
@ -70,6 +73,11 @@ class ModelStore {
|
||||||
|
|
||||||
size_t max_model_count_;
|
size_t max_model_count_;
|
||||||
size_t model_size_;
|
size_t model_size_;
|
||||||
|
|
||||||
|
// Initial model which is the model of iteration 0.
|
||||||
|
std::shared_ptr<MemoryRegister> initial_model_;
|
||||||
|
|
||||||
|
// The number of all models stpred is max_model_count_.
|
||||||
std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_;
|
std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_;
|
||||||
};
|
};
|
||||||
} // namespace server
|
} // namespace server
|
||||||
|
|
|
@ -38,9 +38,9 @@ void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicat
|
||||||
name_, [&](std::shared_ptr<core::MessageHandler> message) { LaunchRoundKernel(message); });
|
name_, [&](std::shared_ptr<core::MessageHandler> message) { LaunchRoundKernel(message); });
|
||||||
|
|
||||||
// Callback when the iteration is finished.
|
// Callback when the iteration is finished.
|
||||||
finish_iteration_cb_ = [this, finish_iteration_cb](void) -> void {
|
finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid) -> void {
|
||||||
MS_LOG(INFO) << "Round " << name_ << " finished! Proceed to next iteration.";
|
MS_LOG(INFO) << "Round " << name_ << " finished! This iteration is valid. Proceed to next iteration.";
|
||||||
finish_iteration_cb();
|
finish_iteration_cb(is_iteration_valid);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Callback for finalizing the server. This can only be called once.
|
// Callback for finalizing the server. This can only be called once.
|
||||||
|
@ -50,9 +50,9 @@ void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicat
|
||||||
iter_timer_ = std::make_shared<IterationTimer>();
|
iter_timer_ = std::make_shared<IterationTimer>();
|
||||||
|
|
||||||
// 1.Set the timeout callback for the timer.
|
// 1.Set the timeout callback for the timer.
|
||||||
iter_timer_->SetTimeOutCallBack([this, timeout_cb](void) -> void {
|
iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool is_iteration_valid) -> void {
|
||||||
MS_LOG(INFO) << "Round " << name_ << " timeout! Proceed to next iteration.";
|
MS_LOG(INFO) << "Round " << name_ << " timeout! This iteration is invalid. Proceed to next iteration.";
|
||||||
timeout_cb();
|
timeout_cb(is_iteration_valid);
|
||||||
});
|
});
|
||||||
|
|
||||||
// 2.Stopping timer callback which will be set to the round kernel.
|
// 2.Stopping timer callback which will be set to the round kernel.
|
||||||
|
@ -112,14 +112,19 @@ const std::string &Round::name() const { return name_; }
|
||||||
|
|
||||||
size_t Round::threshold_count() const { return threshold_count_; }
|
size_t Round::threshold_count() const { return threshold_count_; }
|
||||||
|
|
||||||
|
bool Round::check_timeout() const { return check_timeout_; }
|
||||||
|
|
||||||
size_t Round::time_window() const { return time_window_; }
|
size_t Round::time_window() const { return time_window_; }
|
||||||
|
|
||||||
void Round::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &) {
|
void Round::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
|
||||||
MS_LOG(INFO) << "Round " << name_ << " first count event is triggered.";
|
MS_LOG(INFO) << "Round " << name_ << " first count event is triggered.";
|
||||||
// The timer starts only after the first count event is triggered by DistributedCountService.
|
// The timer starts only after the first count event is triggered by DistributedCountService.
|
||||||
if (check_timeout_) {
|
if (check_timeout_) {
|
||||||
iter_timer_->Start(std::chrono::milliseconds(time_window_));
|
iter_timer_->Start(std::chrono::milliseconds(time_window_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Some kernels override the OnFirstCountEvent method.
|
||||||
|
kernel_->OnFirstCountEvent(message);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -52,6 +52,7 @@ class Round {
|
||||||
|
|
||||||
const std::string &name() const;
|
const std::string &name() const;
|
||||||
size_t threshold_count() const;
|
size_t threshold_count() const;
|
||||||
|
bool check_timeout() const;
|
||||||
size_t time_window() const;
|
size_t time_window() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -174,21 +174,22 @@ bool Server::InitCommunicatorWithWorker() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Server::InitIteration() {
|
void Server::InitIteration() {
|
||||||
iteration_ = std::make_shared<Iteration>();
|
iteration_ = &Iteration::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(iteration_);
|
MS_EXCEPTION_IF_NULL(iteration_);
|
||||||
|
|
||||||
// 1.Add rounds to the iteration according to the server mode.
|
// 1.Add rounds to the iteration according to the server mode.
|
||||||
for (const RoundConfig &config : rounds_config_) {
|
for (const RoundConfig &config : rounds_config_) {
|
||||||
std::shared_ptr<Round> round = std::make_shared<Round>(config.name, config.check_timeout, config.time_window,
|
std::shared_ptr<Round> round = std::make_shared<Round>(config.name, config.check_timeout, config.time_window,
|
||||||
config.check_count, config.threshold_count);
|
config.check_count, config.threshold_count);
|
||||||
MS_LOG(INFO) << "Add round " << config.name << ", check_count: " << config.check_count
|
MS_LOG(INFO) << "Add round " << config.name << ", check_timeout: " << config.check_timeout
|
||||||
<< ", threshold:" << config.threshold_count;
|
<< ", time window: " << config.time_window << ", check_count: " << config.check_count
|
||||||
|
<< ", threshold: " << config.threshold_count;
|
||||||
iteration_->AddRound(round);
|
iteration_->AddRound(round);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2.Initialize all the rounds.
|
// 2.Initialize all the rounds.
|
||||||
TimeOutCb time_out_cb = std::bind(&Iteration::ProceedToNextIter, iteration_);
|
TimeOutCb time_out_cb = std::bind(&Iteration::ProceedToNextIter, iteration_, std::placeholders::_1);
|
||||||
FinishIterCb finish_iter_cb = std::bind(&Iteration::ProceedToNextIter, iteration_);
|
FinishIterCb finish_iter_cb = std::bind(&Iteration::ProceedToNextIter, iteration_, std::placeholders::_1);
|
||||||
iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb);
|
iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -117,7 +117,7 @@ class Server {
|
||||||
std::vector<std::shared_ptr<core::CommunicatorBase>> communicators_with_worker_;
|
std::vector<std::shared_ptr<core::CommunicatorBase>> communicators_with_worker_;
|
||||||
|
|
||||||
// Iteration consists of multiple kinds of rounds.
|
// Iteration consists of multiple kinds of rounds.
|
||||||
std::shared_ptr<Iteration> iteration_;
|
Iteration *iteration_;
|
||||||
|
|
||||||
// Variables set by ps context.
|
// Variables set by ps context.
|
||||||
std::string scheduler_ip_;
|
std::string scheduler_ip_;
|
||||||
|
|
|
@ -787,3 +787,59 @@ def reset_ps_context():
|
||||||
- enable_ps: False.
|
- enable_ps: False.
|
||||||
"""
|
"""
|
||||||
_reset_ps_context()
|
_reset_ps_context()
|
||||||
|
|
||||||
|
def set_fl_context(**kwargs):
|
||||||
|
"""
|
||||||
|
Set federated learning training mode context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enable_fl (bool): Whether to enable federated learning training mode.
|
||||||
|
Default: False.
|
||||||
|
server_mode (string): Describe the server mode, which must one of 'FEDERATED_LEARNING' and 'HYBRID_TRAINING'.
|
||||||
|
Default: 'FEDERATED_LEARNING'.
|
||||||
|
ms_role (string): The process's role in the federated learning mode,
|
||||||
|
which must be one of 'MS_SERVER', 'MS_WORKER' and 'MS_SCHED'.
|
||||||
|
Default: 'MS_NOT_PS'.
|
||||||
|
worker_num (int): The number of workers. Default: 0.
|
||||||
|
server_num (int): The number of federated learning servers. Default: 0.
|
||||||
|
scheduler_ip (string): The scheduler IP. Default: ''.
|
||||||
|
scheduler_port (int): The scheduler port. Default: 0.
|
||||||
|
fl_server_port (int): The http port of the federated learning server.
|
||||||
|
Normally for each server this should be set to the same value. Default: 0.
|
||||||
|
enable_fl_client (bool): Whether this process is federated learning client. Default: False.
|
||||||
|
start_fl_job_threshold (int): The threshold count of startFLJob. Default: 0.
|
||||||
|
start_fl_job_time_window (int): The time window duration for startFLJob in millisecond. Default: 3000.
|
||||||
|
update_model_ratio (float): The ratio for computing the threshold count of updateModel
|
||||||
|
which will be multiplied by start_fl_job_threshold. Default: 1.0.
|
||||||
|
update_model_time_window (int): The time window duration for updateModel in millisecond. Default: 3000.
|
||||||
|
fl_name (string): The federated learning job name. Default: ''.
|
||||||
|
fl_iteration_num (int): Iteration number of federeated learning,
|
||||||
|
which is the number of interactions between client and server. Default: 20.
|
||||||
|
client_epoch_num (int): Client training epoch number. Default: 25.
|
||||||
|
client_batch_size (int): Client training data batch size. Default: 32.
|
||||||
|
client_learning_rate (float): Client training learning rate. Default: 0.001.
|
||||||
|
secure_aggregation (bool): Whether to use secure aggregation algorithm. Default: False.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input key is not the attribute in federated learning mode context.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> context.set_fl_context(enable_fl=True, server_mode='FEDERATED_LEARNING')
|
||||||
|
"""
|
||||||
|
_set_ps_context(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_fl_context(attr_key):
|
||||||
|
"""
|
||||||
|
Get federated learning mode context attribute value according to the key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attr_key (str): The key of the attribute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Returns attribute value according to the key.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input key is not attribute in federated learning mode context.
|
||||||
|
"""
|
||||||
|
return _get_ps_context(attr_key)
|
||||||
|
|
|
@ -19,6 +19,8 @@
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
#include <sys/time.h>
|
#include <sys/time.h>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
// namespace to support utils module definition
|
// namespace to support utils module definition
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -117,8 +119,8 @@ void LogWriter::OutputLog(const std::ostringstream &msg) const {
|
||||||
#define google mindspore_private
|
#define google mindspore_private
|
||||||
auto submodule_name = GetSubModuleName(submodule_);
|
auto submodule_name = GetSubModuleName(submodule_);
|
||||||
google::LogMessage("", 0, GetGlogLevel(log_level_)).stream()
|
google::LogMessage("", 0, GetGlogLevel(log_level_)).stream()
|
||||||
<< "[" << GetLogLevel(log_level_) << "] " << submodule_name << "(" << getpid() << "," << GetProcName()
|
<< "[" << GetLogLevel(log_level_) << "] " << submodule_name << "(" << getpid() << "," << std::hex
|
||||||
<< "):" << GetTimeString() << " "
|
<< std::this_thread::get_id() << std::dec << "," << GetProcName() << "):" << GetTimeString() << " "
|
||||||
<< "[" << location_.file_ << ":" << location_.line_ << "] " << location_.func_ << "] " << msg.str() << std::endl;
|
<< "[" << location_.file_ << ":" << location_.line_ << "] " << location_.func_ << "] " << msg.str() << std::endl;
|
||||||
#undef google
|
#undef google
|
||||||
#else
|
#else
|
||||||
|
|
|
@ -36,6 +36,7 @@ _set_ps_context_func_map = {
|
||||||
"server_mode": ps_context().set_server_mode,
|
"server_mode": ps_context().set_server_mode,
|
||||||
"ms_role": ps_context().set_ms_role,
|
"ms_role": ps_context().set_ms_role,
|
||||||
"enable_ps": ps_context().set_ps_enable,
|
"enable_ps": ps_context().set_ps_enable,
|
||||||
|
"enable_fl": ps_context().set_ps_enable,
|
||||||
"worker_num": ps_context().set_worker_num,
|
"worker_num": ps_context().set_worker_num,
|
||||||
"server_num": ps_context().set_server_num,
|
"server_num": ps_context().set_server_num,
|
||||||
"scheduler_ip": ps_context().set_scheduler_ip,
|
"scheduler_ip": ps_context().set_scheduler_ip,
|
||||||
|
@ -43,10 +44,14 @@ _set_ps_context_func_map = {
|
||||||
"fl_server_port": ps_context().set_fl_server_port,
|
"fl_server_port": ps_context().set_fl_server_port,
|
||||||
"enable_fl_client": ps_context().set_fl_client_enable,
|
"enable_fl_client": ps_context().set_fl_client_enable,
|
||||||
"start_fl_job_threshold": ps_context().set_start_fl_job_threshold,
|
"start_fl_job_threshold": ps_context().set_start_fl_job_threshold,
|
||||||
|
"start_fl_job_time_window": ps_context().set_start_fl_job_time_window,
|
||||||
|
"update_model_ratio": ps_context().set_update_model_ratio,
|
||||||
|
"update_model_time_window": ps_context().set_update_model_time_window,
|
||||||
"fl_name": ps_context().set_fl_name,
|
"fl_name": ps_context().set_fl_name,
|
||||||
"fl_iteration_num": ps_context().set_fl_iteration_num,
|
"fl_iteration_num": ps_context().set_fl_iteration_num,
|
||||||
"client_epoch_num": ps_context().set_client_epoch_num,
|
"client_epoch_num": ps_context().set_client_epoch_num,
|
||||||
"client_batch_size": ps_context().set_client_batch_size,
|
"client_batch_size": ps_context().set_client_batch_size,
|
||||||
|
"client_learning_rate": ps_context().set_client_learning_rate,
|
||||||
"secure_aggregation": ps_context().set_secure_aggregation,
|
"secure_aggregation": ps_context().set_secure_aggregation,
|
||||||
"enable_ps_ssl": ps_context().set_enable_ssl
|
"enable_ps_ssl": ps_context().set_enable_ssl
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,6 +69,7 @@ table ResponseFLJob {
|
||||||
}
|
}
|
||||||
|
|
||||||
table FLPlan {
|
table FLPlan {
|
||||||
|
server_mode:string;
|
||||||
fl_name:string;
|
fl_name:string;
|
||||||
iterations:int;
|
iterations:int;
|
||||||
epochs:int;
|
epochs:int;
|
||||||
|
|
|
@ -26,10 +26,14 @@ parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
||||||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||||
parser.add_argument("--fl_server_port", type=int, default=6666)
|
parser.add_argument("--fl_server_port", type=int, default=6666)
|
||||||
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
|
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
|
||||||
|
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
|
||||||
|
parser.add_argument("--update_model_ratio", type=float, default=1.0)
|
||||||
|
parser.add_argument("--update_model_time_window", type=int, default=3000)
|
||||||
parser.add_argument("--fl_name", type=str, default="Lenet")
|
parser.add_argument("--fl_name", type=str, default="Lenet")
|
||||||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||||
parser.add_argument("--client_batch_size", type=int, default=32)
|
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||||
|
parser.add_argument("--client_learning_rate", type=float, default=0.1)
|
||||||
parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False)
|
parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False)
|
||||||
parser.add_argument("--local_server_num", type=int, default=-1)
|
parser.add_argument("--local_server_num", type=int, default=-1)
|
||||||
|
|
||||||
|
@ -43,10 +47,14 @@ if __name__ == "__main__":
|
||||||
scheduler_port = args.scheduler_port
|
scheduler_port = args.scheduler_port
|
||||||
fl_server_port = args.fl_server_port
|
fl_server_port = args.fl_server_port
|
||||||
start_fl_job_threshold = args.start_fl_job_threshold
|
start_fl_job_threshold = args.start_fl_job_threshold
|
||||||
|
start_fl_job_time_window = args.start_fl_job_time_window
|
||||||
|
update_model_ratio = args.update_model_ratio
|
||||||
|
update_model_time_window = args.update_model_time_window
|
||||||
fl_name = args.fl_name
|
fl_name = args.fl_name
|
||||||
fl_iteration_num = args.fl_iteration_num
|
fl_iteration_num = args.fl_iteration_num
|
||||||
client_epoch_num = args.client_epoch_num
|
client_epoch_num = args.client_epoch_num
|
||||||
client_batch_size = args.client_batch_size
|
client_batch_size = args.client_batch_size
|
||||||
|
client_learning_rate = args.client_learning_rate
|
||||||
secure_aggregation = args.secure_aggregation
|
secure_aggregation = args.secure_aggregation
|
||||||
local_server_num = args.local_server_num
|
local_server_num = args.local_server_num
|
||||||
|
|
||||||
|
@ -70,10 +78,14 @@ if __name__ == "__main__":
|
||||||
cmd_server += " --scheduler_port=" + str(scheduler_port)
|
cmd_server += " --scheduler_port=" + str(scheduler_port)
|
||||||
cmd_server += " --fl_server_port=" + str(fl_server_port + i)
|
cmd_server += " --fl_server_port=" + str(fl_server_port + i)
|
||||||
cmd_server += " --start_fl_job_threshold=" + str(start_fl_job_threshold)
|
cmd_server += " --start_fl_job_threshold=" + str(start_fl_job_threshold)
|
||||||
|
cmd_server += " --start_fl_job_time_window=" + str(start_fl_job_time_window)
|
||||||
|
cmd_server += " --update_model_ratio=" + str(update_model_ratio)
|
||||||
|
cmd_server += " --update_model_time_window=" + str(update_model_time_window)
|
||||||
cmd_server += " --fl_name=" + fl_name
|
cmd_server += " --fl_name=" + fl_name
|
||||||
cmd_server += " --fl_iteration_num=" + str(fl_iteration_num)
|
cmd_server += " --fl_iteration_num=" + str(fl_iteration_num)
|
||||||
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
|
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
|
||||||
cmd_server += " --client_batch_size=" + str(client_batch_size)
|
cmd_server += " --client_batch_size=" + str(client_batch_size)
|
||||||
|
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
|
||||||
cmd_server += " --secure_aggregation=" + str(secure_aggregation)
|
cmd_server += " --secure_aggregation=" + str(secure_aggregation)
|
||||||
cmd_server += " > server.log 2>&1 &"
|
cmd_server += " > server.log 2>&1 &"
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
import datetime
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
import requests
|
import requests
|
||||||
|
@ -129,7 +130,15 @@ def build_get_model(iteration):
|
||||||
buf = builder_get_model.Output()
|
buf = builder_get_model.Output()
|
||||||
return buf
|
return buf
|
||||||
|
|
||||||
weight_name_to_idx = {
|
def datetime_to_timestamp(datetime_obj):
|
||||||
|
"""将本地(local) datetime 格式的时间 (含毫秒) 转为毫秒时间戳
|
||||||
|
:param datetime_obj: {datetime}2016-02-25 20:21:04.242000
|
||||||
|
:return: 13 位的毫秒时间戳 1456402864242
|
||||||
|
"""
|
||||||
|
local_timestamp = time.mktime(datetime_obj.timetuple()) * 1000.0 + datetime_obj.microsecond // 1000.0
|
||||||
|
return local_timestamp
|
||||||
|
|
||||||
|
weight_to_idx = {
|
||||||
"conv1.weight": 0,
|
"conv1.weight": 0,
|
||||||
"conv2.weight": 1,
|
"conv2.weight": 1,
|
||||||
"fc1.weight": 2,
|
"fc1.weight": 2,
|
||||||
|
@ -149,11 +158,12 @@ while True:
|
||||||
print("start url is ", url1)
|
print("start url is ", url1)
|
||||||
x = requests.post(url1, data=build_start_fl_job(current_iteration))
|
x = requests.post(url1, data=build_start_fl_job(current_iteration))
|
||||||
rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
|
rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
|
||||||
print("start fl job iteration:", current_iteration, ", id:", args.pid)
|
|
||||||
while rsp_fl_job.Retcode() != ResponseCode.ResponseCode.SUCCEED:
|
while rsp_fl_job.Retcode() != ResponseCode.ResponseCode.SUCCEED:
|
||||||
x = requests.post(url1, data=build_start_fl_job(current_iteration))
|
x = requests.post(url1, data=build_start_fl_job(current_iteration))
|
||||||
rsp_fl_job = rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
|
rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
|
||||||
print("epoch is", rsp_fl_job.FlPlanConfig().Epochs())
|
print("epoch is", rsp_fl_job.FlPlanConfig().Epochs())
|
||||||
|
print("iteration is", rsp_fl_job.Iteration())
|
||||||
|
current_iteration = rsp_fl_job.Iteration()
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
url2 = "http://" + http_ip + ":" + str(generate_port()) + '/updateModel'
|
url2 = "http://" + http_ip + ":" + str(generate_port()) + '/updateModel'
|
||||||
|
@ -170,22 +180,40 @@ while True:
|
||||||
print("rsp get model iteration:", current_iteration, ", id:", args.pid, rsp_get_model.Retcode())
|
print("rsp get model iteration:", current_iteration, ", id:", args.pid, rsp_get_model.Retcode())
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
repeat_time = 0
|
next_req_timestamp = 0
|
||||||
while rsp_get_model.Retcode() == ResponseCode.ResponseCode.SucNotReady:
|
if rsp_get_model.Retcode() == ResponseCode.ResponseCode.OutOfTime:
|
||||||
time.sleep(0.1)
|
next_req_timestamp = int(rsp_get_model.Timestamp().decode('utf-8'))
|
||||||
x = session.post(url3, data=build_get_model(current_iteration))
|
print("Last iteration is invalid, next request timestamp:", next_req_timestamp)
|
||||||
rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
|
|
||||||
repeat_time += 1
|
|
||||||
if repeat_time > 1000:
|
|
||||||
print("GetModel try timeout ", args.pid)
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
for i in range(0, 1):
|
|
||||||
print(rsp_get_model.FeatureMap(i).WeightFullname())
|
|
||||||
origin = update_model_np_data[weight_name_to_idx[rsp_get_model.FeatureMap(i).WeightFullname().decode('utf-8')]]
|
|
||||||
after = rsp_get_model.FeatureMap(i).DataAsNumpy() * 32
|
|
||||||
print("Before update model", args.pid, origin[0:10])
|
|
||||||
print("After get model", args.pid, after[0:10])
|
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
assert np.allclose(origin, after, rtol=1e-05, atol=1e-05)
|
elif rsp_get_model.Retcode() == ResponseCode.ResponseCode.SucNotReady:
|
||||||
current_iteration += 1
|
repeat_time = 0
|
||||||
|
while rsp_get_model.Retcode() == ResponseCode.ResponseCode.SucNotReady:
|
||||||
|
time.sleep(0.2)
|
||||||
|
x = session.post(url3, data=build_get_model(current_iteration))
|
||||||
|
rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
|
||||||
|
if rsp_get_model.Retcode() == ResponseCode.ResponseCode.OutOfTime:
|
||||||
|
next_req_timestamp = int(rsp_get_model.Timestamp().decode('utf-8'))
|
||||||
|
print("Last iteration is invalid, next request timestamp:", next_req_timestamp)
|
||||||
|
sys.stdout.flush()
|
||||||
|
break
|
||||||
|
repeat_time += 1
|
||||||
|
if repeat_time > 1000:
|
||||||
|
print("GetModel try timeout ", args.pid)
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if next_req_timestamp == 0:
|
||||||
|
for i in range(0, 1):
|
||||||
|
print(rsp_get_model.FeatureMap(i).WeightFullname())
|
||||||
|
origin = update_model_np_data[weight_to_idx[rsp_get_model.FeatureMap(i).WeightFullname().decode('utf-8')]]
|
||||||
|
after = rsp_get_model.FeatureMap(i).DataAsNumpy() * 32
|
||||||
|
print("Before update model", args.pid, origin[0:10])
|
||||||
|
print("After get model", args.pid, after[0:10])
|
||||||
|
sys.stdout.flush()
|
||||||
|
assert np.allclose(origin, after, rtol=1e-05, atol=1e-05)
|
||||||
|
else:
|
||||||
|
# Sleep to the next request timestamp
|
||||||
|
current_ts = datetime_to_timestamp(datetime.datetime.now())
|
||||||
|
duration = next_req_timestamp - current_ts
|
||||||
|
time.sleep(duration / 1000)
|
||||||
|
|
|
@ -34,10 +34,14 @@ parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
||||||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||||
parser.add_argument("--fl_server_port", type=int, default=6666)
|
parser.add_argument("--fl_server_port", type=int, default=6666)
|
||||||
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
|
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
|
||||||
|
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
|
||||||
|
parser.add_argument("--update_model_ratio", type=float, default=1.0)
|
||||||
|
parser.add_argument("--update_model_time_window", type=int, default=3000)
|
||||||
parser.add_argument("--fl_name", type=str, default="Lenet")
|
parser.add_argument("--fl_name", type=str, default="Lenet")
|
||||||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||||
parser.add_argument("--client_batch_size", type=int, default=32)
|
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||||
|
parser.add_argument("--client_learning_rate", type=float, default=0.1)
|
||||||
parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False)
|
parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False)
|
||||||
|
|
||||||
args, _ = parser.parse_known_args()
|
args, _ = parser.parse_known_args()
|
||||||
|
@ -50,14 +54,18 @@ scheduler_ip = args.scheduler_ip
|
||||||
scheduler_port = args.scheduler_port
|
scheduler_port = args.scheduler_port
|
||||||
fl_server_port = args.fl_server_port
|
fl_server_port = args.fl_server_port
|
||||||
start_fl_job_threshold = args.start_fl_job_threshold
|
start_fl_job_threshold = args.start_fl_job_threshold
|
||||||
|
start_fl_job_time_window = args.start_fl_job_time_window
|
||||||
|
update_model_ratio = args.update_model_ratio
|
||||||
|
update_model_time_window = args.update_model_time_window
|
||||||
fl_name = args.fl_name
|
fl_name = args.fl_name
|
||||||
fl_iteration_num = args.fl_iteration_num
|
fl_iteration_num = args.fl_iteration_num
|
||||||
client_epoch_num = args.client_epoch_num
|
client_epoch_num = args.client_epoch_num
|
||||||
client_batch_size = args.client_batch_size
|
client_batch_size = args.client_batch_size
|
||||||
|
client_learning_rate = args.client_learning_rate
|
||||||
secure_aggregation = args.secure_aggregation
|
secure_aggregation = args.secure_aggregation
|
||||||
|
|
||||||
ctx = {
|
ctx = {
|
||||||
"enable_ps": False,
|
"enable_fl": True,
|
||||||
"server_mode": server_mode,
|
"server_mode": server_mode,
|
||||||
"ms_role": ms_role,
|
"ms_role": ms_role,
|
||||||
"worker_num": worker_num,
|
"worker_num": worker_num,
|
||||||
|
@ -66,15 +74,19 @@ ctx = {
|
||||||
"scheduler_port": scheduler_port,
|
"scheduler_port": scheduler_port,
|
||||||
"fl_server_port": fl_server_port,
|
"fl_server_port": fl_server_port,
|
||||||
"start_fl_job_threshold": start_fl_job_threshold,
|
"start_fl_job_threshold": start_fl_job_threshold,
|
||||||
|
"start_fl_job_time_window": start_fl_job_time_window,
|
||||||
|
"update_model_ratio": update_model_ratio,
|
||||||
|
"update_model_time_window": update_model_time_window,
|
||||||
"fl_name": fl_name,
|
"fl_name": fl_name,
|
||||||
"fl_iteration_num": fl_iteration_num,
|
"fl_iteration_num": fl_iteration_num,
|
||||||
"client_epoch_num": client_epoch_num,
|
"client_epoch_num": client_epoch_num,
|
||||||
"client_batch_size": client_batch_size,
|
"client_batch_size": client_batch_size,
|
||||||
|
"client_learning_rate": client_learning_rate,
|
||||||
"secure_aggregation": secure_aggregation
|
"secure_aggregation": secure_aggregation
|
||||||
}
|
}
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)
|
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)
|
||||||
context.set_ps_context(**ctx)
|
context.set_fl_context(**ctx)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
epoch = 5
|
epoch = 5
|
||||||
|
|
Loading…
Reference in New Issue