!16230 Optimize server reliability in multiple scenarios.

From: @zpac
Reviewed-by: @cristoval,@limingqi107
Signed-off-by: @limingqi107
This commit is contained in:
mindspore-ci-bot 2021-05-15 16:50:40 +08:00 committed by Gitee
commit 94ca479fbe
33 changed files with 435 additions and 142 deletions

View File

@ -629,14 +629,17 @@ bool StartServerAction(const ResourcePtr &res) {
uint64_t fl_server_port = ps::PSContext::instance()->fl_server_port();
// Update model threshold is a certain ratio of start_fl_job threshold.
// update_model_threshold_ = start_fl_job_threshold_ * percent_for_update_model_.
// update_model_threshold = start_fl_job_threshold * update_model_ratio.
size_t start_fl_job_threshold = ps::PSContext::instance()->start_fl_job_threshold();
float percent_for_update_model = 1;
size_t update_model_threshold = static_cast<size_t>(std::ceil(start_fl_job_threshold * percent_for_update_model));
float update_model_ratio = ps::PSContext::instance()->update_model_ratio();
size_t update_model_threshold = static_cast<size_t>(std::ceil(start_fl_job_threshold * update_model_ratio));
uint64_t start_fl_job_time_window = ps::PSContext::instance()->start_fl_job_time_window();
uint64_t update_model_time_window = ps::PSContext::instance()->update_model_time_window();
std::vector<ps::server::RoundConfig> rounds_config = {{"startFLJob", false, 3000, true, start_fl_job_threshold},
{"updateModel", false, 3000, true, update_model_threshold},
{"getModel", false, 3000}};
std::vector<ps::server::RoundConfig> rounds_config = {
{"startFLJob", true, start_fl_job_time_window, true, start_fl_job_threshold},
{"updateModel", true, update_model_time_window, true, update_model_threshold},
{"getModel"}};
size_t executor_threshold = 0;
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {

View File

@ -345,11 +345,20 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_scheduler_port", &PSContext::set_scheduler_port, "Set scheduler port.")
.def("set_fl_server_port", &PSContext::set_fl_server_port, "Set federated learning server port.")
.def("set_fl_client_enable", &PSContext::set_fl_client_enable, "Set federated learning client.")
.def("set_start_fl_job_threshold", &PSContext::set_start_fl_job_threshold, "Set threshold count for start_fl_job.")
.def("set_start_fl_job_threshold", &PSContext::set_start_fl_job_threshold,
"Set threshold count for startFLJob round.")
.def("set_start_fl_job_time_window", &PSContext::set_start_fl_job_time_window,
"Set time window for startFLJob round.")
.def("set_update_model_ratio", &PSContext::set_update_model_ratio,
"Set threshold count ratio for updateModel round.")
.def("set_update_model_time_window", &PSContext::set_update_model_time_window,
"Set time window for updateModel round.")
.def("set_fl_name", &PSContext::set_fl_name, "Set federated learning name.")
.def("set_fl_iteration_num", &PSContext::set_fl_iteration_num, "Set federated learning iteration number.")
.def("set_client_epoch_num", &PSContext::set_client_epoch_num, "Set federated learning client epoch number.")
.def("set_client_batch_size", &PSContext::set_client_batch_size, "Set federated learning client batch size.")
.def("set_client_learning_rate", &PSContext::set_client_learning_rate,
"Set federated learning client learning rate.")
.def("set_secure_aggregation", &PSContext::set_secure_aggregation,
"Set federated learning client using secure aggregation.")
.def("set_enable_ssl", &PSContext::enable_ssl, "Set PS SSL mode enabled or disabled.");

View File

@ -196,7 +196,14 @@ void PSContext::set_ms_role(const std::string &role) {
role_ = role;
}
void PSContext::set_worker_num(uint32_t worker_num) { worker_num_ = worker_num; }
void PSContext::set_worker_num(uint32_t worker_num) {
// Hybrid training mode only supports one worker for now.
if (server_mode_ == kServerModeHybrid && worker_num != 1) {
MS_LOG(EXCEPTION) << "The worker number should be set to 1 in hybrid training mode.";
return;
}
worker_num_ = worker_num;
}
uint32_t PSContext::worker_num() const { return worker_num_; }
void PSContext::set_server_num(uint32_t server_num) {
@ -235,7 +242,7 @@ void PSContext::GenerateResetterRound() {
}
binary_server_context = (is_parameter_server_mode << 0) | (is_federated_learning_mode << 1) |
(is_mixed_training_mode << 2) | (secure_aggregation_ << 3) | (worker_upload_weights_ << 4);
(is_mixed_training_mode << 2) | (secure_aggregation_ << 3);
if (kServerContextToResetRoundMap.count(binary_server_context) == 0) {
resetter_round_ = ResetterRound::kNoNeedToReset;
} else {
@ -255,11 +262,27 @@ void PSContext::set_fl_client_enable(bool enabled) { fl_client_enable_ = enabled
bool PSContext::fl_client_enable() { return fl_client_enable_; }
void PSContext::set_start_fl_job_threshold(size_t start_fl_job_threshold) {
void PSContext::set_start_fl_job_threshold(uint64_t start_fl_job_threshold) {
start_fl_job_threshold_ = start_fl_job_threshold;
}
size_t PSContext::start_fl_job_threshold() const { return start_fl_job_threshold_; }
uint64_t PSContext::start_fl_job_threshold() const { return start_fl_job_threshold_; }
void PSContext::set_start_fl_job_time_window(uint64_t start_fl_job_time_window) {
start_fl_job_time_window_ = start_fl_job_time_window;
}
uint64_t PSContext::start_fl_job_time_window() const { return start_fl_job_time_window_; }
void PSContext::set_update_model_ratio(float update_model_ratio) { update_model_ratio_ = update_model_ratio; }
float PSContext::update_model_ratio() const { return update_model_ratio_; }
void PSContext::set_update_model_time_window(uint64_t update_model_time_window) {
update_model_time_window_ = update_model_time_window;
}
uint64_t PSContext::update_model_time_window() const { return update_model_time_window_; }
void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
@ -277,11 +300,9 @@ void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch
uint64_t PSContext::client_batch_size() const { return client_batch_size_; }
void PSContext::set_worker_upload_weights(uint64_t worker_upload_weights) {
worker_upload_weights_ = worker_upload_weights;
}
void PSContext::set_client_learning_rate(float client_learning_rate) { client_learning_rate_ = client_learning_rate; }
uint64_t PSContext::worker_upload_weights() const { return worker_upload_weights_; }
float PSContext::client_learning_rate() const { return client_learning_rate_; }
void PSContext::set_secure_aggregation(bool secure_aggregation) { secure_aggregation_ = secure_aggregation; }

View File

@ -41,15 +41,13 @@ constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
// 1: Server is in federated learning mode.
// 2: Server is in mixed training mode.
// 3: Server enables sucure aggregation.
// 4: Server needs worker to overwrite weights.
// For example: 01010 stands for that the server is in federated learning mode and sucure aggregation is enabled.
enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kWorkerOverwriteWeights };
const std::map<uint32_t, ResetterRound> kServerContextToResetRoundMap = {
{0b00010, ResetterRound::kUpdateModel},
{0b01010, ResetterRound::kReconstructSeccrets},
{0b11100, ResetterRound::kWorkerOverwriteWeights},
{0b10100, ResetterRound::kWorkerOverwriteWeights},
{0b00100, ResetterRound::kUpdateModel}};
// For example: 1010 stands for that the server is in federated learning mode and sucure aggregation is enabled.
enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kWorkerUploadWeights };
const std::map<uint32_t, ResetterRound> kServerContextToResetRoundMap = {{0b0010, ResetterRound::kUpdateModel},
{0b1010, ResetterRound::kReconstructSeccrets},
{0b1100, ResetterRound::kWorkerUploadWeights},
{0b0100, ResetterRound::kWorkerUploadWeights},
{0b0100, ResetterRound::kUpdateModel}};
class PSContext {
public:
@ -115,8 +113,17 @@ class PSContext {
void set_fl_client_enable(bool enabled);
bool fl_client_enable();
void set_start_fl_job_threshold(size_t start_fl_job_threshold);
size_t start_fl_job_threshold() const;
void set_start_fl_job_threshold(uint64_t start_fl_job_threshold);
uint64_t start_fl_job_threshold() const;
void set_start_fl_job_time_window(uint64_t start_fl_job_time_window);
uint64_t start_fl_job_time_window() const;
void set_update_model_ratio(float update_model_ratio);
float update_model_ratio() const;
void set_update_model_time_window(uint64_t update_model_time_window);
uint64_t update_model_time_window() const;
void set_fl_name(const std::string &fl_name);
const std::string &fl_name() const;
@ -133,9 +140,8 @@ class PSContext {
void set_client_batch_size(uint64_t client_batch_size);
uint64_t client_batch_size() const;
// Set true if worker will overwrite weights on server. Used in hybrid training.
void set_worker_upload_weights(uint64_t worker_upload_weights);
uint64_t worker_upload_weights() const;
void set_client_learning_rate(float client_learning_rate);
float client_learning_rate() const;
// Set true if using secure aggregation for federated learning.
void set_secure_aggregation(bool secure_aggregation);
@ -160,11 +166,14 @@ class PSContext {
fl_client_enable_(false),
fl_name_(""),
start_fl_job_threshold_(0),
fl_iteration_num_(0),
client_epoch_num_(0),
client_batch_size_(0),
secure_aggregation_(false),
worker_upload_weights_(false) {}
start_fl_job_time_window_(3000),
update_model_ratio_(1.0),
update_model_time_window_(3000),
fl_iteration_num_(20),
client_epoch_num_(25),
client_batch_size_(32),
client_learning_rate_(0.001),
secure_aggregation_(false) {}
bool ps_enabled_;
bool is_worker_;
bool is_pserver_;
@ -195,7 +204,16 @@ class PSContext {
std::string fl_name_;
// The threshold count of startFLJob round. Used in federated learning for now.
size_t start_fl_job_threshold_;
uint64_t start_fl_job_threshold_;
// The time window of startFLJob round in millisecond.
uint64_t start_fl_job_time_window_;
// Update model threshold is a certain ratio of start_fl_job threshold which is set as update_model_ratio_.
float update_model_ratio_;
// The time window of updateModel round in millisecond.
uint64_t update_model_time_window_;
// Iteration number of federeated learning, which is the number of interactions between client and server.
uint64_t fl_iteration_num_;
@ -206,12 +224,11 @@ class PSContext {
// Client training data batch size. Used in federated learning for now.
uint64_t client_batch_size_;
// Client training learning rate. Used in federated learning for now.
float client_learning_rate_;
// Whether to use secure aggregation algorithm. Used in federated learning for now.
bool secure_aggregation_;
// Whether there's a federated learning worker uploading weights to federated learning server. Used in hybrid training
// mode for now.
bool worker_upload_weights_;
};
} // namespace ps
} // namespace mindspore

View File

@ -56,9 +56,9 @@ using mindspore::kernel::Address;
using mindspore::kernel::AddressPtr;
using mindspore::kernel::CPUKernel;
using FBBuilder = flatbuffers::FlatBufferBuilder;
using TimeOutCb = std::function<void(void)>;
using TimeOutCb = std::function<void(bool)>;
using StopTimerCb = std::function<void(void)>;
using FinishIterCb = std::function<void(void)>;
using FinishIterCb = std::function<void(bool)>;
using FinalizeCb = std::function<void(void)>;
using MessageCallback = std::function<void(const std::shared_ptr<core::MessageHandler> &)>;
@ -148,6 +148,7 @@ constexpr size_t kExecutorMaxTaskNum = 32;
constexpr int kHttpSuccess = 200;
constexpr auto kPBProtocol = "PB";
constexpr auto kFBSProtocol = "FBS";
constexpr auto kSuccess = "Success";
constexpr auto kFedAvg = "FedAvg";
constexpr auto kAggregationKernelType = "Aggregation";
constexpr auto kOptimizerKernelType = "Optimizer";
@ -155,6 +156,7 @@ constexpr auto kCtxFuncGraph = "FuncGraph";
constexpr auto kCtxIterNum = "iteration";
constexpr auto kCtxDeviceMetas = "device_metas";
constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration";
constexpr auto kCtxIterationNextRequestTimestamp = "iteration_next_request_timestamp";
constexpr auto kCtxUpdateModelClientList = "update_model_client_list";
constexpr auto kCtxUpdateModelClientNum = "update_model_client_num";
constexpr auto kCtxUpdateModelThld = "update_model_threshold";

View File

@ -130,7 +130,7 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) {
void DistributedCountService::ResetCounter(const std::string &name) {
if (local_rank_ == counting_server_rank_) {
MS_LOG(INFO) << "Leader server reset count for " << name;
MS_LOG(DEBUG) << "Leader server reset count for " << name;
global_current_count_[name].clear();
}
return;
@ -233,7 +233,7 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<core::Mes
const auto &type = counter_event.type();
const auto &name = counter_event.name();
MS_LOG(INFO) << "Rank " << local_rank_ << " do counter event " << type << " for " << name;
MS_LOG(DEBUG) << "Rank " << local_rank_ << " do counter event " << type << " for " << name;
if (type == CounterEventType::FIRST_CNT) {
counter_handlers_[name].first_count_handler(message);
} else if (type == CounterEventType::LAST_CNT) {
@ -259,7 +259,7 @@ void DistributedCountService::TriggerCounterEvent(const std::string &name) {
}
void DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
MS_LOG(INFO) << "Activating first count event for " << name;
MS_LOG(DEBUG) << "Activating first count event for " << name;
CounterEvent first_count_event;
first_count_event.set_type(CounterEventType::FIRST_CNT);
first_count_event.set_name(name);

View File

@ -79,10 +79,10 @@ void DistributedMetadataStore::ResetMetadata(const std::string &name) {
return;
}
void DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta) {
bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta) {
if (router_ == nullptr) {
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
return;
return false;
}
uint32_t stored_rank = router_->Find(name);
@ -90,18 +90,26 @@ void DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM
if (local_rank_ == stored_rank) {
if (!DoUpdateMetadata(name, meta)) {
MS_LOG(ERROR) << "Updating meta data failed.";
return;
return false;
}
} else {
PBMetadataWithName metadata_with_name;
metadata_with_name.set_name(name);
*metadata_with_name.mutable_metadata() = meta;
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, core::TcpUserCommand::kUpdateMetadata)) {
std::shared_ptr<std::vector<unsigned char>> update_meta_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, core::TcpUserCommand::kUpdateMetadata,
&update_meta_rsp_msg)) {
MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed.";
return;
return false;
}
std::string update_meta_rsp = reinterpret_cast<const char *>(update_meta_rsp_msg->data());
if (update_meta_rsp != kSuccess) {
MS_LOG(ERROR) << "Updating metadata in server " << stored_rank << " failed.";
return false;
}
}
return;
return true;
}
PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
@ -166,6 +174,7 @@ void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr
std::string update_meta_rsp_msg;
if (!DoUpdateMetadata(name, meta_with_name.metadata())) {
update_meta_rsp_msg = "Updating meta data failed.";
MS_LOG(ERROR) << update_meta_rsp_msg;
} else {
update_meta_rsp_msg = "Success";
}

View File

@ -52,7 +52,7 @@ class DistributedMetadataStore {
void ResetMetadata(const std::string &name);
// Update the metadata for the name.
void UpdateMetadata(const std::string &name, const PBMetadata &meta);
bool UpdateMetadata(const std::string &name, const PBMetadata &meta);
// Get the metadata for the name.
PBMetadata GetMetadata(const std::string &name);

View File

@ -23,8 +23,6 @@
namespace mindspore {
namespace ps {
namespace server {
Iteration::Iteration() : iteration_num_(1) { LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); }
void Iteration::AddRound(const std::shared_ptr<Round> &round) {
MS_EXCEPTION_IF_NULL(round);
rounds_.push_back(round);
@ -49,28 +47,48 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<core::CommunicatorB
// The time window for one iteration, which will be used in some round kernels.
size_t iteration_time_window =
std::accumulate(rounds_.begin(), rounds_.end(), 0,
[](size_t total, const std::shared_ptr<Round> &round) { return total + round->time_window(); });
std::accumulate(rounds_.begin(), rounds_.end(), 0, [](size_t total, const std::shared_ptr<Round> &round) {
return round->check_timeout() ? total + round->time_window() : total;
});
LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window);
MS_LOG(INFO) << "Time window for one iteration is " << iteration_time_window;
return;
}
void Iteration::ProceedToNextIter() {
void Iteration::ProceedToNextIter(bool is_iteration_valid) {
iteration_num_ = LocalMetaStore::GetInstance().curr_iter_num();
// Store the model for each iteration.
const auto &model = Executor::GetInstance().GetModel();
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
if (is_iteration_valid) {
// Store the model which is successfully aggregated for this iteration.
const auto &model = Executor::GetInstance().GetModel();
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
} else {
// Store last iteration's model because this iteration is considered as invalid.
const auto &model = ModelStore::GetInstance().GetModelByIterNum(iteration_num_ - 1);
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid.";
}
for (auto &round : rounds_) {
round->Reset();
}
iteration_num_++;
// After the job is done, reset the iteration to the initial number and reset ModelStore.
if (iteration_num_ > PSContext::instance()->fl_iteration_num()) {
MS_LOG(INFO) << PSContext::instance()->fl_iteration_num() << " iterations are completed.";
iteration_num_ = 1;
ModelStore::GetInstance().Reset();
}
is_last_iteration_valid_ = is_iteration_valid;
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
MS_LOG(INFO) << "Proceed to next iteration:" << iteration_num_ << "\n";
}
const std::vector<std::shared_ptr<Round>> &Iteration::rounds() { return rounds_; }
bool Iteration::is_last_iteration_valid() const { return is_last_iteration_valid_; }
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -31,8 +31,10 @@ namespace server {
// Rounds, only after all the rounds are finished, this iteration is considered as completed.
class Iteration {
public:
Iteration();
~Iteration() = default;
static Iteration &GetInstance() {
static Iteration instance;
return instance;
}
// Add a round for the iteration. This method will be called multiple times for each round.
void AddRound(const std::shared_ptr<Round> &round);
@ -41,16 +43,29 @@ class Iteration {
void InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators,
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb);
// The server proceeds to the next iteration only after the last iteration finishes.
void ProceedToNextIter();
// The server proceeds to the next iteration only after the last round finishes or the timer expires.
// If the timer expires, we consider this iteration as invalid.
void ProceedToNextIter(bool is_iteration_valid);
const std::vector<std::shared_ptr<Round>> &rounds();
bool is_last_iteration_valid() const;
private:
Iteration() : iteration_num_(1), is_last_iteration_valid_(true) {
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
}
~Iteration() = default;
Iteration(const Iteration &) = delete;
Iteration &operator=(const Iteration &) = delete;
std::vector<std::shared_ptr<Round>> rounds_;
// Server's current iteration number.
size_t iteration_num_;
// Last iteration is successfully finished.
bool is_last_iteration_valid_;
};
} // namespace server
} // namespace ps

View File

@ -29,7 +29,7 @@ void IterationTimer::Start(const std::chrono::milliseconds &duration) {
monitor_thread_ = std::thread([&]() {
while (running_.load()) {
if (CURRENT_TIME_MILLI > end_time_) {
timeout_callback_();
timeout_callback_(false);
running_ = false;
}
// The time tick is 1 millisecond.

View File

@ -47,6 +47,7 @@ class ApplyMomentumKernel : public ApplyMomentumCPUKernel, public OptimizerKerne
}
void GenerateReuseKernelNodeInfo() override {
MS_LOG(INFO) << "FedAvg reuse 'weight', 'accumulation', 'learning rate' and 'momentum' of the kernel node.";
reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, 0));
reuse_kernel_node_inputs_info_.insert(std::make_pair(kAccumulation, 1));
reuse_kernel_node_inputs_info_.insert(std::make_pair(kLearningRate, 2));

View File

@ -92,7 +92,6 @@ class FedAvgKernel : public AggregationKernel {
weight_addr[i] /= data_size_addr[0];
}
done_ = true;
DistributedCountService::GetInstance().ResetCounter(name_);
return;
};
DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler, last_cnt_handler});
@ -125,6 +124,7 @@ class FedAvgKernel : public AggregationKernel {
participated_ = true;
DistributedCountService::GetInstance().Count(
name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_));
GenerateReuseKernelNodeInfo();
return true;
}
@ -149,6 +149,7 @@ class FedAvgKernel : public AggregationKernel {
private:
void GenerateReuseKernelNodeInfo() override {
MS_LOG(INFO) << "FedAvg reuse 'weight' of the kernel node.";
// Only the trainable parameter is reused for federated average.
reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, cnode_weight_idx_));
return;

View File

@ -19,6 +19,7 @@
#include <memory>
#include <string>
#include <vector>
#include "ps/server/iteration.h"
#include "ps/server/model_store.h"
namespace mindspore {
@ -67,27 +68,31 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons
const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
size_t latest_iter_num = iter_to_model.rbegin()->first;
// If this iteration is not finished yet, return ResponseCode_SucNotReady so that clients could get model later.
if ((current_iter == get_model_iter && latest_iter_num != current_iter) || current_iter == get_model_iter - 1) {
std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter);
BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(WARNING) << reason;
return;
}
if (iter_to_model.count(get_model_iter) == 0) {
std::string reason = "The iteration of GetModel request" + std::to_string(get_model_iter) +
" is invalid. Current iteration is " + std::to_string(current_iter);
BuildGetModelRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter, feature_maps,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
MS_LOG(ERROR) << reason;
return;
// If the model of get_model_iter is not stored, return the latest version of model and current iteration number.
MS_LOG(WARNING) << "The iteration of GetModel request " << std::to_string(get_model_iter)
<< " is invalid. Current iteration is " << std::to_string(current_iter);
feature_maps = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num);
} else {
feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter);
}
feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter);
BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED,
"Get model for iteration " + std::to_string(get_model_iter) + " success.", current_iter,
feature_maps, std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
// If the iteration of this model is invalid, return ResponseCode_OutOfTime to the clients could startFLJob according
// to next_req_time.
auto response_code =
Iteration::GetInstance().is_last_iteration_valid() ? schema::ResponseCode_SUCCEED : schema::ResponseCode_OutOfTime;
BuildGetModelRsp(fbb, response_code, "Get model for iteration " + std::to_string(get_model_iter), current_iter,
feature_maps,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
return;
}

View File

@ -68,7 +68,7 @@ void RoundKernel::StopTimer() {
void RoundKernel::FinishIteration() {
if (finish_iteration_cb_) {
finish_iteration_cb_();
finish_iteration_cb_(true);
}
return;
}

View File

@ -61,15 +61,12 @@ class RoundKernel : virtual public CPUKernel {
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) = 0;
// The callbacks when first message and last message for this round kernel is received.
// These methods is called by class DistributedCountService and triggered by leader server(Rank 0).
// virtual void OnFirstCountEvent(std::shared_ptr<core::MessageHandler> message);
// virtual void OnLastCnt(std::shared_ptr<core::MessageHandler> message);
// Some rounds could be stateful in a iteration. Reset method resets the status of this round.
virtual bool Reset() = 0;
// The counter event handlers for DistributedCountService.
// The callbacks when first message and last message for this round kernel is received.
// These methods is called by class DistributedCountService and triggered by counting server.
virtual void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message);
virtual void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message);

View File

@ -25,9 +25,12 @@ namespace ps {
namespace server {
namespace kernel {
void StartFLJobKernel::InitKernel(size_t) {
// The time window of one iteration should be started at the first message of startFLJob round.
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
}
iter_next_req_timestamp_ = CURRENT_TIME_MILLI.count() + iteration_time_window_;
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
executor_ = &Executor::GetInstance();
MS_EXCEPTION_IF_NULL(executor_);
@ -85,11 +88,17 @@ bool StartFLJobKernel::Reset() {
return true;
}
void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &) {
iter_next_req_timestamp_ = CURRENT_TIME_MILLI.count() + iteration_time_window_;
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
}
bool StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb) {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
std::string reason = "Current amount for startFLJob has reached the threshold. Please startFLJob later.";
BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
BuildStartFLJobRsp(
fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return true;
}
@ -117,8 +126,9 @@ bool StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
ret = false;
}
if (!ret) {
BuildStartFLJobRsp(fbb, schema::ResponseCode_NotSelected, reason, false,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
BuildStartFLJobRsp(
fbb, schema::ResponseCode_NotSelected, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
}
return ret;
@ -128,8 +138,9 @@ bool StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestFLJob *start_fl_job_req) {
if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str())) {
std::string reason = "startFLJob counting failed.";
BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
BuildStartFLJobRsp(
fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return false;
}
@ -139,11 +150,18 @@ bool StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
PBMetadata metadata;
*metadata.mutable_device_meta() = device_meta;
DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata);
if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata)) {
std::string reason = "Updating device metadata failed.";
BuildStartFLJobRsp(fbb, schema::ResponseCode_SystemError, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)),
{});
return;
}
std::map<std::string, AddressPtr> feature_maps = executor_->GetModel();
BuildStartFLJobRsp(fbb, schema::ResponseCode_SUCCEED, "success", true,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_), feature_maps);
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)),
feature_maps);
return;
}
@ -153,13 +171,16 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
std::map<std::string, AddressPtr> feature_maps) {
auto fbs_reason = fbb->CreateString(reason);
auto fbs_next_req_time = fbb->CreateString(next_req_time);
auto fbs_server_mode = fbb->CreateString(PSContext::instance()->server_mode());
auto fbs_fl_name = fbb->CreateString(PSContext::instance()->fl_name());
schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
fl_plan_builder.add_fl_name(fbs_fl_name);
fl_plan_builder.add_server_mode(fbs_server_mode);
fl_plan_builder.add_iterations(PSContext::instance()->fl_iteration_num());
fl_plan_builder.add_epochs(PSContext::instance()->client_epoch_num());
fl_plan_builder.add_mini_batch(PSContext::instance()->client_batch_size());
fl_plan_builder.add_lr(PSContext::instance()->client_learning_rate());
auto fbs_fl_plan = fl_plan_builder.Finish();
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;

View File

@ -32,7 +32,7 @@ namespace server {
namespace kernel {
class StartFLJobKernel : public RoundKernel {
public:
StartFLJobKernel() = default;
StartFLJobKernel() : executor_(nullptr), iteration_time_window_(0), iter_next_req_timestamp_(0) {}
~StartFLJobKernel() override = default;
void InitKernel(size_t threshold_count) override;
@ -40,6 +40,8 @@ class StartFLJobKernel : public RoundKernel {
const std::vector<AddressPtr> &outputs) override;
bool Reset() override;
void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
private:
// Returns whether the startFLJob count of this iteration has reached the threshold.
bool ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb);
@ -66,6 +68,9 @@ class StartFLJobKernel : public RoundKernel {
// The time window of one iteration.
size_t iteration_time_window_;
// Timestamp of next request time for this iteration.
uint64_t iter_next_req_timestamp_;
};
} // namespace kernel
} // namespace server

View File

@ -39,6 +39,7 @@ void UpdateModelKernel::InitKernel(size_t threshold_count) {
PBMetadata client_list;
DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxUpdateModelClientList, client_list);
LocalMetaStore::GetInstance().put_value(kCtxUpdateModelThld, threshold_count);
LocalMetaStore::GetInstance().put_value(kCtxFedAvgTotalDataSize, kInitialDataSizeSum);
}
bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
@ -103,8 +104,9 @@ void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHand
bool UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb) {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
std::string reason = "Current amount for updateModel is enough.";
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return false;
}
@ -117,8 +119,9 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) +
", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num());
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return false;
}
@ -128,14 +131,24 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
std::string update_model_fl_id = update_model_req->fl_id()->str();
if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
std::string reason = "devices_meta for " + update_model_fl_id + " is not set.";
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return false;
}
size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size();
auto feature_map = ParseFeatureMap(update_model_req);
if (feature_map.empty()) {
std::string reason = "Feature map is empty.";
BuildUpdateModelRsp(
fbb, schema::ResponseCode_RequestError, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return false;
}
for (auto weight : feature_map) {
weight.second[kNewDataSize].addr = &data_size;
weight.second[kNewDataSize].size = sizeof(size_t);
@ -146,10 +159,17 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
fl_id.set_fl_id(update_model_fl_id);
PBMetadata comm_value;
*comm_value.mutable_fl_id() = fl_id;
DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value);
if (!DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value)) {
std::string reason = "Updating metadata of UpdateModelClientList failed.";
BuildUpdateModelRsp(
fbb, schema::ResponseCode_SystemError, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return false;
}
BuildUpdateModelRsp(fbb, schema::ResponseCode_SucNotReady, "success not ready",
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
BuildUpdateModelRsp(fbb, schema::ResponseCode_SUCCEED, "success not ready",
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
return true;
}
@ -174,8 +194,9 @@ bool UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fb
const schema::RequestUpdateModel *update_model_req) {
if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str())) {
std::string reason = "UpdateModel counting failed.";
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return false;
}

