Add Server part 3
This commit is contained in:
parent
0397f8ac7e
commit
a6f9814552
|
@ -47,6 +47,7 @@
|
||||||
#include "ps/parameter_server.h"
|
#include "ps/parameter_server.h"
|
||||||
#include "ps/scheduler.h"
|
#include "ps/scheduler.h"
|
||||||
#include "ps/worker.h"
|
#include "ps/worker.h"
|
||||||
|
#include "ps/server/server.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -619,6 +620,47 @@ bool StartPSServerAction(const ResourcePtr &res) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool StartServerAction(const ResourcePtr &res) {
|
||||||
|
FuncGraphPtr func_graph = res->func_graph();
|
||||||
|
const std::string &server_mode_ = ps::PSContext::instance()->server_mode();
|
||||||
|
size_t worker_num = ps::PSContext::instance()->initial_worker_num();
|
||||||
|
size_t server_num = ps::PSContext::instance()->initial_server_num();
|
||||||
|
uint64_t fl_server_port = ps::PSContext::instance()->fl_server_port();
|
||||||
|
|
||||||
|
// Update model threshold is a certain ratio of start_fl_job threshold.
|
||||||
|
// update_model_threshold_ = start_fl_job_threshold_ * percent_for_update_model_.
|
||||||
|
size_t start_fl_job_threshold = ps::PSContext::instance()->start_fl_job_threshold();
|
||||||
|
float percent_for_update_model = 1;
|
||||||
|
size_t update_model_threshold = static_cast<size_t>(std::ceil(start_fl_job_threshold * percent_for_update_model));
|
||||||
|
|
||||||
|
std::vector<ps::server::RoundConfig> rounds_config = {
|
||||||
|
{"startFLJob", false, 3000, false, start_fl_job_threshold},
|
||||||
|
{"updateModel", false, 3000, false, update_model_threshold},
|
||||||
|
{"getModel", false, 3000},
|
||||||
|
{"asyncUpdateModel"},
|
||||||
|
{"asyncGetModel"},
|
||||||
|
{"push", false, 3000, true, worker_num},
|
||||||
|
{"pull", false, 3000, true, worker_num},
|
||||||
|
{"getWeightsByKey", false, 3000, true, 1},
|
||||||
|
{"overwriteWeightsByKey", false, 3000, true, server_num},
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t executor_threshold = 0;
|
||||||
|
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
|
||||||
|
executor_threshold = update_model_threshold;
|
||||||
|
ps::server::Server::GetInstance().Initialize(true, true, fl_server_port, rounds_config, func_graph,
|
||||||
|
executor_threshold);
|
||||||
|
} else if (server_mode_ == ps::kServerModePS) {
|
||||||
|
executor_threshold = worker_num;
|
||||||
|
ps::server::Server::GetInstance().Initialize(true, false, 0, rounds_config, func_graph, executor_threshold);
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "Server mode " << server_mode_ << " is not supported.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ps::server::Server::GetInstance().Run();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool StartPSSchedulerAction(const ResourcePtr &res) {
|
bool StartPSSchedulerAction(const ResourcePtr &res) {
|
||||||
ps::Scheduler::GetInstance().Run();
|
ps::Scheduler::GetInstance().Run();
|
||||||
return true;
|
return true;
|
||||||
|
@ -797,6 +839,14 @@ std::vector<ActionItem> VmPipeline() {
|
||||||
}
|
}
|
||||||
|
|
||||||
#if (ENABLE_CPU && !_WIN32)
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
|
std::vector<ActionItem> ServerPipeline() {
|
||||||
|
auto actions = CommonPipeline();
|
||||||
|
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
||||||
|
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||||
|
actions.emplace_back(std::make_pair("server", StartServerAction));
|
||||||
|
return actions;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<ActionItem> PServerPipeline() {
|
std::vector<ActionItem> PServerPipeline() {
|
||||||
auto actions = CommonPipeline();
|
auto actions = CommonPipeline();
|
||||||
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
||||||
|
|
|
@ -43,10 +43,14 @@ bool ExecuteAction(const ResourcePtr &res);
|
||||||
bool StartPSWorkerAction(const ResourcePtr &res);
|
bool StartPSWorkerAction(const ResourcePtr &res);
|
||||||
bool StartPSServerAction(const ResourcePtr &res);
|
bool StartPSServerAction(const ResourcePtr &res);
|
||||||
bool StartPSSchedulerAction(const ResourcePtr &res);
|
bool StartPSSchedulerAction(const ResourcePtr &res);
|
||||||
|
// This action is only for federated learning only. In later version, parameter server mode and federated learning will
|
||||||
|
// use the same action.
|
||||||
|
bool StartServerAction(const ResourcePtr &res);
|
||||||
|
|
||||||
std::vector<ActionItem> GePipeline();
|
std::vector<ActionItem> GePipeline();
|
||||||
std::vector<ActionItem> VmPipeline();
|
std::vector<ActionItem> VmPipeline();
|
||||||
std::vector<ActionItem> PServerPipeline();
|
std::vector<ActionItem> PServerPipeline();
|
||||||
|
std::vector<ActionItem> ServerPipeline();
|
||||||
std::vector<ActionItem> PSchedulerPipeline();
|
std::vector<ActionItem> PSchedulerPipeline();
|
||||||
abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph,
|
abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph,
|
||||||
const abstract::AbstractBasePtrList &args_spec, bool clear = false);
|
const abstract::AbstractBasePtrList &args_spec, bool clear = false);
|
||||||
|
|
|
@ -326,7 +326,24 @@ PYBIND11_MODULE(_c_expression, m) {
|
||||||
.def("insert_accumu_init_info", &PSContext::InsertAccumuInitInfo, "Insert accumulation initialization value.")
|
.def("insert_accumu_init_info", &PSContext::InsertAccumuInitInfo, "Insert accumulation initialization value.")
|
||||||
.def("clone_hash_table", &PSContext::CloneHashTable, "Clone a hash table.")
|
.def("clone_hash_table", &PSContext::CloneHashTable, "Clone a hash table.")
|
||||||
.def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not.")
|
.def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not.")
|
||||||
.def("set_rank_id", &PSContext::set_rank_id, "Set rank id for worker on ps mode.");
|
.def("set_rank_id", &PSContext::set_rank_id, "Set rank id for worker on ps mode.")
|
||||||
|
.def("set_server_mode", &PSContext::set_server_mode, "Set server mode.")
|
||||||
|
.def("server_mode", &PSContext::server_mode, "Get server mode.")
|
||||||
|
.def("set_ms_role", &PSContext::set_ms_role, "Set role for this process.")
|
||||||
|
.def("ms_role", &PSContext::ms_role, "Get role for this process.")
|
||||||
|
.def("set_worker_num", &PSContext::set_worker_num, "Set worker number.")
|
||||||
|
.def("set_server_num", &PSContext::set_server_num, "Set server number.")
|
||||||
|
.def("set_scheduler_ip", &PSContext::set_scheduler_ip, "Set scheduler ip.")
|
||||||
|
.def("set_scheduler_port", &PSContext::set_scheduler_port, "Set scheduler port.")
|
||||||
|
.def("set_fl_server_port", &PSContext::set_fl_server_port, "Set federated learning server port.")
|
||||||
|
.def("set_fl_client_enable", &PSContext::set_fl_client_enable, "Set federated learning client.")
|
||||||
|
.def("set_start_fl_job_threshold", &PSContext::set_start_fl_job_threshold, "Set threshold count for start_fl_job.")
|
||||||
|
.def("set_fl_name", &PSContext::set_fl_name, "Set federated learning name.")
|
||||||
|
.def("set_fl_iteration_num", &PSContext::set_fl_iteration_num, "Set federated learning iteration number.")
|
||||||
|
.def("set_client_epoch_num", &PSContext::set_client_epoch_num, "Set federated learning client epoch number.")
|
||||||
|
.def("set_client_batch_size", &PSContext::set_client_batch_size, "Set federated learning client batch size.")
|
||||||
|
.def("set_secure_aggregation", &PSContext::set_secure_aggregation,
|
||||||
|
"Set federated learning client using secure aggregation.");
|
||||||
|
|
||||||
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
|
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
|
||||||
.def(py::init())
|
.def(py::init())
|
||||||
|
|
|
@ -55,6 +55,7 @@
|
||||||
#include "ps/worker.h"
|
#include "ps/worker.h"
|
||||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||||
#include "ps/ps_cache/ps_cache_manager.h"
|
#include "ps/ps_cache/ps_cache_manager.h"
|
||||||
|
#include "ps/server/server.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if (ENABLE_GE || ENABLE_D)
|
#if (ENABLE_GE || ENABLE_D)
|
||||||
|
@ -529,6 +530,11 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
|
||||||
std::string backend = MsContext::GetInstance()->backend_policy();
|
std::string backend = MsContext::GetInstance()->backend_policy();
|
||||||
|
|
||||||
#if (ENABLE_CPU && !_WIN32)
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
|
const std::string &server_mode = ps::PSContext::instance()->server_mode();
|
||||||
|
if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) &&
|
||||||
|
ps::PSContext::instance()->is_server()) {
|
||||||
|
return ServerPipeline();
|
||||||
|
}
|
||||||
if (ps::PSContext::instance()->is_server()) {
|
if (ps::PSContext::instance()->is_server()) {
|
||||||
resource->results()[kBackend] = compile::CreateBackend();
|
resource->results()[kBackend] = compile::CreateBackend();
|
||||||
return PServerPipeline();
|
return PServerPipeline();
|
||||||
|
|
|
@ -50,10 +50,13 @@ if(NOT ENABLE_CPU OR WIN32)
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/apply_momentum_kernel.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/apply_momentum_kernel.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/aggregation_kernel_factory.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/aggregation_kernel_factory.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc")
|
||||||
|
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/fed_avg_kernel.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/optimizer_kernel_factory.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/optimizer_kernel_factory.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel_factory.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel_factory.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/start_fl_job_kernel.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/start_fl_job_kernel.cc")
|
||||||
|
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/update_model_kernel.cc")
|
||||||
|
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/get_model_kernel.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/params_info.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/params_info.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/consistent_hash_ring.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/consistent_hash_ring.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/iteration_timer.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/iteration_timer.cc")
|
||||||
|
@ -67,6 +70,7 @@ if(NOT ENABLE_CPU OR WIN32)
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/iteration.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/iteration.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/model_store.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/model_store.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/round.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/round.cc")
|
||||||
|
list(REMOVE_ITEM _PS_SRC_FILES "server/server.cc")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc")
|
||||||
|
|
|
@ -61,7 +61,12 @@ void PSContext::SetPSEnable(bool enabled) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool PSContext::is_ps_mode() const { return ps_enabled_; }
|
bool PSContext::is_ps_mode() const {
|
||||||
|
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return ps_enabled_;
|
||||||
|
}
|
||||||
|
|
||||||
void PSContext::Reset() {
|
void PSContext::Reset() {
|
||||||
ps_enabled_ = false;
|
ps_enabled_ = false;
|
||||||
|
@ -77,6 +82,9 @@ void PSContext::Reset() {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string PSContext::ms_role() const {
|
std::string PSContext::ms_role() const {
|
||||||
|
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
|
||||||
|
return role_;
|
||||||
|
}
|
||||||
if (is_worker_) {
|
if (is_worker_) {
|
||||||
return kEnvRoleOfWorker;
|
return kEnvRoleOfWorker;
|
||||||
} else if (is_pserver_) {
|
} else if (is_pserver_) {
|
||||||
|
@ -88,11 +96,26 @@ std::string PSContext::ms_role() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool PSContext::is_worker() const { return is_worker_; }
|
bool PSContext::is_worker() const {
|
||||||
|
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
|
||||||
|
return role_ == kRoleOfWorker;
|
||||||
|
}
|
||||||
|
return is_worker_;
|
||||||
|
}
|
||||||
|
|
||||||
bool PSContext::is_server() const { return is_pserver_; }
|
bool PSContext::is_server() const {
|
||||||
|
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
|
||||||
|
return role_ == kEnvRoleOfServer;
|
||||||
|
}
|
||||||
|
return is_pserver_;
|
||||||
|
}
|
||||||
|
|
||||||
bool PSContext::is_scheduler() const { return is_sched_; }
|
bool PSContext::is_scheduler() const {
|
||||||
|
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
|
||||||
|
return role_ == kEnvRoleOfScheduler;
|
||||||
|
}
|
||||||
|
return is_sched_;
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t PSContext::initial_worker_num() { return worker_num_; }
|
uint32_t PSContext::initial_worker_num() { return worker_num_; }
|
||||||
|
|
||||||
|
@ -150,6 +173,94 @@ void PSContext::set_rank_id(int rank_id) const {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PSContext::set_server_mode(const std::string &server_mode) {
|
||||||
|
if (server_mode != kServerModePS && server_mode != kServerModeFL && server_mode != kServerModeHybrid) {
|
||||||
|
MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL
|
||||||
|
<< " or " << kServerModeHybrid;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
server_mode_ = server_mode;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string &PSContext::server_mode() const { return server_mode_; }
|
||||||
|
|
||||||
|
void PSContext::set_ms_role(const std::string &role) {
|
||||||
|
if (server_mode_ != kServerModeFL && server_mode_ != kServerModeHybrid) {
|
||||||
|
MS_LOG(EXCEPTION) << "Only federated learning supports to set role by ps context.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (role != kEnvRoleOfWorker && role != kEnvRoleOfServer && role != kEnvRoleOfScheduler) {
|
||||||
|
MS_LOG(EXCEPTION) << "ms_role " << role << " is invalid.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
role_ = role;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PSContext::set_worker_num(uint32_t worker_num) { worker_num_ = worker_num; }
|
||||||
|
uint32_t PSContext::worker_num() const { return worker_num_; }
|
||||||
|
|
||||||
|
void PSContext::set_server_num(uint32_t server_num) {
|
||||||
|
if (server_num == 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "Server number must be greater than 0.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
server_num_ = server_num;
|
||||||
|
}
|
||||||
|
uint32_t PSContext::server_num() const { return server_num_; }
|
||||||
|
|
||||||
|
void PSContext::set_scheduler_ip(const std::string &sched_ip) { scheduler_host_ = sched_ip; }
|
||||||
|
|
||||||
|
std::string PSContext::scheduler_ip() const { return scheduler_host_; }
|
||||||
|
|
||||||
|
void PSContext::set_scheduler_port(uint16_t sched_port) { scheduler_port_ = sched_port; }
|
||||||
|
|
||||||
|
uint16_t PSContext::scheduler_port() const { return scheduler_port_; }
|
||||||
|
|
||||||
|
void PSContext::GenerateResetterRound() {
|
||||||
|
uint32_t binary_server_context = 0;
|
||||||
|
bool is_parameter_server_mode = false;
|
||||||
|
bool is_federated_learning_mode = false;
|
||||||
|
bool is_mixed_training_mode = false;
|
||||||
|
|
||||||
|
if (server_mode_ == kServerModePS) {
|
||||||
|
is_parameter_server_mode = true;
|
||||||
|
} else if (server_mode_ == kServerModeFL) {
|
||||||
|
is_federated_learning_mode = true;
|
||||||
|
} else if (server_mode_ == kServerModeHybrid) {
|
||||||
|
is_mixed_training_mode = true;
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << server_mode_ << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL
|
||||||
|
<< " or " << kServerModeHybrid;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
binary_server_context = (is_parameter_server_mode << 0) | (is_federated_learning_mode << 1) |
|
||||||
|
(is_mixed_training_mode << 2) | (secure_aggregation_ << 3) | (worker_overwrite_weights_ << 4);
|
||||||
|
if (kServerContextToResetRoundMap.count(binary_server_context) == 0) {
|
||||||
|
resetter_round_ = ResetterRound::kNoNeedToReset;
|
||||||
|
} else {
|
||||||
|
resetter_round_ = kServerContextToResetRoundMap.at(binary_server_context);
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Server context is " << binary_server_context << ". Resetter round is " << resetter_round_;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ResetterRound PSContext::resetter_round() const { return resetter_round_; }
|
||||||
|
|
||||||
|
void PSContext::set_fl_server_port(uint16_t fl_server_port) { fl_server_port_ = fl_server_port; }
|
||||||
|
|
||||||
|
uint16_t PSContext::fl_server_port() const { return fl_server_port_; }
|
||||||
|
|
||||||
|
void PSContext::set_fl_client_enable(bool enabled) { fl_client_enable_ = enabled; }
|
||||||
|
|
||||||
|
bool PSContext::fl_client_enable() { return fl_client_enable_; }
|
||||||
|
|
||||||
|
void PSContext::set_start_fl_job_threshold(size_t start_fl_job_threshold) {
|
||||||
|
start_fl_job_threshold_ = start_fl_job_threshold;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t PSContext::start_fl_job_threshold() const { return start_fl_job_threshold_; }
|
||||||
|
|
||||||
void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
|
void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
|
||||||
|
|
||||||
const std::string &PSContext::fl_name() const { return fl_name_; }
|
const std::string &PSContext::fl_name() const { return fl_name_; }
|
||||||
|
@ -165,5 +276,15 @@ uint64_t PSContext::client_epoch_num() const { return client_epoch_num_; }
|
||||||
void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch_size_ = client_batch_size; }
|
void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch_size_ = client_batch_size; }
|
||||||
|
|
||||||
uint64_t PSContext::client_batch_size() const { return client_batch_size_; }
|
uint64_t PSContext::client_batch_size() const { return client_batch_size_; }
|
||||||
|
|
||||||
|
void PSContext::set_worker_overwrite_weights(uint64_t worker_overwrite_weights) {
|
||||||
|
worker_overwrite_weights_ = worker_overwrite_weights;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t PSContext::worker_overwrite_weights() const { return worker_overwrite_weights_; }
|
||||||
|
|
||||||
|
void PSContext::set_secure_aggregation(bool secure_aggregation) { secure_aggregation_ = secure_aggregation; }
|
||||||
|
|
||||||
|
bool PSContext::secure_aggregation() const { return secure_aggregation_; }
|
||||||
} // namespace ps
|
} // namespace ps
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#ifndef MINDSPORE_CCSRC_PS_CONTEXT_H_
|
#ifndef MINDSPORE_CCSRC_PS_CONTEXT_H_
|
||||||
#define MINDSPORE_CCSRC_PS_CONTEXT_H_
|
#define MINDSPORE_CCSRC_PS_CONTEXT_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ps/constants.h"
|
#include "ps/constants.h"
|
||||||
|
@ -24,12 +25,32 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace ps {
|
||||||
|
constexpr char kServerModePS[] = "PARAMETER_SERVER";
|
||||||
|
constexpr char kServerModeFL[] = "FEDERATED_LEARNING";
|
||||||
|
constexpr char kServerModeHybrid[] = "HYBRID_TRAINING";
|
||||||
constexpr char kEnvRole[] = "MS_ROLE";
|
constexpr char kEnvRole[] = "MS_ROLE";
|
||||||
constexpr char kEnvRoleOfPServer[] = "MS_PSERVER";
|
constexpr char kEnvRoleOfPServer[] = "MS_PSERVER";
|
||||||
|
constexpr char kEnvRoleOfServer[] = "MS_SERVER";
|
||||||
constexpr char kEnvRoleOfWorker[] = "MS_WORKER";
|
constexpr char kEnvRoleOfWorker[] = "MS_WORKER";
|
||||||
constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
|
constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
|
||||||
constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
|
constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
|
||||||
|
|
||||||
|
// Use binary data to represent federated learning server's context so that we can judge which round resets the
|
||||||
|
// iteration. From right to left, each bit stands for:
|
||||||
|
// 0: Server is in parameter server mode.
|
||||||
|
// 1: Server is in federated learning mode.
|
||||||
|
// 2: Server is in mixed training mode.
|
||||||
|
// 3: Server enables sucure aggregation.
|
||||||
|
// 4: Server needs worker to overwrite weights.
|
||||||
|
// For example: 01010 stands for that the server is in federated learning mode and sucure aggregation is enabled.
|
||||||
|
enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kWorkerOverwriteWeights };
|
||||||
|
const std::map<uint32_t, ResetterRound> kServerContextToResetRoundMap = {
|
||||||
|
{0b00010, ResetterRound::kUpdateModel},
|
||||||
|
{0b01010, ResetterRound::kReconstructSeccrets},
|
||||||
|
{0b11100, ResetterRound::kWorkerOverwriteWeights},
|
||||||
|
{0b10100, ResetterRound::kWorkerOverwriteWeights},
|
||||||
|
{0b00100, ResetterRound::kUpdateModel}};
|
||||||
|
|
||||||
class PSContext {
|
class PSContext {
|
||||||
public:
|
public:
|
||||||
~PSContext() = default;
|
~PSContext() = default;
|
||||||
|
@ -60,19 +81,64 @@ class PSContext {
|
||||||
void set_cache_enable(bool cache_enable) const;
|
void set_cache_enable(bool cache_enable) const;
|
||||||
void set_rank_id(int rank_id) const;
|
void set_rank_id(int rank_id) const;
|
||||||
|
|
||||||
// Setter and getter for federated learning.
|
// In new server framework, process role, worker number, server number, scheduler ip and scheduler port should be set
|
||||||
|
// by ps_context.
|
||||||
|
void set_server_mode(const std::string &server_mode);
|
||||||
|
const std::string &server_mode() const;
|
||||||
|
|
||||||
|
void set_ms_role(const std::string &role);
|
||||||
|
|
||||||
|
void set_worker_num(uint32_t worker_num);
|
||||||
|
uint32_t worker_num() const;
|
||||||
|
|
||||||
|
void set_server_num(uint32_t server_num);
|
||||||
|
uint32_t server_num() const;
|
||||||
|
|
||||||
|
void set_scheduler_ip(const std::string &sched_ip);
|
||||||
|
std::string scheduler_ip() const;
|
||||||
|
|
||||||
|
void set_scheduler_port(uint16_t sched_port);
|
||||||
|
uint16_t scheduler_port() const;
|
||||||
|
|
||||||
|
// Methods federated learning.
|
||||||
|
|
||||||
|
// Generate which round should reset the iteration.
|
||||||
|
void GenerateResetterRound();
|
||||||
|
ResetterRound resetter_round() const;
|
||||||
|
|
||||||
|
void set_fl_server_port(uint16_t fl_server_port);
|
||||||
|
uint16_t fl_server_port() const;
|
||||||
|
|
||||||
|
// Set true if this process is a federated learning worker in cross-silo scenario.
|
||||||
|
void set_fl_client_enable(bool enabled);
|
||||||
|
bool fl_client_enable();
|
||||||
|
|
||||||
|
void set_start_fl_job_threshold(size_t start_fl_job_threshold);
|
||||||
|
size_t start_fl_job_threshold() const;
|
||||||
|
|
||||||
void set_fl_name(const std::string &fl_name);
|
void set_fl_name(const std::string &fl_name);
|
||||||
const std::string &fl_name() const;
|
const std::string &fl_name() const;
|
||||||
|
|
||||||
|
// Set the iteration number of the federated learning.
|
||||||
void set_fl_iteration_num(uint64_t fl_iteration_num);
|
void set_fl_iteration_num(uint64_t fl_iteration_num);
|
||||||
uint64_t fl_iteration_num() const;
|
uint64_t fl_iteration_num() const;
|
||||||
|
|
||||||
|
// Set the training epoch number of the client.
|
||||||
void set_client_epoch_num(uint64_t client_epoch_num);
|
void set_client_epoch_num(uint64_t client_epoch_num);
|
||||||
uint64_t client_epoch_num() const;
|
uint64_t client_epoch_num() const;
|
||||||
|
|
||||||
|
// Set the data batch size of the client.
|
||||||
void set_client_batch_size(uint64_t client_batch_size);
|
void set_client_batch_size(uint64_t client_batch_size);
|
||||||
uint64_t client_batch_size() const;
|
uint64_t client_batch_size() const;
|
||||||
|
|
||||||
|
// Set true if worker will overwrite weights on server. Used in hybrid training.
|
||||||
|
void set_worker_overwrite_weights(uint64_t worker_overwrite_weights);
|
||||||
|
uint64_t worker_overwrite_weights() const;
|
||||||
|
|
||||||
|
// Set true if using secure aggregation for federated learning.
|
||||||
|
void set_secure_aggregation(bool secure_aggregation);
|
||||||
|
bool secure_aggregation() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
PSContext()
|
PSContext()
|
||||||
: ps_enabled_(false),
|
: ps_enabled_(false),
|
||||||
|
@ -94,11 +160,22 @@ class PSContext {
|
||||||
std::string scheduler_host_;
|
std::string scheduler_host_;
|
||||||
uint16_t scheduler_port_;
|
uint16_t scheduler_port_;
|
||||||
|
|
||||||
|
std::string role_;
|
||||||
|
|
||||||
// Members for federated learning.
|
// Members for federated learning.
|
||||||
|
std::string server_mode_;
|
||||||
|
ResetterRound resetter_round_;
|
||||||
|
uint16_t fl_server_port_;
|
||||||
|
bool fl_client_enable_;
|
||||||
std::string fl_name_;
|
std::string fl_name_;
|
||||||
|
size_t start_fl_job_threshold_;
|
||||||
uint64_t fl_iteration_num_;
|
uint64_t fl_iteration_num_;
|
||||||
uint64_t client_epoch_num_;
|
uint64_t client_epoch_num_;
|
||||||
uint64_t client_batch_size_;
|
uint64_t client_batch_size_;
|
||||||
|
bool worker_overwrite_weights_;
|
||||||
|
|
||||||
|
// Federated learning security.
|
||||||
|
bool secure_aggregation_;
|
||||||
};
|
};
|
||||||
} // namespace ps
|
} // namespace ps
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -20,6 +20,9 @@ namespace mindspore {
|
||||||
namespace ps {
|
namespace ps {
|
||||||
void Scheduler::Run() {
|
void Scheduler::Run() {
|
||||||
MS_LOG(INFO) << "Start scheduler.";
|
MS_LOG(INFO) << "Start scheduler.";
|
||||||
|
core::ClusterMetadata::instance()->Init(
|
||||||
|
PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(),
|
||||||
|
PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port());
|
||||||
scheduler_node_.Start();
|
scheduler_node_.Start();
|
||||||
scheduler_node_.Finish();
|
scheduler_node_.Finish();
|
||||||
scheduler_node_.Stop();
|
scheduler_node_.Stop();
|
||||||
|
|
|
@ -44,6 +44,14 @@ enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER };
|
||||||
enum CommType { HTTP = 0, TCP };
|
enum CommType { HTTP = 0, TCP };
|
||||||
enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum };
|
enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum };
|
||||||
|
|
||||||
|
struct RoundConfig {
|
||||||
|
std::string name;
|
||||||
|
bool check_timeout = false;
|
||||||
|
size_t time_window = 3000;
|
||||||
|
bool check_count = false;
|
||||||
|
size_t threshold_count = 0;
|
||||||
|
};
|
||||||
|
|
||||||
using mindspore::kernel::Address;
|
using mindspore::kernel::Address;
|
||||||
using mindspore::kernel::AddressPtr;
|
using mindspore::kernel::AddressPtr;
|
||||||
using mindspore::kernel::CPUKernel;
|
using mindspore::kernel::CPUKernel;
|
||||||
|
@ -73,6 +81,7 @@ using ReuseKernelNodeInfo = std::map<std::string, size_t>;
|
||||||
using UploadData = std::map<std::string, Address>;
|
using UploadData = std::map<std::string, Address>;
|
||||||
|
|
||||||
constexpr auto kWeight = "weight";
|
constexpr auto kWeight = "weight";
|
||||||
|
constexpr auto kNewWeight = "new_weight";
|
||||||
constexpr auto kAccumulation = "accum";
|
constexpr auto kAccumulation = "accum";
|
||||||
constexpr auto kLearningRate = "lr";
|
constexpr auto kLearningRate = "lr";
|
||||||
constexpr auto kGradient = "grad";
|
constexpr auto kGradient = "grad";
|
||||||
|
@ -87,6 +96,8 @@ constexpr auto kAdamBeta1 = "beta1";
|
||||||
constexpr auto kAdamBeta2 = "beta2";
|
constexpr auto kAdamBeta2 = "beta2";
|
||||||
constexpr auto kAdamEps = "eps";
|
constexpr auto kAdamEps = "eps";
|
||||||
constexpr auto kFtrlLinear = "linear";
|
constexpr auto kFtrlLinear = "linear";
|
||||||
|
constexpr auto kDataSize = "data_size";
|
||||||
|
constexpr auto kNewDataSize = "new_data_size";
|
||||||
|
|
||||||
// OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is
|
// OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is
|
||||||
// launched.
|
// launched.
|
||||||
|
@ -137,6 +148,7 @@ constexpr size_t kExecutorMaxTaskNum = 32;
|
||||||
constexpr int kHttpSuccess = 200;
|
constexpr int kHttpSuccess = 200;
|
||||||
constexpr auto kPBProtocol = "PB";
|
constexpr auto kPBProtocol = "PB";
|
||||||
constexpr auto kFBSProtocol = "FBS";
|
constexpr auto kFBSProtocol = "FBS";
|
||||||
|
constexpr auto kFedAvg = "FedAvg";
|
||||||
constexpr auto kAggregationKernelType = "Aggregation";
|
constexpr auto kAggregationKernelType = "Aggregation";
|
||||||
constexpr auto kOptimizerKernelType = "Optimizer";
|
constexpr auto kOptimizerKernelType = "Optimizer";
|
||||||
constexpr auto kCtxFuncGraph = "FuncGraph";
|
constexpr auto kCtxFuncGraph = "FuncGraph";
|
||||||
|
@ -145,6 +157,8 @@ constexpr auto kCtxDeviceMetas = "device_metas";
|
||||||
constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration";
|
constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration";
|
||||||
constexpr auto kCtxUpdateModelClientList = "update_model_client_list";
|
constexpr auto kCtxUpdateModelClientList = "update_model_client_list";
|
||||||
constexpr auto kCtxUpdateModelClientNum = "update_model_client_num";
|
constexpr auto kCtxUpdateModelClientNum = "update_model_client_num";
|
||||||
|
constexpr auto kCtxUpdateModelThld = "update_model_threshold";
|
||||||
|
constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size";
|
||||||
|
|
||||||
// This macro the current timestamp in milliseconds.
|
// This macro the current timestamp in milliseconds.
|
||||||
#define CURRENT_TIME_MILLI \
|
#define CURRENT_TIME_MILLI \
|
||||||
|
|
|
@ -112,19 +112,19 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) {
|
||||||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||||
return global_current_count_[name].size() == global_threshold_count_[name];
|
return global_current_count_[name].size() == global_threshold_count_[name];
|
||||||
} else {
|
} else {
|
||||||
CountReachThresholdRequest count_reach_threashold_req;
|
CountReachThresholdRequest count_reach_threshold_req;
|
||||||
count_reach_threashold_req.set_name(name);
|
count_reach_threshold_req.set_name(name);
|
||||||
|
|
||||||
std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr;
|
std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr;
|
||||||
if (!communicator_->SendPbRequest(count_reach_threashold_req, counting_server_rank_,
|
if (!communicator_->SendPbRequest(count_reach_threshold_req, counting_server_rank_,
|
||||||
core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) {
|
core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) {
|
||||||
MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name;
|
MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
CountReachThresholdResponse count_reach_threashold_rsp;
|
CountReachThresholdResponse count_reach_threshold_rsp;
|
||||||
count_reach_threashold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), query_cnt_enough_rsp_msg->size());
|
count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), query_cnt_enough_rsp_msg->size());
|
||||||
return count_reach_threashold_rsp.is_enough();
|
return count_reach_threshold_rsp.is_enough();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -200,9 +200,9 @@ void DistributedCountService::HandleCountReachThresholdRequest(const std::shared
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
CountReachThresholdRequest count_reach_threashold_req;
|
CountReachThresholdRequest count_reach_threshold_req;
|
||||||
count_reach_threashold_req.ParseFromArray(message->data(), message->len());
|
count_reach_threshold_req.ParseFromArray(message->data(), message->len());
|
||||||
const std::string &name = count_reach_threashold_req.name();
|
const std::string &name = count_reach_threshold_req.name();
|
||||||
|
|
||||||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||||
if (global_threshold_count_.count(name) == 0) {
|
if (global_threshold_count_.count(name) == 0) {
|
||||||
|
@ -210,10 +210,10 @@ void DistributedCountService::HandleCountReachThresholdRequest(const std::shared
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
CountReachThresholdResponse count_reach_threashold_rsp;
|
CountReachThresholdResponse count_reach_threshold_rsp;
|
||||||
count_reach_threashold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]);
|
count_reach_threshold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]);
|
||||||
communicator_->SendResponse(count_reach_threashold_rsp.SerializeAsString().data(),
|
communicator_->SendResponse(count_reach_threshold_rsp.SerializeAsString().data(),
|
||||||
count_reach_threashold_rsp.SerializeAsString().size(), message);
|
count_reach_threshold_rsp.SerializeAsString().size(), message);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -193,7 +193,29 @@ void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<co
|
||||||
|
|
||||||
bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) {
|
bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) {
|
||||||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||||
metadata_[name] = meta;
|
if (meta.has_device_meta()) {
|
||||||
|
auto &fl_id_to_meta_map = *metadata_[name].mutable_device_metas()->mutable_fl_id_to_meta();
|
||||||
|
auto &fl_id = meta.device_meta().fl_id();
|
||||||
|
auto &device_meta = meta.device_meta();
|
||||||
|
fl_id_to_meta_map[fl_id] = device_meta;
|
||||||
|
} else if (meta.has_fl_id()) {
|
||||||
|
auto client_list = metadata_[name].mutable_client_list();
|
||||||
|
auto &fl_id = meta.fl_id().fl_id();
|
||||||
|
// Check whether the new item already exists.
|
||||||
|
bool add_flag = true;
|
||||||
|
for (int i = 0; i < client_list->fl_id_size(); i++) {
|
||||||
|
if (fl_id == client_list->fl_id(i)) {
|
||||||
|
add_flag = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (add_flag) {
|
||||||
|
client_list->add_fl_id(fl_id);
|
||||||
|
}
|
||||||
|
} else if (meta.has_update_model_threshold()) {
|
||||||
|
auto update_model_threshold = metadata_[name].mutable_update_model_threshold();
|
||||||
|
*update_model_threshold = meta.update_model_threshold();
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
} // namespace server
|
} // namespace server
|
||||||
|
|
|
@ -23,7 +23,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace ps {
|
||||||
namespace server {
|
namespace server {
|
||||||
void Executor::Init(const FuncGraphPtr &func_graph, size_t aggregation_count) {
|
void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
if (aggregation_count == 0) {
|
if (aggregation_count == 0) {
|
||||||
MS_LOG(EXCEPTION) << "Server aggregation count must be greater than 0";
|
MS_LOG(EXCEPTION) << "Server aggregation count must be greater than 0";
|
||||||
|
|
|
@ -43,7 +43,7 @@ class Executor {
|
||||||
// be used for aggregators.
|
// be used for aggregators.
|
||||||
// As noted in header file parameter_aggregator.h, we create aggregators by trainable parameters, which is the
|
// As noted in header file parameter_aggregator.h, we create aggregators by trainable parameters, which is the
|
||||||
// optimizer cnode's input. So we need to initialize server executor using func_graph.
|
// optimizer cnode's input. So we need to initialize server executor using func_graph.
|
||||||
void Init(const FuncGraphPtr &func_graph, size_t aggregation_count);
|
void Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count);
|
||||||
|
|
||||||
// Called in parameter server training mode to do Push operation.
|
// Called in parameter server training mode to do Push operation.
|
||||||
// For the same trainable parameter, HandlePush method must be called aggregation_count_ times before it's considered
|
// For the same trainable parameter, HandlePush method must be called aggregation_count_ times before it's considered
|
||||||
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ps/server/kernel/fed_avg_kernel.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
namespace kernel {
|
||||||
|
REG_AGGREGATION_KERNEL_TWO(FedAvg,
|
||||||
|
ParamsInfo()
|
||||||
|
.AddInputNameType(kWeight, kNumberTypeFloat32)
|
||||||
|
.AddInputNameType(kDataSize, kNumberTypeUInt64)
|
||||||
|
.AddInputNameType(kNewWeight, kNumberTypeFloat32)
|
||||||
|
.AddInputNameType(kNewDataSize, kNumberTypeUInt64),
|
||||||
|
FedAvgKernel, float, size_t)
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,179 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include <functional>
|
||||||
|
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||||
|
#include "ps/server/common.h"
|
||||||
|
#include "ps/server/collective_ops_impl.h"
|
||||||
|
#include "ps/server/distributed_count_service.h"
|
||||||
|
#include "ps/server/local_meta_store.h"
|
||||||
|
#include "ps/server/kernel/aggregation_kernel.h"
|
||||||
|
#include "ps/server/kernel/aggregation_kernel_factory.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
namespace kernel {
|
||||||
|
// The implementation for the federated average. We do weighted average for the weights. The uploaded weights from
|
||||||
|
// FL-clients is already multiplied by its data size so only sum and division are done in this kernel.
|
||||||
|
|
||||||
|
// Pay attention that this kernel is the distributed version of federated average, which means each server node in the
|
||||||
|
// cluster in invalved in the aggragation process. So the DistributedCountService and CollectiveOpsImpl are called.
|
||||||
|
template <typename T, typename S>
|
||||||
|
class FedAvgKernel : public AggregationKernel {
|
||||||
|
public:
|
||||||
|
FedAvgKernel() : participated_(false) {}
|
||||||
|
~FedAvgKernel() override = default;
|
||||||
|
void InitKernel(const CNodePtr &kernel_node) override {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
|
std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||||
|
if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 ||
|
||||||
|
kNameToIdxMap.at(cnode_name).at("inputs").count("weight") == 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "Can't find index info of weight for kernel " << cnode_name;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
cnode_weight_idx_ = kNameToIdxMap.at(cnode_name).at("inputs").at("weight");
|
||||||
|
std::vector<size_t> weight_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, cnode_weight_idx_);
|
||||||
|
size_t weight_size =
|
||||||
|
std::accumulate(weight_shape.begin(), weight_shape.end(), sizeof(T), std::multiplies<size_t>());
|
||||||
|
size_t new_weight_size = weight_size;
|
||||||
|
|
||||||
|
input_size_list_.push_back(weight_size);
|
||||||
|
input_size_list_.push_back(sizeof(size_t));
|
||||||
|
input_size_list_.push_back(new_weight_size);
|
||||||
|
input_size_list_.push_back(sizeof(size_t));
|
||||||
|
|
||||||
|
auto weight_node =
|
||||||
|
AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, cnode_weight_idx_), 0).first;
|
||||||
|
MS_EXCEPTION_IF_NULL(weight_node);
|
||||||
|
name_ = cnode_name + "." + weight_node->fullname_with_scope();
|
||||||
|
MS_LOG(INFO) << "Register counter for " << name_;
|
||||||
|
auto first_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) {
|
||||||
|
std::unique_lock<std::mutex> lock(weight_mutex_);
|
||||||
|
if (!participated_) {
|
||||||
|
ClearWeightAndDataSize();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto last_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) {
|
||||||
|
T *weight_addr = reinterpret_cast<T *>(weight_addr_->addr);
|
||||||
|
size_t weight_size = weight_addr_->size;
|
||||||
|
S *data_size_addr = reinterpret_cast<S *>(data_size_addr_->addr);
|
||||||
|
if (!CollectiveOpsImpl::GetInstance().AllReduce<T>(weight_addr, weight_addr, weight_size / sizeof(T))) {
|
||||||
|
MS_LOG(ERROR) << "Federated average allreduce failed.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!CollectiveOpsImpl::GetInstance().AllReduce<S>(data_size_addr, data_size_addr, 1)) {
|
||||||
|
MS_LOG(ERROR) << "Federated average allreduce failed.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
LocalMetaStore::GetInstance().put_value(kCtxFedAvgTotalDataSize, data_size_addr[0]);
|
||||||
|
for (size_t i = 0; i < weight_size / sizeof(T); i++) {
|
||||||
|
weight_addr[i] /= data_size_addr[0];
|
||||||
|
}
|
||||||
|
done_ = true;
|
||||||
|
DistributedCountService::GetInstance().ResetCounter(name_);
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler, last_cnt_handler});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs) override {
|
||||||
|
std::unique_lock<std::mutex> lock(weight_mutex_);
|
||||||
|
// The weight and new_weight values should be multiplied by clients already, so we don't need to do multiplication
|
||||||
|
// again.
|
||||||
|
T *weight_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||||
|
S *data_size_addr = reinterpret_cast<S *>(inputs[1]->addr);
|
||||||
|
T *new_weight_addr = reinterpret_cast<T *>(inputs[2]->addr);
|
||||||
|
S *new_data_size_addr = reinterpret_cast<S *>(inputs[3]->addr);
|
||||||
|
if (accum_count_ == 0) {
|
||||||
|
ClearWeightAndDataSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(DEBUG) << "Iteration: " << LocalMetaStore::GetInstance().curr_iter_num() << " launching FedAvgKernel for "
|
||||||
|
<< name_ << " new data size is " << new_data_size_addr[0] << ", current total data size is "
|
||||||
|
<< data_size_addr[0];
|
||||||
|
for (size_t i = 0; i < inputs[2]->size / sizeof(T); i++) {
|
||||||
|
weight_addr[i] += new_weight_addr[i];
|
||||||
|
}
|
||||||
|
data_size_addr[0] += new_data_size_addr[0];
|
||||||
|
lock.unlock();
|
||||||
|
|
||||||
|
accum_count_++;
|
||||||
|
participated_ = true;
|
||||||
|
DistributedCountService::GetInstance().Count(
|
||||||
|
name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
void Reset() {
|
||||||
|
accum_count_ = 0;
|
||||||
|
done_ = false;
|
||||||
|
participated_ = false;
|
||||||
|
DistributedCountService::GetInstance().ResetCounter(name_);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
bool IsAggregationDone() { return done_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void GenerateReuseKernelNodeInfo() override {
|
||||||
|
// Only the trainable parameter is reused for federated average.
|
||||||
|
reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, cnode_weight_idx_));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// In some cases, the Launch method is not called and the weights involved in AllReduce should be set to 0.
|
||||||
|
void ClearWeightAndDataSize() {
|
||||||
|
int ret = memset_s(weight_addr_->addr, weight_addr_->size, 0x00, weight_addr_->size);
|
||||||
|
if (ret != 0) {
|
||||||
|
MS_LOG(ERROR) << "memset_s error, errorno(" << ret << ")";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ret = memset_s(data_size_addr_->addr, data_size_addr_->size, 0x00, data_size_addr_->size);
|
||||||
|
if (ret != 0) {
|
||||||
|
MS_LOG(ERROR) << "memset_s error, errorno(" << ret << ")";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The trainable parameter index of the kernel node which is parsed from the frontend func_graph.
|
||||||
|
size_t cnode_weight_idx_;
|
||||||
|
|
||||||
|
// The address pointer of the inputs.
|
||||||
|
AddressPtr weight_addr_;
|
||||||
|
AddressPtr data_size_addr_;
|
||||||
|
AddressPtr new_weight_addr_;
|
||||||
|
AddressPtr new_data_size_addr_;
|
||||||
|
|
||||||
|
// Whether the kernel's Launch method is called.
|
||||||
|
bool participated_;
|
||||||
|
|
||||||
|
// The kernel could be called concurrently so we need lock to ensure threadsafe.
|
||||||
|
std::mutex weight_mutex_;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_
|
|
@ -0,0 +1,125 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ps/server/kernel/round/get_model_kernel.h"
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "ps/server/model_store.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
namespace kernel {
|
||||||
|
void GetModelKernel::InitKernel(size_t) {
|
||||||
|
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||||
|
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||||
|
}
|
||||||
|
|
||||||
|
executor_ = &Executor::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(executor_);
|
||||||
|
if (!executor_->initialized()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GetModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs) {
|
||||||
|
MS_LOG(INFO) << "Launching GetModelKernel kernel.";
|
||||||
|
void *req_data = inputs[0]->addr;
|
||||||
|
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
|
||||||
|
if (fbb == nullptr || req_data == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const schema::RequestGetModel *get_model_req = flatbuffers::GetRoot<schema::RequestGetModel>(req_data);
|
||||||
|
GetModel(get_model_req, fbb);
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GetModelKernel::Reset() {
|
||||||
|
MS_LOG(INFO) << "Get model kernel reset!";
|
||||||
|
StopTimer();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, const std::shared_ptr<FBBuilder> &fbb) {
|
||||||
|
std::map<std::string, AddressPtr> feature_maps;
|
||||||
|
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
|
size_t get_model_iter = static_cast<size_t>(get_model_req->iteration());
|
||||||
|
const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
|
||||||
|
size_t latest_iter_num = iter_to_model.rbegin()->first;
|
||||||
|
|
||||||
|
if ((current_iter == get_model_iter && latest_iter_num != current_iter) || current_iter == get_model_iter - 1) {
|
||||||
|
std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter);
|
||||||
|
BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||||
|
MS_LOG(WARNING) << reason;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (iter_to_model.count(get_model_iter) == 0) {
|
||||||
|
std::string reason = "The iteration of GetModel request" + std::to_string(get_model_iter) +
|
||||||
|
" is invalid. Current iteration is " + std::to_string(current_iter);
|
||||||
|
BuildGetModelRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter, feature_maps,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter);
|
||||||
|
BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED,
|
||||||
|
"Get model for iteration " + std::to_string(get_model_iter) + " success.", current_iter,
|
||||||
|
feature_maps, std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GetModelKernel::BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||||
|
const std::string &reason, const size_t iter,
|
||||||
|
const std::map<std::string, AddressPtr> &feature_maps,
|
||||||
|
const std::string ×tamp) {
|
||||||
|
auto fbs_reason = fbb->CreateString(reason);
|
||||||
|
auto fbs_timestamp = fbb->CreateString(timestamp);
|
||||||
|
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
|
||||||
|
for (const auto &feature_map : feature_maps) {
|
||||||
|
auto fbs_weight_fullname = fbb->CreateString(feature_map.first);
|
||||||
|
auto fbs_weight_data =
|
||||||
|
fbb->CreateVector(reinterpret_cast<float *>(feature_map.second->addr), feature_map.second->size / sizeof(float));
|
||||||
|
auto fbs_feature_map = schema::CreateFeatureMap(*(fbb.get()), fbs_weight_fullname, fbs_weight_data);
|
||||||
|
fbs_feature_maps.push_back(fbs_feature_map);
|
||||||
|
}
|
||||||
|
auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
|
||||||
|
|
||||||
|
schema::ResponseGetModelBuilder rsp_get_model_builder(*(fbb.get()));
|
||||||
|
rsp_get_model_builder.add_retcode(retcode);
|
||||||
|
rsp_get_model_builder.add_reason(fbs_reason);
|
||||||
|
rsp_get_model_builder.add_iteration(static_cast<int>(iter));
|
||||||
|
rsp_get_model_builder.add_feature_map(fbs_feature_maps_vector);
|
||||||
|
rsp_get_model_builder.add_timestamp(fbs_timestamp);
|
||||||
|
auto rsp_get_model = rsp_get_model_builder.Finish();
|
||||||
|
fbb->Finish(rsp_get_model);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
REG_ROUND_KERNEL(getModel, GetModelKernel)
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,59 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_MODEL_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_MODEL_KERNEL_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "ps/server/common.h"
|
||||||
|
#include "ps/server/executor.h"
|
||||||
|
#include "ps/server/kernel/round/round_kernel.h"
|
||||||
|
#include "ps/server/kernel/round/round_kernel_factory.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
namespace kernel {
|
||||||
|
class GetModelKernel : public RoundKernel {
|
||||||
|
public:
|
||||||
|
GetModelKernel() = default;
|
||||||
|
~GetModelKernel() override = default;
|
||||||
|
|
||||||
|
void InitKernel(size_t) override;
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs);
|
||||||
|
bool Reset() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void GetModel(const schema::RequestGetModel *get_model_req, const std::shared_ptr<FBBuilder> &fbb);
|
||||||
|
void BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||||
|
const std::string &reason, const size_t iter,
|
||||||
|
const std::map<std::string, AddressPtr> &feature_maps, const std::string ×tamp);
|
||||||
|
|
||||||
|
// The executor is for getting model for getModel request.
|
||||||
|
Executor *executor_;
|
||||||
|
|
||||||
|
// The time window of one iteration.
|
||||||
|
size_t iteration_time_window_;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
|
@ -49,7 +49,7 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
void *req_data = inputs[0]->addr;
|
void *req_data = inputs[0]->addr;
|
||||||
const std::shared_ptr<FBBuilder> &fbb = std::make_shared<FBBuilder>();
|
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
|
||||||
if (fbb == nullptr || req_data == nullptr) {
|
if (fbb == nullptr || req_data == nullptr) {
|
||||||
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
|
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -0,0 +1,203 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "ps/server/kernel/round/update_model_kernel.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
namespace kernel {
|
||||||
|
void UpdateModelKernel::InitKernel(size_t threshold_count) {
|
||||||
|
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||||
|
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||||
|
}
|
||||||
|
|
||||||
|
executor_ = &Executor::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(executor_);
|
||||||
|
if (!executor_->initialized()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
PBMetadata client_list;
|
||||||
|
DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxUpdateModelClientList, client_list);
|
||||||
|
LocalMetaStore::GetInstance().put_value(kCtxUpdateModelThld, threshold_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs) {
|
||||||
|
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||||
|
MS_LOG(ERROR) << "inputs or outputs size is invalid.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
void *req_data = inputs[0]->addr;
|
||||||
|
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
|
||||||
|
if (fbb == nullptr || req_data == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Launching UpdateModelKernel kernel.";
|
||||||
|
if (!ReachThresholdForUpdateModel(fbb)) {
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const schema::RequestUpdateModel *update_model_req = flatbuffers::GetRoot<schema::RequestUpdateModel>(req_data);
|
||||||
|
if (!UpdateModel(update_model_req, fbb)) {
|
||||||
|
MS_LOG(ERROR) << "Updating model failed.";
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!CountForUpdateModel(fbb, update_model_req)) {
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool UpdateModelKernel::Reset() {
|
||||||
|
MS_LOG(INFO) << "Update model kernel reset!";
|
||||||
|
StopTimer();
|
||||||
|
DistributedCountService::GetInstance().ResetCounter(name_);
|
||||||
|
executor_->ResetAggregationStatus();
|
||||||
|
DistributedMetadataStore::GetInstance().ResetMetadata(kCtxUpdateModelClientList);
|
||||||
|
size_t &total_data_size = LocalMetaStore::GetInstance().mutable_value<size_t>(kCtxFedAvgTotalDataSize);
|
||||||
|
total_data_size = 0;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
|
||||||
|
if (PSContext::instance()->resetter_round() == ResetterRound::kUpdateModel) {
|
||||||
|
while (!executor_->IsAllWeightAggregationDone()) {
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t total_data_size = LocalMetaStore::GetInstance().value<size_t>(kCtxFedAvgTotalDataSize);
|
||||||
|
MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is "
|
||||||
|
<< total_data_size;
|
||||||
|
|
||||||
|
FinishIterCb();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb) {
|
||||||
|
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||||
|
std::string reason = "Current amount for updateModel is enough.";
|
||||||
|
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
|
||||||
|
const std::shared_ptr<FBBuilder> &fbb) {
|
||||||
|
size_t iteration = static_cast<size_t>(update_model_req->iteration());
|
||||||
|
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
|
||||||
|
std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) +
|
||||||
|
", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num());
|
||||||
|
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas);
|
||||||
|
FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas();
|
||||||
|
std::string update_model_fl_id = update_model_req->fl_id()->str();
|
||||||
|
if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
|
||||||
|
std::string reason = "devices_meta for " + update_model_fl_id + " is not set.";
|
||||||
|
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size();
|
||||||
|
auto feature_map = ParseFeatureMap(update_model_req);
|
||||||
|
for (auto weight : feature_map) {
|
||||||
|
weight.second[kNewDataSize].addr = &data_size;
|
||||||
|
weight.second[kNewDataSize].size = sizeof(size_t);
|
||||||
|
executor_->HandleModelUpdate(weight.first, weight.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
FLId fl_id;
|
||||||
|
fl_id.set_fl_id(update_model_fl_id);
|
||||||
|
PBMetadata comm_value;
|
||||||
|
*comm_value.mutable_fl_id() = fl_id;
|
||||||
|
DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value);
|
||||||
|
|
||||||
|
BuildUpdateModelRsp(fbb, schema::ResponseCode_SucNotReady, "success not ready",
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<std::string, UploadData> UpdateModelKernel::ParseFeatureMap(
|
||||||
|
const schema::RequestUpdateModel *update_model_req) {
|
||||||
|
RETURN_IF_NULL(update_model_req, {});
|
||||||
|
std::map<std::string, UploadData> feature_map;
|
||||||
|
auto fbs_feature_map = update_model_req->feature_map();
|
||||||
|
for (size_t i = 0; i < fbs_feature_map->size(); i++) {
|
||||||
|
std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str();
|
||||||
|
float *weight_data = const_cast<float *>(fbs_feature_map->Get(i)->data()->data());
|
||||||
|
size_t weight_size = fbs_feature_map->Get(i)->data()->size() * sizeof(float);
|
||||||
|
UploadData upload_data;
|
||||||
|
upload_data[kNewWeight].addr = weight_data;
|
||||||
|
upload_data[kNewWeight].size = weight_size;
|
||||||
|
feature_map[weight_full_name] = upload_data;
|
||||||
|
}
|
||||||
|
return feature_map;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
|
const schema::RequestUpdateModel *update_model_req) {
|
||||||
|
if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str())) {
|
||||||
|
std::string reason = "UpdateModel counting failed.";
|
||||||
|
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||||
|
const std::string &reason, const std::string &next_req_time) {
|
||||||
|
auto fbs_reason = fbb->CreateString(reason);
|
||||||
|
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
||||||
|
|
||||||
|
schema::ResponseUpdateModelBuilder rsp_update_model_builder(*(fbb.get()));
|
||||||
|
rsp_update_model_builder.add_retcode(retcode);
|
||||||
|
rsp_update_model_builder.add_reason(fbs_reason);
|
||||||
|
rsp_update_model_builder.add_next_req_time(fbs_next_req_time);
|
||||||
|
auto rsp_update_model = rsp_update_model_builder.Finish();
|
||||||
|
fbb->Finish(rsp_update_model);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
REG_ROUND_KERNEL(updateModel, UpdateModelKernel)
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,64 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "ps/server/common.h"
|
||||||
|
#include "ps/server/kernel/round/round_kernel.h"
|
||||||
|
#include "ps/server/kernel/round/round_kernel_factory.h"
|
||||||
|
#include "ps/server/executor.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
namespace kernel {
|
||||||
|
class UpdateModelKernel : public RoundKernel {
|
||||||
|
public:
|
||||||
|
UpdateModelKernel() = default;
|
||||||
|
~UpdateModelKernel() override = default;
|
||||||
|
|
||||||
|
void InitKernel(size_t threshold_count) override;
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs);
|
||||||
|
bool Reset() override;
|
||||||
|
|
||||||
|
// In some cases, the last updateModel message means this server iteration is finished.
|
||||||
|
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb);
|
||||||
|
bool UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr<FBBuilder> &fbb);
|
||||||
|
std::map<std::string, UploadData> ParseFeatureMap(const schema::RequestUpdateModel *update_model_req);
|
||||||
|
bool CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestUpdateModel *update_model_req);
|
||||||
|
void BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||||
|
const std::string &reason, const std::string &next_req_time);
|
||||||
|
|
||||||
|
// The executor is for updating the model for updateModel request.
|
||||||
|
Executor *executor_;
|
||||||
|
|
||||||
|
// The time window of one iteration.
|
||||||
|
size_t iteration_time_window_;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
|
@ -23,7 +23,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace ps {
|
||||||
namespace server {
|
namespace server {
|
||||||
void ModelStore::Init(uint32_t max_count) {
|
void ModelStore::Initialize(uint32_t max_count) {
|
||||||
if (!Executor::GetInstance().initialized()) {
|
if (!Executor::GetInstance().initialized()) {
|
||||||
MS_LOG(EXCEPTION) << "Server's executor must be initialized before model storage.";
|
MS_LOG(EXCEPTION) << "Server's executor must be initialized before model storage.";
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -40,7 +40,7 @@ class ModelStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize ModelStore with max count of models need to be stored.
|
// Initialize ModelStore with max count of models need to be stored.
|
||||||
void Init(uint32_t max_count = 3);
|
void Initialize(uint32_t max_count = 3);
|
||||||
|
|
||||||
// Store the model of the given iteration. The model is acquired from Executor. If the current model count is already
|
// Store the model of the given iteration. The model is acquired from Executor. If the current model count is already
|
||||||
// max_model_count_, the earliest model will be replaced.
|
// max_model_count_, the earliest model will be replaced.
|
||||||
|
|
|
@ -302,7 +302,7 @@ bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr<ke
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &cnode) {
|
std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &cnode) {
|
||||||
std::vector<std::string> aggregation_algorithm = {};
|
std::vector<std::string> aggregation_algorithm = {kFedAvg};
|
||||||
MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm;
|
MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm;
|
||||||
return aggregation_algorithm;
|
return aggregation_algorithm;
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,251 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ps/server/server.h"
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <csignal>
|
||||||
|
#include "ps/server/round.h"
|
||||||
|
#include "ps/server/model_store.h"
|
||||||
|
#include "ps/server/iteration.h"
|
||||||
|
#include "ps/server/collective_ops_impl.h"
|
||||||
|
#include "ps/server/distributed_metadata_store.h"
|
||||||
|
#include "ps/server/distributed_count_service.h"
|
||||||
|
#include "ps/server/kernel/round/round_kernel_factory.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
static std::vector<std::shared_ptr<core::CommunicatorBase>> global_worker_server_comms = {};
|
||||||
|
// This function is for the exit of server process when an interrupt signal is captured.
|
||||||
|
void SignalHandler(int signal) {
|
||||||
|
MS_LOG(INFO) << "Interrupt signal captured: " << signal;
|
||||||
|
std::for_each(global_worker_server_comms.begin(), global_worker_server_comms.end(),
|
||||||
|
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
|
||||||
|
const FuncGraphPtr &func_graph, size_t executor_threshold) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
func_graph_ = func_graph;
|
||||||
|
|
||||||
|
if (rounds_config.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Rounds are empty.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rounds_config_ = rounds_config;
|
||||||
|
|
||||||
|
use_tcp_ = use_tcp;
|
||||||
|
use_http_ = use_http;
|
||||||
|
http_port_ = http_port;
|
||||||
|
executor_threshold_ = executor_threshold;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Each step of the server pipeline may have dependency on other steps, which includes:
|
||||||
|
|
||||||
|
// InitServerContext must be the first step to set contexts for later steps.
|
||||||
|
|
||||||
|
// Server Running relies on URL or Message Type Register:
|
||||||
|
// StartCommunicator---->InitIteration
|
||||||
|
|
||||||
|
// Metadata Register relies on Hash Ring of Servers which relies on Network Building Completion:
|
||||||
|
// RegisterRoundKernel---->StartCommunicator
|
||||||
|
|
||||||
|
// Kernel Initialization relies on Executor Initialization:
|
||||||
|
// RegisterRoundKernel---->InitExecutor
|
||||||
|
|
||||||
|
// Getting Model Size relies on ModelStorage Initialization which relies on Executor Initialization:
|
||||||
|
// InitCipher---->InitExecutor
|
||||||
|
void Server::Run() {
|
||||||
|
signal(SIGINT, SignalHandler);
|
||||||
|
InitServerContext();
|
||||||
|
InitCluster();
|
||||||
|
InitIteration();
|
||||||
|
StartCommunicator();
|
||||||
|
InitExecutor();
|
||||||
|
RegisterRoundKernel();
|
||||||
|
MS_LOG(INFO) << "Server started successfully.";
|
||||||
|
|
||||||
|
// Wait communicators to stop so the main thread is blocked.
|
||||||
|
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||||
|
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Join(); });
|
||||||
|
communicator_with_server_->Join();
|
||||||
|
MsException::Instance().CheckException();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Server::InitServerContext() {
|
||||||
|
PSContext::instance()->GenerateResetterRound();
|
||||||
|
scheduler_ip_ = PSContext::instance()->scheduler_host();
|
||||||
|
scheduler_port_ = PSContext::instance()->scheduler_port();
|
||||||
|
worker_num_ = PSContext::instance()->initial_worker_num();
|
||||||
|
server_num_ = PSContext::instance()->initial_server_num();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Server::InitCluster() {
|
||||||
|
server_node_ = std::make_shared<core::ServerNode>();
|
||||||
|
MS_EXCEPTION_IF_NULL(server_node_);
|
||||||
|
task_executor_ = std::make_shared<core::TaskExecutor>(32);
|
||||||
|
MS_EXCEPTION_IF_NULL(task_executor_);
|
||||||
|
if (!InitCommunicatorWithServer()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Initializing cross-server communicator failed.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!InitCommunicatorWithWorker()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Initializing worker-server communicator failed.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
global_worker_server_comms = communicators_with_worker_;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Server::InitCommunicatorWithServer() {
|
||||||
|
MS_EXCEPTION_IF_NULL(task_executor_);
|
||||||
|
MS_EXCEPTION_IF_NULL(server_node_);
|
||||||
|
|
||||||
|
communicator_with_server_ =
|
||||||
|
server_node_->GetOrCreateTcpComm(scheduler_ip_, scheduler_port_, worker_num_, server_num_, task_executor_);
|
||||||
|
MS_EXCEPTION_IF_NULL(communicator_with_server_);
|
||||||
|
|
||||||
|
// Set exception event callbacks for server.
|
||||||
|
auto tcp_comm = std::dynamic_pointer_cast<core::TcpCommunicator>(communicator_with_server_);
|
||||||
|
MS_EXCEPTION_IF_NULL(tcp_comm);
|
||||||
|
|
||||||
|
tcp_comm->RegisterEventCallback(core::CLUSTER_TIMEOUT, [&]() {
|
||||||
|
MS_LOG(ERROR) << "Event CLUSTER_TIMEOUT is captured. This is because some nodes(Scheduler/Server/Worker) are not "
|
||||||
|
"started during network building phase.";
|
||||||
|
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||||
|
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||||
|
communicator_with_server_->Stop();
|
||||||
|
});
|
||||||
|
|
||||||
|
tcp_comm->RegisterEventCallback(core::SCHEDULER_TIMEOUT, [&]() {
|
||||||
|
MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
|
||||||
|
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||||
|
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||||
|
communicator_with_server_->Stop();
|
||||||
|
});
|
||||||
|
|
||||||
|
tcp_comm->RegisterEventCallback(core::NODE_TIMEOUT, [&]() {
|
||||||
|
MS_LOG(ERROR)
|
||||||
|
<< "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
|
||||||
|
"network building phase.";
|
||||||
|
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||||
|
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||||
|
communicator_with_server_->Stop();
|
||||||
|
});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Server::InitCommunicatorWithWorker() {
|
||||||
|
MS_EXCEPTION_IF_NULL(server_node_);
|
||||||
|
MS_EXCEPTION_IF_NULL(task_executor_);
|
||||||
|
if (!use_tcp_ && !use_http_) {
|
||||||
|
MS_LOG(EXCEPTION) << "At least one type of protocol should be set.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (use_tcp_) {
|
||||||
|
auto tcp_comm = communicator_with_server_;
|
||||||
|
MS_EXCEPTION_IF_NULL(tcp_comm);
|
||||||
|
communicators_with_worker_.push_back(tcp_comm);
|
||||||
|
}
|
||||||
|
if (use_http_) {
|
||||||
|
auto http_comm = server_node_->GetOrCreateHttpComm("0.0.0.0", http_port_, task_executor_);
|
||||||
|
MS_EXCEPTION_IF_NULL(http_comm);
|
||||||
|
communicators_with_worker_.push_back(http_comm);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Server::InitIteration() {
|
||||||
|
iteration_ = std::make_shared<Iteration>();
|
||||||
|
MS_EXCEPTION_IF_NULL(iteration_);
|
||||||
|
|
||||||
|
// 1.Add rounds to the iteration according to the server mode.
|
||||||
|
for (const RoundConfig &config : rounds_config_) {
|
||||||
|
std::shared_ptr<Round> round = std::make_shared<Round>(config.name, config.check_timeout, config.time_window,
|
||||||
|
config.check_count, config.threshold_count);
|
||||||
|
MS_LOG(INFO) << "Add round " << config.name << ", check_count: " << config.check_count
|
||||||
|
<< ", threshold:" << config.threshold_count;
|
||||||
|
iteration_->AddRound(round);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2.Initialize all the rounds.
|
||||||
|
TimeOutCb time_out_cb = std::bind(&Iteration::ProceedToNextIter, iteration_);
|
||||||
|
FinishIterCb finish_iter_cb = std::bind(&Iteration::ProceedToNextIter, iteration_);
|
||||||
|
iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Server::InitExecutor() {
|
||||||
|
if (executor_threshold_ == 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "The executor's threshold should greater than 0.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// The train engine instance is used in both push-type and pull-type kernels,
|
||||||
|
// so the required_cnt of these kernels must be the same as update_model_threshold_.
|
||||||
|
MS_LOG(INFO) << "Required count for push-type and pull-type kernels is " << executor_threshold_;
|
||||||
|
Executor::GetInstance().Initialize(func_graph_, executor_threshold_);
|
||||||
|
ModelStore::GetInstance().Initialize();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Server::RegisterRoundKernel() {
|
||||||
|
MS_EXCEPTION_IF_NULL(iteration_);
|
||||||
|
auto &rounds = iteration_->rounds();
|
||||||
|
if (rounds.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Server has no round registered.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &round : rounds) {
|
||||||
|
const std::string &name = round->name();
|
||||||
|
std::shared_ptr<kernel::RoundKernel> round_kernel = kernel::RoundKernelFactory::GetInstance().Create(name);
|
||||||
|
if (round_kernel == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Round kernel for round " << name << " is not registered.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For some round kernels, the threshold count should be set.
|
||||||
|
round_kernel->InitKernel(round->threshold_count());
|
||||||
|
round->BindRoundKernel(round_kernel);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Server::StartCommunicator() {
|
||||||
|
MS_EXCEPTION_IF_NULL(communicator_with_server_);
|
||||||
|
if (communicators_with_worker_.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Communicators for communication with worker is empty.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Start communicator with server.";
|
||||||
|
communicator_with_server_->Start();
|
||||||
|
DistributedMetadataStore::GetInstance().Initialize(server_node_);
|
||||||
|
CollectiveOpsImpl::GetInstance().Initialize(server_node_);
|
||||||
|
DistributedCountService::GetInstance().Initialize(server_node_, kLeaderServerRank);
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Start communicator with worker.";
|
||||||
|
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||||
|
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Start(); });
|
||||||
|
}
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,131 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PS_SERVER_SERVER_H_
|
||||||
|
#define MINDSPORE_CCSRC_PS_SERVER_SERVER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "ps/core/communicator/communicator_base.h"
|
||||||
|
#include "ps/core/communicator/tcp_communicator.h"
|
||||||
|
#include "ps/core/communicator/task_executor.h"
|
||||||
|
#include "ps/server/common.h"
|
||||||
|
#include "ps/server/executor.h"
|
||||||
|
#include "ps/server/iteration.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
// Class Server is the entrance of MindSpore's parameter server training mode and federated learning.
|
||||||
|
class Server {
|
||||||
|
public:
|
||||||
|
static Server &GetInstance() {
|
||||||
|
static Server instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
|
||||||
|
const FuncGraphPtr &func_graph, size_t executor_threshold);
|
||||||
|
|
||||||
|
// According to the current MindSpore framework, method Run is a step of the server pipeline. This method will be
|
||||||
|
// blocked until the server is finalized.
|
||||||
|
// func_graph is the frontend graph which will be parse in server's exector and aggregator.
|
||||||
|
void Run();
|
||||||
|
|
||||||
|
private:
|
||||||
|
Server()
|
||||||
|
: server_node_(nullptr),
|
||||||
|
task_executor_(nullptr),
|
||||||
|
use_tcp_(false),
|
||||||
|
use_http_(false),
|
||||||
|
http_port_(0),
|
||||||
|
func_graph_(nullptr),
|
||||||
|
executor_threshold_(0),
|
||||||
|
communicator_with_server_(nullptr),
|
||||||
|
communicators_with_worker_({}),
|
||||||
|
iteration_(nullptr),
|
||||||
|
scheduler_ip_(""),
|
||||||
|
scheduler_port_(0),
|
||||||
|
server_num_(0),
|
||||||
|
worker_num_(0) {}
|
||||||
|
~Server() = default;
|
||||||
|
Server(const Server &) = delete;
|
||||||
|
Server &operator=(const Server &) = delete;
|
||||||
|
|
||||||
|
// Load variables which is set by ps_context.
|
||||||
|
void InitServerContext();
|
||||||
|
|
||||||
|
// Initialize the server cluster, server node and communicators.
|
||||||
|
void InitCluster();
|
||||||
|
bool InitCommunicatorWithServer();
|
||||||
|
bool InitCommunicatorWithWorker();
|
||||||
|
|
||||||
|
// Initialize iteration with rounds. Which rounds to use could be set by ps_context as well.
|
||||||
|
void InitIteration();
|
||||||
|
|
||||||
|
// Initialize executor according to the server mode.
|
||||||
|
void InitExecutor();
|
||||||
|
|
||||||
|
// Create round kernels and bind these kernels with corresponding Round.
|
||||||
|
void RegisterRoundKernel();
|
||||||
|
|
||||||
|
// The communicators should be started after all initializations are completed.
|
||||||
|
void StartCommunicator();
|
||||||
|
|
||||||
|
// The server node is initialized in Server.
|
||||||
|
std::shared_ptr<core::ServerNode> server_node_;
|
||||||
|
|
||||||
|
// The task executor of the communicators. This helps server to handle network message concurrently. The tasks
|
||||||
|
// submitted to this task executor is asynchronous.
|
||||||
|
std::shared_ptr<core::TaskExecutor> task_executor_;
|
||||||
|
|
||||||
|
// Which protocol should communicators use.
|
||||||
|
bool use_tcp_;
|
||||||
|
bool use_http_;
|
||||||
|
uint64_t http_port_;
|
||||||
|
|
||||||
|
// The configure of all rounds.
|
||||||
|
std::vector<RoundConfig> rounds_config_;
|
||||||
|
|
||||||
|
// The graph passed by the frontend without backend optimizing.
|
||||||
|
FuncGraphPtr func_graph_;
|
||||||
|
|
||||||
|
// The threshold count for executor to do aggregation or optimizing.
|
||||||
|
size_t executor_threshold_;
|
||||||
|
|
||||||
|
// Server need a tcp communicator to communicate with other servers for counting, metadata storing, collective
|
||||||
|
// operations, etc.
|
||||||
|
std::shared_ptr<core::CommunicatorBase> communicator_with_server_;
|
||||||
|
|
||||||
|
// The communication with workers(including mobile devices), has multiple protocol types: HTTP and TCP.
|
||||||
|
// In some cases, both types should be supported in one distributed training job. So here we may have multiple
|
||||||
|
// communicators.
|
||||||
|
std::vector<std::shared_ptr<core::CommunicatorBase>> communicators_with_worker_;
|
||||||
|
|
||||||
|
// Iteration consists of multiple kinds of rounds.
|
||||||
|
std::shared_ptr<Iteration> iteration_;
|
||||||
|
|
||||||
|
// Variables set by ps context.
|
||||||
|
std::string scheduler_ip_;
|
||||||
|
uint16_t scheduler_port_;
|
||||||
|
uint32_t server_num_;
|
||||||
|
uint32_t worker_num_;
|
||||||
|
};
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_PS_SERVER_SERVER_H_
|
|
@ -33,7 +33,21 @@ def ps_context():
|
||||||
return _ps_context
|
return _ps_context
|
||||||
|
|
||||||
_set_ps_context_func_map = {
|
_set_ps_context_func_map = {
|
||||||
"enable_ps": ps_context().set_ps_enable
|
"server_mode": ps_context().set_server_mode,
|
||||||
|
"ms_role": ps_context().set_ms_role,
|
||||||
|
"enable_ps": ps_context().set_ps_enable,
|
||||||
|
"worker_num": ps_context().set_worker_num,
|
||||||
|
"server_num": ps_context().set_server_num,
|
||||||
|
"scheduler_ip": ps_context().set_scheduler_ip,
|
||||||
|
"scheduler_port": ps_context().set_scheduler_port,
|
||||||
|
"fl_server_port": ps_context().set_fl_server_port,
|
||||||
|
"enable_fl_client": ps_context().set_fl_client_enable,
|
||||||
|
"start_fl_job_threshold": ps_context().set_start_fl_job_threshold,
|
||||||
|
"fl_name": ps_context().set_fl_name,
|
||||||
|
"fl_iteration_num": ps_context().set_fl_iteration_num,
|
||||||
|
"client_epoch_num": ps_context().set_client_epoch_num,
|
||||||
|
"client_batch_size": ps_context().set_client_batch_size,
|
||||||
|
"secure_aggregation": ps_context().set_secure_aggregation
|
||||||
}
|
}
|
||||||
|
|
||||||
_get_ps_context_func_map = {
|
_get_ps_context_func_map = {
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Finish test_mobile_lenet.py case")
|
||||||
|
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||||
|
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
scheduler_port = args.scheduler_port
|
||||||
|
|
||||||
|
cmd = "pid=`ps -ef|grep \"scheduler_port=" + str(scheduler_port) + "\" "
|
||||||
|
cmd += " | grep -v \"grep\" | grep -v \"finish\" |awk '{print $2}'` && "
|
||||||
|
cmd += "for id in $pid; do kill -9 $id && echo \"killed $id\"; done"
|
||||||
|
|
||||||
|
subprocess.call(['bash', '-c', cmd])
|
|
@ -0,0 +1,52 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Run test_mobile_lenet.py case")
|
||||||
|
parser.add_argument("--device_target", type=str, default="CPU")
|
||||||
|
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
|
||||||
|
parser.add_argument("--worker_num", type=int, default=0)
|
||||||
|
parser.add_argument("--server_num", type=int, default=2)
|
||||||
|
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
||||||
|
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||||
|
parser.add_argument("--fl_server_port", type=int, default=6666)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
device_target = args.device_target
|
||||||
|
server_mode = args.server_mode
|
||||||
|
worker_num = args.worker_num
|
||||||
|
server_num = args.server_num
|
||||||
|
scheduler_ip = args.scheduler_ip
|
||||||
|
scheduler_port = args.scheduler_port
|
||||||
|
fl_server_port = args.fl_server_port
|
||||||
|
|
||||||
|
cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&"
|
||||||
|
cmd_sched += "mkdir ${execute_path}/scheduler/ &&"
|
||||||
|
cmd_sched += "cd ${execute_path}/scheduler/ || exit && export GLOG_v=1 &&"
|
||||||
|
cmd_sched += "python ${self_path}/../test_mobile_lenet.py"
|
||||||
|
cmd_sched += " --device_target=" + device_target
|
||||||
|
cmd_sched += " --server_mode=" + server_mode
|
||||||
|
cmd_sched += " --ms_role=MS_SCHED"
|
||||||
|
cmd_sched += " --worker_num=" + str(worker_num)
|
||||||
|
cmd_sched += " --server_num=" + str(server_num)
|
||||||
|
cmd_sched += " --scheduler_ip=" + scheduler_ip
|
||||||
|
cmd_sched += " --scheduler_port=" + str(scheduler_port)
|
||||||
|
cmd_sched += " --fl_server_port=" + str(fl_server_port)
|
||||||
|
cmd_sched += " > scheduler.log 2>&1 &"
|
||||||
|
|
||||||
|
subprocess.call(['bash', '-c', cmd_sched])
|
|
@ -0,0 +1,82 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import argparse
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Run test_mobile_lenet.py case")
|
||||||
|
parser.add_argument("--device_target", type=str, default="CPU")
|
||||||
|
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
|
||||||
|
parser.add_argument("--worker_num", type=int, default=0)
|
||||||
|
parser.add_argument("--server_num", type=int, default=2)
|
||||||
|
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
||||||
|
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||||
|
parser.add_argument("--fl_server_port", type=int, default=6666)
|
||||||
|
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
|
||||||
|
parser.add_argument("--fl_name", type=str, default="Lenet")
|
||||||
|
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||||
|
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||||
|
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||||
|
parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False)
|
||||||
|
parser.add_argument("--local_server_num", type=int, default=-1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
device_target = args.device_target
|
||||||
|
server_mode = args.server_mode
|
||||||
|
worker_num = args.worker_num
|
||||||
|
server_num = args.server_num
|
||||||
|
scheduler_ip = args.scheduler_ip
|
||||||
|
scheduler_port = args.scheduler_port
|
||||||
|
fl_server_port = args.fl_server_port
|
||||||
|
start_fl_job_threshold = args.start_fl_job_threshold
|
||||||
|
fl_name = args.fl_name
|
||||||
|
fl_iteration_num = args.fl_iteration_num
|
||||||
|
client_epoch_num = args.client_epoch_num
|
||||||
|
client_batch_size = args.client_batch_size
|
||||||
|
secure_aggregation = args.secure_aggregation
|
||||||
|
local_server_num = args.local_server_num
|
||||||
|
|
||||||
|
if local_server_num == -1:
|
||||||
|
local_server_num = server_num
|
||||||
|
|
||||||
|
assert local_server_num <= server_num, "The local server number should not be bigger than total server number."
|
||||||
|
|
||||||
|
for i in range(local_server_num):
|
||||||
|
cmd_server = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && "
|
||||||
|
cmd_server += "rm -rf ${execute_path}/server_" + str(i) + "/ &&"
|
||||||
|
cmd_server += "mkdir ${execute_path}/server_" + str(i) + "/ &&"
|
||||||
|
cmd_server += "cd ${execute_path}/server_" + str(i) + "/ || exit && export GLOG_v=1 &&"
|
||||||
|
cmd_server += "python ${self_path}/../test_mobile_lenet.py"
|
||||||
|
cmd_server += " --device_target=" + device_target
|
||||||
|
cmd_server += " --server_mode=" + server_mode
|
||||||
|
cmd_server += " --ms_role=MS_SERVER"
|
||||||
|
cmd_server += " --worker_num=" + str(worker_num)
|
||||||
|
cmd_server += " --server_num=" + str(server_num)
|
||||||
|
cmd_server += " --scheduler_ip=" + scheduler_ip
|
||||||
|
cmd_server += " --scheduler_port=" + str(scheduler_port)
|
||||||
|
cmd_server += " --fl_server_port=" + str(fl_server_port + i)
|
||||||
|
cmd_server += " --start_fl_job_threshold=" + str(start_fl_job_threshold)
|
||||||
|
cmd_server += " --fl_name=" + fl_name
|
||||||
|
cmd_server += " --fl_iteration_num=" + str(fl_iteration_num)
|
||||||
|
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
|
||||||
|
cmd_server += " --client_batch_size=" + str(client_batch_size)
|
||||||
|
cmd_server += " --secure_aggregation=" + str(secure_aggregation)
|
||||||
|
cmd_server += " > server.log 2>&1 &"
|
||||||
|
|
||||||
|
import time
|
||||||
|
time.sleep(0.3)
|
||||||
|
subprocess.call(['bash', '-c', cmd_server])
|
|
@ -0,0 +1,29 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
export PYTHONPATH=../../../../:$PYTHONPATH
|
||||||
|
server_num=$1
|
||||||
|
worker_num=$2
|
||||||
|
ip=$3
|
||||||
|
port=$4
|
||||||
|
|
||||||
|
for((i=0;i<worker_num;i++));
|
||||||
|
do
|
||||||
|
ofs=`expr $i % $server_num`
|
||||||
|
real_port=`expr $port + $ofs`
|
||||||
|
echo $real_port
|
||||||
|
python simulator.py --pid=$i --http_ip=$ip --http_port=$port --use_elb=True --server_num=$1 > simulator_$i.log 2>&1 &
|
||||||
|
done
|
|
@ -0,0 +1,191 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import requests
|
||||||
|
import flatbuffers
|
||||||
|
import numpy as np
|
||||||
|
from mindspore.schema import (RequestFLJob, ResponseFLJob, ResponseCode,
|
||||||
|
RequestUpdateModel, FeatureMap, RequestGetModel, ResponseGetModel)
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--pid", type=int, default=0)
|
||||||
|
parser.add_argument("--http_ip", type=str, default="10.113.216.106")
|
||||||
|
parser.add_argument("--http_port", type=int, default=6666)
|
||||||
|
parser.add_argument("--use_elb", type=bool, default=False)
|
||||||
|
parser.add_argument("--server_num", type=int, default=1)
|
||||||
|
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
pid = args.pid
|
||||||
|
http_ip = args.http_ip
|
||||||
|
http_port = args.http_port
|
||||||
|
use_elb = args.use_elb
|
||||||
|
server_num = args.server_num
|
||||||
|
|
||||||
|
str_fl_id = 'fl_lenet_' + str(pid)
|
||||||
|
|
||||||
|
def generate_port():
|
||||||
|
if not use_elb:
|
||||||
|
return http_port
|
||||||
|
port = random.randint(0, 100000) % server_num + http_port
|
||||||
|
return port
|
||||||
|
|
||||||
|
|
||||||
|
def build_start_fl_job(iteration):
|
||||||
|
start_fl_job_builder = flatbuffers.Builder(1024)
|
||||||
|
|
||||||
|
fl_name = start_fl_job_builder.CreateString('fl_test_job')
|
||||||
|
fl_id = start_fl_job_builder.CreateString(str_fl_id)
|
||||||
|
data_size = 32
|
||||||
|
timestamp = start_fl_job_builder.CreateString('2020/11/16/19/18')
|
||||||
|
|
||||||
|
RequestFLJob.RequestFLJobStart(start_fl_job_builder)
|
||||||
|
RequestFLJob.RequestFLJobAddFlName(start_fl_job_builder, fl_name)
|
||||||
|
RequestFLJob.RequestFLJobAddFlId(start_fl_job_builder, fl_id)
|
||||||
|
RequestFLJob.RequestFLJobAddIteration(start_fl_job_builder, iteration)
|
||||||
|
RequestFLJob.RequestFLJobAddDataSize(start_fl_job_builder, data_size)
|
||||||
|
RequestFLJob.RequestFLJobAddTimestamp(start_fl_job_builder, timestamp)
|
||||||
|
fl_job_req = RequestFLJob.RequestFLJobEnd(start_fl_job_builder)
|
||||||
|
|
||||||
|
start_fl_job_builder.Finish(fl_job_req)
|
||||||
|
buf = start_fl_job_builder.Output()
|
||||||
|
return buf
|
||||||
|
|
||||||
|
def build_feature_map(builder, names, lengths):
|
||||||
|
if len(names) != len(lengths):
|
||||||
|
return None
|
||||||
|
feature_maps = []
|
||||||
|
np_data = []
|
||||||
|
for j, _ in enumerate(names):
|
||||||
|
name = names[j]
|
||||||
|
length = lengths[j]
|
||||||
|
weight_full_name = builder.CreateString(name)
|
||||||
|
FeatureMap.FeatureMapStartDataVector(builder, length)
|
||||||
|
weight = np.random.rand(length) * 32
|
||||||
|
np_data.append(weight)
|
||||||
|
for idx in range(length - 1, -1, -1):
|
||||||
|
builder.PrependFloat32(weight[idx])
|
||||||
|
data = builder.EndVector(length)
|
||||||
|
FeatureMap.FeatureMapStart(builder)
|
||||||
|
FeatureMap.FeatureMapAddData(builder, data)
|
||||||
|
FeatureMap.FeatureMapAddWeightFullname(builder, weight_full_name)
|
||||||
|
feature_map = FeatureMap.FeatureMapEnd(builder)
|
||||||
|
feature_maps.append(feature_map)
|
||||||
|
return feature_maps, np_data
|
||||||
|
|
||||||
|
def build_update_model(iteration):
|
||||||
|
builder_update_model = flatbuffers.Builder(1)
|
||||||
|
fl_name = builder_update_model.CreateString('fl_test_job')
|
||||||
|
fl_id = builder_update_model.CreateString(str_fl_id)
|
||||||
|
timestamp = builder_update_model.CreateString('2020/11/16/19/18')
|
||||||
|
|
||||||
|
feature_maps, np_data = build_feature_map(builder_update_model,
|
||||||
|
["conv1.weight", "conv2.weight", "fc1.weight",
|
||||||
|
"fc2.weight", "fc3.weight", "fc1.bias", "fc2.bias", "fc3.bias"],
|
||||||
|
[450, 2400, 48000, 10080, 5208, 120, 84, 62])
|
||||||
|
|
||||||
|
RequestUpdateModel.RequestUpdateModelStartFeatureMapVector(builder_update_model, 1)
|
||||||
|
for single_feature_map in feature_maps:
|
||||||
|
builder_update_model.PrependUOffsetTRelative(single_feature_map)
|
||||||
|
feature_map = builder_update_model.EndVector(len(feature_maps))
|
||||||
|
|
||||||
|
RequestUpdateModel.RequestUpdateModelStart(builder_update_model)
|
||||||
|
RequestUpdateModel.RequestUpdateModelAddFlName(builder_update_model, fl_name)
|
||||||
|
RequestUpdateModel.RequestUpdateModelAddFlId(builder_update_model, fl_id)
|
||||||
|
RequestUpdateModel.RequestUpdateModelAddIteration(builder_update_model, iteration)
|
||||||
|
RequestUpdateModel.RequestUpdateModelAddFeatureMap(builder_update_model, feature_map)
|
||||||
|
RequestUpdateModel.RequestUpdateModelAddTimestamp(builder_update_model, timestamp)
|
||||||
|
req_update_model = RequestUpdateModel.RequestUpdateModelEnd(builder_update_model)
|
||||||
|
builder_update_model.Finish(req_update_model)
|
||||||
|
buf = builder_update_model.Output()
|
||||||
|
return buf, np_data
|
||||||
|
|
||||||
|
def build_get_model(iteration):
|
||||||
|
builder_get_model = flatbuffers.Builder(1)
|
||||||
|
fl_name = builder_get_model.CreateString('fl_test_job')
|
||||||
|
timestamp = builder_get_model.CreateString('2020/12/16/19/18')
|
||||||
|
|
||||||
|
RequestGetModel.RequestGetModelStart(builder_get_model)
|
||||||
|
RequestGetModel.RequestGetModelAddFlName(builder_get_model, fl_name)
|
||||||
|
RequestGetModel.RequestGetModelAddIteration(builder_get_model, iteration)
|
||||||
|
RequestGetModel.RequestGetModelAddTimestamp(builder_get_model, timestamp)
|
||||||
|
req_get_model = RequestGetModel.RequestGetModelEnd(builder_get_model)
|
||||||
|
builder_get_model.Finish(req_get_model)
|
||||||
|
buf = builder_get_model.Output()
|
||||||
|
return buf
|
||||||
|
|
||||||
|
weight_name_to_idx = {
|
||||||
|
"conv1.weight": 0,
|
||||||
|
"conv2.weight": 1,
|
||||||
|
"fc1.weight": 2,
|
||||||
|
"fc2.weight": 3,
|
||||||
|
"fc3.weight": 4,
|
||||||
|
"fc1.bias": 5,
|
||||||
|
"fc2.bias": 6,
|
||||||
|
"fc3.bias": 7
|
||||||
|
}
|
||||||
|
|
||||||
|
session = requests.Session()
|
||||||
|
current_iteration = 1
|
||||||
|
url = "http://" + http_ip + ":" + str(generate_port())
|
||||||
|
np.random.seed(0)
|
||||||
|
while True:
|
||||||
|
url1 = "http://" + http_ip + ":" + str(generate_port()) + '/startFLJob'
|
||||||
|
print("start url is ", url1)
|
||||||
|
x = requests.post(url1, data=build_start_fl_job(current_iteration))
|
||||||
|
rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
|
||||||
|
print("start fl job iteration:", current_iteration, ", id:", args.pid)
|
||||||
|
while rsp_fl_job.Retcode() != ResponseCode.ResponseCode.SUCCEED:
|
||||||
|
x = requests.post(url1, data=build_start_fl_job(current_iteration))
|
||||||
|
rsp_fl_job = rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
|
||||||
|
print("epoch is", rsp_fl_job.FlPlanConfig().Epochs())
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
url2 = "http://" + http_ip + ":" + str(generate_port()) + '/updateModel'
|
||||||
|
print("req update model iteration:", current_iteration, ", id:", args.pid)
|
||||||
|
update_model_buf, update_model_np_data = build_update_model(current_iteration)
|
||||||
|
x = session.post(url2, data=update_model_buf)
|
||||||
|
print("rsp update model iteration:", current_iteration, ", id:", args.pid)
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
url3 = "http://" + http_ip + ":" + str(generate_port()) + '/getModel'
|
||||||
|
print("req get model iteration:", current_iteration, ", id:", args.pid)
|
||||||
|
x = session.post(url3, data=build_get_model(current_iteration))
|
||||||
|
rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
|
||||||
|
print("rsp get model iteration:", current_iteration, ", id:", args.pid, rsp_get_model.Retcode())
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
repeat_time = 0
|
||||||
|
while rsp_get_model.Retcode() == ResponseCode.ResponseCode.SucNotReady:
|
||||||
|
time.sleep(0.1)
|
||||||
|
x = session.post(url3, data=build_get_model(current_iteration))
|
||||||
|
rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
|
||||||
|
repeat_time += 1
|
||||||
|
if repeat_time > 1000:
|
||||||
|
print("GetModel try timeout ", args.pid)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
for i in range(0, 1):
|
||||||
|
print(rsp_get_model.FeatureMap(i).WeightFullname())
|
||||||
|
origin = update_model_np_data[weight_name_to_idx[rsp_get_model.FeatureMap(i).WeightFullname().decode('utf-8')]]
|
||||||
|
after = rsp_get_model.FeatureMap(i).DataAsNumpy() * 32
|
||||||
|
print("Before update model", args.pid, origin[0:10])
|
||||||
|
print("After get model", args.pid, after[0:10])
|
||||||
|
sys.stdout.flush()
|
||||||
|
assert np.allclose(origin, after, rtol=1e-05, atol=1e-05)
|
||||||
|
current_iteration += 1
|
|
@ -0,0 +1,423 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""AdamWeightDecayForBert, a customized Adam for bert. Input: gradient, overflow flag."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.common.tensor import Tensor
|
||||||
|
from mindspore._checkparam import Validator as validator
|
||||||
|
from mindspore._checkparam import Rel
|
||||||
|
from mindspore.nn.optim.optimizer import Optimizer
|
||||||
|
|
||||||
|
_adam_opt = C.MultitypeFuncGraph("adam_opt")
|
||||||
|
_scaler_one = Tensor(1, mstype.int32)
|
||||||
|
_scaler_ten = Tensor(10, mstype.float32)
|
||||||
|
|
||||||
|
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
|
||||||
|
"Tensor", "Bool", "Bool")
|
||||||
|
def _update_run_kernel(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter):
|
||||||
|
"""
|
||||||
|
Update parameters by AdamWeightDecay op.
|
||||||
|
"""
|
||||||
|
if optim_filter:
|
||||||
|
adam = P.AdamWeightDecay()
|
||||||
|
if decay_flags:
|
||||||
|
next_param = adam(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient)
|
||||||
|
else:
|
||||||
|
next_param = adam(param, m, v, lr, beta1, beta2, eps, 0.0, gradient)
|
||||||
|
return next_param
|
||||||
|
return gradient
|
||||||
|
|
||||||
|
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
|
||||||
|
"Tensor", "Bool", "Bool")
|
||||||
|
def _update_run_op(beta1, beta2, eps, lr, overflow, weight_decay, param, m, v, gradient, decay_flag, optim_filter):
|
||||||
|
"""
|
||||||
|
Update parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
|
||||||
|
beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
|
||||||
|
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
|
||||||
|
lr (Tensor): Learning rate.
|
||||||
|
overflow (Tensor): Whether overflow occurs.
|
||||||
|
weight_decay (Number): Weight decay. Should be equal to or greater than 0.
|
||||||
|
param (Tensor): Parameters.
|
||||||
|
m (Tensor): m value of parameters.
|
||||||
|
v (Tensor): v value of parameters.
|
||||||
|
gradient (Tensor): Gradient of parameters.
|
||||||
|
decay_flag (bool): Applies weight decay or not.
|
||||||
|
optim_filter (bool): Applies parameter update or not.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, the new value of v after updating.
|
||||||
|
"""
|
||||||
|
if optim_filter:
|
||||||
|
op_mul = P.Mul()
|
||||||
|
op_square = P.Square()
|
||||||
|
op_sqrt = P.Sqrt()
|
||||||
|
op_cast = P.Cast()
|
||||||
|
op_reshape = P.Reshape()
|
||||||
|
op_shape = P.Shape()
|
||||||
|
op_select = P.Select()
|
||||||
|
|
||||||
|
param_fp32 = op_cast(param, mstype.float32)
|
||||||
|
m_fp32 = op_cast(m, mstype.float32)
|
||||||
|
v_fp32 = op_cast(v, mstype.float32)
|
||||||
|
gradient_fp32 = op_cast(gradient, mstype.float32)
|
||||||
|
|
||||||
|
cond = op_cast(F.fill(mstype.int32, op_shape(m_fp32), 1) * op_reshape(overflow, (())), mstype.bool_)
|
||||||
|
next_m = op_mul(beta1, m_fp32) + op_select(cond, m_fp32,\
|
||||||
|
op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32))
|
||||||
|
|
||||||
|
next_v = op_mul(beta2, v_fp32) + op_select(cond, v_fp32,\
|
||||||
|
op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient_fp32)))
|
||||||
|
|
||||||
|
update = next_m / (eps + op_sqrt(next_v))
|
||||||
|
if decay_flag:
|
||||||
|
update = op_mul(weight_decay, param_fp32) + update
|
||||||
|
|
||||||
|
update_with_lr = op_mul(lr, update)
|
||||||
|
zeros = F.fill(mstype.float32, op_shape(param_fp32), 0)
|
||||||
|
next_param = param_fp32 - op_select(cond, zeros, op_reshape(update_with_lr, op_shape(param_fp32)))
|
||||||
|
|
||||||
|
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
|
||||||
|
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
|
||||||
|
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
|
||||||
|
|
||||||
|
return op_cast(next_param, F.dtype(param))
|
||||||
|
return gradient
|
||||||
|
|
||||||
|
|
||||||
|
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
|
||||||
|
"Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
|
||||||
|
def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
|
||||||
|
beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable):
|
||||||
|
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
|
||||||
|
success = True
|
||||||
|
indices = gradient.indices
|
||||||
|
values = gradient.values
|
||||||
|
if ps_parameter and not cache_enable:
|
||||||
|
op_shape = P.Shape()
|
||||||
|
shapes = (op_shape(param), op_shape(m), op_shape(v),
|
||||||
|
op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
|
||||||
|
op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
|
||||||
|
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
|
||||||
|
eps, values, indices), shapes), param))
|
||||||
|
return success
|
||||||
|
|
||||||
|
if not target:
|
||||||
|
success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2,
|
||||||
|
eps, values, indices))
|
||||||
|
else:
|
||||||
|
op_mul = P.Mul()
|
||||||
|
op_square = P.Square()
|
||||||
|
op_sqrt = P.Sqrt()
|
||||||
|
scatter_add = P.ScatterAdd(use_locking)
|
||||||
|
|
||||||
|
assign_m = F.assign(m, op_mul(beta1, m))
|
||||||
|
assign_v = F.assign(v, op_mul(beta2, v))
|
||||||
|
|
||||||
|
grad_indices = gradient.indices
|
||||||
|
grad_value = gradient.values
|
||||||
|
|
||||||
|
next_m = scatter_add(m,
|
||||||
|
grad_indices,
|
||||||
|
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
|
||||||
|
|
||||||
|
next_v = scatter_add(v,
|
||||||
|
grad_indices,
|
||||||
|
op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value)))
|
||||||
|
|
||||||
|
if use_nesterov:
|
||||||
|
m_temp = next_m * _scaler_ten
|
||||||
|
assign_m_nesterov = F.assign(m, op_mul(beta1, next_m))
|
||||||
|
div_value = scatter_add(m,
|
||||||
|
op_mul(grad_indices, _scaler_one),
|
||||||
|
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
|
||||||
|
param_update = div_value / (op_sqrt(next_v) + eps)
|
||||||
|
|
||||||
|
m_recover = F.assign(m, m_temp / _scaler_ten)
|
||||||
|
|
||||||
|
F.control_depend(m_temp, assign_m_nesterov)
|
||||||
|
F.control_depend(assign_m_nesterov, div_value)
|
||||||
|
F.control_depend(param_update, m_recover)
|
||||||
|
else:
|
||||||
|
param_update = next_m / (op_sqrt(next_v) + eps)
|
||||||
|
|
||||||
|
lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
|
||||||
|
|
||||||
|
next_param = param - lr_t * param_update
|
||||||
|
|
||||||
|
F.control_depend(assign_m, next_m)
|
||||||
|
F.control_depend(assign_v, next_v)
|
||||||
|
|
||||||
|
success = F.depend(success, F.assign(param, next_param))
|
||||||
|
success = F.depend(success, F.assign(m, next_m))
|
||||||
|
success = F.depend(success, F.assign(v, next_v))
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
|
|
||||||
|
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
|
||||||
|
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
|
||||||
|
def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target,
|
||||||
|
beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param,
|
||||||
|
moment1, moment2, ps_parameter, cache_enable):
|
||||||
|
"""Apply adam optimizer to the weight parameter using Tensor."""
|
||||||
|
success = True
|
||||||
|
if ps_parameter and not cache_enable:
|
||||||
|
op_shape = P.Shape()
|
||||||
|
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
|
||||||
|
(op_shape(param), op_shape(moment1), op_shape(moment2))), param))
|
||||||
|
else:
|
||||||
|
success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
|
||||||
|
eps, gradient))
|
||||||
|
return success
|
||||||
|
|
||||||
|
|
||||||
|
@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||||
|
"Tensor", "Tensor")
|
||||||
|
def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2):
|
||||||
|
"""Apply AdamOffload optimizer to the weight parameter using Tensor."""
|
||||||
|
success = True
|
||||||
|
delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient)
|
||||||
|
success = F.depend(success, F.assign_add(param, delat_param))
|
||||||
|
return success
|
||||||
|
|
||||||
|
|
||||||
|
def _check_param_value(beta1, beta2, eps, prim_name):
|
||||||
|
"""Check the type of inputs."""
|
||||||
|
validator.check_value_type("beta1", beta1, [float], prim_name)
|
||||||
|
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||||
|
validator.check_value_type("eps", eps, [float], prim_name)
|
||||||
|
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
|
||||||
|
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
|
||||||
|
validator.check_positive_float(eps, "eps", prim_name)
|
||||||
|
|
||||||
|
class AdamWeightDecayForBert(Optimizer):
|
||||||
|
"""
|
||||||
|
Implements the Adam algorithm to fix the weight decay.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
|
||||||
|
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
|
||||||
|
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
|
||||||
|
|
||||||
|
To improve parameter groups performance, the customized order of parameters can be supported.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
|
||||||
|
the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
|
||||||
|
"lr", "weight_decay" and "order_params" are the keys can be parsed.
|
||||||
|
|
||||||
|
- params: Required. The value must be a list of `Parameter`.
|
||||||
|
|
||||||
|
- lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
|
||||||
|
If not, the `learning_rate` in the API will be used.
|
||||||
|
|
||||||
|
- weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
|
||||||
|
will be used. If not, the `weight_decay` in the API will be used.
|
||||||
|
|
||||||
|
- order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
|
||||||
|
the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
|
||||||
|
which in the 'order_params' must be in one of group parameters.
|
||||||
|
|
||||||
|
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
|
||||||
|
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
|
||||||
|
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
|
||||||
|
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
|
||||||
|
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
|
||||||
|
dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
|
||||||
|
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
|
||||||
|
Default: 1e-3.
|
||||||
|
beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
|
||||||
|
Should be in range (0.0, 1.0).
|
||||||
|
beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
|
||||||
|
Should be in range (0.0, 1.0).
|
||||||
|
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
|
||||||
|
Should be greater than 0.
|
||||||
|
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||||
|
- **overflow** (tuple[Tensor]) - The overflow flag in dynamiclossscale.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
tuple[bool], all elements are True.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = Net()
|
||||||
|
>>> #1) All parameters use the same learning rate and weight decay
|
||||||
|
>>> optim = AdamWeightDecay(params=net.trainable_params())
|
||||||
|
>>>
|
||||||
|
>>> #2) Use parameter groups and set different values
|
||||||
|
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
|
||||||
|
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
|
||||||
|
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
|
||||||
|
... {'params': no_conv_params, 'lr': 0.01},
|
||||||
|
... {'order_params': net.trainable_params()}]
|
||||||
|
>>> optim = AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0)
|
||||||
|
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
|
||||||
|
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
|
||||||
|
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
|
||||||
|
>>>
|
||||||
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
|
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
||||||
|
"""
|
||||||
|
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
|
||||||
|
super(AdamWeightDecayForBert, self).__init__(learning_rate, params, weight_decay)
|
||||||
|
_check_param_value(beta1, beta2, eps, self.cls_name)
|
||||||
|
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
|
||||||
|
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
|
||||||
|
self.eps = Tensor(np.array([eps]).astype(np.float32))
|
||||||
|
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
|
||||||
|
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
|
||||||
|
self.hyper_map = C.HyperMap()
|
||||||
|
self.op_select = P.Select()
|
||||||
|
self.op_cast = P.Cast()
|
||||||
|
self.op_reshape = P.Reshape()
|
||||||
|
self.op_shape = P.Shape()
|
||||||
|
|
||||||
|
def construct(self, gradients, overflow):
|
||||||
|
"""AdamWeightDecayForBert"""
|
||||||
|
lr = self.get_lr()
|
||||||
|
cond = self.op_cast(F.fill(mstype.int32, self.op_shape(self.beta1), 1) *\
|
||||||
|
self.op_reshape(overflow, (())), mstype.bool_)
|
||||||
|
beta1 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta1)
|
||||||
|
beta2 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta2)
|
||||||
|
if self.is_group:
|
||||||
|
if self.is_group_lr:
|
||||||
|
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
|
||||||
|
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||||
|
gradients, self.decay_flags, self.optim_filter)
|
||||||
|
else:
|
||||||
|
optim_result = self.hyper_map(F.partial(_adam_opt, beta1, beta2, self.eps, lr, overflow),
|
||||||
|
self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||||
|
gradients, self.decay_flags, self.optim_filter)
|
||||||
|
else:
|
||||||
|
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
|
||||||
|
self.parameters, self.moments1, self.moments2,
|
||||||
|
gradients, self.decay_flags, self.optim_filter)
|
||||||
|
if self.use_parallel:
|
||||||
|
self.broadcast_params(optim_result)
|
||||||
|
return optim_result
|
||||||
|
|
||||||
|
class AdamWeightDecayOp(Optimizer):
|
||||||
|
"""
|
||||||
|
Implements the Adam algorithm to fix the weight decay. It is a complete operator, not a combination of other ops.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
|
||||||
|
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
|
||||||
|
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
|
||||||
|
|
||||||
|
To improve parameter groups performance, the customized order of parameters can be supported.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
|
||||||
|
the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
|
||||||
|
"lr", "weight_decay" and "order_params" are the keys can be parsed.
|
||||||
|
|
||||||
|
- params: Required. The value must be a list of `Parameter`.
|
||||||
|
|
||||||
|
- lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
|
||||||
|
If not, the `learning_rate` in the API will be used.
|
||||||
|
|
||||||
|
- weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
|
||||||
|
will be used. If not, the `weight_decay` in the API will be used.
|
||||||
|
|
||||||
|
- order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
|
||||||
|
the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
|
||||||
|
which in the 'order_params' must be in one of group parameters.
|
||||||
|
|
||||||
|
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
|
||||||
|
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
|
||||||
|
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
|
||||||
|
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
|
||||||
|
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
|
||||||
|
dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
|
||||||
|
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
|
||||||
|
Default: 1e-3.
|
||||||
|
beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
|
||||||
|
Should be in range (0.0, 1.0).
|
||||||
|
beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
|
||||||
|
Should be in range (0.0, 1.0).
|
||||||
|
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
|
||||||
|
Should be greater than 0.
|
||||||
|
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
tuple[bool], all elements are True.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``GPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = Net()
|
||||||
|
>>> #1) All parameters use the same learning rate and weight decay
|
||||||
|
>>> optim = AdamWeightDecayOp(params=net.trainable_params())
|
||||||
|
>>>
|
||||||
|
>>> #2) Use parameter groups and set different values
|
||||||
|
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
|
||||||
|
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
|
||||||
|
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
|
||||||
|
... {'params': no_conv_params, 'lr': 0.01},
|
||||||
|
... {'order_params': net.trainable_params()}]
|
||||||
|
>>> optim = AdamWeightDecayOp(group_params, learning_rate=0.1, weight_decay=0.0)
|
||||||
|
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
|
||||||
|
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
|
||||||
|
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
|
||||||
|
>>>
|
||||||
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
|
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
||||||
|
"""
|
||||||
|
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
|
||||||
|
super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay)
|
||||||
|
_check_param_value(beta1, beta2, eps, self.cls_name)
|
||||||
|
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
|
||||||
|
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
|
||||||
|
self.eps = Tensor(np.array([eps]).astype(np.float32))
|
||||||
|
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
|
||||||
|
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
|
||||||
|
self.hyper_map = C.HyperMap()
|
||||||
|
|
||||||
|
def construct(self, gradients):
|
||||||
|
"""AdamWeightDecayOp"""
|
||||||
|
lr = self.get_lr()
|
||||||
|
if self.is_group:
|
||||||
|
if self.is_group_lr:
|
||||||
|
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
|
||||||
|
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||||
|
gradients, self.decay_flags, self.optim_filter)
|
||||||
|
else:
|
||||||
|
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr),
|
||||||
|
self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||||
|
gradients, self.decay_flags, self.optim_filter)
|
||||||
|
else:
|
||||||
|
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
|
||||||
|
self.parameters, self.moments1, self.moments2,
|
||||||
|
gradients, self.decay_flags, self.optim_filter)
|
||||||
|
if self.use_parallel:
|
||||||
|
self.broadcast_params(optim_result)
|
||||||
|
return optim_result
|
|
@ -0,0 +1,72 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.common.initializer import TruncatedNormal
|
||||||
|
|
||||||
|
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||||
|
"""weight initial for conv layer"""
|
||||||
|
weight = weight_variable()
|
||||||
|
return nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
weight_init=weight,
|
||||||
|
has_bias=False,
|
||||||
|
pad_mode="valid",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fc_with_initialize(input_channels, out_channels):
|
||||||
|
"""weight initial for fc layer"""
|
||||||
|
weight = weight_variable()
|
||||||
|
bias = weight_variable()
|
||||||
|
return nn.Dense(input_channels, out_channels, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
def weight_variable():
|
||||||
|
"""weight initial"""
|
||||||
|
return TruncatedNormal(0.02)
|
||||||
|
|
||||||
|
|
||||||
|
class LeNet5(nn.Cell):
|
||||||
|
def __init__(self, num_class=10, channel=3):
|
||||||
|
super(LeNet5, self).__init__()
|
||||||
|
self.num_class = num_class
|
||||||
|
self.conv1 = conv(channel, 6, 5)
|
||||||
|
self.conv2 = conv(6, 16, 5)
|
||||||
|
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
|
||||||
|
self.fc2 = fc_with_initialize(120, 84)
|
||||||
|
self.fc3 = fc_with_initialize(84, self.num_class)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||||
|
self.flatten = nn.Flatten()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.max_pool2d(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.max_pool2d(x)
|
||||||
|
x = self.flatten(x)
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.fc3(x)
|
||||||
|
return x
|
|
@ -0,0 +1,96 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||||
|
from src.model import LeNet5
|
||||||
|
from src.adam import AdamWeightDecayOp
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="test_fl_lenet")
|
||||||
|
parser.add_argument("--device_target", type=str, default="CPU")
|
||||||
|
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
|
||||||
|
parser.add_argument("--ms_role", type=str, default="MS_WORKER")
|
||||||
|
parser.add_argument("--worker_num", type=int, default=0)
|
||||||
|
parser.add_argument("--server_num", type=int, default=1)
|
||||||
|
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
||||||
|
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||||
|
parser.add_argument("--fl_server_port", type=int, default=6666)
|
||||||
|
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
|
||||||
|
parser.add_argument("--fl_name", type=str, default="Lenet")
|
||||||
|
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||||
|
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||||
|
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||||
|
parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False)
|
||||||
|
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
device_target = args.device_target
|
||||||
|
server_mode = args.server_mode
|
||||||
|
ms_role = args.ms_role
|
||||||
|
worker_num = args.worker_num
|
||||||
|
server_num = args.server_num
|
||||||
|
scheduler_ip = args.scheduler_ip
|
||||||
|
scheduler_port = args.scheduler_port
|
||||||
|
fl_server_port = args.fl_server_port
|
||||||
|
start_fl_job_threshold = args.start_fl_job_threshold
|
||||||
|
fl_name = args.fl_name
|
||||||
|
fl_iteration_num = args.fl_iteration_num
|
||||||
|
client_epoch_num = args.client_epoch_num
|
||||||
|
client_batch_size = args.client_batch_size
|
||||||
|
secure_aggregation = args.secure_aggregation
|
||||||
|
|
||||||
|
ctx = {
|
||||||
|
"enable_ps": False,
|
||||||
|
"server_mode": server_mode,
|
||||||
|
"ms_role": ms_role,
|
||||||
|
"worker_num": worker_num,
|
||||||
|
"server_num": server_num,
|
||||||
|
"scheduler_ip": scheduler_ip,
|
||||||
|
"scheduler_port": scheduler_port,
|
||||||
|
"fl_server_port": fl_server_port,
|
||||||
|
"start_fl_job_threshold": start_fl_job_threshold,
|
||||||
|
"fl_name": fl_name,
|
||||||
|
"fl_iteration_num": fl_iteration_num,
|
||||||
|
"client_epoch_num": client_epoch_num,
|
||||||
|
"client_batch_size": client_batch_size,
|
||||||
|
"secure_aggregation": secure_aggregation
|
||||||
|
}
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)
|
||||||
|
context.set_ps_context(**ctx)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
epoch = 5
|
||||||
|
np.random.seed(0)
|
||||||
|
network = LeNet5(62)
|
||||||
|
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||||
|
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
|
||||||
|
net_adam_opt = AdamWeightDecayOp(network.trainable_params(), weight_decay=0.1)
|
||||||
|
net_with_criterion = WithLossCell(network, criterion)
|
||||||
|
train_network = TrainOneStepCell(net_with_criterion, net_opt)
|
||||||
|
train_network.set_train()
|
||||||
|
losses = []
|
||||||
|
|
||||||
|
for _ in range(epoch):
|
||||||
|
data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32))
|
||||||
|
label = Tensor(np.random.randint(0, 61, (32)).astype(np.int32))
|
||||||
|
loss = train_network(data, label).asnumpy()
|
||||||
|
losses.append(loss)
|
||||||
|
print(losses)
|
Loading…
Reference in New Issue