From a6f981455224ca1d8e2c3c402f3515d4fbdea2c5 Mon Sep 17 00:00:00 2001 From: ZPaC Date: Mon, 3 May 2021 21:49:42 +0800 Subject: [PATCH] Add Server part 3 --- mindspore/ccsrc/pipeline/jit/action.cc | 50 +++ mindspore/ccsrc/pipeline/jit/action.h | 4 + mindspore/ccsrc/pipeline/jit/init.cc | 19 +- mindspore/ccsrc/pipeline/jit/pipeline.cc | 6 + mindspore/ccsrc/ps/CMakeLists.txt | 4 + mindspore/ccsrc/ps/ps_context.cc | 129 +++++- mindspore/ccsrc/ps/ps_context.h | 79 +++- mindspore/ccsrc/ps/scheduler.cc | 3 + mindspore/ccsrc/ps/server/common.h | 14 + .../ps/server/distributed_count_service.cc | 26 +- .../ps/server/distributed_metadata_store.cc | 24 +- mindspore/ccsrc/ps/server/executor.cc | 2 +- mindspore/ccsrc/ps/server/executor.h | 2 +- .../ccsrc/ps/server/kernel/fed_avg_kernel.cc | 33 ++ .../ccsrc/ps/server/kernel/fed_avg_kernel.h | 179 ++++++++ .../server/kernel/round/get_model_kernel.cc | 125 ++++++ .../ps/server/kernel/round/get_model_kernel.h | 59 +++ .../kernel/round/start_fl_job_kernel.cc | 2 +- .../kernel/round/update_model_kernel.cc | 203 +++++++++ .../server/kernel/round/update_model_kernel.h | 64 +++ mindspore/ccsrc/ps/server/model_store.cc | 2 +- mindspore/ccsrc/ps/server/model_store.h | 2 +- .../ccsrc/ps/server/parameter_aggregator.cc | 2 +- mindspore/ccsrc/ps/server/server.cc | 251 +++++++++++ mindspore/ccsrc/ps/server/server.h | 131 ++++++ mindspore/parallel/_ps_context.py | 16 +- tests/st/fl/mobile/finish_mobile.py | 30 ++ tests/st/fl/mobile/run_mobile_sched.py | 52 +++ tests/st/fl/mobile/run_mobile_server.py | 82 ++++ tests/st/fl/mobile/run_smlt.sh | 29 ++ tests/st/fl/mobile/simulator.py | 191 ++++++++ tests/st/fl/mobile/src/adam.py | 423 ++++++++++++++++++ tests/st/fl/mobile/src/model.py | 72 +++ tests/st/fl/mobile/test_mobile_lenet.py | 96 ++++ 34 files changed, 2379 insertions(+), 27 deletions(-) create mode 100644 mindspore/ccsrc/ps/server/kernel/fed_avg_kernel.cc create mode 100644 mindspore/ccsrc/ps/server/kernel/fed_avg_kernel.h create mode 100644 mindspore/ccsrc/ps/server/kernel/round/get_model_kernel.cc create mode 100644 mindspore/ccsrc/ps/server/kernel/round/get_model_kernel.h create mode 100644 mindspore/ccsrc/ps/server/kernel/round/update_model_kernel.cc create mode 100644 mindspore/ccsrc/ps/server/kernel/round/update_model_kernel.h create mode 100644 mindspore/ccsrc/ps/server/server.cc create mode 100644 mindspore/ccsrc/ps/server/server.h create mode 100644 tests/st/fl/mobile/finish_mobile.py create mode 100644 tests/st/fl/mobile/run_mobile_sched.py create mode 100644 tests/st/fl/mobile/run_mobile_server.py create mode 100644 tests/st/fl/mobile/run_smlt.sh create mode 100644 tests/st/fl/mobile/simulator.py create mode 100644 tests/st/fl/mobile/src/adam.py create mode 100644 tests/st/fl/mobile/src/model.py create mode 100644 tests/st/fl/mobile/test_mobile_lenet.py diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index dd56ac35331..44db02baa9d 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -47,6 +47,7 @@ #include "ps/parameter_server.h" #include "ps/scheduler.h" #include "ps/worker.h" +#include "ps/server/server.h" #endif namespace mindspore { @@ -619,6 +620,47 @@ bool StartPSServerAction(const ResourcePtr &res) { return true; } +bool StartServerAction(const ResourcePtr &res) { + FuncGraphPtr func_graph = res->func_graph(); + const std::string &server_mode_ = ps::PSContext::instance()->server_mode(); + size_t worker_num = ps::PSContext::instance()->initial_worker_num(); + size_t server_num = ps::PSContext::instance()->initial_server_num(); + uint64_t fl_server_port = ps::PSContext::instance()->fl_server_port(); + + // Update model threshold is a certain ratio of start_fl_job threshold. + // update_model_threshold_ = start_fl_job_threshold_ * percent_for_update_model_. + size_t start_fl_job_threshold = ps::PSContext::instance()->start_fl_job_threshold(); + float percent_for_update_model = 1; + size_t update_model_threshold = static_cast(std::ceil(start_fl_job_threshold * percent_for_update_model)); + + std::vector rounds_config = { + {"startFLJob", false, 3000, false, start_fl_job_threshold}, + {"updateModel", false, 3000, false, update_model_threshold}, + {"getModel", false, 3000}, + {"asyncUpdateModel"}, + {"asyncGetModel"}, + {"push", false, 3000, true, worker_num}, + {"pull", false, 3000, true, worker_num}, + {"getWeightsByKey", false, 3000, true, 1}, + {"overwriteWeightsByKey", false, 3000, true, server_num}, + }; + + size_t executor_threshold = 0; + if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) { + executor_threshold = update_model_threshold; + ps::server::Server::GetInstance().Initialize(true, true, fl_server_port, rounds_config, func_graph, + executor_threshold); + } else if (server_mode_ == ps::kServerModePS) { + executor_threshold = worker_num; + ps::server::Server::GetInstance().Initialize(true, false, 0, rounds_config, func_graph, executor_threshold); + } else { + MS_LOG(EXCEPTION) << "Server mode " << server_mode_ << " is not supported."; + return false; + } + ps::server::Server::GetInstance().Run(); + return true; +} + bool StartPSSchedulerAction(const ResourcePtr &res) { ps::Scheduler::GetInstance().Run(); return true; @@ -797,6 +839,14 @@ std::vector VmPipeline() { } #if (ENABLE_CPU && !_WIN32) +std::vector ServerPipeline() { + auto actions = CommonPipeline(); + actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); + actions.emplace_back(std::make_pair("validate", ValidateAction)); + actions.emplace_back(std::make_pair("server", StartServerAction)); + return actions; +} + std::vector PServerPipeline() { auto actions = CommonPipeline(); actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); diff --git a/mindspore/ccsrc/pipeline/jit/action.h b/mindspore/ccsrc/pipeline/jit/action.h index e00abae37a9..4329fedef5d 100644 --- a/mindspore/ccsrc/pipeline/jit/action.h +++ b/mindspore/ccsrc/pipeline/jit/action.h @@ -43,10 +43,14 @@ bool ExecuteAction(const ResourcePtr &res); bool StartPSWorkerAction(const ResourcePtr &res); bool StartPSServerAction(const ResourcePtr &res); bool StartPSSchedulerAction(const ResourcePtr &res); +// This action is only for federated learning only. In later version, parameter server mode and federated learning will +// use the same action. +bool StartServerAction(const ResourcePtr &res); std::vector GePipeline(); std::vector VmPipeline(); std::vector PServerPipeline(); +std::vector ServerPipeline(); std::vector PSchedulerPipeline(); abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, const abstract::AbstractBasePtrList &args_spec, bool clear = false); diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index a6f62c74811..95a3f936097 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -326,7 +326,24 @@ PYBIND11_MODULE(_c_expression, m) { .def("insert_accumu_init_info", &PSContext::InsertAccumuInitInfo, "Insert accumulation initialization value.") .def("clone_hash_table", &PSContext::CloneHashTable, "Clone a hash table.") .def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not.") - .def("set_rank_id", &PSContext::set_rank_id, "Set rank id for worker on ps mode."); + .def("set_rank_id", &PSContext::set_rank_id, "Set rank id for worker on ps mode.") + .def("set_server_mode", &PSContext::set_server_mode, "Set server mode.") + .def("server_mode", &PSContext::server_mode, "Get server mode.") + .def("set_ms_role", &PSContext::set_ms_role, "Set role for this process.") + .def("ms_role", &PSContext::ms_role, "Get role for this process.") + .def("set_worker_num", &PSContext::set_worker_num, "Set worker number.") + .def("set_server_num", &PSContext::set_server_num, "Set server number.") + .def("set_scheduler_ip", &PSContext::set_scheduler_ip, "Set scheduler ip.") + .def("set_scheduler_port", &PSContext::set_scheduler_port, "Set scheduler port.") + .def("set_fl_server_port", &PSContext::set_fl_server_port, "Set federated learning server port.") + .def("set_fl_client_enable", &PSContext::set_fl_client_enable, "Set federated learning client.") + .def("set_start_fl_job_threshold", &PSContext::set_start_fl_job_threshold, "Set threshold count for start_fl_job.") + .def("set_fl_name", &PSContext::set_fl_name, "Set federated learning name.") + .def("set_fl_iteration_num", &PSContext::set_fl_iteration_num, "Set federated learning iteration number.") + .def("set_client_epoch_num", &PSContext::set_client_epoch_num, "Set federated learning client epoch number.") + .def("set_client_batch_size", &PSContext::set_client_batch_size, "Set federated learning client batch size.") + .def("set_secure_aggregation", &PSContext::set_secure_aggregation, + "Set federated learning client using secure aggregation."); (void)py::class_>(m, "OpInfoLoaderPy") .def(py::init()) diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 3636e92a543..af71bc3f07b 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -55,6 +55,7 @@ #include "ps/worker.h" #include "ps/ps_cache/ps_data/ps_data_prefetch.h" #include "ps/ps_cache/ps_cache_manager.h" +#include "ps/server/server.h" #endif #if (ENABLE_GE || ENABLE_D) @@ -529,6 +530,11 @@ std::vector GetPipeline(const ResourcePtr &resource, const std::stri std::string backend = MsContext::GetInstance()->backend_policy(); #if (ENABLE_CPU && !_WIN32) + const std::string &server_mode = ps::PSContext::instance()->server_mode(); + if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) && + ps::PSContext::instance()->is_server()) { + return ServerPipeline(); + } if (ps::PSContext::instance()->is_server()) { resource->results()[kBackend] = compile::CreateBackend(); return PServerPipeline(); diff --git a/mindspore/ccsrc/ps/CMakeLists.txt b/mindspore/ccsrc/ps/CMakeLists.txt index 2245b793246..2b3855d52ff 100644 --- a/mindspore/ccsrc/ps/CMakeLists.txt +++ b/mindspore/ccsrc/ps/CMakeLists.txt @@ -50,10 +50,13 @@ if(NOT ENABLE_CPU OR WIN32) list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/apply_momentum_kernel.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/aggregation_kernel_factory.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/fed_avg_kernel.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/optimizer_kernel_factory.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel_factory.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/start_fl_job_kernel.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/update_model_kernel.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/get_model_kernel.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/params_info.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/consistent_hash_ring.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/iteration_timer.cc") @@ -67,6 +70,7 @@ if(NOT ENABLE_CPU OR WIN32) list(REMOVE_ITEM _PS_SRC_FILES "server/iteration.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/model_store.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/round.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/server.cc") endif() list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc") diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index 222a3f31933..8b77ff8b6c4 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -61,7 +61,12 @@ void PSContext::SetPSEnable(bool enabled) { } } -bool PSContext::is_ps_mode() const { return ps_enabled_; } +bool PSContext::is_ps_mode() const { + if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) { + return true; + } + return ps_enabled_; +} void PSContext::Reset() { ps_enabled_ = false; @@ -77,6 +82,9 @@ void PSContext::Reset() { } std::string PSContext::ms_role() const { + if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) { + return role_; + } if (is_worker_) { return kEnvRoleOfWorker; } else if (is_pserver_) { @@ -88,11 +96,26 @@ std::string PSContext::ms_role() const { } } -bool PSContext::is_worker() const { return is_worker_; } +bool PSContext::is_worker() const { + if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) { + return role_ == kRoleOfWorker; + } + return is_worker_; +} -bool PSContext::is_server() const { return is_pserver_; } +bool PSContext::is_server() const { + if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) { + return role_ == kEnvRoleOfServer; + } + return is_pserver_; +} -bool PSContext::is_scheduler() const { return is_sched_; } +bool PSContext::is_scheduler() const { + if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) { + return role_ == kEnvRoleOfScheduler; + } + return is_sched_; +} uint32_t PSContext::initial_worker_num() { return worker_num_; } @@ -150,6 +173,94 @@ void PSContext::set_rank_id(int rank_id) const { #endif } +void PSContext::set_server_mode(const std::string &server_mode) { + if (server_mode != kServerModePS && server_mode != kServerModeFL && server_mode != kServerModeHybrid) { + MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL + << " or " << kServerModeHybrid; + return; + } + server_mode_ = server_mode; +} + +const std::string &PSContext::server_mode() const { return server_mode_; } + +void PSContext::set_ms_role(const std::string &role) { + if (server_mode_ != kServerModeFL && server_mode_ != kServerModeHybrid) { + MS_LOG(EXCEPTION) << "Only federated learning supports to set role by ps context."; + return; + } + if (role != kEnvRoleOfWorker && role != kEnvRoleOfServer && role != kEnvRoleOfScheduler) { + MS_LOG(EXCEPTION) << "ms_role " << role << " is invalid."; + return; + } + role_ = role; +} + +void PSContext::set_worker_num(uint32_t worker_num) { worker_num_ = worker_num; } +uint32_t PSContext::worker_num() const { return worker_num_; } + +void PSContext::set_server_num(uint32_t server_num) { + if (server_num == 0) { + MS_LOG(EXCEPTION) << "Server number must be greater than 0."; + return; + } + server_num_ = server_num; +} +uint32_t PSContext::server_num() const { return server_num_; } + +void PSContext::set_scheduler_ip(const std::string &sched_ip) { scheduler_host_ = sched_ip; } + +std::string PSContext::scheduler_ip() const { return scheduler_host_; } + +void PSContext::set_scheduler_port(uint16_t sched_port) { scheduler_port_ = sched_port; } + +uint16_t PSContext::scheduler_port() const { return scheduler_port_; } + +void PSContext::GenerateResetterRound() { + uint32_t binary_server_context = 0; + bool is_parameter_server_mode = false; + bool is_federated_learning_mode = false; + bool is_mixed_training_mode = false; + + if (server_mode_ == kServerModePS) { + is_parameter_server_mode = true; + } else if (server_mode_ == kServerModeFL) { + is_federated_learning_mode = true; + } else if (server_mode_ == kServerModeHybrid) { + is_mixed_training_mode = true; + } else { + MS_LOG(EXCEPTION) << server_mode_ << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL + << " or " << kServerModeHybrid; + return; + } + + binary_server_context = (is_parameter_server_mode << 0) | (is_federated_learning_mode << 1) | + (is_mixed_training_mode << 2) | (secure_aggregation_ << 3) | (worker_overwrite_weights_ << 4); + if (kServerContextToResetRoundMap.count(binary_server_context) == 0) { + resetter_round_ = ResetterRound::kNoNeedToReset; + } else { + resetter_round_ = kServerContextToResetRoundMap.at(binary_server_context); + } + MS_LOG(INFO) << "Server context is " << binary_server_context << ". Resetter round is " << resetter_round_; + return; +} + +ResetterRound PSContext::resetter_round() const { return resetter_round_; } + +void PSContext::set_fl_server_port(uint16_t fl_server_port) { fl_server_port_ = fl_server_port; } + +uint16_t PSContext::fl_server_port() const { return fl_server_port_; } + +void PSContext::set_fl_client_enable(bool enabled) { fl_client_enable_ = enabled; } + +bool PSContext::fl_client_enable() { return fl_client_enable_; } + +void PSContext::set_start_fl_job_threshold(size_t start_fl_job_threshold) { + start_fl_job_threshold_ = start_fl_job_threshold; +} + +size_t PSContext::start_fl_job_threshold() const { return start_fl_job_threshold_; } + void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; } const std::string &PSContext::fl_name() const { return fl_name_; } @@ -165,5 +276,15 @@ uint64_t PSContext::client_epoch_num() const { return client_epoch_num_; } void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch_size_ = client_batch_size; } uint64_t PSContext::client_batch_size() const { return client_batch_size_; } + +void PSContext::set_worker_overwrite_weights(uint64_t worker_overwrite_weights) { + worker_overwrite_weights_ = worker_overwrite_weights; +} + +uint64_t PSContext::worker_overwrite_weights() const { return worker_overwrite_weights_; } + +void PSContext::set_secure_aggregation(bool secure_aggregation) { secure_aggregation_ = secure_aggregation; } + +bool PSContext::secure_aggregation() const { return secure_aggregation_; } } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_context.h b/mindspore/ccsrc/ps/ps_context.h index 6a506959359..68e45db5db8 100644 --- a/mindspore/ccsrc/ps/ps_context.h +++ b/mindspore/ccsrc/ps/ps_context.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_CCSRC_PS_CONTEXT_H_ #define MINDSPORE_CCSRC_PS_CONTEXT_H_ +#include #include #include #include "ps/constants.h" @@ -24,12 +25,32 @@ namespace mindspore { namespace ps { +constexpr char kServerModePS[] = "PARAMETER_SERVER"; +constexpr char kServerModeFL[] = "FEDERATED_LEARNING"; +constexpr char kServerModeHybrid[] = "HYBRID_TRAINING"; constexpr char kEnvRole[] = "MS_ROLE"; constexpr char kEnvRoleOfPServer[] = "MS_PSERVER"; +constexpr char kEnvRoleOfServer[] = "MS_SERVER"; constexpr char kEnvRoleOfWorker[] = "MS_WORKER"; constexpr char kEnvRoleOfScheduler[] = "MS_SCHED"; constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS"; +// Use binary data to represent federated learning server's context so that we can judge which round resets the +// iteration. From right to left, each bit stands for: +// 0: Server is in parameter server mode. +// 1: Server is in federated learning mode. +// 2: Server is in mixed training mode. +// 3: Server enables sucure aggregation. +// 4: Server needs worker to overwrite weights. +// For example: 01010 stands for that the server is in federated learning mode and sucure aggregation is enabled. +enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kWorkerOverwriteWeights }; +const std::map kServerContextToResetRoundMap = { + {0b00010, ResetterRound::kUpdateModel}, + {0b01010, ResetterRound::kReconstructSeccrets}, + {0b11100, ResetterRound::kWorkerOverwriteWeights}, + {0b10100, ResetterRound::kWorkerOverwriteWeights}, + {0b00100, ResetterRound::kUpdateModel}}; + class PSContext { public: ~PSContext() = default; @@ -60,19 +81,64 @@ class PSContext { void set_cache_enable(bool cache_enable) const; void set_rank_id(int rank_id) const; - // Setter and getter for federated learning. + // In new server framework, process role, worker number, server number, scheduler ip and scheduler port should be set + // by ps_context. + void set_server_mode(const std::string &server_mode); + const std::string &server_mode() const; + + void set_ms_role(const std::string &role); + + void set_worker_num(uint32_t worker_num); + uint32_t worker_num() const; + + void set_server_num(uint32_t server_num); + uint32_t server_num() const; + + void set_scheduler_ip(const std::string &sched_ip); + std::string scheduler_ip() const; + + void set_scheduler_port(uint16_t sched_port); + uint16_t scheduler_port() const; + + // Methods federated learning. + + // Generate which round should reset the iteration. + void GenerateResetterRound(); + ResetterRound resetter_round() const; + + void set_fl_server_port(uint16_t fl_server_port); + uint16_t fl_server_port() const; + + // Set true if this process is a federated learning worker in cross-silo scenario. + void set_fl_client_enable(bool enabled); + bool fl_client_enable(); + + void set_start_fl_job_threshold(size_t start_fl_job_threshold); + size_t start_fl_job_threshold() const; + void set_fl_name(const std::string &fl_name); const std::string &fl_name() const; + // Set the iteration number of the federated learning. void set_fl_iteration_num(uint64_t fl_iteration_num); uint64_t fl_iteration_num() const; + // Set the training epoch number of the client. void set_client_epoch_num(uint64_t client_epoch_num); uint64_t client_epoch_num() const; + // Set the data batch size of the client. void set_client_batch_size(uint64_t client_batch_size); uint64_t client_batch_size() const; + // Set true if worker will overwrite weights on server. Used in hybrid training. + void set_worker_overwrite_weights(uint64_t worker_overwrite_weights); + uint64_t worker_overwrite_weights() const; + + // Set true if using secure aggregation for federated learning. + void set_secure_aggregation(bool secure_aggregation); + bool secure_aggregation() const; + private: PSContext() : ps_enabled_(false), @@ -94,11 +160,22 @@ class PSContext { std::string scheduler_host_; uint16_t scheduler_port_; + std::string role_; + // Members for federated learning. + std::string server_mode_; + ResetterRound resetter_round_; + uint16_t fl_server_port_; + bool fl_client_enable_; std::string fl_name_; + size_t start_fl_job_threshold_; uint64_t fl_iteration_num_; uint64_t client_epoch_num_; uint64_t client_batch_size_; + bool worker_overwrite_weights_; + + // Federated learning security. + bool secure_aggregation_; }; } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/scheduler.cc b/mindspore/ccsrc/ps/scheduler.cc index acb1ca785cc..08726a66e58 100755 --- a/mindspore/ccsrc/ps/scheduler.cc +++ b/mindspore/ccsrc/ps/scheduler.cc @@ -20,6 +20,9 @@ namespace mindspore { namespace ps { void Scheduler::Run() { MS_LOG(INFO) << "Start scheduler."; + core::ClusterMetadata::instance()->Init( + PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(), + PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port()); scheduler_node_.Start(); scheduler_node_.Finish(); scheduler_node_.Stop(); diff --git a/mindspore/ccsrc/ps/server/common.h b/mindspore/ccsrc/ps/server/common.h index cf931b52013..1e646732e06 100644 --- a/mindspore/ccsrc/ps/server/common.h +++ b/mindspore/ccsrc/ps/server/common.h @@ -44,6 +44,14 @@ enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER }; enum CommType { HTTP = 0, TCP }; enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum }; +struct RoundConfig { + std::string name; + bool check_timeout = false; + size_t time_window = 3000; + bool check_count = false; + size_t threshold_count = 0; +}; + using mindspore::kernel::Address; using mindspore::kernel::AddressPtr; using mindspore::kernel::CPUKernel; @@ -73,6 +81,7 @@ using ReuseKernelNodeInfo = std::map; using UploadData = std::map; constexpr auto kWeight = "weight"; +constexpr auto kNewWeight = "new_weight"; constexpr auto kAccumulation = "accum"; constexpr auto kLearningRate = "lr"; constexpr auto kGradient = "grad"; @@ -87,6 +96,8 @@ constexpr auto kAdamBeta1 = "beta1"; constexpr auto kAdamBeta2 = "beta2"; constexpr auto kAdamEps = "eps"; constexpr auto kFtrlLinear = "linear"; +constexpr auto kDataSize = "data_size"; +constexpr auto kNewDataSize = "new_data_size"; // OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is // launched. @@ -137,6 +148,7 @@ constexpr size_t kExecutorMaxTaskNum = 32; constexpr int kHttpSuccess = 200; constexpr auto kPBProtocol = "PB"; constexpr auto kFBSProtocol = "FBS"; +constexpr auto kFedAvg = "FedAvg"; constexpr auto kAggregationKernelType = "Aggregation"; constexpr auto kOptimizerKernelType = "Optimizer"; constexpr auto kCtxFuncGraph = "FuncGraph"; @@ -145,6 +157,8 @@ constexpr auto kCtxDeviceMetas = "device_metas"; constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration"; constexpr auto kCtxUpdateModelClientList = "update_model_client_list"; constexpr auto kCtxUpdateModelClientNum = "update_model_client_num"; +constexpr auto kCtxUpdateModelThld = "update_model_threshold"; +constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size"; // This macro the current timestamp in milliseconds. #define CURRENT_TIME_MILLI \ diff --git a/mindspore/ccsrc/ps/server/distributed_count_service.cc b/mindspore/ccsrc/ps/server/distributed_count_service.cc index 749b0a5eef9..91e8cf2bcd6 100644 --- a/mindspore/ccsrc/ps/server/distributed_count_service.cc +++ b/mindspore/ccsrc/ps/server/distributed_count_service.cc @@ -112,19 +112,19 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) { std::unique_lock lock(mutex_[name]); return global_current_count_[name].size() == global_threshold_count_[name]; } else { - CountReachThresholdRequest count_reach_threashold_req; - count_reach_threashold_req.set_name(name); + CountReachThresholdRequest count_reach_threshold_req; + count_reach_threshold_req.set_name(name); std::shared_ptr> query_cnt_enough_rsp_msg = nullptr; - if (!communicator_->SendPbRequest(count_reach_threashold_req, counting_server_rank_, + if (!communicator_->SendPbRequest(count_reach_threshold_req, counting_server_rank_, core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) { MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name; return false; } - CountReachThresholdResponse count_reach_threashold_rsp; - count_reach_threashold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), query_cnt_enough_rsp_msg->size()); - return count_reach_threashold_rsp.is_enough(); + CountReachThresholdResponse count_reach_threshold_rsp; + count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), query_cnt_enough_rsp_msg->size()); + return count_reach_threshold_rsp.is_enough(); } } @@ -200,9 +200,9 @@ void DistributedCountService::HandleCountReachThresholdRequest(const std::shared return; } - CountReachThresholdRequest count_reach_threashold_req; - count_reach_threashold_req.ParseFromArray(message->data(), message->len()); - const std::string &name = count_reach_threashold_req.name(); + CountReachThresholdRequest count_reach_threshold_req; + count_reach_threshold_req.ParseFromArray(message->data(), message->len()); + const std::string &name = count_reach_threshold_req.name(); std::unique_lock lock(mutex_[name]); if (global_threshold_count_.count(name) == 0) { @@ -210,10 +210,10 @@ void DistributedCountService::HandleCountReachThresholdRequest(const std::shared return; } - CountReachThresholdResponse count_reach_threashold_rsp; - count_reach_threashold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]); - communicator_->SendResponse(count_reach_threashold_rsp.SerializeAsString().data(), - count_reach_threashold_rsp.SerializeAsString().size(), message); + CountReachThresholdResponse count_reach_threshold_rsp; + count_reach_threshold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]); + communicator_->SendResponse(count_reach_threshold_rsp.SerializeAsString().data(), + count_reach_threshold_rsp.SerializeAsString().size(), message); return; } diff --git a/mindspore/ccsrc/ps/server/distributed_metadata_store.cc b/mindspore/ccsrc/ps/server/distributed_metadata_store.cc index b4e49c64ea3..5e1d637ffbc 100644 --- a/mindspore/ccsrc/ps/server/distributed_metadata_store.cc +++ b/mindspore/ccsrc/ps/server/distributed_metadata_store.cc @@ -193,7 +193,29 @@ void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr lock(mutex_[name]); - metadata_[name] = meta; + if (meta.has_device_meta()) { + auto &fl_id_to_meta_map = *metadata_[name].mutable_device_metas()->mutable_fl_id_to_meta(); + auto &fl_id = meta.device_meta().fl_id(); + auto &device_meta = meta.device_meta(); + fl_id_to_meta_map[fl_id] = device_meta; + } else if (meta.has_fl_id()) { + auto client_list = metadata_[name].mutable_client_list(); + auto &fl_id = meta.fl_id().fl_id(); + // Check whether the new item already exists. + bool add_flag = true; + for (int i = 0; i < client_list->fl_id_size(); i++) { + if (fl_id == client_list->fl_id(i)) { + add_flag = false; + break; + } + } + if (add_flag) { + client_list->add_fl_id(fl_id); + } + } else if (meta.has_update_model_threshold()) { + auto update_model_threshold = metadata_[name].mutable_update_model_threshold(); + *update_model_threshold = meta.update_model_threshold(); + } return true; } } // namespace server diff --git a/mindspore/ccsrc/ps/server/executor.cc b/mindspore/ccsrc/ps/server/executor.cc index 949f4f0f131..dfe10d3f869 100644 --- a/mindspore/ccsrc/ps/server/executor.cc +++ b/mindspore/ccsrc/ps/server/executor.cc @@ -23,7 +23,7 @@ namespace mindspore { namespace ps { namespace server { -void Executor::Init(const FuncGraphPtr &func_graph, size_t aggregation_count) { +void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count) { MS_EXCEPTION_IF_NULL(func_graph); if (aggregation_count == 0) { MS_LOG(EXCEPTION) << "Server aggregation count must be greater than 0"; diff --git a/mindspore/ccsrc/ps/server/executor.h b/mindspore/ccsrc/ps/server/executor.h index ef4eafb6af6..5c777d0ad70 100644 --- a/mindspore/ccsrc/ps/server/executor.h +++ b/mindspore/ccsrc/ps/server/executor.h @@ -43,7 +43,7 @@ class Executor { // be used for aggregators. // As noted in header file parameter_aggregator.h, we create aggregators by trainable parameters, which is the // optimizer cnode's input. So we need to initialize server executor using func_graph. - void Init(const FuncGraphPtr &func_graph, size_t aggregation_count); + void Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count); // Called in parameter server training mode to do Push operation. // For the same trainable parameter, HandlePush method must be called aggregation_count_ times before it's considered diff --git a/mindspore/ccsrc/ps/server/kernel/fed_avg_kernel.cc b/mindspore/ccsrc/ps/server/kernel/fed_avg_kernel.cc new file mode 100644 index 00000000000..30a8d806c77 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/fed_avg_kernel.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/kernel/fed_avg_kernel.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +REG_AGGREGATION_KERNEL_TWO(FedAvg, + ParamsInfo() + .AddInputNameType(kWeight, kNumberTypeFloat32) + .AddInputNameType(kDataSize, kNumberTypeUInt64) + .AddInputNameType(kNewWeight, kNumberTypeFloat32) + .AddInputNameType(kNewDataSize, kNumberTypeUInt64), + FedAvgKernel, float, size_t) +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/kernel/fed_avg_kernel.h b/mindspore/ccsrc/ps/server/kernel/fed_avg_kernel.h new file mode 100644 index 00000000000..2a313affc62 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/fed_avg_kernel.h @@ -0,0 +1,179 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_ +#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_ + +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "ps/server/common.h" +#include "ps/server/collective_ops_impl.h" +#include "ps/server/distributed_count_service.h" +#include "ps/server/local_meta_store.h" +#include "ps/server/kernel/aggregation_kernel.h" +#include "ps/server/kernel/aggregation_kernel_factory.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +// The implementation for the federated average. We do weighted average for the weights. The uploaded weights from +// FL-clients is already multiplied by its data size so only sum and division are done in this kernel. + +// Pay attention that this kernel is the distributed version of federated average, which means each server node in the +// cluster in invalved in the aggragation process. So the DistributedCountService and CollectiveOpsImpl are called. +template +class FedAvgKernel : public AggregationKernel { + public: + FedAvgKernel() : participated_(false) {} + ~FedAvgKernel() override = default; + void InitKernel(const CNodePtr &kernel_node) override { + MS_EXCEPTION_IF_NULL(kernel_node); + std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node); + if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 || + kNameToIdxMap.at(cnode_name).at("inputs").count("weight") == 0) { + MS_LOG(EXCEPTION) << "Can't find index info of weight for kernel " << cnode_name; + return; + } + cnode_weight_idx_ = kNameToIdxMap.at(cnode_name).at("inputs").at("weight"); + std::vector weight_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, cnode_weight_idx_); + size_t weight_size = + std::accumulate(weight_shape.begin(), weight_shape.end(), sizeof(T), std::multiplies()); + size_t new_weight_size = weight_size; + + input_size_list_.push_back(weight_size); + input_size_list_.push_back(sizeof(size_t)); + input_size_list_.push_back(new_weight_size); + input_size_list_.push_back(sizeof(size_t)); + + auto weight_node = + AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, cnode_weight_idx_), 0).first; + MS_EXCEPTION_IF_NULL(weight_node); + name_ = cnode_name + "." + weight_node->fullname_with_scope(); + MS_LOG(INFO) << "Register counter for " << name_; + auto first_cnt_handler = [&](std::shared_ptr) { + std::unique_lock lock(weight_mutex_); + if (!participated_) { + ClearWeightAndDataSize(); + } + }; + auto last_cnt_handler = [&](std::shared_ptr) { + T *weight_addr = reinterpret_cast(weight_addr_->addr); + size_t weight_size = weight_addr_->size; + S *data_size_addr = reinterpret_cast(data_size_addr_->addr); + if (!CollectiveOpsImpl::GetInstance().AllReduce(weight_addr, weight_addr, weight_size / sizeof(T))) { + MS_LOG(ERROR) << "Federated average allreduce failed."; + return; + } + if (!CollectiveOpsImpl::GetInstance().AllReduce(data_size_addr, data_size_addr, 1)) { + MS_LOG(ERROR) << "Federated average allreduce failed."; + return; + } + LocalMetaStore::GetInstance().put_value(kCtxFedAvgTotalDataSize, data_size_addr[0]); + for (size_t i = 0; i < weight_size / sizeof(T); i++) { + weight_addr[i] /= data_size_addr[0]; + } + done_ = true; + DistributedCountService::GetInstance().ResetCounter(name_); + return; + }; + DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler, last_cnt_handler}); + return; + } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + std::unique_lock lock(weight_mutex_); + // The weight and new_weight values should be multiplied by clients already, so we don't need to do multiplication + // again. + T *weight_addr = reinterpret_cast(inputs[0]->addr); + S *data_size_addr = reinterpret_cast(inputs[1]->addr); + T *new_weight_addr = reinterpret_cast(inputs[2]->addr); + S *new_data_size_addr = reinterpret_cast(inputs[3]->addr); + if (accum_count_ == 0) { + ClearWeightAndDataSize(); + } + + MS_LOG(DEBUG) << "Iteration: " << LocalMetaStore::GetInstance().curr_iter_num() << " launching FedAvgKernel for " + << name_ << " new data size is " << new_data_size_addr[0] << ", current total data size is " + << data_size_addr[0]; + for (size_t i = 0; i < inputs[2]->size / sizeof(T); i++) { + weight_addr[i] += new_weight_addr[i]; + } + data_size_addr[0] += new_data_size_addr[0]; + lock.unlock(); + + accum_count_++; + participated_ = true; + DistributedCountService::GetInstance().Count( + name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_)); + return true; + } + void Reset() { + accum_count_ = 0; + done_ = false; + participated_ = false; + DistributedCountService::GetInstance().ResetCounter(name_); + return; + } + bool IsAggregationDone() { return done_; } + + private: + void GenerateReuseKernelNodeInfo() override { + // Only the trainable parameter is reused for federated average. + reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, cnode_weight_idx_)); + return; + } + + // In some cases, the Launch method is not called and the weights involved in AllReduce should be set to 0. + void ClearWeightAndDataSize() { + int ret = memset_s(weight_addr_->addr, weight_addr_->size, 0x00, weight_addr_->size); + if (ret != 0) { + MS_LOG(ERROR) << "memset_s error, errorno(" << ret << ")"; + return; + } + ret = memset_s(data_size_addr_->addr, data_size_addr_->size, 0x00, data_size_addr_->size); + if (ret != 0) { + MS_LOG(ERROR) << "memset_s error, errorno(" << ret << ")"; + return; + } + return; + } + + // The trainable parameter index of the kernel node which is parsed from the frontend func_graph. + size_t cnode_weight_idx_; + + // The address pointer of the inputs. + AddressPtr weight_addr_; + AddressPtr data_size_addr_; + AddressPtr new_weight_addr_; + AddressPtr new_data_size_addr_; + + // Whether the kernel's Launch method is called. + bool participated_; + + // The kernel could be called concurrently so we need lock to ensure threadsafe. + std::mutex weight_mutex_; +}; +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_ diff --git a/mindspore/ccsrc/ps/server/kernel/round/get_model_kernel.cc b/mindspore/ccsrc/ps/server/kernel/round/get_model_kernel.cc new file mode 100644 index 00000000000..4c565a273cf --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/get_model_kernel.cc @@ -0,0 +1,125 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/kernel/round/get_model_kernel.h" +#include +#include +#include +#include +#include "ps/server/model_store.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +void GetModelKernel::InitKernel(size_t) { + if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { + iteration_time_window_ = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); + } + + executor_ = &Executor::GetInstance(); + MS_EXCEPTION_IF_NULL(executor_); + if (!executor_->initialized()) { + MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline."; + return; + } +} + +bool GetModelKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + MS_LOG(INFO) << "Launching GetModelKernel kernel."; + void *req_data = inputs[0]->addr; + std::shared_ptr fbb = std::make_shared(); + if (fbb == nullptr || req_data == nullptr) { + MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr."; + return false; + } + + const schema::RequestGetModel *get_model_req = flatbuffers::GetRoot(req_data); + GetModel(get_model_req, fbb); + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; +} + +bool GetModelKernel::Reset() { + MS_LOG(INFO) << "Get model kernel reset!"; + StopTimer(); + return true; +} + +void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, const std::shared_ptr &fbb) { + std::map feature_maps; + size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num(); + size_t get_model_iter = static_cast(get_model_req->iteration()); + const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model(); + size_t latest_iter_num = iter_to_model.rbegin()->first; + + if ((current_iter == get_model_iter && latest_iter_num != current_iter) || current_iter == get_model_iter - 1) { + std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter); + BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps, + std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); + MS_LOG(WARNING) << reason; + return; + } + + if (iter_to_model.count(get_model_iter) == 0) { + std::string reason = "The iteration of GetModel request" + std::to_string(get_model_iter) + + " is invalid. Current iteration is " + std::to_string(current_iter); + BuildGetModelRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter, feature_maps, + std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); + MS_LOG(ERROR) << reason; + return; + } + + feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter); + BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, + "Get model for iteration " + std::to_string(get_model_iter) + " success.", current_iter, + feature_maps, std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); + return; +} + +void GetModelKernel::BuildGetModelRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, + const std::string &reason, const size_t iter, + const std::map &feature_maps, + const std::string ×tamp) { + auto fbs_reason = fbb->CreateString(reason); + auto fbs_timestamp = fbb->CreateString(timestamp); + std::vector> fbs_feature_maps; + for (const auto &feature_map : feature_maps) { + auto fbs_weight_fullname = fbb->CreateString(feature_map.first); + auto fbs_weight_data = + fbb->CreateVector(reinterpret_cast(feature_map.second->addr), feature_map.second->size / sizeof(float)); + auto fbs_feature_map = schema::CreateFeatureMap(*(fbb.get()), fbs_weight_fullname, fbs_weight_data); + fbs_feature_maps.push_back(fbs_feature_map); + } + auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps); + + schema::ResponseGetModelBuilder rsp_get_model_builder(*(fbb.get())); + rsp_get_model_builder.add_retcode(retcode); + rsp_get_model_builder.add_reason(fbs_reason); + rsp_get_model_builder.add_iteration(static_cast(iter)); + rsp_get_model_builder.add_feature_map(fbs_feature_maps_vector); + rsp_get_model_builder.add_timestamp(fbs_timestamp); + auto rsp_get_model = rsp_get_model_builder.Finish(); + fbb->Finish(rsp_get_model); + return; +} + +REG_ROUND_KERNEL(getModel, GetModelKernel) +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/kernel/round/get_model_kernel.h b/mindspore/ccsrc/ps/server/kernel/round/get_model_kernel.h new file mode 100644 index 00000000000..10d0cc15ff3 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/get_model_kernel.h @@ -0,0 +1,59 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_MODEL_KERNEL_H_ +#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_MODEL_KERNEL_H_ + +#include +#include +#include +#include +#include "ps/server/common.h" +#include "ps/server/executor.h" +#include "ps/server/kernel/round/round_kernel.h" +#include "ps/server/kernel/round/round_kernel_factory.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +class GetModelKernel : public RoundKernel { + public: + GetModelKernel() = default; + ~GetModelKernel() override = default; + + void InitKernel(size_t) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + bool Reset() override; + + private: + void GetModel(const schema::RequestGetModel *get_model_req, const std::shared_ptr &fbb); + void BuildGetModelRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, + const std::string &reason, const size_t iter, + const std::map &feature_maps, const std::string ×tamp); + + // The executor is for getting model for getModel request. + Executor *executor_; + + // The time window of one iteration. + size_t iteration_time_window_; +}; +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_ diff --git a/mindspore/ccsrc/ps/server/kernel/round/start_fl_job_kernel.cc b/mindspore/ccsrc/ps/server/kernel/round/start_fl_job_kernel.cc index 3c277bdf45d..276f055ac10 100644 --- a/mindspore/ccsrc/ps/server/kernel/round/start_fl_job_kernel.cc +++ b/mindspore/ccsrc/ps/server/kernel/round/start_fl_job_kernel.cc @@ -49,7 +49,7 @@ bool StartFLJobKernel::Launch(const std::vector &inputs, const std:: return false; } void *req_data = inputs[0]->addr; - const std::shared_ptr &fbb = std::make_shared(); + std::shared_ptr fbb = std::make_shared(); if (fbb == nullptr || req_data == nullptr) { MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr."; return false; diff --git a/mindspore/ccsrc/ps/server/kernel/round/update_model_kernel.cc b/mindspore/ccsrc/ps/server/kernel/round/update_model_kernel.cc new file mode 100644 index 00000000000..766d8afe5f2 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/update_model_kernel.cc @@ -0,0 +1,203 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "ps/server/kernel/round/update_model_kernel.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +void UpdateModelKernel::InitKernel(size_t threshold_count) { + if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { + iteration_time_window_ = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); + } + + executor_ = &Executor::GetInstance(); + MS_EXCEPTION_IF_NULL(executor_); + if (!executor_->initialized()) { + MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline."; + return; + } + + PBMetadata client_list; + DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxUpdateModelClientList, client_list); + LocalMetaStore::GetInstance().put_value(kCtxUpdateModelThld, threshold_count); +} + +bool UpdateModelKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + if (inputs.size() != 1 || outputs.size() != 1) { + MS_LOG(ERROR) << "inputs or outputs size is invalid."; + return false; + } + void *req_data = inputs[0]->addr; + std::shared_ptr fbb = std::make_shared(); + if (fbb == nullptr || req_data == nullptr) { + MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr."; + return false; + } + + MS_LOG(INFO) << "Launching UpdateModelKernel kernel."; + if (!ReachThresholdForUpdateModel(fbb)) { + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return false; + } + + const schema::RequestUpdateModel *update_model_req = flatbuffers::GetRoot(req_data); + if (!UpdateModel(update_model_req, fbb)) { + MS_LOG(ERROR) << "Updating model failed."; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return false; + } + + if (!CountForUpdateModel(fbb, update_model_req)) { + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return false; + } + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; +} + +bool UpdateModelKernel::Reset() { + MS_LOG(INFO) << "Update model kernel reset!"; + StopTimer(); + DistributedCountService::GetInstance().ResetCounter(name_); + executor_->ResetAggregationStatus(); + DistributedMetadataStore::GetInstance().ResetMetadata(kCtxUpdateModelClientList); + size_t &total_data_size = LocalMetaStore::GetInstance().mutable_value(kCtxFedAvgTotalDataSize); + total_data_size = 0; + return true; +} + +void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr &message) { + if (PSContext::instance()->resetter_round() == ResetterRound::kUpdateModel) { + while (!executor_->IsAllWeightAggregationDone()) { + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + } + + size_t total_data_size = LocalMetaStore::GetInstance().value(kCtxFedAvgTotalDataSize); + MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is " + << total_data_size; + + FinishIterCb(); + } +} + +bool UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr &fbb) { + if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { + std::string reason = "Current amount for updateModel is enough."; + BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, + std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); + MS_LOG(ERROR) << reason; + return false; + } + return true; +} + +bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req, + const std::shared_ptr &fbb) { + size_t iteration = static_cast(update_model_req->iteration()); + if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) { + std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) + + ", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num()); + BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, + std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); + MS_LOG(ERROR) << reason; + return false; + } + + PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas); + FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas(); + std::string update_model_fl_id = update_model_req->fl_id()->str(); + if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) { + std::string reason = "devices_meta for " + update_model_fl_id + " is not set."; + BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, + std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); + MS_LOG(ERROR) << reason; + return false; + } + + size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size(); + auto feature_map = ParseFeatureMap(update_model_req); + for (auto weight : feature_map) { + weight.second[kNewDataSize].addr = &data_size; + weight.second[kNewDataSize].size = sizeof(size_t); + executor_->HandleModelUpdate(weight.first, weight.second); + } + + FLId fl_id; + fl_id.set_fl_id(update_model_fl_id); + PBMetadata comm_value; + *comm_value.mutable_fl_id() = fl_id; + DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value); + + BuildUpdateModelRsp(fbb, schema::ResponseCode_SucNotReady, "success not ready", + std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); + return true; +} + +std::map UpdateModelKernel::ParseFeatureMap( + const schema::RequestUpdateModel *update_model_req) { + RETURN_IF_NULL(update_model_req, {}); + std::map feature_map; + auto fbs_feature_map = update_model_req->feature_map(); + for (size_t i = 0; i < fbs_feature_map->size(); i++) { + std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str(); + float *weight_data = const_cast(fbs_feature_map->Get(i)->data()->data()); + size_t weight_size = fbs_feature_map->Get(i)->data()->size() * sizeof(float); + UploadData upload_data; + upload_data[kNewWeight].addr = weight_data; + upload_data[kNewWeight].size = weight_size; + feature_map[weight_full_name] = upload_data; + } + return feature_map; +} + +bool UpdateModelKernel::CountForUpdateModel(const std::shared_ptr &fbb, + const schema::RequestUpdateModel *update_model_req) { + if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str())) { + std::string reason = "UpdateModel counting failed."; + BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, + std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); + MS_LOG(ERROR) << reason; + return false; + } + return true; +} + +void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, + const std::string &reason, const std::string &next_req_time) { + auto fbs_reason = fbb->CreateString(reason); + auto fbs_next_req_time = fbb->CreateString(next_req_time); + + schema::ResponseUpdateModelBuilder rsp_update_model_builder(*(fbb.get())); + rsp_update_model_builder.add_retcode(retcode); + rsp_update_model_builder.add_reason(fbs_reason); + rsp_update_model_builder.add_next_req_time(fbs_next_req_time); + auto rsp_update_model = rsp_update_model_builder.Finish(); + fbb->Finish(rsp_update_model); + return; +} + +REG_ROUND_KERNEL(updateModel, UpdateModelKernel) +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/kernel/round/update_model_kernel.h b/mindspore/ccsrc/ps/server/kernel/round/update_model_kernel.h new file mode 100644 index 00000000000..a48db5459e8 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/update_model_kernel.h @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_ +#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_ + +#include +#include +#include +#include +#include "ps/server/common.h" +#include "ps/server/kernel/round/round_kernel.h" +#include "ps/server/kernel/round/round_kernel_factory.h" +#include "ps/server/executor.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +class UpdateModelKernel : public RoundKernel { + public: + UpdateModelKernel() = default; + ~UpdateModelKernel() override = default; + + void InitKernel(size_t threshold_count) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + bool Reset() override; + + // In some cases, the last updateModel message means this server iteration is finished. + void OnLastCountEvent(const std::shared_ptr &message) override; + + private: + bool ReachThresholdForUpdateModel(const std::shared_ptr &fbb); + bool UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr &fbb); + std::map ParseFeatureMap(const schema::RequestUpdateModel *update_model_req); + bool CountForUpdateModel(const std::shared_ptr &fbb, const schema::RequestUpdateModel *update_model_req); + void BuildUpdateModelRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, + const std::string &reason, const std::string &next_req_time); + + // The executor is for updating the model for updateModel request. + Executor *executor_; + + // The time window of one iteration. + size_t iteration_time_window_; +}; +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_ diff --git a/mindspore/ccsrc/ps/server/model_store.cc b/mindspore/ccsrc/ps/server/model_store.cc index 62c23f50ad4..eb784952634 100644 --- a/mindspore/ccsrc/ps/server/model_store.cc +++ b/mindspore/ccsrc/ps/server/model_store.cc @@ -23,7 +23,7 @@ namespace mindspore { namespace ps { namespace server { -void ModelStore::Init(uint32_t max_count) { +void ModelStore::Initialize(uint32_t max_count) { if (!Executor::GetInstance().initialized()) { MS_LOG(EXCEPTION) << "Server's executor must be initialized before model storage."; return; diff --git a/mindspore/ccsrc/ps/server/model_store.h b/mindspore/ccsrc/ps/server/model_store.h index bbaf7ba295b..459cbafb25f 100644 --- a/mindspore/ccsrc/ps/server/model_store.h +++ b/mindspore/ccsrc/ps/server/model_store.h @@ -40,7 +40,7 @@ class ModelStore { } // Initialize ModelStore with max count of models need to be stored. - void Init(uint32_t max_count = 3); + void Initialize(uint32_t max_count = 3); // Store the model of the given iteration. The model is acquired from Executor. If the current model count is already // max_model_count_, the earliest model will be replaced. diff --git a/mindspore/ccsrc/ps/server/parameter_aggregator.cc b/mindspore/ccsrc/ps/server/parameter_aggregator.cc index a7ffcfdc0fd..a683d68f8a4 100644 --- a/mindspore/ccsrc/ps/server/parameter_aggregator.cc +++ b/mindspore/ccsrc/ps/server/parameter_aggregator.cc @@ -302,7 +302,7 @@ bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &cnode) { - std::vector aggregation_algorithm = {}; + std::vector aggregation_algorithm = {kFedAvg}; MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm; return aggregation_algorithm; } diff --git a/mindspore/ccsrc/ps/server/server.cc b/mindspore/ccsrc/ps/server/server.cc new file mode 100644 index 00000000000..de2ffb2db53 --- /dev/null +++ b/mindspore/ccsrc/ps/server/server.cc @@ -0,0 +1,251 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/server.h" +#include +#include +#include +#include "ps/server/round.h" +#include "ps/server/model_store.h" +#include "ps/server/iteration.h" +#include "ps/server/collective_ops_impl.h" +#include "ps/server/distributed_metadata_store.h" +#include "ps/server/distributed_count_service.h" +#include "ps/server/kernel/round/round_kernel_factory.h" + +namespace mindspore { +namespace ps { +namespace server { +static std::vector> global_worker_server_comms = {}; +// This function is for the exit of server process when an interrupt signal is captured. +void SignalHandler(int signal) { + MS_LOG(INFO) << "Interrupt signal captured: " << signal; + std::for_each(global_worker_server_comms.begin(), global_worker_server_comms.end(), + [](const std::shared_ptr &communicator) { communicator->Stop(); }); + return; +} + +void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector &rounds_config, + const FuncGraphPtr &func_graph, size_t executor_threshold) { + MS_EXCEPTION_IF_NULL(func_graph); + func_graph_ = func_graph; + + if (rounds_config.empty()) { + MS_LOG(EXCEPTION) << "Rounds are empty."; + return; + } + rounds_config_ = rounds_config; + + use_tcp_ = use_tcp; + use_http_ = use_http; + http_port_ = http_port; + executor_threshold_ = executor_threshold; + return; +} + +// Each step of the server pipeline may have dependency on other steps, which includes: + +// InitServerContext must be the first step to set contexts for later steps. + +// Server Running relies on URL or Message Type Register: +// StartCommunicator---->InitIteration + +// Metadata Register relies on Hash Ring of Servers which relies on Network Building Completion: +// RegisterRoundKernel---->StartCommunicator + +// Kernel Initialization relies on Executor Initialization: +// RegisterRoundKernel---->InitExecutor + +// Getting Model Size relies on ModelStorage Initialization which relies on Executor Initialization: +// InitCipher---->InitExecutor +void Server::Run() { + signal(SIGINT, SignalHandler); + InitServerContext(); + InitCluster(); + InitIteration(); + StartCommunicator(); + InitExecutor(); + RegisterRoundKernel(); + MS_LOG(INFO) << "Server started successfully."; + + // Wait communicators to stop so the main thread is blocked. + std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), + [](const std::shared_ptr &communicator) { communicator->Join(); }); + communicator_with_server_->Join(); + MsException::Instance().CheckException(); + return; +} + +void Server::InitServerContext() { + PSContext::instance()->GenerateResetterRound(); + scheduler_ip_ = PSContext::instance()->scheduler_host(); + scheduler_port_ = PSContext::instance()->scheduler_port(); + worker_num_ = PSContext::instance()->initial_worker_num(); + server_num_ = PSContext::instance()->initial_server_num(); + return; +} + +void Server::InitCluster() { + server_node_ = std::make_shared(); + MS_EXCEPTION_IF_NULL(server_node_); + task_executor_ = std::make_shared(32); + MS_EXCEPTION_IF_NULL(task_executor_); + if (!InitCommunicatorWithServer()) { + MS_LOG(EXCEPTION) << "Initializing cross-server communicator failed."; + return; + } + if (!InitCommunicatorWithWorker()) { + MS_LOG(EXCEPTION) << "Initializing worker-server communicator failed."; + return; + } + global_worker_server_comms = communicators_with_worker_; + return; +} + +bool Server::InitCommunicatorWithServer() { + MS_EXCEPTION_IF_NULL(task_executor_); + MS_EXCEPTION_IF_NULL(server_node_); + + communicator_with_server_ = + server_node_->GetOrCreateTcpComm(scheduler_ip_, scheduler_port_, worker_num_, server_num_, task_executor_); + MS_EXCEPTION_IF_NULL(communicator_with_server_); + + // Set exception event callbacks for server. + auto tcp_comm = std::dynamic_pointer_cast(communicator_with_server_); + MS_EXCEPTION_IF_NULL(tcp_comm); + + tcp_comm->RegisterEventCallback(core::CLUSTER_TIMEOUT, [&]() { + MS_LOG(ERROR) << "Event CLUSTER_TIMEOUT is captured. This is because some nodes(Scheduler/Server/Worker) are not " + "started during network building phase."; + std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), + [](const std::shared_ptr &communicator) { communicator->Stop(); }); + communicator_with_server_->Stop(); + }); + + tcp_comm->RegisterEventCallback(core::SCHEDULER_TIMEOUT, [&]() { + MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed."; + std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), + [](const std::shared_ptr &communicator) { communicator->Stop(); }); + communicator_with_server_->Stop(); + }); + + tcp_comm->RegisterEventCallback(core::NODE_TIMEOUT, [&]() { + MS_LOG(ERROR) + << "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the " + "network building phase."; + std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), + [](const std::shared_ptr &communicator) { communicator->Stop(); }); + communicator_with_server_->Stop(); + }); + return true; +} + +bool Server::InitCommunicatorWithWorker() { + MS_EXCEPTION_IF_NULL(server_node_); + MS_EXCEPTION_IF_NULL(task_executor_); + if (!use_tcp_ && !use_http_) { + MS_LOG(EXCEPTION) << "At least one type of protocol should be set."; + return false; + } + if (use_tcp_) { + auto tcp_comm = communicator_with_server_; + MS_EXCEPTION_IF_NULL(tcp_comm); + communicators_with_worker_.push_back(tcp_comm); + } + if (use_http_) { + auto http_comm = server_node_->GetOrCreateHttpComm("0.0.0.0", http_port_, task_executor_); + MS_EXCEPTION_IF_NULL(http_comm); + communicators_with_worker_.push_back(http_comm); + } + return true; +} + +void Server::InitIteration() { + iteration_ = std::make_shared(); + MS_EXCEPTION_IF_NULL(iteration_); + + // 1.Add rounds to the iteration according to the server mode. + for (const RoundConfig &config : rounds_config_) { + std::shared_ptr round = std::make_shared(config.name, config.check_timeout, config.time_window, + config.check_count, config.threshold_count); + MS_LOG(INFO) << "Add round " << config.name << ", check_count: " << config.check_count + << ", threshold:" << config.threshold_count; + iteration_->AddRound(round); + } + + // 2.Initialize all the rounds. + TimeOutCb time_out_cb = std::bind(&Iteration::ProceedToNextIter, iteration_); + FinishIterCb finish_iter_cb = std::bind(&Iteration::ProceedToNextIter, iteration_); + iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb); + return; +} + +void Server::InitExecutor() { + if (executor_threshold_ == 0) { + MS_LOG(EXCEPTION) << "The executor's threshold should greater than 0."; + return; + } + // The train engine instance is used in both push-type and pull-type kernels, + // so the required_cnt of these kernels must be the same as update_model_threshold_. + MS_LOG(INFO) << "Required count for push-type and pull-type kernels is " << executor_threshold_; + Executor::GetInstance().Initialize(func_graph_, executor_threshold_); + ModelStore::GetInstance().Initialize(); + return; +} + +void Server::RegisterRoundKernel() { + MS_EXCEPTION_IF_NULL(iteration_); + auto &rounds = iteration_->rounds(); + if (rounds.empty()) { + MS_LOG(EXCEPTION) << "Server has no round registered."; + return; + } + + for (auto &round : rounds) { + const std::string &name = round->name(); + std::shared_ptr round_kernel = kernel::RoundKernelFactory::GetInstance().Create(name); + if (round_kernel == nullptr) { + MS_LOG(EXCEPTION) << "Round kernel for round " << name << " is not registered."; + return; + } + + // For some round kernels, the threshold count should be set. + round_kernel->InitKernel(round->threshold_count()); + round->BindRoundKernel(round_kernel); + } + return; +} + +void Server::StartCommunicator() { + MS_EXCEPTION_IF_NULL(communicator_with_server_); + if (communicators_with_worker_.empty()) { + MS_LOG(EXCEPTION) << "Communicators for communication with worker is empty."; + return; + } + + MS_LOG(INFO) << "Start communicator with server."; + communicator_with_server_->Start(); + DistributedMetadataStore::GetInstance().Initialize(server_node_); + CollectiveOpsImpl::GetInstance().Initialize(server_node_); + DistributedCountService::GetInstance().Initialize(server_node_, kLeaderServerRank); + + MS_LOG(INFO) << "Start communicator with worker."; + std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), + [](const std::shared_ptr &communicator) { communicator->Start(); }); +} +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/server.h b/mindspore/ccsrc/ps/server/server.h new file mode 100644 index 00000000000..c57f4418c26 --- /dev/null +++ b/mindspore/ccsrc/ps/server/server.h @@ -0,0 +1,131 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_SERVER_H_ +#define MINDSPORE_CCSRC_PS_SERVER_SERVER_H_ + +#include +#include +#include +#include "ps/core/communicator/communicator_base.h" +#include "ps/core/communicator/tcp_communicator.h" +#include "ps/core/communicator/task_executor.h" +#include "ps/server/common.h" +#include "ps/server/executor.h" +#include "ps/server/iteration.h" + +namespace mindspore { +namespace ps { +namespace server { +// Class Server is the entrance of MindSpore's parameter server training mode and federated learning. +class Server { + public: + static Server &GetInstance() { + static Server instance; + return instance; + } + + void Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector &rounds_config, + const FuncGraphPtr &func_graph, size_t executor_threshold); + + // According to the current MindSpore framework, method Run is a step of the server pipeline. This method will be + // blocked until the server is finalized. + // func_graph is the frontend graph which will be parse in server's exector and aggregator. + void Run(); + + private: + Server() + : server_node_(nullptr), + task_executor_(nullptr), + use_tcp_(false), + use_http_(false), + http_port_(0), + func_graph_(nullptr), + executor_threshold_(0), + communicator_with_server_(nullptr), + communicators_with_worker_({}), + iteration_(nullptr), + scheduler_ip_(""), + scheduler_port_(0), + server_num_(0), + worker_num_(0) {} + ~Server() = default; + Server(const Server &) = delete; + Server &operator=(const Server &) = delete; + + // Load variables which is set by ps_context. + void InitServerContext(); + + // Initialize the server cluster, server node and communicators. + void InitCluster(); + bool InitCommunicatorWithServer(); + bool InitCommunicatorWithWorker(); + + // Initialize iteration with rounds. Which rounds to use could be set by ps_context as well. + void InitIteration(); + + // Initialize executor according to the server mode. + void InitExecutor(); + + // Create round kernels and bind these kernels with corresponding Round. + void RegisterRoundKernel(); + + // The communicators should be started after all initializations are completed. + void StartCommunicator(); + + // The server node is initialized in Server. + std::shared_ptr server_node_; + + // The task executor of the communicators. This helps server to handle network message concurrently. The tasks + // submitted to this task executor is asynchronous. + std::shared_ptr task_executor_; + + // Which protocol should communicators use. + bool use_tcp_; + bool use_http_; + uint64_t http_port_; + + // The configure of all rounds. + std::vector rounds_config_; + + // The graph passed by the frontend without backend optimizing. + FuncGraphPtr func_graph_; + + // The threshold count for executor to do aggregation or optimizing. + size_t executor_threshold_; + + // Server need a tcp communicator to communicate with other servers for counting, metadata storing, collective + // operations, etc. + std::shared_ptr communicator_with_server_; + + // The communication with workers(including mobile devices), has multiple protocol types: HTTP and TCP. + // In some cases, both types should be supported in one distributed training job. So here we may have multiple + // communicators. + std::vector> communicators_with_worker_; + + // Iteration consists of multiple kinds of rounds. + std::shared_ptr iteration_; + + // Variables set by ps context. + std::string scheduler_ip_; + uint16_t scheduler_port_; + uint32_t server_num_; + uint32_t worker_num_; +}; +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_SERVER_H_ diff --git a/mindspore/parallel/_ps_context.py b/mindspore/parallel/_ps_context.py index 834b992f97b..df4f11beb79 100644 --- a/mindspore/parallel/_ps_context.py +++ b/mindspore/parallel/_ps_context.py @@ -33,7 +33,21 @@ def ps_context(): return _ps_context _set_ps_context_func_map = { - "enable_ps": ps_context().set_ps_enable + "server_mode": ps_context().set_server_mode, + "ms_role": ps_context().set_ms_role, + "enable_ps": ps_context().set_ps_enable, + "worker_num": ps_context().set_worker_num, + "server_num": ps_context().set_server_num, + "scheduler_ip": ps_context().set_scheduler_ip, + "scheduler_port": ps_context().set_scheduler_port, + "fl_server_port": ps_context().set_fl_server_port, + "enable_fl_client": ps_context().set_fl_client_enable, + "start_fl_job_threshold": ps_context().set_start_fl_job_threshold, + "fl_name": ps_context().set_fl_name, + "fl_iteration_num": ps_context().set_fl_iteration_num, + "client_epoch_num": ps_context().set_client_epoch_num, + "client_batch_size": ps_context().set_client_batch_size, + "secure_aggregation": ps_context().set_secure_aggregation } _get_ps_context_func_map = { diff --git a/tests/st/fl/mobile/finish_mobile.py b/tests/st/fl/mobile/finish_mobile.py new file mode 100644 index 00000000000..69cdcafc20c --- /dev/null +++ b/tests/st/fl/mobile/finish_mobile.py @@ -0,0 +1,30 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse +import subprocess + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Finish test_mobile_lenet.py case") + parser.add_argument("--scheduler_port", type=int, default=8113) + + args, _ = parser.parse_known_args() + scheduler_port = args.scheduler_port + + cmd = "pid=`ps -ef|grep \"scheduler_port=" + str(scheduler_port) + "\" " + cmd += " | grep -v \"grep\" | grep -v \"finish\" |awk '{print $2}'` && " + cmd += "for id in $pid; do kill -9 $id && echo \"killed $id\"; done" + + subprocess.call(['bash', '-c', cmd]) diff --git a/tests/st/fl/mobile/run_mobile_sched.py b/tests/st/fl/mobile/run_mobile_sched.py new file mode 100644 index 00000000000..4bd68ab75e5 --- /dev/null +++ b/tests/st/fl/mobile/run_mobile_sched.py @@ -0,0 +1,52 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse +import subprocess + +parser = argparse.ArgumentParser(description="Run test_mobile_lenet.py case") +parser.add_argument("--device_target", type=str, default="CPU") +parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING") +parser.add_argument("--worker_num", type=int, default=0) +parser.add_argument("--server_num", type=int, default=2) +parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") +parser.add_argument("--scheduler_port", type=int, default=8113) +parser.add_argument("--fl_server_port", type=int, default=6666) + +if __name__ == "__main__": + args, _ = parser.parse_known_args() + device_target = args.device_target + server_mode = args.server_mode + worker_num = args.worker_num + server_num = args.server_num + scheduler_ip = args.scheduler_ip + scheduler_port = args.scheduler_port + fl_server_port = args.fl_server_port + + cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&" + cmd_sched += "mkdir ${execute_path}/scheduler/ &&" + cmd_sched += "cd ${execute_path}/scheduler/ || exit && export GLOG_v=1 &&" + cmd_sched += "python ${self_path}/../test_mobile_lenet.py" + cmd_sched += " --device_target=" + device_target + cmd_sched += " --server_mode=" + server_mode + cmd_sched += " --ms_role=MS_SCHED" + cmd_sched += " --worker_num=" + str(worker_num) + cmd_sched += " --server_num=" + str(server_num) + cmd_sched += " --scheduler_ip=" + scheduler_ip + cmd_sched += " --scheduler_port=" + str(scheduler_port) + cmd_sched += " --fl_server_port=" + str(fl_server_port) + cmd_sched += " > scheduler.log 2>&1 &" + + subprocess.call(['bash', '-c', cmd_sched]) diff --git a/tests/st/fl/mobile/run_mobile_server.py b/tests/st/fl/mobile/run_mobile_server.py new file mode 100644 index 00000000000..e3a5be281b3 --- /dev/null +++ b/tests/st/fl/mobile/run_mobile_server.py @@ -0,0 +1,82 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import ast +import argparse +import subprocess + +parser = argparse.ArgumentParser(description="Run test_mobile_lenet.py case") +parser.add_argument("--device_target", type=str, default="CPU") +parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING") +parser.add_argument("--worker_num", type=int, default=0) +parser.add_argument("--server_num", type=int, default=2) +parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") +parser.add_argument("--scheduler_port", type=int, default=8113) +parser.add_argument("--fl_server_port", type=int, default=6666) +parser.add_argument("--start_fl_job_threshold", type=int, default=1) +parser.add_argument("--fl_name", type=str, default="Lenet") +parser.add_argument("--fl_iteration_num", type=int, default=25) +parser.add_argument("--client_epoch_num", type=int, default=20) +parser.add_argument("--client_batch_size", type=int, default=32) +parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False) +parser.add_argument("--local_server_num", type=int, default=-1) + +if __name__ == "__main__": + args, _ = parser.parse_known_args() + device_target = args.device_target + server_mode = args.server_mode + worker_num = args.worker_num + server_num = args.server_num + scheduler_ip = args.scheduler_ip + scheduler_port = args.scheduler_port + fl_server_port = args.fl_server_port + start_fl_job_threshold = args.start_fl_job_threshold + fl_name = args.fl_name + fl_iteration_num = args.fl_iteration_num + client_epoch_num = args.client_epoch_num + client_batch_size = args.client_batch_size + secure_aggregation = args.secure_aggregation + local_server_num = args.local_server_num + + if local_server_num == -1: + local_server_num = server_num + + assert local_server_num <= server_num, "The local server number should not be bigger than total server number." + + for i in range(local_server_num): + cmd_server = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && " + cmd_server += "rm -rf ${execute_path}/server_" + str(i) + "/ &&" + cmd_server += "mkdir ${execute_path}/server_" + str(i) + "/ &&" + cmd_server += "cd ${execute_path}/server_" + str(i) + "/ || exit && export GLOG_v=1 &&" + cmd_server += "python ${self_path}/../test_mobile_lenet.py" + cmd_server += " --device_target=" + device_target + cmd_server += " --server_mode=" + server_mode + cmd_server += " --ms_role=MS_SERVER" + cmd_server += " --worker_num=" + str(worker_num) + cmd_server += " --server_num=" + str(server_num) + cmd_server += " --scheduler_ip=" + scheduler_ip + cmd_server += " --scheduler_port=" + str(scheduler_port) + cmd_server += " --fl_server_port=" + str(fl_server_port + i) + cmd_server += " --start_fl_job_threshold=" + str(start_fl_job_threshold) + cmd_server += " --fl_name=" + fl_name + cmd_server += " --fl_iteration_num=" + str(fl_iteration_num) + cmd_server += " --client_epoch_num=" + str(client_epoch_num) + cmd_server += " --client_batch_size=" + str(client_batch_size) + cmd_server += " --secure_aggregation=" + str(secure_aggregation) + cmd_server += " > server.log 2>&1 &" + + import time + time.sleep(0.3) + subprocess.call(['bash', '-c', cmd_server]) diff --git a/tests/st/fl/mobile/run_smlt.sh b/tests/st/fl/mobile/run_smlt.sh new file mode 100644 index 00000000000..6f304a6be25 --- /dev/null +++ b/tests/st/fl/mobile/run_smlt.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +export PYTHONPATH=../../../../:$PYTHONPATH +server_num=$1 +worker_num=$2 +ip=$3 +port=$4 + +for((i=0;i simulator_$i.log 2>&1 & +done diff --git a/tests/st/fl/mobile/simulator.py b/tests/st/fl/mobile/simulator.py new file mode 100644 index 00000000000..45dc0612b68 --- /dev/null +++ b/tests/st/fl/mobile/simulator.py @@ -0,0 +1,191 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse +import time +import random +import sys +import requests +import flatbuffers +import numpy as np +from mindspore.schema import (RequestFLJob, ResponseFLJob, ResponseCode, + RequestUpdateModel, FeatureMap, RequestGetModel, ResponseGetModel) + +parser = argparse.ArgumentParser() +parser.add_argument("--pid", type=int, default=0) +parser.add_argument("--http_ip", type=str, default="10.113.216.106") +parser.add_argument("--http_port", type=int, default=6666) +parser.add_argument("--use_elb", type=bool, default=False) +parser.add_argument("--server_num", type=int, default=1) + +args, _ = parser.parse_known_args() +pid = args.pid +http_ip = args.http_ip +http_port = args.http_port +use_elb = args.use_elb +server_num = args.server_num + +str_fl_id = 'fl_lenet_' + str(pid) + +def generate_port(): + if not use_elb: + return http_port + port = random.randint(0, 100000) % server_num + http_port + return port + + +def build_start_fl_job(iteration): + start_fl_job_builder = flatbuffers.Builder(1024) + + fl_name = start_fl_job_builder.CreateString('fl_test_job') + fl_id = start_fl_job_builder.CreateString(str_fl_id) + data_size = 32 + timestamp = start_fl_job_builder.CreateString('2020/11/16/19/18') + + RequestFLJob.RequestFLJobStart(start_fl_job_builder) + RequestFLJob.RequestFLJobAddFlName(start_fl_job_builder, fl_name) + RequestFLJob.RequestFLJobAddFlId(start_fl_job_builder, fl_id) + RequestFLJob.RequestFLJobAddIteration(start_fl_job_builder, iteration) + RequestFLJob.RequestFLJobAddDataSize(start_fl_job_builder, data_size) + RequestFLJob.RequestFLJobAddTimestamp(start_fl_job_builder, timestamp) + fl_job_req = RequestFLJob.RequestFLJobEnd(start_fl_job_builder) + + start_fl_job_builder.Finish(fl_job_req) + buf = start_fl_job_builder.Output() + return buf + +def build_feature_map(builder, names, lengths): + if len(names) != len(lengths): + return None + feature_maps = [] + np_data = [] + for j, _ in enumerate(names): + name = names[j] + length = lengths[j] + weight_full_name = builder.CreateString(name) + FeatureMap.FeatureMapStartDataVector(builder, length) + weight = np.random.rand(length) * 32 + np_data.append(weight) + for idx in range(length - 1, -1, -1): + builder.PrependFloat32(weight[idx]) + data = builder.EndVector(length) + FeatureMap.FeatureMapStart(builder) + FeatureMap.FeatureMapAddData(builder, data) + FeatureMap.FeatureMapAddWeightFullname(builder, weight_full_name) + feature_map = FeatureMap.FeatureMapEnd(builder) + feature_maps.append(feature_map) + return feature_maps, np_data + +def build_update_model(iteration): + builder_update_model = flatbuffers.Builder(1) + fl_name = builder_update_model.CreateString('fl_test_job') + fl_id = builder_update_model.CreateString(str_fl_id) + timestamp = builder_update_model.CreateString('2020/11/16/19/18') + + feature_maps, np_data = build_feature_map(builder_update_model, + ["conv1.weight", "conv2.weight", "fc1.weight", + "fc2.weight", "fc3.weight", "fc1.bias", "fc2.bias", "fc3.bias"], + [450, 2400, 48000, 10080, 5208, 120, 84, 62]) + + RequestUpdateModel.RequestUpdateModelStartFeatureMapVector(builder_update_model, 1) + for single_feature_map in feature_maps: + builder_update_model.PrependUOffsetTRelative(single_feature_map) + feature_map = builder_update_model.EndVector(len(feature_maps)) + + RequestUpdateModel.RequestUpdateModelStart(builder_update_model) + RequestUpdateModel.RequestUpdateModelAddFlName(builder_update_model, fl_name) + RequestUpdateModel.RequestUpdateModelAddFlId(builder_update_model, fl_id) + RequestUpdateModel.RequestUpdateModelAddIteration(builder_update_model, iteration) + RequestUpdateModel.RequestUpdateModelAddFeatureMap(builder_update_model, feature_map) + RequestUpdateModel.RequestUpdateModelAddTimestamp(builder_update_model, timestamp) + req_update_model = RequestUpdateModel.RequestUpdateModelEnd(builder_update_model) + builder_update_model.Finish(req_update_model) + buf = builder_update_model.Output() + return buf, np_data + +def build_get_model(iteration): + builder_get_model = flatbuffers.Builder(1) + fl_name = builder_get_model.CreateString('fl_test_job') + timestamp = builder_get_model.CreateString('2020/12/16/19/18') + + RequestGetModel.RequestGetModelStart(builder_get_model) + RequestGetModel.RequestGetModelAddFlName(builder_get_model, fl_name) + RequestGetModel.RequestGetModelAddIteration(builder_get_model, iteration) + RequestGetModel.RequestGetModelAddTimestamp(builder_get_model, timestamp) + req_get_model = RequestGetModel.RequestGetModelEnd(builder_get_model) + builder_get_model.Finish(req_get_model) + buf = builder_get_model.Output() + return buf + +weight_name_to_idx = { + "conv1.weight": 0, + "conv2.weight": 1, + "fc1.weight": 2, + "fc2.weight": 3, + "fc3.weight": 4, + "fc1.bias": 5, + "fc2.bias": 6, + "fc3.bias": 7 +} + +session = requests.Session() +current_iteration = 1 +url = "http://" + http_ip + ":" + str(generate_port()) +np.random.seed(0) +while True: + url1 = "http://" + http_ip + ":" + str(generate_port()) + '/startFLJob' + print("start url is ", url1) + x = requests.post(url1, data=build_start_fl_job(current_iteration)) + rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0) + print("start fl job iteration:", current_iteration, ", id:", args.pid) + while rsp_fl_job.Retcode() != ResponseCode.ResponseCode.SUCCEED: + x = requests.post(url1, data=build_start_fl_job(current_iteration)) + rsp_fl_job = rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0) + print("epoch is", rsp_fl_job.FlPlanConfig().Epochs()) + sys.stdout.flush() + + url2 = "http://" + http_ip + ":" + str(generate_port()) + '/updateModel' + print("req update model iteration:", current_iteration, ", id:", args.pid) + update_model_buf, update_model_np_data = build_update_model(current_iteration) + x = session.post(url2, data=update_model_buf) + print("rsp update model iteration:", current_iteration, ", id:", args.pid) + sys.stdout.flush() + + url3 = "http://" + http_ip + ":" + str(generate_port()) + '/getModel' + print("req get model iteration:", current_iteration, ", id:", args.pid) + x = session.post(url3, data=build_get_model(current_iteration)) + rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0) + print("rsp get model iteration:", current_iteration, ", id:", args.pid, rsp_get_model.Retcode()) + sys.stdout.flush() + + repeat_time = 0 + while rsp_get_model.Retcode() == ResponseCode.ResponseCode.SucNotReady: + time.sleep(0.1) + x = session.post(url3, data=build_get_model(current_iteration)) + rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0) + repeat_time += 1 + if repeat_time > 1000: + print("GetModel try timeout ", args.pid) + sys.exit(0) + + for i in range(0, 1): + print(rsp_get_model.FeatureMap(i).WeightFullname()) + origin = update_model_np_data[weight_name_to_idx[rsp_get_model.FeatureMap(i).WeightFullname().decode('utf-8')]] + after = rsp_get_model.FeatureMap(i).DataAsNumpy() * 32 + print("Before update model", args.pid, origin[0:10]) + print("After get model", args.pid, after[0:10]) + sys.stdout.flush() + assert np.allclose(origin, after, rtol=1e-05, atol=1e-05) + current_iteration += 1 diff --git a/tests/st/fl/mobile/src/adam.py b/tests/st/fl/mobile/src/adam.py new file mode 100644 index 00000000000..27adc4d3e29 --- /dev/null +++ b/tests/st/fl/mobile/src/adam.py @@ -0,0 +1,423 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""AdamWeightDecayForBert, a customized Adam for bert. Input: gradient, overflow flag.""" +import numpy as np + +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.common.tensor import Tensor +from mindspore._checkparam import Validator as validator +from mindspore._checkparam import Rel +from mindspore.nn.optim.optimizer import Optimizer + +_adam_opt = C.MultitypeFuncGraph("adam_opt") +_scaler_one = Tensor(1, mstype.int32) +_scaler_ten = Tensor(10, mstype.float32) + +@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", + "Tensor", "Bool", "Bool") +def _update_run_kernel(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter): + """ + Update parameters by AdamWeightDecay op. + """ + if optim_filter: + adam = P.AdamWeightDecay() + if decay_flags: + next_param = adam(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient) + else: + next_param = adam(param, m, v, lr, beta1, beta2, eps, 0.0, gradient) + return next_param + return gradient + +@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", + "Tensor", "Bool", "Bool") +def _update_run_op(beta1, beta2, eps, lr, overflow, weight_decay, param, m, v, gradient, decay_flag, optim_filter): + """ + Update parameters. + + Args: + beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). + beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). + eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. + lr (Tensor): Learning rate. + overflow (Tensor): Whether overflow occurs. + weight_decay (Number): Weight decay. Should be equal to or greater than 0. + param (Tensor): Parameters. + m (Tensor): m value of parameters. + v (Tensor): v value of parameters. + gradient (Tensor): Gradient of parameters. + decay_flag (bool): Applies weight decay or not. + optim_filter (bool): Applies parameter update or not. + + Returns: + Tensor, the new value of v after updating. + """ + if optim_filter: + op_mul = P.Mul() + op_square = P.Square() + op_sqrt = P.Sqrt() + op_cast = P.Cast() + op_reshape = P.Reshape() + op_shape = P.Shape() + op_select = P.Select() + + param_fp32 = op_cast(param, mstype.float32) + m_fp32 = op_cast(m, mstype.float32) + v_fp32 = op_cast(v, mstype.float32) + gradient_fp32 = op_cast(gradient, mstype.float32) + + cond = op_cast(F.fill(mstype.int32, op_shape(m_fp32), 1) * op_reshape(overflow, (())), mstype.bool_) + next_m = op_mul(beta1, m_fp32) + op_select(cond, m_fp32,\ + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)) + + next_v = op_mul(beta2, v_fp32) + op_select(cond, v_fp32,\ + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient_fp32))) + + update = next_m / (eps + op_sqrt(next_v)) + if decay_flag: + update = op_mul(weight_decay, param_fp32) + update + + update_with_lr = op_mul(lr, update) + zeros = F.fill(mstype.float32, op_shape(param_fp32), 0) + next_param = param_fp32 - op_select(cond, zeros, op_reshape(update_with_lr, op_shape(param_fp32))) + + next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) + next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) + next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) + + return op_cast(next_param, F.dtype(param)) + return gradient + + +@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", + "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") +def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, + beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable): + """Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" + success = True + indices = gradient.indices + values = gradient.values + if ps_parameter and not cache_enable: + op_shape = P.Shape() + shapes = (op_shape(param), op_shape(m), op_shape(v), + op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), + op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices)) + success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, + eps, values, indices), shapes), param)) + return success + + if not target: + success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2, + eps, values, indices)) + else: + op_mul = P.Mul() + op_square = P.Square() + op_sqrt = P.Sqrt() + scatter_add = P.ScatterAdd(use_locking) + + assign_m = F.assign(m, op_mul(beta1, m)) + assign_v = F.assign(v, op_mul(beta2, v)) + + grad_indices = gradient.indices + grad_value = gradient.values + + next_m = scatter_add(m, + grad_indices, + op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value)) + + next_v = scatter_add(v, + grad_indices, + op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value))) + + if use_nesterov: + m_temp = next_m * _scaler_ten + assign_m_nesterov = F.assign(m, op_mul(beta1, next_m)) + div_value = scatter_add(m, + op_mul(grad_indices, _scaler_one), + op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value)) + param_update = div_value / (op_sqrt(next_v) + eps) + + m_recover = F.assign(m, m_temp / _scaler_ten) + + F.control_depend(m_temp, assign_m_nesterov) + F.control_depend(assign_m_nesterov, div_value) + F.control_depend(param_update, m_recover) + else: + param_update = next_m / (op_sqrt(next_v) + eps) + + lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power) + + next_param = param - lr_t * param_update + + F.control_depend(assign_m, next_m) + F.control_depend(assign_v, next_v) + + success = F.depend(success, F.assign(param, next_param)) + success = F.depend(success, F.assign(m, next_m)) + success = F.depend(success, F.assign(v, next_v)) + + return success + + +@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", + "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") +def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, + beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, + moment1, moment2, ps_parameter, cache_enable): + """Apply adam optimizer to the weight parameter using Tensor.""" + success = True + if ps_parameter and not cache_enable: + op_shape = P.Shape() + success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), + (op_shape(param), op_shape(moment1), op_shape(moment2))), param)) + else: + success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, + eps, gradient)) + return success + + +@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "Tensor") +def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2): + """Apply AdamOffload optimizer to the weight parameter using Tensor.""" + success = True + delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient) + success = F.depend(success, F.assign_add(param, delat_param)) + return success + + +def _check_param_value(beta1, beta2, eps, prim_name): + """Check the type of inputs.""" + validator.check_value_type("beta1", beta1, [float], prim_name) + validator.check_value_type("beta2", beta2, [float], prim_name) + validator.check_value_type("eps", eps, [float], prim_name) + validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name) + validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name) + validator.check_positive_float(eps, "eps", prim_name) + +class AdamWeightDecayForBert(Optimizer): + """ + Implements the Adam algorithm to fix the weight decay. + + Note: + When separating parameter groups, the weight decay in each group will be applied on the parameters if the + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. + + To improve parameter groups performance, the customized order of parameters can be supported. + + Args: + params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, + the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params", + "lr", "weight_decay" and "order_params" are the keys can be parsed. + + - params: Required. The value must be a list of `Parameter`. + + - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used. + If not, the `learning_rate` in the API will be used. + + - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay + will be used. If not, the `weight_decay` in the API will be used. + + - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and + the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters + which in the 'order_params' must be in one of group parameters. + + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. + When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero + dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. + Default: 1e-3. + beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9. + Should be in range (0.0, 1.0). + beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999. + Should be in range (0.0, 1.0). + eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. + Should be greater than 0. + weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0. + + Inputs: + - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. + - **overflow** (tuple[Tensor]) - The overflow flag in dynamiclossscale. + + Outputs: + tuple[bool], all elements are True. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> net = Net() + >>> #1) All parameters use the same learning rate and weight decay + >>> optim = AdamWeightDecay(params=net.trainable_params()) + >>> + >>> #2) Use parameter groups and set different values + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, + ... {'params': no_conv_params, 'lr': 0.01}, + ... {'order_params': net.trainable_params()}] + >>> optim = AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. + >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> model = Model(net, loss_fn=loss, optimizer=optim) + """ + def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): + super(AdamWeightDecayForBert, self).__init__(learning_rate, params, weight_decay) + _check_param_value(beta1, beta2, eps, self.cls_name) + self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) + self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) + self.eps = Tensor(np.array([eps]).astype(np.float32)) + self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros') + self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros') + self.hyper_map = C.HyperMap() + self.op_select = P.Select() + self.op_cast = P.Cast() + self.op_reshape = P.Reshape() + self.op_shape = P.Shape() + + def construct(self, gradients, overflow): + """AdamWeightDecayForBert""" + lr = self.get_lr() + cond = self.op_cast(F.fill(mstype.int32, self.op_shape(self.beta1), 1) *\ + self.op_reshape(overflow, (())), mstype.bool_) + beta1 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta1) + beta2 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta2) + if self.is_group: + if self.is_group_lr: + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps), + lr, self.weight_decay, self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.hyper_map(F.partial(_adam_opt, beta1, beta2, self.eps, lr, overflow), + self.weight_decay, self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay), + self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + if self.use_parallel: + self.broadcast_params(optim_result) + return optim_result + +class AdamWeightDecayOp(Optimizer): + """ + Implements the Adam algorithm to fix the weight decay. It is a complete operator, not a combination of other ops. + + Note: + When separating parameter groups, the weight decay in each group will be applied on the parameters if the + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. + + To improve parameter groups performance, the customized order of parameters can be supported. + + Args: + params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, + the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params", + "lr", "weight_decay" and "order_params" are the keys can be parsed. + + - params: Required. The value must be a list of `Parameter`. + + - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used. + If not, the `learning_rate` in the API will be used. + + - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay + will be used. If not, the `weight_decay` in the API will be used. + + - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and + the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters + which in the 'order_params' must be in one of group parameters. + + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. + When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero + dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. + Default: 1e-3. + beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9. + Should be in range (0.0, 1.0). + beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999. + Should be in range (0.0, 1.0). + eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. + Should be greater than 0. + weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0. + + Inputs: + - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. + + Outputs: + tuple[bool], all elements are True. + + Supported Platforms: + ``GPU`` + + Examples: + >>> net = Net() + >>> #1) All parameters use the same learning rate and weight decay + >>> optim = AdamWeightDecayOp(params=net.trainable_params()) + >>> + >>> #2) Use parameter groups and set different values + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, + ... {'params': no_conv_params, 'lr': 0.01}, + ... {'order_params': net.trainable_params()}] + >>> optim = AdamWeightDecayOp(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. + >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> model = Model(net, loss_fn=loss, optimizer=optim) + """ + def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): + super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay) + _check_param_value(beta1, beta2, eps, self.cls_name) + self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) + self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) + self.eps = Tensor(np.array([eps]).astype(np.float32)) + self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros') + self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros') + self.hyper_map = C.HyperMap() + + def construct(self, gradients): + """AdamWeightDecayOp""" + lr = self.get_lr() + if self.is_group: + if self.is_group_lr: + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps), + lr, self.weight_decay, self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr), + self.weight_decay, self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay), + self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + if self.use_parallel: + self.broadcast_params(optim_result) + return optim_result diff --git a/tests/st/fl/mobile/src/model.py b/tests/st/fl/mobile/src/model.py new file mode 100644 index 00000000000..1a631a79941 --- /dev/null +++ b/tests/st/fl/mobile/src/model.py @@ -0,0 +1,72 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore.nn as nn +from mindspore.common.initializer import TruncatedNormal + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + """weight initial for conv layer""" + weight = weight_variable() + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + weight_init=weight, + has_bias=False, + pad_mode="valid", + ) + + +def fc_with_initialize(input_channels, out_channels): + """weight initial for fc layer""" + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + + +def weight_variable(): + """weight initial""" + return TruncatedNormal(0.02) + + +class LeNet5(nn.Cell): + def __init__(self, num_class=10, channel=3): + super(LeNet5, self).__init__() + self.num_class = num_class + self.conv1 = conv(channel, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, self.num_class) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x diff --git a/tests/st/fl/mobile/test_mobile_lenet.py b/tests/st/fl/mobile/test_mobile_lenet.py new file mode 100644 index 00000000000..44529a8ee97 --- /dev/null +++ b/tests/st/fl/mobile/test_mobile_lenet.py @@ -0,0 +1,96 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import ast +import argparse +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.nn import TrainOneStepCell, WithLossCell +from src.model import LeNet5 +from src.adam import AdamWeightDecayOp + +parser = argparse.ArgumentParser(description="test_fl_lenet") +parser.add_argument("--device_target", type=str, default="CPU") +parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING") +parser.add_argument("--ms_role", type=str, default="MS_WORKER") +parser.add_argument("--worker_num", type=int, default=0) +parser.add_argument("--server_num", type=int, default=1) +parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") +parser.add_argument("--scheduler_port", type=int, default=8113) +parser.add_argument("--fl_server_port", type=int, default=6666) +parser.add_argument("--start_fl_job_threshold", type=int, default=1) +parser.add_argument("--fl_name", type=str, default="Lenet") +parser.add_argument("--fl_iteration_num", type=int, default=25) +parser.add_argument("--client_epoch_num", type=int, default=20) +parser.add_argument("--client_batch_size", type=int, default=32) +parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False) + +args, _ = parser.parse_known_args() +device_target = args.device_target +server_mode = args.server_mode +ms_role = args.ms_role +worker_num = args.worker_num +server_num = args.server_num +scheduler_ip = args.scheduler_ip +scheduler_port = args.scheduler_port +fl_server_port = args.fl_server_port +start_fl_job_threshold = args.start_fl_job_threshold +fl_name = args.fl_name +fl_iteration_num = args.fl_iteration_num +client_epoch_num = args.client_epoch_num +client_batch_size = args.client_batch_size +secure_aggregation = args.secure_aggregation + +ctx = { + "enable_ps": False, + "server_mode": server_mode, + "ms_role": ms_role, + "worker_num": worker_num, + "server_num": server_num, + "scheduler_ip": scheduler_ip, + "scheduler_port": scheduler_port, + "fl_server_port": fl_server_port, + "start_fl_job_threshold": start_fl_job_threshold, + "fl_name": fl_name, + "fl_iteration_num": fl_iteration_num, + "client_epoch_num": client_epoch_num, + "client_batch_size": client_batch_size, + "secure_aggregation": secure_aggregation +} + +context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False) +context.set_ps_context(**ctx) + +if __name__ == "__main__": + epoch = 5 + np.random.seed(0) + network = LeNet5(62) + criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) + net_adam_opt = AdamWeightDecayOp(network.trainable_params(), weight_decay=0.1) + net_with_criterion = WithLossCell(network, criterion) + train_network = TrainOneStepCell(net_with_criterion, net_opt) + train_network.set_train() + losses = [] + + for _ in range(epoch): + data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32)) + label = Tensor(np.random.randint(0, 61, (32)).astype(np.int32)) + loss = train_network(data, label).asnumpy() + losses.append(loss) + print(losses)