View File

@ -30,6 +30,9 @@ namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
// The initial data size sum of federated learning is 0, which will be accumulated in updateModel round.
constexpr uint64_t kInitialDataSizeSum = 0;
class UpdateModelKernel : public RoundKernel {
public:
UpdateModelKernel() = default;

View File

@ -30,7 +30,8 @@ void ModelStore::Initialize(uint32_t max_count) {
}
max_model_count_ = max_count;
iteration_to_model_[kInitIterationNum] = AssignNewModelMemory();
initial_model_ = AssignNewModelMemory();
iteration_to_model_[kInitIterationNum] = initial_model_;
model_size_ = ComputeModelSize();
}
@ -52,7 +53,6 @@ bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::strin
MS_LOG(ERROR) << "Memory for the new model is nullptr.";
return false;
}
iteration_to_model_[iteration] = memory_register;
} else {
// If iteration_to_model_ size is already max_model_count_, we need to replace earliest model with the newest model.
@ -97,6 +97,12 @@ std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration
return model;
}
void ModelStore::Reset() {
initial_model_ = iteration_to_model_.rbegin()->second;
iteration_to_model_.clear();
iteration_to_model_[kInitIterationNum] = initial_model_;
}
const std::map<size_t, std::shared_ptr<MemoryRegister>> &ModelStore::iteration_to_model() const {
return iteration_to_model_;
}
@ -121,6 +127,14 @@ std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() {
return nullptr;
}
auto src_data_size = weight_size;
auto dst_data_size = weight_size;
int ret = memcpy_s(weight_data.get(), dst_data_size, weight.second->addr, src_data_size);
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return nullptr;
}
memory_register->RegisterArray(weight_name, &weight_data, weight_size);
}
return memory_register;

View File

@ -49,6 +49,9 @@ class ModelStore {
// Get model of the given iteration.
std::map<std::string, AddressPtr> GetModelByIterNum(size_t iteration);
// Reset the stored models. Called when federated learning job finishes.
void Reset();
// Returns all models stored in ModelStore.
const std::map<size_t, std::shared_ptr<MemoryRegister>> &iteration_to_model() const;
@ -70,6 +73,11 @@ class ModelStore {
size_t max_model_count_;
size_t model_size_;
// Initial model which is the model of iteration 0.
std::shared_ptr<MemoryRegister> initial_model_;
// The number of all models stpred is max_model_count_.
std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_;
};
} // namespace server

View File

@ -38,9 +38,9 @@ void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicat
name_, [&](std::shared_ptr<core::MessageHandler> message) { LaunchRoundKernel(message); });
// Callback when the iteration is finished.
finish_iteration_cb_ = [this, finish_iteration_cb](void) -> void {
MS_LOG(INFO) << "Round " << name_ << " finished! Proceed to next iteration.";
finish_iteration_cb();
finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid) -> void {
MS_LOG(INFO) << "Round " << name_ << " finished! This iteration is valid. Proceed to next iteration.";
finish_iteration_cb(is_iteration_valid);
};
// Callback for finalizing the server. This can only be called once.
@ -50,9 +50,9 @@ void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicat
iter_timer_ = std::make_shared<IterationTimer>();
// 1.Set the timeout callback for the timer.
iter_timer_->SetTimeOutCallBack([this, timeout_cb](void) -> void {
MS_LOG(INFO) << "Round " << name_ << " timeout! Proceed to next iteration.";
timeout_cb();
iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool is_iteration_valid) -> void {
MS_LOG(INFO) << "Round " << name_ << " timeout! This iteration is invalid. Proceed to next iteration.";
timeout_cb(is_iteration_valid);
});
// 2.Stopping timer callback which will be set to the round kernel.
@ -112,14 +112,19 @@ const std::string &Round::name() const { return name_; }
size_t Round::threshold_count() const { return threshold_count_; }
bool Round::check_timeout() const { return check_timeout_; }
size_t Round::time_window() const { return time_window_; }
void Round::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &) {
void Round::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
MS_LOG(INFO) << "Round " << name_ << " first count event is triggered.";
// The timer starts only after the first count event is triggered by DistributedCountService.
if (check_timeout_) {
iter_timer_->Start(std::chrono::milliseconds(time_window_));
}
// Some kernels override the OnFirstCountEvent method.
kernel_->OnFirstCountEvent(message);
return;
}

View File

@ -52,6 +52,7 @@ class Round {
const std::string &name() const;
size_t threshold_count() const;
bool check_timeout() const;
size_t time_window() const;
private:

View File

@ -174,21 +174,22 @@ bool Server::InitCommunicatorWithWorker() {
}
void Server::InitIteration() {
iteration_ = std::make_shared<Iteration>();
iteration_ = &Iteration::GetInstance();
MS_EXCEPTION_IF_NULL(iteration_);
// 1.Add rounds to the iteration according to the server mode.
for (const RoundConfig &config : rounds_config_) {
std::shared_ptr<Round> round = std::make_shared<Round>(config.name, config.check_timeout, config.time_window,
config.check_count, config.threshold_count);
MS_LOG(INFO) << "Add round " << config.name << ", check_count: " << config.check_count
<< ", threshold:" << config.threshold_count;
MS_LOG(INFO) << "Add round " << config.name << ", check_timeout: " << config.check_timeout
<< ", time window: " << config.time_window << ", check_count: " << config.check_count
<< ", threshold: " << config.threshold_count;
iteration_->AddRound(round);
}
// 2.Initialize all the rounds.
TimeOutCb time_out_cb = std::bind(&Iteration::ProceedToNextIter, iteration_);
FinishIterCb finish_iter_cb = std::bind(&Iteration::ProceedToNextIter, iteration_);
TimeOutCb time_out_cb = std::bind(&Iteration::ProceedToNextIter, iteration_, std::placeholders::_1);
FinishIterCb finish_iter_cb = std::bind(&Iteration::ProceedToNextIter, iteration_, std::placeholders::_1);
iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb);
return;
}

View File

@ -117,7 +117,7 @@ class Server {
std::vector<std::shared_ptr<core::CommunicatorBase>> communicators_with_worker_;
// Iteration consists of multiple kinds of rounds.
std::shared_ptr<Iteration> iteration_;
Iteration *iteration_;
// Variables set by ps context.
std::string scheduler_ip_;

View File

@ -787,3 +787,59 @@ def reset_ps_context():
- enable_ps: False.
"""
_reset_ps_context()
def set_fl_context(**kwargs):
"""
Set federated learning training mode context.
Args:
enable_fl (bool): Whether to enable federated learning training mode.
Default: False.
server_mode (string): Describe the server mode, which must one of 'FEDERATED_LEARNING' and 'HYBRID_TRAINING'.
Default: 'FEDERATED_LEARNING'.
ms_role (string): The process's role in the federated learning mode,
which must be one of 'MS_SERVER', 'MS_WORKER' and 'MS_SCHED'.
Default: 'MS_NOT_PS'.
worker_num (int): The number of workers. Default: 0.
server_num (int): The number of federated learning servers. Default: 0.
scheduler_ip (string): The scheduler IP. Default: ''.
scheduler_port (int): The scheduler port. Default: 0.
fl_server_port (int): The http port of the federated learning server.
Normally for each server this should be set to the same value. Default: 0.
enable_fl_client (bool): Whether this process is federated learning client. Default: False.
start_fl_job_threshold (int): The threshold count of startFLJob. Default: 0.
start_fl_job_time_window (int): The time window duration for startFLJob in millisecond. Default: 3000.
update_model_ratio (float): The ratio for computing the threshold count of updateModel
which will be multiplied by start_fl_job_threshold. Default: 1.0.
update_model_time_window (int): The time window duration for updateModel in millisecond. Default: 3000.
fl_name (string): The federated learning job name. Default: ''.
fl_iteration_num (int): Iteration number of federeated learning,
which is the number of interactions between client and server. Default: 20.
client_epoch_num (int): Client training epoch number. Default: 25.
client_batch_size (int): Client training data batch size. Default: 32.
client_learning_rate (float): Client training learning rate. Default: 0.001.
secure_aggregation (bool): Whether to use secure aggregation algorithm. Default: False.
Raises:
ValueError: If input key is not the attribute in federated learning mode context.
Examples:
>>> context.set_fl_context(enable_fl=True, server_mode='FEDERATED_LEARNING')
"""
_set_ps_context(**kwargs)
def get_fl_context(attr_key):
"""
Get federated learning mode context attribute value according to the key.
Args:
attr_key (str): The key of the attribute.
Returns:
Returns attribute value according to the key.
Raises:
ValueError: If input key is not attribute in federated learning mode context.
"""
return _get_ps_context(attr_key)

View File

@ -19,6 +19,8 @@
#include <unistd.h>
#include <sys/time.h>
#include <map>
#include <iomanip>
#include <thread>
// namespace to support utils module definition
namespace mindspore {
@ -117,8 +119,8 @@ void LogWriter::OutputLog(const std::ostringstream &msg) const {
#define google mindspore_private
auto submodule_name = GetSubModuleName(submodule_);
google::LogMessage("", 0, GetGlogLevel(log_level_)).stream()
<< "[" << GetLogLevel(log_level_) << "] " << submodule_name << "(" << getpid() << "," << GetProcName()
<< "):" << GetTimeString() << " "
<< "[" << GetLogLevel(log_level_) << "] " << submodule_name << "(" << getpid() << "," << std::hex
<< std::this_thread::get_id() << std::dec << "," << GetProcName() << "):" << GetTimeString() << " "
<< "[" << location_.file_ << ":" << location_.line_ << "] " << location_.func_ << "] " << msg.str() << std::endl;
#undef google
#else

View File

@ -36,6 +36,7 @@ _set_ps_context_func_map = {
"server_mode": ps_context().set_server_mode,
"ms_role": ps_context().set_ms_role,
"enable_ps": ps_context().set_ps_enable,
"enable_fl": ps_context().set_ps_enable,
"worker_num": ps_context().set_worker_num,
"server_num": ps_context().set_server_num,
"scheduler_ip": ps_context().set_scheduler_ip,
@ -43,10 +44,14 @@ _set_ps_context_func_map = {
"fl_server_port": ps_context().set_fl_server_port,
"enable_fl_client": ps_context().set_fl_client_enable,
"start_fl_job_threshold": ps_context().set_start_fl_job_threshold,
"start_fl_job_time_window": ps_context().set_start_fl_job_time_window,
"update_model_ratio": ps_context().set_update_model_ratio,
"update_model_time_window": ps_context().set_update_model_time_window,
"fl_name": ps_context().set_fl_name,
"fl_iteration_num": ps_context().set_fl_iteration_num,
"client_epoch_num": ps_context().set_client_epoch_num,
"client_batch_size": ps_context().set_client_batch_size,
"client_learning_rate": ps_context().set_client_learning_rate,
"secure_aggregation": ps_context().set_secure_aggregation,
"enable_ps_ssl": ps_context().set_enable_ssl
}

View File

@ -69,6 +69,7 @@ table ResponseFLJob {
}
table FLPlan {
server_mode:string;
fl_name:string;
iterations:int;
epochs:int;

View File

@ -26,10 +26,14 @@ parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--fl_server_port", type=int, default=6666)
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
parser.add_argument("--update_model_ratio", type=float, default=1.0)
parser.add_argument("--update_model_time_window", type=int, default=3000)
parser.add_argument("--fl_name", type=str, default="Lenet")
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False)
parser.add_argument("--local_server_num", type=int, default=-1)
@ -43,10 +47,14 @@ if __name__ == "__main__":
scheduler_port = args.scheduler_port
fl_server_port = args.fl_server_port
start_fl_job_threshold = args.start_fl_job_threshold
start_fl_job_time_window = args.start_fl_job_time_window
update_model_ratio = args.update_model_ratio
update_model_time_window = args.update_model_time_window
fl_name = args.fl_name
fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num
client_batch_size = args.client_batch_size
client_learning_rate = args.client_learning_rate
secure_aggregation = args.secure_aggregation
local_server_num = args.local_server_num
@ -70,10 +78,14 @@ if __name__ == "__main__":
cmd_server += " --scheduler_port=" + str(scheduler_port)
cmd_server += " --fl_server_port=" + str(fl_server_port + i)
cmd_server += " --start_fl_job_threshold=" + str(start_fl_job_threshold)
cmd_server += " --start_fl_job_time_window=" + str(start_fl_job_time_window)
cmd_server += " --update_model_ratio=" + str(update_model_ratio)
cmd_server += " --update_model_time_window=" + str(update_model_time_window)
cmd_server += " --fl_name=" + fl_name
cmd_server += " --fl_iteration_num=" + str(fl_iteration_num)
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
cmd_server += " --client_batch_size=" + str(client_batch_size)
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
cmd_server += " --secure_aggregation=" + str(secure_aggregation)
cmd_server += " > server.log 2>&1 &"

View File

@ -15,6 +15,7 @@
import argparse
import time
import datetime
import random
import sys
import requests
@ -129,7 +130,15 @@ def build_get_model(iteration):
buf = builder_get_model.Output()
return buf
weight_name_to_idx = {
def datetime_to_timestamp(datetime_obj):
"""将本地(local) datetime 格式的时间 (含毫秒) 转为毫秒时间戳
:param datetime_obj: {datetime}2016-02-25 20:21:04.242000
:return: 13 位的毫秒时间戳 1456402864242
"""
local_timestamp = time.mktime(datetime_obj.timetuple()) * 1000.0 + datetime_obj.microsecond // 1000.0
return local_timestamp
weight_to_idx = {
"conv1.weight": 0,
"conv2.weight": 1,
"fc1.weight": 2,
@ -149,11 +158,12 @@ while True:
print("start url is ", url1)
x = requests.post(url1, data=build_start_fl_job(current_iteration))
rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
print("start fl job iteration:", current_iteration, ", id:", args.pid)
while rsp_fl_job.Retcode() != ResponseCode.ResponseCode.SUCCEED:
x = requests.post(url1, data=build_start_fl_job(current_iteration))
rsp_fl_job = rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
print("epoch is", rsp_fl_job.FlPlanConfig().Epochs())
print("iteration is", rsp_fl_job.Iteration())
current_iteration = rsp_fl_job.Iteration()
sys.stdout.flush()
url2 = "http://" + http_ip + ":" + str(generate_port()) + '/updateModel'
@ -170,22 +180,40 @@ while True:
print("rsp get model iteration:", current_iteration, ", id:", args.pid, rsp_get_model.Retcode())
sys.stdout.flush()
repeat_time = 0
while rsp_get_model.Retcode() == ResponseCode.ResponseCode.SucNotReady:
time.sleep(0.1)
x = session.post(url3, data=build_get_model(current_iteration))
rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
repeat_time += 1
if repeat_time > 1000:
print("GetModel try timeout ", args.pid)
sys.exit(0)
for i in range(0, 1):
print(rsp_get_model.FeatureMap(i).WeightFullname())
origin = update_model_np_data[weight_name_to_idx[rsp_get_model.FeatureMap(i).WeightFullname().decode('utf-8')]]
after = rsp_get_model.FeatureMap(i).DataAsNumpy() * 32
print("Before update model", args.pid, origin[0:10])
print("After get model", args.pid, after[0:10])
next_req_timestamp = 0
if rsp_get_model.Retcode() == ResponseCode.ResponseCode.OutOfTime:
next_req_timestamp = int(rsp_get_model.Timestamp().decode('utf-8'))
print("Last iteration is invalid, next request timestamp:", next_req_timestamp)
sys.stdout.flush()
assert np.allclose(origin, after, rtol=1e-05, atol=1e-05)
current_iteration += 1
elif rsp_get_model.Retcode() == ResponseCode.ResponseCode.SucNotReady:
repeat_time = 0
while rsp_get_model.Retcode() == ResponseCode.ResponseCode.SucNotReady:
time.sleep(0.2)
x = session.post(url3, data=build_get_model(current_iteration))
rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
if rsp_get_model.Retcode() == ResponseCode.ResponseCode.OutOfTime:
next_req_timestamp = int(rsp_get_model.Timestamp().decode('utf-8'))
print("Last iteration is invalid, next request timestamp:", next_req_timestamp)
sys.stdout.flush()
break
repeat_time += 1
if repeat_time > 1000:
print("GetModel try timeout ", args.pid)
sys.exit(0)
else:
pass
if next_req_timestamp == 0:
for i in range(0, 1):
print(rsp_get_model.FeatureMap(i).WeightFullname())
origin = update_model_np_data[weight_to_idx[rsp_get_model.FeatureMap(i).WeightFullname().decode('utf-8')]]
after = rsp_get_model.FeatureMap(i).DataAsNumpy() * 32
print("Before update model", args.pid, origin[0:10])
print("After get model", args.pid, after[0:10])
sys.stdout.flush()
assert np.allclose(origin, after, rtol=1e-05, atol=1e-05)
else:
# Sleep to the next request timestamp
current_ts = datetime_to_timestamp(datetime.datetime.now())
duration = next_req_timestamp - current_ts
time.sleep(duration / 1000)

View File

@ -34,10 +34,14 @@ parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--fl_server_port", type=int, default=6666)
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
parser.add_argument("--update_model_ratio", type=float, default=1.0)
parser.add_argument("--update_model_time_window", type=int, default=3000)
parser.add_argument("--fl_name", type=str, default="Lenet")
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False)
args, _ = parser.parse_known_args()
@ -50,14 +54,18 @@ scheduler_ip = args.scheduler_ip
scheduler_port = args.scheduler_port
fl_server_port = args.fl_server_port
start_fl_job_threshold = args.start_fl_job_threshold
start_fl_job_time_window = args.start_fl_job_time_window
update_model_ratio = args.update_model_ratio
update_model_time_window = args.update_model_time_window
fl_name = args.fl_name
fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num
client_batch_size = args.client_batch_size
client_learning_rate = args.client_learning_rate
secure_aggregation = args.secure_aggregation
ctx = {
"enable_ps": False,
"enable_fl": True,
"server_mode": server_mode,
"ms_role": ms_role,
"worker_num": worker_num,
@ -66,15 +74,19 @@ ctx = {
"scheduler_port": scheduler_port,
"fl_server_port": fl_server_port,
"start_fl_job_threshold": start_fl_job_threshold,
"start_fl_job_time_window": start_fl_job_time_window,
"update_model_ratio": update_model_ratio,
"update_model_time_window": update_model_time_window,
"fl_name": fl_name,
"fl_iteration_num": fl_iteration_num,
"client_epoch_num": client_epoch_num,
"client_batch_size": client_batch_size,
"client_learning_rate": client_learning_rate,
"secure_aggregation": secure_aggregation
}
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)
context.set_ps_context(**ctx)
context.set_fl_context(**ctx)
if __name__ == "__main__":
epoch = 5