forked from mindspore-Ecosystem/mindspore
!16035 Add Server part 3
From: @zpac Reviewed-by: @cristoval Signed-off-by:
This commit is contained in:
commit
dcec57955c
|
@ -47,6 +47,7 @@
|
|||
#include "ps/parameter_server.h"
|
||||
#include "ps/scheduler.h"
|
||||
#include "ps/worker.h"
|
||||
#include "ps/server/server.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -619,6 +620,47 @@ bool StartPSServerAction(const ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool StartServerAction(const ResourcePtr &res) {
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
const std::string &server_mode_ = ps::PSContext::instance()->server_mode();
|
||||
size_t worker_num = ps::PSContext::instance()->initial_worker_num();
|
||||
size_t server_num = ps::PSContext::instance()->initial_server_num();
|
||||
uint64_t fl_server_port = ps::PSContext::instance()->fl_server_port();
|
||||
|
||||
// Update model threshold is a certain ratio of start_fl_job threshold.
|
||||
// update_model_threshold_ = start_fl_job_threshold_ * percent_for_update_model_.
|
||||
size_t start_fl_job_threshold = ps::PSContext::instance()->start_fl_job_threshold();
|
||||
float percent_for_update_model = 1;
|
||||
size_t update_model_threshold = static_cast<size_t>(std::ceil(start_fl_job_threshold * percent_for_update_model));
|
||||
|
||||
std::vector<ps::server::RoundConfig> rounds_config = {
|
||||
{"startFLJob", false, 3000, false, start_fl_job_threshold},
|
||||
{"updateModel", false, 3000, false, update_model_threshold},
|
||||
{"getModel", false, 3000},
|
||||
{"asyncUpdateModel"},
|
||||
{"asyncGetModel"},
|
||||
{"push", false, 3000, true, worker_num},
|
||||
{"pull", false, 3000, true, worker_num},
|
||||
{"getWeightsByKey", false, 3000, true, 1},
|
||||
{"overwriteWeightsByKey", false, 3000, true, server_num},
|
||||
};
|
||||
|
||||
size_t executor_threshold = 0;
|
||||
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
|
||||
executor_threshold = update_model_threshold;
|
||||
ps::server::Server::GetInstance().Initialize(true, true, fl_server_port, rounds_config, func_graph,
|
||||
executor_threshold);
|
||||
} else if (server_mode_ == ps::kServerModePS) {
|
||||
executor_threshold = worker_num;
|
||||
ps::server::Server::GetInstance().Initialize(true, false, 0, rounds_config, func_graph, executor_threshold);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Server mode " << server_mode_ << " is not supported.";
|
||||
return false;
|
||||
}
|
||||
ps::server::Server::GetInstance().Run();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool StartPSSchedulerAction(const ResourcePtr &res) {
|
||||
ps::Scheduler::GetInstance().Run();
|
||||
return true;
|
||||
|
@ -797,6 +839,14 @@ std::vector<ActionItem> VmPipeline() {
|
|||
}
|
||||
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
std::vector<ActionItem> ServerPipeline() {
|
||||
auto actions = CommonPipeline();
|
||||
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
||||
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||
actions.emplace_back(std::make_pair("server", StartServerAction));
|
||||
return actions;
|
||||
}
|
||||
|
||||
std::vector<ActionItem> PServerPipeline() {
|
||||
auto actions = CommonPipeline();
|
||||
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
||||
|
|
|
@ -43,10 +43,14 @@ bool ExecuteAction(const ResourcePtr &res);
|
|||
bool StartPSWorkerAction(const ResourcePtr &res);
|
||||
bool StartPSServerAction(const ResourcePtr &res);
|
||||
bool StartPSSchedulerAction(const ResourcePtr &res);
|
||||
// This action is only for federated learning only. In later version, parameter server mode and federated learning will
|
||||
// use the same action.
|
||||
bool StartServerAction(const ResourcePtr &res);
|
||||
|
||||
std::vector<ActionItem> GePipeline();
|
||||
std::vector<ActionItem> VmPipeline();
|
||||
std::vector<ActionItem> PServerPipeline();
|
||||
std::vector<ActionItem> ServerPipeline();
|
||||
std::vector<ActionItem> PSchedulerPipeline();
|
||||
abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph,
|
||||
const abstract::AbstractBasePtrList &args_spec, bool clear = false);
|
||||
|
|
|
@ -326,7 +326,24 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("insert_accumu_init_info", &PSContext::InsertAccumuInitInfo, "Insert accumulation initialization value.")
|
||||
.def("clone_hash_table", &PSContext::CloneHashTable, "Clone a hash table.")
|
||||
.def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not.")
|
||||
.def("set_rank_id", &PSContext::set_rank_id, "Set rank id for worker on ps mode.");
|
||||
.def("set_rank_id", &PSContext::set_rank_id, "Set rank id for worker on ps mode.")
|
||||
.def("set_server_mode", &PSContext::set_server_mode, "Set server mode.")
|
||||
.def("server_mode", &PSContext::server_mode, "Get server mode.")
|
||||
.def("set_ms_role", &PSContext::set_ms_role, "Set role for this process.")
|
||||
.def("ms_role", &PSContext::ms_role, "Get role for this process.")
|
||||
.def("set_worker_num", &PSContext::set_worker_num, "Set worker number.")
|
||||
.def("set_server_num", &PSContext::set_server_num, "Set server number.")
|
||||
.def("set_scheduler_ip", &PSContext::set_scheduler_ip, "Set scheduler ip.")
|
||||
.def("set_scheduler_port", &PSContext::set_scheduler_port, "Set scheduler port.")
|
||||
.def("set_fl_server_port", &PSContext::set_fl_server_port, "Set federated learning server port.")
|
||||
.def("set_fl_client_enable", &PSContext::set_fl_client_enable, "Set federated learning client.")
|
||||
.def("set_start_fl_job_threshold", &PSContext::set_start_fl_job_threshold, "Set threshold count for start_fl_job.")
|
||||
.def("set_fl_name", &PSContext::set_fl_name, "Set federated learning name.")
|
||||
.def("set_fl_iteration_num", &PSContext::set_fl_iteration_num, "Set federated learning iteration number.")
|
||||
.def("set_client_epoch_num", &PSContext::set_client_epoch_num, "Set federated learning client epoch number.")
|
||||
.def("set_client_batch_size", &PSContext::set_client_batch_size, "Set federated learning client batch size.")
|
||||
.def("set_secure_aggregation", &PSContext::set_secure_aggregation,
|
||||
"Set federated learning client using secure aggregation.");
|
||||
|
||||
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
|
||||
.def(py::init())
|
||||
|
|
|
@ -55,6 +55,7 @@
|
|||
#include "ps/worker.h"
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#include "ps/server/server.h"
|
||||
#endif
|
||||
|
||||
#if (ENABLE_GE || ENABLE_D)
|
||||
|
@ -529,6 +530,11 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
|
|||
std::string backend = MsContext::GetInstance()->backend_policy();
|
||||
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
const std::string &server_mode = ps::PSContext::instance()->server_mode();
|
||||
if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) &&
|
||||
ps::PSContext::instance()->is_server()) {
|
||||
return ServerPipeline();
|
||||
}
|
||||
if (ps::PSContext::instance()->is_server()) {
|
||||
resource->results()[kBackend] = compile::CreateBackend();
|
||||
return PServerPipeline();
|
||||
|
|
|
@ -50,10 +50,13 @@ if(NOT ENABLE_CPU OR WIN32)
|
|||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/apply_momentum_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/aggregation_kernel_factory.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/fed_avg_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/optimizer_kernel_factory.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel_factory.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/start_fl_job_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/update_model_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/get_model_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/params_info.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/consistent_hash_ring.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/iteration_timer.cc")
|
||||
|
@ -67,6 +70,7 @@ if(NOT ENABLE_CPU OR WIN32)
|
|||
list(REMOVE_ITEM _PS_SRC_FILES "server/iteration.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/model_store.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/round.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/server.cc")
|
||||
endif()
|
||||
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc")
|
||||
|
|
|
@ -61,7 +61,12 @@ void PSContext::SetPSEnable(bool enabled) {
|
|||
}
|
||||
}
|
||||
|
||||
bool PSContext::is_ps_mode() const { return ps_enabled_; }
|
||||
bool PSContext::is_ps_mode() const {
|
||||
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
|
||||
return true;
|
||||
}
|
||||
return ps_enabled_;
|
||||
}
|
||||
|
||||
void PSContext::Reset() {
|
||||
ps_enabled_ = false;
|
||||
|
@ -77,6 +82,9 @@ void PSContext::Reset() {
|
|||
}
|
||||
|
||||
std::string PSContext::ms_role() const {
|
||||
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
|
||||
return role_;
|
||||
}
|
||||
if (is_worker_) {
|
||||
return kEnvRoleOfWorker;
|
||||
} else if (is_pserver_) {
|
||||
|
@ -88,11 +96,26 @@ std::string PSContext::ms_role() const {
|
|||
}
|
||||
}
|
||||
|
||||
bool PSContext::is_worker() const { return is_worker_; }
|
||||
bool PSContext::is_worker() const {
|
||||
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
|
||||
return role_ == kRoleOfWorker;
|
||||
}
|
||||
return is_worker_;
|
||||
}
|
||||
|
||||
bool PSContext::is_server() const { return is_pserver_; }
|
||||
bool PSContext::is_server() const {
|
||||
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
|
||||
return role_ == kEnvRoleOfServer;
|
||||
}
|
||||
return is_pserver_;
|
||||
}
|
||||
|
||||
bool PSContext::is_scheduler() const { return is_sched_; }
|
||||
bool PSContext::is_scheduler() const {
|
||||
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
|
||||
return role_ == kEnvRoleOfScheduler;
|
||||
}
|
||||
return is_sched_;
|
||||
}
|
||||
|
||||
uint32_t PSContext::initial_worker_num() { return worker_num_; }
|
||||
|
||||
|
@ -150,6 +173,94 @@ void PSContext::set_rank_id(int rank_id) const {
|
|||
#endif
|
||||
}
|
||||
|
||||
void PSContext::set_server_mode(const std::string &server_mode) {
|
||||
if (server_mode != kServerModePS && server_mode != kServerModeFL && server_mode != kServerModeHybrid) {
|
||||
MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL
|
||||
<< " or " << kServerModeHybrid;
|
||||
return;
|
||||
}
|
||||
server_mode_ = server_mode;
|
||||
}
|
||||
|
||||
const std::string &PSContext::server_mode() const { return server_mode_; }
|
||||
|
||||
void PSContext::set_ms_role(const std::string &role) {
|
||||
if (server_mode_ != kServerModeFL && server_mode_ != kServerModeHybrid) {
|
||||
MS_LOG(EXCEPTION) << "Only federated learning supports to set role by ps context.";
|
||||
return;
|
||||
}
|
||||
if (role != kEnvRoleOfWorker && role != kEnvRoleOfServer && role != kEnvRoleOfScheduler) {
|
||||
MS_LOG(EXCEPTION) << "ms_role " << role << " is invalid.";
|
||||
return;
|
||||
}
|
||||
role_ = role;
|
||||
}
|
||||
|
||||
void PSContext::set_worker_num(uint32_t worker_num) { worker_num_ = worker_num; }
|
||||
uint32_t PSContext::worker_num() const { return worker_num_; }
|
||||
|
||||
void PSContext::set_server_num(uint32_t server_num) {
|
||||
if (server_num == 0) {
|
||||
MS_LOG(EXCEPTION) << "Server number must be greater than 0.";
|
||||
return;
|
||||
}
|
||||
server_num_ = server_num;
|
||||
}
|
||||
uint32_t PSContext::server_num() const { return server_num_; }
|
||||
|
||||
void PSContext::set_scheduler_ip(const std::string &sched_ip) { scheduler_host_ = sched_ip; }
|
||||
|
||||
std::string PSContext::scheduler_ip() const { return scheduler_host_; }
|
||||
|
||||
void PSContext::set_scheduler_port(uint16_t sched_port) { scheduler_port_ = sched_port; }
|
||||
|
||||
uint16_t PSContext::scheduler_port() const { return scheduler_port_; }
|
||||
|
||||
void PSContext::GenerateResetterRound() {
|
||||
uint32_t binary_server_context = 0;
|
||||
bool is_parameter_server_mode = false;
|
||||
bool is_federated_learning_mode = false;
|
||||
bool is_mixed_training_mode = false;
|
||||
|
||||
if (server_mode_ == kServerModePS) {
|
||||
is_parameter_server_mode = true;
|
||||
} else if (server_mode_ == kServerModeFL) {
|
||||
is_federated_learning_mode = true;
|
||||
} else if (server_mode_ == kServerModeHybrid) {
|
||||
is_mixed_training_mode = true;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << server_mode_ << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL
|
||||
<< " or " << kServerModeHybrid;
|
||||
return;
|
||||
}
|
||||
|
||||
binary_server_context = (is_parameter_server_mode << 0) | (is_federated_learning_mode << 1) |
|
||||
(is_mixed_training_mode << 2) | (secure_aggregation_ << 3) | (worker_overwrite_weights_ << 4);
|
||||
if (kServerContextToResetRoundMap.count(binary_server_context) == 0) {
|
||||
resetter_round_ = ResetterRound::kNoNeedToReset;
|
||||
} else {
|
||||
resetter_round_ = kServerContextToResetRoundMap.at(binary_server_context);
|
||||
}
|
||||
MS_LOG(INFO) << "Server context is " << binary_server_context << ". Resetter round is " << resetter_round_;
|
||||
return;
|
||||
}
|
||||
|
||||
ResetterRound PSContext::resetter_round() const { return resetter_round_; }
|
||||
|
||||
void PSContext::set_fl_server_port(uint16_t fl_server_port) { fl_server_port_ = fl_server_port; }
|
||||
|
||||
uint16_t PSContext::fl_server_port() const { return fl_server_port_; }
|
||||
|
||||
void PSContext::set_fl_client_enable(bool enabled) { fl_client_enable_ = enabled; }
|
||||
|
||||
bool PSContext::fl_client_enable() { return fl_client_enable_; }
|
||||
|
||||
void PSContext::set_start_fl_job_threshold(size_t start_fl_job_threshold) {
|
||||
start_fl_job_threshold_ = start_fl_job_threshold;
|
||||
}
|
||||
|
||||
size_t PSContext::start_fl_job_threshold() const { return start_fl_job_threshold_; }
|
||||
|
||||
void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
|
||||
|
||||
const std::string &PSContext::fl_name() const { return fl_name_; }
|
||||
|
@ -165,5 +276,15 @@ uint64_t PSContext::client_epoch_num() const { return client_epoch_num_; }
|
|||
void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch_size_ = client_batch_size; }
|
||||
|
||||
uint64_t PSContext::client_batch_size() const { return client_batch_size_; }
|
||||
|
||||
void PSContext::set_worker_overwrite_weights(uint64_t worker_overwrite_weights) {
|
||||
worker_overwrite_weights_ = worker_overwrite_weights;
|
||||
}
|
||||
|
||||
uint64_t PSContext::worker_overwrite_weights() const { return worker_overwrite_weights_; }
|
||||
|
||||
void PSContext::set_secure_aggregation(bool secure_aggregation) { secure_aggregation_ = secure_aggregation; }
|
||||
|
||||
bool PSContext::secure_aggregation() const { return secure_aggregation_; }
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_PS_CONTEXT_H_
|
||||
#define MINDSPORE_CCSRC_PS_CONTEXT_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ps/constants.h"
|
||||
|
@ -24,12 +25,32 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
constexpr char kServerModePS[] = "PARAMETER_SERVER";
|
||||
constexpr char kServerModeFL[] = "FEDERATED_LEARNING";
|
||||
constexpr char kServerModeHybrid[] = "HYBRID_TRAINING";
|
||||
constexpr char kEnvRole[] = "MS_ROLE";
|
||||
constexpr char kEnvRoleOfPServer[] = "MS_PSERVER";
|
||||
constexpr char kEnvRoleOfServer[] = "MS_SERVER";
|
||||
constexpr char kEnvRoleOfWorker[] = "MS_WORKER";
|
||||
constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
|
||||
constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
|
||||
|
||||
// Use binary data to represent federated learning server's context so that we can judge which round resets the
|
||||
// iteration. From right to left, each bit stands for:
|
||||
// 0: Server is in parameter server mode.
|
||||
// 1: Server is in federated learning mode.
|
||||
// 2: Server is in mixed training mode.
|
||||
// 3: Server enables sucure aggregation.
|
||||
// 4: Server needs worker to overwrite weights.
|
||||
// For example: 01010 stands for that the server is in federated learning mode and sucure aggregation is enabled.
|
||||
enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kWorkerOverwriteWeights };
|
||||
const std::map<uint32_t, ResetterRound> kServerContextToResetRoundMap = {
|
||||
{0b00010, ResetterRound::kUpdateModel},
|
||||
{0b01010, ResetterRound::kReconstructSeccrets},
|
||||
{0b11100, ResetterRound::kWorkerOverwriteWeights},
|
||||
{0b10100, ResetterRound::kWorkerOverwriteWeights},
|
||||
{0b00100, ResetterRound::kUpdateModel}};
|
||||
|
||||
class PSContext {
|
||||
public:
|
||||
~PSContext() = default;
|
||||
|
@ -60,19 +81,64 @@ class PSContext {
|
|||
void set_cache_enable(bool cache_enable) const;
|
||||
void set_rank_id(int rank_id) const;
|
||||
|
||||
// Setter and getter for federated learning.
|
||||
// In new server framework, process role, worker number, server number, scheduler ip and scheduler port should be set
|
||||
// by ps_context.
|
||||
void set_server_mode(const std::string &server_mode);
|
||||
const std::string &server_mode() const;
|
||||
|
||||
void set_ms_role(const std::string &role);
|
||||
|
||||
void set_worker_num(uint32_t worker_num);
|
||||
uint32_t worker_num() const;
|
||||
|
||||
void set_server_num(uint32_t server_num);
|
||||
uint32_t server_num() const;
|
||||
|
||||
void set_scheduler_ip(const std::string &sched_ip);
|
||||
std::string scheduler_ip() const;
|
||||
|
||||
void set_scheduler_port(uint16_t sched_port);
|
||||
uint16_t scheduler_port() const;
|
||||
|
||||
// Methods federated learning.
|
||||
|
||||
// Generate which round should reset the iteration.
|
||||
void GenerateResetterRound();
|
||||
ResetterRound resetter_round() const;
|
||||
|
||||
void set_fl_server_port(uint16_t fl_server_port);
|
||||
uint16_t fl_server_port() const;
|
||||
|
||||
// Set true if this process is a federated learning worker in cross-silo scenario.
|
||||
void set_fl_client_enable(bool enabled);
|
||||
bool fl_client_enable();
|
||||
|
||||
void set_start_fl_job_threshold(size_t start_fl_job_threshold);
|
||||
size_t start_fl_job_threshold() const;
|
||||
|
||||
void set_fl_name(const std::string &fl_name);
|
||||
const std::string &fl_name() const;
|
||||
|
||||
// Set the iteration number of the federated learning.
|
||||
void set_fl_iteration_num(uint64_t fl_iteration_num);
|
||||
uint64_t fl_iteration_num() const;
|
||||
|
||||
// Set the training epoch number of the client.
|
||||
void set_client_epoch_num(uint64_t client_epoch_num);
|
||||
uint64_t client_epoch_num() const;
|
||||
|
||||
// Set the data batch size of the client.
|
||||
void set_client_batch_size(uint64_t client_batch_size);
|
||||
uint64_t client_batch_size() const;
|
||||
|
||||
// Set true if worker will overwrite weights on server. Used in hybrid training.
|
||||
void set_worker_overwrite_weights(uint64_t worker_overwrite_weights);
|
||||
uint64_t worker_overwrite_weights() const;
|
||||
|
||||
// Set true if using secure aggregation for federated learning.
|
||||
void set_secure_aggregation(bool secure_aggregation);
|
||||
bool secure_aggregation() const;
|
||||
|
||||
private:
|
||||
PSContext()
|
||||
: ps_enabled_(false),
|
||||
|
@ -94,11 +160,22 @@ class PSContext {
|
|||
std::string scheduler_host_;
|
||||
uint16_t scheduler_port_;
|
||||
|
||||
std::string role_;
|
||||
|
||||
// Members for federated learning.
|
||||
std::string server_mode_;
|
||||
ResetterRound resetter_round_;
|
||||
uint16_t fl_server_port_;
|
||||
bool fl_client_enable_;
|
||||
std::string fl_name_;
|
||||
size_t start_fl_job_threshold_;
|
||||
uint64_t fl_iteration_num_;
|
||||
uint64_t client_epoch_num_;
|
||||
uint64_t client_batch_size_;
|
||||
bool worker_overwrite_weights_;
|
||||
|
||||
// Federated learning security.
|
||||
bool secure_aggregation_;
|
||||
};
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,6 +20,9 @@ namespace mindspore {
|
|||
namespace ps {
|
||||
void Scheduler::Run() {
|
||||
MS_LOG(INFO) << "Start scheduler.";
|
||||
core::ClusterMetadata::instance()->Init(
|
||||
PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(),
|
||||
PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port());
|
||||
scheduler_node_.Start();
|
||||
scheduler_node_.Finish();
|
||||
scheduler_node_.Stop();
|
||||
|
|
|
@ -44,6 +44,14 @@ enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER };
|
|||
enum CommType { HTTP = 0, TCP };
|
||||
enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum };
|
||||
|
||||
struct RoundConfig {
|
||||
std::string name;
|
||||
bool check_timeout = false;
|
||||
size_t time_window = 3000;
|
||||
bool check_count = false;
|
||||
size_t threshold_count = 0;
|
||||
};
|
||||
|
||||
using mindspore::kernel::Address;
|
||||
using mindspore::kernel::AddressPtr;
|
||||
using mindspore::kernel::CPUKernel;
|
||||
|
@ -73,6 +81,7 @@ using ReuseKernelNodeInfo = std::map<std::string, size_t>;
|
|||
using UploadData = std::map<std::string, Address>;
|
||||
|
||||
constexpr auto kWeight = "weight";
|
||||
constexpr auto kNewWeight = "new_weight";
|
||||
constexpr auto kAccumulation = "accum";
|
||||
constexpr auto kLearningRate = "lr";
|
||||
constexpr auto kGradient = "grad";
|
||||
|
@ -87,6 +96,8 @@ constexpr auto kAdamBeta1 = "beta1";
|
|||
constexpr auto kAdamBeta2 = "beta2";
|
||||
constexpr auto kAdamEps = "eps";
|
||||
constexpr auto kFtrlLinear = "linear";
|
||||
constexpr auto kDataSize = "data_size";
|
||||
constexpr auto kNewDataSize = "new_data_size";
|
||||
|
||||
// OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is
|
||||
// launched.
|
||||
|
@ -137,6 +148,7 @@ constexpr size_t kExecutorMaxTaskNum = 32;
|
|||
constexpr int kHttpSuccess = 200;
|
||||
constexpr auto kPBProtocol = "PB";
|
||||
constexpr auto kFBSProtocol = "FBS";
|
||||
constexpr auto kFedAvg = "FedAvg";
|
||||
constexpr auto kAggregationKernelType = "Aggregation";
|
||||
constexpr auto kOptimizerKernelType = "Optimizer";
|
||||
constexpr auto kCtxFuncGraph = "FuncGraph";
|
||||
|
@ -145,6 +157,8 @@ constexpr auto kCtxDeviceMetas = "device_metas";
|
|||
constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration";
|
||||
constexpr auto kCtxUpdateModelClientList = "update_model_client_list";
|
||||
constexpr auto kCtxUpdateModelClientNum = "update_model_client_num";
|
||||
constexpr auto kCtxUpdateModelThld = "update_model_threshold";
|
||||
constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size";
|
||||
|
||||
// This macro the current timestamp in milliseconds.
|
||||
#define CURRENT_TIME_MILLI \
|
||||
|
|
|
@ -112,19 +112,19 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) {
|
|||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||
return global_current_count_[name].size() == global_threshold_count_[name];
|
||||
} else {
|
||||
CountReachThresholdRequest count_reach_threashold_req;
|
||||
count_reach_threashold_req.set_name(name);
|
||||
CountReachThresholdRequest count_reach_threshold_req;
|
||||
count_reach_threshold_req.set_name(name);
|
||||
|
||||
std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr;
|
||||
if (!communicator_->SendPbRequest(count_reach_threashold_req, counting_server_rank_,
|
||||
if (!communicator_->SendPbRequest(count_reach_threshold_req, counting_server_rank_,
|
||||
core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) {
|
||||
MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name;
|
||||
return false;
|
||||
}
|
||||
|
||||
CountReachThresholdResponse count_reach_threashold_rsp;
|
||||
count_reach_threashold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), query_cnt_enough_rsp_msg->size());
|
||||
return count_reach_threashold_rsp.is_enough();
|
||||
CountReachThresholdResponse count_reach_threshold_rsp;
|
||||
count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), query_cnt_enough_rsp_msg->size());
|
||||
return count_reach_threshold_rsp.is_enough();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -200,9 +200,9 @@ void DistributedCountService::HandleCountReachThresholdRequest(const std::shared
|
|||
return;
|
||||
}
|
||||
|
||||
CountReachThresholdRequest count_reach_threashold_req;
|
||||
count_reach_threashold_req.ParseFromArray(message->data(), message->len());
|
||||
const std::string &name = count_reach_threashold_req.name();
|
||||
CountReachThresholdRequest count_reach_threshold_req;
|
||||
count_reach_threshold_req.ParseFromArray(message->data(), message->len());
|
||||
const std::string &name = count_reach_threshold_req.name();
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||
if (global_threshold_count_.count(name) == 0) {
|
||||
|
@ -210,10 +210,10 @@ void DistributedCountService::HandleCountReachThresholdRequest(const std::shared
|
|||
return;
|
||||
}
|
||||
|
||||
CountReachThresholdResponse count_reach_threashold_rsp;
|
||||
count_reach_threashold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]);
|
||||
communicator_->SendResponse(count_reach_threashold_rsp.SerializeAsString().data(),
|
||||
count_reach_threashold_rsp.SerializeAsString().size(), message);
|
||||
CountReachThresholdResponse count_reach_threshold_rsp;
|
||||
count_reach_threshold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]);
|
||||
communicator_->SendResponse(count_reach_threshold_rsp.SerializeAsString().data(),
|
||||
count_reach_threshold_rsp.SerializeAsString().size(), message);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -193,7 +193,29 @@ void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<co
|
|||
|
||||
bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) {
|
||||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||
metadata_[name] = meta;
|
||||
if (meta.has_device_meta()) {
|
||||
auto &fl_id_to_meta_map = *metadata_[name].mutable_device_metas()->mutable_fl_id_to_meta();
|
||||
auto &fl_id = meta.device_meta().fl_id();
|
||||
auto &device_meta = meta.device_meta();
|
||||
fl_id_to_meta_map[fl_id] = device_meta;
|
||||
} else if (meta.has_fl_id()) {
|
||||
auto client_list = metadata_[name].mutable_client_list();
|
||||
auto &fl_id = meta.fl_id().fl_id();
|
||||
// Check whether the new item already exists.
|
||||
bool add_flag = true;
|
||||
for (int i = 0; i < client_list->fl_id_size(); i++) {
|
||||
if (fl_id == client_list->fl_id(i)) {
|
||||
add_flag = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (add_flag) {
|
||||
client_list->add_fl_id(fl_id);
|
||||
}
|
||||
} else if (meta.has_update_model_threshold()) {
|
||||
auto update_model_threshold = metadata_[name].mutable_update_model_threshold();
|
||||
*update_model_threshold = meta.update_model_threshold();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace server
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
void Executor::Init(const FuncGraphPtr &func_graph, size_t aggregation_count) {
|
||||
void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (aggregation_count == 0) {
|
||||
MS_LOG(EXCEPTION) << "Server aggregation count must be greater than 0";
|
||||
|
|
|
@ -43,7 +43,7 @@ class Executor {
|
|||
// be used for aggregators.
|
||||
// As noted in header file parameter_aggregator.h, we create aggregators by trainable parameters, which is the
|
||||
// optimizer cnode's input. So we need to initialize server executor using func_graph.
|
||||
void Init(const FuncGraphPtr &func_graph, size_t aggregation_count);
|
||||
void Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count);
|
||||
|
||||
// Called in parameter server training mode to do Push operation.
|
||||
// For the same trainable parameter, HandlePush method must be called aggregation_count_ times before it's considered
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/kernel/fed_avg_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
REG_AGGREGATION_KERNEL_TWO(FedAvg,
|
||||
ParamsInfo()
|
||||
.AddInputNameType(kWeight, kNumberTypeFloat32)
|
||||
.AddInputNameType(kDataSize, kNumberTypeUInt64)
|
||||
.AddInputNameType(kNewWeight, kNumberTypeFloat32)
|
||||
.AddInputNameType(kNewDataSize, kNumberTypeUInt64),
|
||||
FedAvgKernel, float, size_t)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,179 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/collective_ops_impl.h"
|
||||
#include "ps/server/distributed_count_service.h"
|
||||
#include "ps/server/local_meta_store.h"
|
||||
#include "ps/server/kernel/aggregation_kernel.h"
|
||||
#include "ps/server/kernel/aggregation_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
// The implementation for the federated average. We do weighted average for the weights. The uploaded weights from
|
||||
// FL-clients is already multiplied by its data size so only sum and division are done in this kernel.
|
||||
|
||||
// Pay attention that this kernel is the distributed version of federated average, which means each server node in the
|
||||
// cluster in invalved in the aggragation process. So the DistributedCountService and CollectiveOpsImpl are called.
|
||||
template <typename T, typename S>
|
||||
class FedAvgKernel : public AggregationKernel {
|
||||
public:
|
||||
FedAvgKernel() : participated_(false) {}
|
||||
~FedAvgKernel() override = default;
|
||||
void InitKernel(const CNodePtr &kernel_node) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 ||
|
||||
kNameToIdxMap.at(cnode_name).at("inputs").count("weight") == 0) {
|
||||
MS_LOG(EXCEPTION) << "Can't find index info of weight for kernel " << cnode_name;
|
||||
return;
|
||||
}
|
||||
cnode_weight_idx_ = kNameToIdxMap.at(cnode_name).at("inputs").at("weight");
|
||||
std::vector<size_t> weight_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, cnode_weight_idx_);
|
||||
size_t weight_size =
|
||||
std::accumulate(weight_shape.begin(), weight_shape.end(), sizeof(T), std::multiplies<size_t>());
|
||||
size_t new_weight_size = weight_size;
|
||||
|
||||
input_size_list_.push_back(weight_size);
|
||||
input_size_list_.push_back(sizeof(size_t));
|
||||
input_size_list_.push_back(new_weight_size);
|
||||
input_size_list_.push_back(sizeof(size_t));
|
||||
|
||||
auto weight_node =
|
||||
AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, cnode_weight_idx_), 0).first;
|
||||
MS_EXCEPTION_IF_NULL(weight_node);
|
||||
name_ = cnode_name + "." + weight_node->fullname_with_scope();
|
||||
MS_LOG(INFO) << "Register counter for " << name_;
|
||||
auto first_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) {
|
||||
std::unique_lock<std::mutex> lock(weight_mutex_);
|
||||
if (!participated_) {
|
||||
ClearWeightAndDataSize();
|
||||
}
|
||||
};
|
||||
auto last_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) {
|
||||
T *weight_addr = reinterpret_cast<T *>(weight_addr_->addr);
|
||||
size_t weight_size = weight_addr_->size;
|
||||
S *data_size_addr = reinterpret_cast<S *>(data_size_addr_->addr);
|
||||
if (!CollectiveOpsImpl::GetInstance().AllReduce<T>(weight_addr, weight_addr, weight_size / sizeof(T))) {
|
||||
MS_LOG(ERROR) << "Federated average allreduce failed.";
|
||||
return;
|
||||
}
|
||||
if (!CollectiveOpsImpl::GetInstance().AllReduce<S>(data_size_addr, data_size_addr, 1)) {
|
||||
MS_LOG(ERROR) << "Federated average allreduce failed.";
|
||||
return;
|
||||
}
|
||||
LocalMetaStore::GetInstance().put_value(kCtxFedAvgTotalDataSize, data_size_addr[0]);
|
||||
for (size_t i = 0; i < weight_size / sizeof(T); i++) {
|
||||
weight_addr[i] /= data_size_addr[0];
|
||||
}
|
||||
done_ = true;
|
||||
DistributedCountService::GetInstance().ResetCounter(name_);
|
||||
return;
|
||||
};
|
||||
DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler, last_cnt_handler});
|
||||
return;
|
||||
}
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
std::unique_lock<std::mutex> lock(weight_mutex_);
|
||||
// The weight and new_weight values should be multiplied by clients already, so we don't need to do multiplication
|
||||
// again.
|
||||
T *weight_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
S *data_size_addr = reinterpret_cast<S *>(inputs[1]->addr);
|
||||
T *new_weight_addr = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
S *new_data_size_addr = reinterpret_cast<S *>(inputs[3]->addr);
|
||||
if (accum_count_ == 0) {
|
||||
ClearWeightAndDataSize();
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "Iteration: " << LocalMetaStore::GetInstance().curr_iter_num() << " launching FedAvgKernel for "
|
||||
<< name_ << " new data size is " << new_data_size_addr[0] << ", current total data size is "
|
||||
<< data_size_addr[0];
|
||||
for (size_t i = 0; i < inputs[2]->size / sizeof(T); i++) {
|
||||
weight_addr[i] += new_weight_addr[i];
|
||||
}
|
||||
data_size_addr[0] += new_data_size_addr[0];
|
||||
lock.unlock();
|
||||
|
||||
accum_count_++;
|
||||
participated_ = true;
|
||||
DistributedCountService::GetInstance().Count(
|
||||
name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_));
|
||||
return true;
|
||||
}
|
||||
void Reset() {
|
||||
accum_count_ = 0;
|
||||
done_ = false;
|
||||
participated_ = false;
|
||||
DistributedCountService::GetInstance().ResetCounter(name_);
|
||||
return;
|
||||
}
|
||||
bool IsAggregationDone() { return done_; }
|
||||
|
||||
private:
|
||||
void GenerateReuseKernelNodeInfo() override {
|
||||
// Only the trainable parameter is reused for federated average.
|
||||
reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, cnode_weight_idx_));
|
||||
return;
|
||||
}
|
||||
|
||||
// In some cases, the Launch method is not called and the weights involved in AllReduce should be set to 0.
|
||||
void ClearWeightAndDataSize() {
|
||||
int ret = memset_s(weight_addr_->addr, weight_addr_->size, 0x00, weight_addr_->size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memset_s error, errorno(" << ret << ")";
|
||||
return;
|
||||
}
|
||||
ret = memset_s(data_size_addr_->addr, data_size_addr_->size, 0x00, data_size_addr_->size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memset_s error, errorno(" << ret << ")";
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// The trainable parameter index of the kernel node which is parsed from the frontend func_graph.
|
||||
size_t cnode_weight_idx_;
|
||||
|
||||
// The address pointer of the inputs.
|
||||
AddressPtr weight_addr_;
|
||||
AddressPtr data_size_addr_;
|
||||
AddressPtr new_weight_addr_;
|
||||
AddressPtr new_data_size_addr_;
|
||||
|
||||
// Whether the kernel's Launch method is called.
|
||||
bool participated_;
|
||||
|
||||
// The kernel could be called concurrently so we need lock to ensure threadsafe.
|
||||
std::mutex weight_mutex_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_
|
|
@ -0,0 +1,125 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/kernel/round/get_model_kernel.h"
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ps/server/model_store.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
void GetModelKernel::InitKernel(size_t) {
|
||||
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
}
|
||||
|
||||
executor_ = &Executor::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
if (!executor_->initialized()) {
|
||||
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
bool GetModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
MS_LOG(INFO) << "Launching GetModelKernel kernel.";
|
||||
void *req_data = inputs[0]->addr;
|
||||
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
|
||||
return false;
|
||||
}
|
||||
|
||||
const schema::RequestGetModel *get_model_req = flatbuffers::GetRoot<schema::RequestGetModel>(req_data);
|
||||
GetModel(get_model_req, fbb);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GetModelKernel::Reset() {
|
||||
MS_LOG(INFO) << "Get model kernel reset!";
|
||||
StopTimer();
|
||||
return true;
|
||||
}
|
||||
|
||||
void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, const std::shared_ptr<FBBuilder> &fbb) {
|
||||
std::map<std::string, AddressPtr> feature_maps;
|
||||
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
size_t get_model_iter = static_cast<size_t>(get_model_req->iteration());
|
||||
const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
|
||||
size_t latest_iter_num = iter_to_model.rbegin()->first;
|
||||
|
||||
if ((current_iter == get_model_iter && latest_iter_num != current_iter) || current_iter == get_model_iter - 1) {
|
||||
std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter);
|
||||
BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps,
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
MS_LOG(WARNING) << reason;
|
||||
return;
|
||||
}
|
||||
|
||||
if (iter_to_model.count(get_model_iter) == 0) {
|
||||
std::string reason = "The iteration of GetModel request" + std::to_string(get_model_iter) +
|
||||
" is invalid. Current iteration is " + std::to_string(current_iter);
|
||||
BuildGetModelRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter, feature_maps,
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return;
|
||||
}
|
||||
|
||||
feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter);
|
||||
BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED,
|
||||
"Get model for iteration " + std::to_string(get_model_iter) + " success.", current_iter,
|
||||
feature_maps, std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
return;
|
||||
}
|
||||
|
||||
void GetModelKernel::BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason, const size_t iter,
|
||||
const std::map<std::string, AddressPtr> &feature_maps,
|
||||
const std::string ×tamp) {
|
||||
auto fbs_reason = fbb->CreateString(reason);
|
||||
auto fbs_timestamp = fbb->CreateString(timestamp);
|
||||
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
|
||||
for (const auto &feature_map : feature_maps) {
|
||||
auto fbs_weight_fullname = fbb->CreateString(feature_map.first);
|
||||
auto fbs_weight_data =
|
||||
fbb->CreateVector(reinterpret_cast<float *>(feature_map.second->addr), feature_map.second->size / sizeof(float));
|
||||
auto fbs_feature_map = schema::CreateFeatureMap(*(fbb.get()), fbs_weight_fullname, fbs_weight_data);
|
||||
fbs_feature_maps.push_back(fbs_feature_map);
|
||||
}
|
||||
auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
|
||||
|
||||
schema::ResponseGetModelBuilder rsp_get_model_builder(*(fbb.get()));
|
||||
rsp_get_model_builder.add_retcode(retcode);
|
||||
rsp_get_model_builder.add_reason(fbs_reason);
|
||||
rsp_get_model_builder.add_iteration(static_cast<int>(iter));
|
||||
rsp_get_model_builder.add_feature_map(fbs_feature_maps_vector);
|
||||
rsp_get_model_builder.add_timestamp(fbs_timestamp);
|
||||
auto rsp_get_model = rsp_get_model_builder.Finish();
|
||||
fbb->Finish(rsp_get_model);
|
||||
return;
|
||||
}
|
||||
|
||||
REG_ROUND_KERNEL(getModel, GetModelKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_MODEL_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_MODEL_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/executor.h"
|
||||
#include "ps/server/kernel/round/round_kernel.h"
|
||||
#include "ps/server/kernel/round/round_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
class GetModelKernel : public RoundKernel {
|
||||
public:
|
||||
GetModelKernel() = default;
|
||||
~GetModelKernel() override = default;
|
||||
|
||||
void InitKernel(size_t) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
bool Reset() override;
|
||||
|
||||
private:
|
||||
void GetModel(const schema::RequestGetModel *get_model_req, const std::shared_ptr<FBBuilder> &fbb);
|
||||
void BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason, const size_t iter,
|
||||
const std::map<std::string, AddressPtr> &feature_maps, const std::string ×tamp);
|
||||
|
||||
// The executor is for getting model for getModel request.
|
||||
Executor *executor_;
|
||||
|
||||
// The time window of one iteration.
|
||||
size_t iteration_time_window_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
|
@ -49,7 +49,7 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
return false;
|
||||
}
|
||||
void *req_data = inputs[0]->addr;
|
||||
const std::shared_ptr<FBBuilder> &fbb = std::make_shared<FBBuilder>();
|
||||
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
|
||||
return false;
|
||||
|
|
|
@ -0,0 +1,203 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ps/server/kernel/round/update_model_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
void UpdateModelKernel::InitKernel(size_t threshold_count) {
|
||||
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
}
|
||||
|
||||
executor_ = &Executor::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
if (!executor_->initialized()) {
|
||||
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
||||
return;
|
||||
}
|
||||
|
||||
PBMetadata client_list;
|
||||
DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxUpdateModelClientList, client_list);
|
||||
LocalMetaStore::GetInstance().put_value(kCtxUpdateModelThld, threshold_count);
|
||||
}
|
||||
|
||||
bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "inputs or outputs size is invalid.";
|
||||
return false;
|
||||
}
|
||||
void *req_data = inputs[0]->addr;
|
||||
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Launching UpdateModelKernel kernel.";
|
||||
if (!ReachThresholdForUpdateModel(fbb)) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return false;
|
||||
}
|
||||
|
||||
const schema::RequestUpdateModel *update_model_req = flatbuffers::GetRoot<schema::RequestUpdateModel>(req_data);
|
||||
if (!UpdateModel(update_model_req, fbb)) {
|
||||
MS_LOG(ERROR) << "Updating model failed.";
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!CountForUpdateModel(fbb, update_model_req)) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return false;
|
||||
}
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool UpdateModelKernel::Reset() {
|
||||
MS_LOG(INFO) << "Update model kernel reset!";
|
||||
StopTimer();
|
||||
DistributedCountService::GetInstance().ResetCounter(name_);
|
||||
executor_->ResetAggregationStatus();
|
||||
DistributedMetadataStore::GetInstance().ResetMetadata(kCtxUpdateModelClientList);
|
||||
size_t &total_data_size = LocalMetaStore::GetInstance().mutable_value<size_t>(kCtxFedAvgTotalDataSize);
|
||||
total_data_size = 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
if (PSContext::instance()->resetter_round() == ResetterRound::kUpdateModel) {
|
||||
while (!executor_->IsAllWeightAggregationDone()) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||
}
|
||||
|
||||
size_t total_data_size = LocalMetaStore::GetInstance().value<size_t>(kCtxFedAvgTotalDataSize);
|
||||
MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is "
|
||||
<< total_data_size;
|
||||
|
||||
FinishIterCb();
|
||||
}
|
||||
}
|
||||
|
||||
bool UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb) {
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
std::string reason = "Current amount for updateModel is enough.";
|
||||
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
|
||||
const std::shared_ptr<FBBuilder> &fbb) {
|
||||
size_t iteration = static_cast<size_t>(update_model_req->iteration());
|
||||
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
|
||||
std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) +
|
||||
", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num());
|
||||
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
||||
PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas);
|
||||
FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas();
|
||||
std::string update_model_fl_id = update_model_req->fl_id()->str();
|
||||
if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
|
||||
std::string reason = "devices_meta for " + update_model_fl_id + " is not set.";
|
||||
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size();
|
||||
auto feature_map = ParseFeatureMap(update_model_req);
|
||||
for (auto weight : feature_map) {
|
||||
weight.second[kNewDataSize].addr = &data_size;
|
||||
weight.second[kNewDataSize].size = sizeof(size_t);
|
||||
executor_->HandleModelUpdate(weight.first, weight.second);
|
||||
}
|
||||
|
||||
FLId fl_id;
|
||||
fl_id.set_fl_id(update_model_fl_id);
|
||||
PBMetadata comm_value;
|
||||
*comm_value.mutable_fl_id() = fl_id;
|
||||
DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value);
|
||||
|
||||
BuildUpdateModelRsp(fbb, schema::ResponseCode_SucNotReady, "success not ready",
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
return true;
|
||||
}
|
||||
|
||||
std::map<std::string, UploadData> UpdateModelKernel::ParseFeatureMap(
|
||||
const schema::RequestUpdateModel *update_model_req) {
|
||||
RETURN_IF_NULL(update_model_req, {});
|
||||
std::map<std::string, UploadData> feature_map;
|
||||
auto fbs_feature_map = update_model_req->feature_map();
|
||||
for (size_t i = 0; i < fbs_feature_map->size(); i++) {
|
||||
std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str();
|
||||
float *weight_data = const_cast<float *>(fbs_feature_map->Get(i)->data()->data());
|
||||
size_t weight_size = fbs_feature_map->Get(i)->data()->size() * sizeof(float);
|
||||
UploadData upload_data;
|
||||
upload_data[kNewWeight].addr = weight_data;
|
||||
upload_data[kNewWeight].size = weight_size;
|
||||
feature_map[weight_full_name] = upload_data;
|
||||
}
|
||||
return feature_map;
|
||||
}
|
||||
|
||||
bool UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
|
||||
const schema::RequestUpdateModel *update_model_req) {
|
||||
if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str())) {
|
||||
std::string reason = "UpdateModel counting failed.";
|
||||
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason, const std::string &next_req_time) {
|
||||
auto fbs_reason = fbb->CreateString(reason);
|
||||
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
||||
|
||||
schema::ResponseUpdateModelBuilder rsp_update_model_builder(*(fbb.get()));
|
||||
rsp_update_model_builder.add_retcode(retcode);
|
||||
rsp_update_model_builder.add_reason(fbs_reason);
|
||||
rsp_update_model_builder.add_next_req_time(fbs_next_req_time);
|
||||
auto rsp_update_model = rsp_update_model_builder.Finish();
|
||||
fbb->Finish(rsp_update_model);
|
||||
return;
|
||||
}
|
||||
|
||||
REG_ROUND_KERNEL(updateModel, UpdateModelKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/kernel/round/round_kernel.h"
|
||||
#include "ps/server/kernel/round/round_kernel_factory.h"
|
||||
#include "ps/server/executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
class UpdateModelKernel : public RoundKernel {
|
||||
public:
|
||||
UpdateModelKernel() = default;
|
||||
~UpdateModelKernel() override = default;
|
||||
|
||||
void InitKernel(size_t threshold_count) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
bool Reset() override;
|
||||
|
||||
// In some cases, the last updateModel message means this server iteration is finished.
|
||||
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
|
||||
|
||||
private:
|
||||
bool ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb);
|
||||
bool UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr<FBBuilder> &fbb);
|
||||
std::map<std::string, UploadData> ParseFeatureMap(const schema::RequestUpdateModel *update_model_req);
|
||||
bool CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestUpdateModel *update_model_req);
|
||||
void BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason, const std::string &next_req_time);
|
||||
|
||||
// The executor is for updating the model for updateModel request.
|
||||
Executor *executor_;
|
||||
|
||||
// The time window of one iteration.
|
||||
size_t iteration_time_window_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
|
@ -23,7 +23,7 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
void ModelStore::Init(uint32_t max_count) {
|
||||
void ModelStore::Initialize(uint32_t max_count) {
|
||||
if (!Executor::GetInstance().initialized()) {
|
||||
MS_LOG(EXCEPTION) << "Server's executor must be initialized before model storage.";
|
||||
return;
|
||||
|
|
|
@ -40,7 +40,7 @@ class ModelStore {
|
|||
}
|
||||
|
||||
// Initialize ModelStore with max count of models need to be stored.
|
||||
void Init(uint32_t max_count = 3);
|
||||
void Initialize(uint32_t max_count = 3);
|
||||
|
||||
// Store the model of the given iteration. The model is acquired from Executor. If the current model count is already
|
||||
// max_model_count_, the earliest model will be replaced.
|
||||
|
|
|
@ -302,7 +302,7 @@ bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr<ke
|
|||
}
|
||||
|
||||
std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &cnode) {
|
||||
std::vector<std::string> aggregation_algorithm = {};
|
||||
std::vector<std::string> aggregation_algorithm = {kFedAvg};
|
||||
MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm;
|
||||
return aggregation_algorithm;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,251 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/server.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <csignal>
|
||||
#include "ps/server/round.h"
|
||||
#include "ps/server/model_store.h"
|
||||
#include "ps/server/iteration.h"
|
||||
#include "ps/server/collective_ops_impl.h"
|
||||
#include "ps/server/distributed_metadata_store.h"
|
||||
#include "ps/server/distributed_count_service.h"
|
||||
#include "ps/server/kernel/round/round_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
static std::vector<std::shared_ptr<core::CommunicatorBase>> global_worker_server_comms = {};
|
||||
// This function is for the exit of server process when an interrupt signal is captured.
|
||||
void SignalHandler(int signal) {
|
||||
MS_LOG(INFO) << "Interrupt signal captured: " << signal;
|
||||
std::for_each(global_worker_server_comms.begin(), global_worker_server_comms.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
return;
|
||||
}
|
||||
|
||||
void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
|
||||
const FuncGraphPtr &func_graph, size_t executor_threshold) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
func_graph_ = func_graph;
|
||||
|
||||
if (rounds_config.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Rounds are empty.";
|
||||
return;
|
||||
}
|
||||
rounds_config_ = rounds_config;
|
||||
|
||||
use_tcp_ = use_tcp;
|
||||
use_http_ = use_http;
|
||||
http_port_ = http_port;
|
||||
executor_threshold_ = executor_threshold;
|
||||
return;
|
||||
}
|
||||
|
||||
// Each step of the server pipeline may have dependency on other steps, which includes:
|
||||
|
||||
// InitServerContext must be the first step to set contexts for later steps.
|
||||
|
||||
// Server Running relies on URL or Message Type Register:
|
||||
// StartCommunicator---->InitIteration
|
||||
|
||||
// Metadata Register relies on Hash Ring of Servers which relies on Network Building Completion:
|
||||
// RegisterRoundKernel---->StartCommunicator
|
||||
|
||||
// Kernel Initialization relies on Executor Initialization:
|
||||
// RegisterRoundKernel---->InitExecutor
|
||||
|
||||
// Getting Model Size relies on ModelStorage Initialization which relies on Executor Initialization:
|
||||
// InitCipher---->InitExecutor
|
||||
void Server::Run() {
|
||||
signal(SIGINT, SignalHandler);
|
||||
InitServerContext();
|
||||
InitCluster();
|
||||
InitIteration();
|
||||
StartCommunicator();
|
||||
InitExecutor();
|
||||
RegisterRoundKernel();
|
||||
MS_LOG(INFO) << "Server started successfully.";
|
||||
|
||||
// Wait communicators to stop so the main thread is blocked.
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Join(); });
|
||||
communicator_with_server_->Join();
|
||||
MsException::Instance().CheckException();
|
||||
return;
|
||||
}
|
||||
|
||||
void Server::InitServerContext() {
|
||||
PSContext::instance()->GenerateResetterRound();
|
||||
scheduler_ip_ = PSContext::instance()->scheduler_host();
|
||||
scheduler_port_ = PSContext::instance()->scheduler_port();
|
||||
worker_num_ = PSContext::instance()->initial_worker_num();
|
||||
server_num_ = PSContext::instance()->initial_server_num();
|
||||
return;
|
||||
}
|
||||
|
||||
void Server::InitCluster() {
|
||||
server_node_ = std::make_shared<core::ServerNode>();
|
||||
MS_EXCEPTION_IF_NULL(server_node_);
|
||||
task_executor_ = std::make_shared<core::TaskExecutor>(32);
|
||||
MS_EXCEPTION_IF_NULL(task_executor_);
|
||||
if (!InitCommunicatorWithServer()) {
|
||||
MS_LOG(EXCEPTION) << "Initializing cross-server communicator failed.";
|
||||
return;
|
||||
}
|
||||
if (!InitCommunicatorWithWorker()) {
|
||||
MS_LOG(EXCEPTION) << "Initializing worker-server communicator failed.";
|
||||
return;
|
||||
}
|
||||
global_worker_server_comms = communicators_with_worker_;
|
||||
return;
|
||||
}
|
||||
|
||||
bool Server::InitCommunicatorWithServer() {
|
||||
MS_EXCEPTION_IF_NULL(task_executor_);
|
||||
MS_EXCEPTION_IF_NULL(server_node_);
|
||||
|
||||
communicator_with_server_ =
|
||||
server_node_->GetOrCreateTcpComm(scheduler_ip_, scheduler_port_, worker_num_, server_num_, task_executor_);
|
||||
MS_EXCEPTION_IF_NULL(communicator_with_server_);
|
||||
|
||||
// Set exception event callbacks for server.
|
||||
auto tcp_comm = std::dynamic_pointer_cast<core::TcpCommunicator>(communicator_with_server_);
|
||||
MS_EXCEPTION_IF_NULL(tcp_comm);
|
||||
|
||||
tcp_comm->RegisterEventCallback(core::CLUSTER_TIMEOUT, [&]() {
|
||||
MS_LOG(ERROR) << "Event CLUSTER_TIMEOUT is captured. This is because some nodes(Scheduler/Server/Worker) are not "
|
||||
"started during network building phase.";
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
communicator_with_server_->Stop();
|
||||
});
|
||||
|
||||
tcp_comm->RegisterEventCallback(core::SCHEDULER_TIMEOUT, [&]() {
|
||||
MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
communicator_with_server_->Stop();
|
||||
});
|
||||
|
||||
tcp_comm->RegisterEventCallback(core::NODE_TIMEOUT, [&]() {
|
||||
MS_LOG(ERROR)
|
||||
<< "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
|
||||
"network building phase.";
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
communicator_with_server_->Stop();
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Server::InitCommunicatorWithWorker() {
|
||||
MS_EXCEPTION_IF_NULL(server_node_);
|
||||
MS_EXCEPTION_IF_NULL(task_executor_);
|
||||
if (!use_tcp_ && !use_http_) {
|
||||
MS_LOG(EXCEPTION) << "At least one type of protocol should be set.";
|
||||
return false;
|
||||
}
|
||||
if (use_tcp_) {
|
||||
auto tcp_comm = communicator_with_server_;
|
||||
MS_EXCEPTION_IF_NULL(tcp_comm);
|
||||
communicators_with_worker_.push_back(tcp_comm);
|
||||
}
|
||||
if (use_http_) {
|
||||
auto http_comm = server_node_->GetOrCreateHttpComm("0.0.0.0", http_port_, task_executor_);
|
||||
MS_EXCEPTION_IF_NULL(http_comm);
|
||||
communicators_with_worker_.push_back(http_comm);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void Server::InitIteration() {
|
||||
iteration_ = std::make_shared<Iteration>();
|
||||
MS_EXCEPTION_IF_NULL(iteration_);
|
||||
|
||||
// 1.Add rounds to the iteration according to the server mode.
|
||||
for (const RoundConfig &config : rounds_config_) {
|
||||
std::shared_ptr<Round> round = std::make_shared<Round>(config.name, config.check_timeout, config.time_window,
|
||||
config.check_count, config.threshold_count);
|
||||
MS_LOG(INFO) << "Add round " << config.name << ", check_count: " << config.check_count
|
||||
<< ", threshold:" << config.threshold_count;
|
||||
iteration_->AddRound(round);
|
||||
}
|
||||
|
||||
// 2.Initialize all the rounds.
|
||||
TimeOutCb time_out_cb = std::bind(&Iteration::ProceedToNextIter, iteration_);
|
||||
FinishIterCb finish_iter_cb = std::bind(&Iteration::ProceedToNextIter, iteration_);
|
||||
iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb);
|
||||
return;
|
||||
}
|
||||
|
||||
void Server::InitExecutor() {
|
||||
if (executor_threshold_ == 0) {
|
||||
MS_LOG(EXCEPTION) << "The executor's threshold should greater than 0.";
|
||||
return;
|
||||
}
|
||||
// The train engine instance is used in both push-type and pull-type kernels,
|
||||
// so the required_cnt of these kernels must be the same as update_model_threshold_.
|
||||
MS_LOG(INFO) << "Required count for push-type and pull-type kernels is " << executor_threshold_;
|
||||
Executor::GetInstance().Initialize(func_graph_, executor_threshold_);
|
||||
ModelStore::GetInstance().Initialize();
|
||||
return;
|
||||
}
|
||||
|
||||
void Server::RegisterRoundKernel() {
|
||||
MS_EXCEPTION_IF_NULL(iteration_);
|
||||
auto &rounds = iteration_->rounds();
|
||||
if (rounds.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Server has no round registered.";
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto &round : rounds) {
|
||||
const std::string &name = round->name();
|
||||
std::shared_ptr<kernel::RoundKernel> round_kernel = kernel::RoundKernelFactory::GetInstance().Create(name);
|
||||
if (round_kernel == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Round kernel for round " << name << " is not registered.";
|
||||
return;
|
||||
}
|
||||
|
||||
// For some round kernels, the threshold count should be set.
|
||||
round_kernel->InitKernel(round->threshold_count());
|
||||
round->BindRoundKernel(round_kernel);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void Server::StartCommunicator() {
|
||||
MS_EXCEPTION_IF_NULL(communicator_with_server_);
|
||||
if (communicators_with_worker_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Communicators for communication with worker is empty.";
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Start communicator with server.";
|
||||
communicator_with_server_->Start();
|
||||
DistributedMetadataStore::GetInstance().Initialize(server_node_);
|
||||
CollectiveOpsImpl::GetInstance().Initialize(server_node_);
|
||||
DistributedCountService::GetInstance().Initialize(server_node_, kLeaderServerRank);
|
||||
|
||||
MS_LOG(INFO) << "Start communicator with worker.";
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Start(); });
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,131 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_SERVER_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_SERVER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ps/core/communicator/communicator_base.h"
|
||||
#include "ps/core/communicator/tcp_communicator.h"
|
||||
#include "ps/core/communicator/task_executor.h"
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/executor.h"
|
||||
#include "ps/server/iteration.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
// Class Server is the entrance of MindSpore's parameter server training mode and federated learning.
|
||||
class Server {
|
||||
public:
|
||||
static Server &GetInstance() {
|
||||
static Server instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
|
||||
const FuncGraphPtr &func_graph, size_t executor_threshold);
|
||||
|
||||
// According to the current MindSpore framework, method Run is a step of the server pipeline. This method will be
|
||||
// blocked until the server is finalized.
|
||||
// func_graph is the frontend graph which will be parse in server's exector and aggregator.
|
||||
void Run();
|
||||
|
||||
private:
|
||||
Server()
|
||||
: server_node_(nullptr),
|
||||
task_executor_(nullptr),
|
||||
use_tcp_(false),
|
||||
use_http_(false),
|
||||
http_port_(0),
|
||||
func_graph_(nullptr),
|
||||
executor_threshold_(0),
|
||||
communicator_with_server_(nullptr),
|
||||
communicators_with_worker_({}),
|
||||
iteration_(nullptr),
|
||||
scheduler_ip_(""),
|
||||
scheduler_port_(0),
|
||||
server_num_(0),
|
||||
worker_num_(0) {}
|
||||
~Server() = default;
|
||||
Server(const Server &) = delete;
|
||||
Server &operator=(const Server &) = delete;
|
||||
|
||||
// Load variables which is set by ps_context.
|
||||
void InitServerContext();
|
||||
|
||||
// Initialize the server cluster, server node and communicators.
|
||||
void InitCluster();
|
||||
bool InitCommunicatorWithServer();
|
||||
bool InitCommunicatorWithWorker();
|
||||
|
||||
// Initialize iteration with rounds. Which rounds to use could be set by ps_context as well.
|
||||
void InitIteration();
|
||||
|
||||
// Initialize executor according to the server mode.
|
||||
void InitExecutor();
|
||||
|
||||
// Create round kernels and bind these kernels with corresponding Round.
|
||||
void RegisterRoundKernel();
|
||||
|
||||
// The communicators should be started after all initializations are completed.
|
||||
void StartCommunicator();
|
||||
|
||||
// The server node is initialized in Server.
|
||||
std::shared_ptr<core::ServerNode> server_node_;
|
||||
|
||||
// The task executor of the communicators. This helps server to handle network message concurrently. The tasks
|
||||
// submitted to this task executor is asynchronous.
|
||||
std::shared_ptr<core::TaskExecutor> task_executor_;
|
||||
|
||||
// Which protocol should communicators use.
|
||||
bool use_tcp_;
|
||||
bool use_http_;
|
||||
uint64_t http_port_;
|
||||
|
||||
// The configure of all rounds.
|
||||
std::vector<RoundConfig> rounds_config_;
|
||||
|
||||
// The graph passed by the frontend without backend optimizing.
|
||||
FuncGraphPtr func_graph_;
|
||||
|
||||
// The threshold count for executor to do aggregation or optimizing.
|
||||
size_t executor_threshold_;
|
||||
|
||||
// Server need a tcp communicator to communicate with other servers for counting, metadata storing, collective
|
||||
// operations, etc.
|
||||
std::shared_ptr<core::CommunicatorBase> communicator_with_server_;
|
||||
|
||||
// The communication with workers(including mobile devices), has multiple protocol types: HTTP and TCP.
|
||||
// In some cases, both types should be supported in one distributed training job. So here we may have multiple
|
||||
// communicators.
|
||||
std::vector<std::shared_ptr<core::CommunicatorBase>> communicators_with_worker_;
|
||||
|
||||
// Iteration consists of multiple kinds of rounds.
|
||||
std::shared_ptr<Iteration> iteration_;
|
||||
|
||||
// Variables set by ps context.
|
||||
std::string scheduler_ip_;
|
||||
uint16_t scheduler_port_;
|
||||
uint32_t server_num_;
|
||||
uint32_t worker_num_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_SERVER_H_
|
|
@ -33,7 +33,21 @@ def ps_context():
|
|||
return _ps_context
|
||||
|
||||
_set_ps_context_func_map = {
|
||||
"enable_ps": ps_context().set_ps_enable
|
||||
"server_mode": ps_context().set_server_mode,
|
||||
"ms_role": ps_context().set_ms_role,
|
||||
"enable_ps": ps_context().set_ps_enable,
|
||||
"worker_num": ps_context().set_worker_num,
|
||||
"server_num": ps_context().set_server_num,
|
||||
"scheduler_ip": ps_context().set_scheduler_ip,
|
||||
"scheduler_port": ps_context().set_scheduler_port,
|
||||
"fl_server_port": ps_context().set_fl_server_port,
|
||||
"enable_fl_client": ps_context().set_fl_client_enable,
|
||||
"start_fl_job_threshold": ps_context().set_start_fl_job_threshold,
|
||||
"fl_name": ps_context().set_fl_name,
|
||||
"fl_iteration_num": ps_context().set_fl_iteration_num,
|
||||
"client_epoch_num": ps_context().set_client_epoch_num,
|
||||
"client_batch_size": ps_context().set_client_batch_size,
|
||||
"secure_aggregation": ps_context().set_secure_aggregation
|
||||
}
|
||||
|
||||
_get_ps_context_func_map = {
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Finish test_mobile_lenet.py case")
|
||||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
scheduler_port = args.scheduler_port
|
||||
|
||||
cmd = "pid=`ps -ef|grep \"scheduler_port=" + str(scheduler_port) + "\" "
|
||||
cmd += " | grep -v \"grep\" | grep -v \"finish\" |awk '{print $2}'` && "
|
||||
cmd += "for id in $pid; do kill -9 $id && echo \"killed $id\"; done"
|
||||
|
||||
subprocess.call(['bash', '-c', cmd])
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run test_mobile_lenet.py case")
|
||||
parser.add_argument("--device_target", type=str, default="CPU")
|
||||
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
|
||||
parser.add_argument("--worker_num", type=int, default=0)
|
||||
parser.add_argument("--server_num", type=int, default=2)
|
||||
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
parser.add_argument("--fl_server_port", type=int, default=6666)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
server_mode = args.server_mode
|
||||
worker_num = args.worker_num
|
||||
server_num = args.server_num
|
||||
scheduler_ip = args.scheduler_ip
|
||||
scheduler_port = args.scheduler_port
|
||||
fl_server_port = args.fl_server_port
|
||||
|
||||
cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&"
|
||||
cmd_sched += "mkdir ${execute_path}/scheduler/ &&"
|
||||
cmd_sched += "cd ${execute_path}/scheduler/ || exit && export GLOG_v=1 &&"
|
||||
cmd_sched += "python ${self_path}/../test_mobile_lenet.py"
|
||||
cmd_sched += " --device_target=" + device_target
|
||||
cmd_sched += " --server_mode=" + server_mode
|
||||
cmd_sched += " --ms_role=MS_SCHED"
|
||||
cmd_sched += " --worker_num=" + str(worker_num)
|
||||
cmd_sched += " --server_num=" + str(server_num)
|
||||
cmd_sched += " --scheduler_ip=" + scheduler_ip
|
||||
cmd_sched += " --scheduler_port=" + str(scheduler_port)
|
||||
cmd_sched += " --fl_server_port=" + str(fl_server_port)
|
||||
cmd_sched += " > scheduler.log 2>&1 &"
|
||||
|
||||
subprocess.call(['bash', '-c', cmd_sched])
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import ast
|
||||
import argparse
|
||||
import subprocess
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run test_mobile_lenet.py case")
|
||||
parser.add_argument("--device_target", type=str, default="CPU")
|
||||
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
|
||||
parser.add_argument("--worker_num", type=int, default=0)
|
||||
parser.add_argument("--server_num", type=int, default=2)
|
||||
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
parser.add_argument("--fl_server_port", type=int, default=6666)
|
||||
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
|
||||
parser.add_argument("--fl_name", type=str, default="Lenet")
|
||||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||
parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False)
|
||||
parser.add_argument("--local_server_num", type=int, default=-1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
server_mode = args.server_mode
|
||||
worker_num = args.worker_num
|
||||
server_num = args.server_num
|
||||
scheduler_ip = args.scheduler_ip
|
||||
scheduler_port = args.scheduler_port
|
||||
fl_server_port = args.fl_server_port
|
||||
start_fl_job_threshold = args.start_fl_job_threshold
|
||||
fl_name = args.fl_name
|
||||
fl_iteration_num = args.fl_iteration_num
|
||||
client_epoch_num = args.client_epoch_num
|
||||
client_batch_size = args.client_batch_size
|
||||
secure_aggregation = args.secure_aggregation
|
||||
local_server_num = args.local_server_num
|
||||
|
||||
if local_server_num == -1:
|
||||
local_server_num = server_num
|
||||
|
||||
assert local_server_num <= server_num, "The local server number should not be bigger than total server number."
|
||||
|
||||
for i in range(local_server_num):
|
||||
cmd_server = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && "
|
||||
cmd_server += "rm -rf ${execute_path}/server_" + str(i) + "/ &&"
|
||||
cmd_server += "mkdir ${execute_path}/server_" + str(i) + "/ &&"
|
||||
cmd_server += "cd ${execute_path}/server_" + str(i) + "/ || exit && export GLOG_v=1 &&"
|
||||
cmd_server += "python ${self_path}/../test_mobile_lenet.py"
|
||||
cmd_server += " --device_target=" + device_target
|
||||
cmd_server += " --server_mode=" + server_mode
|
||||
cmd_server += " --ms_role=MS_SERVER"
|
||||
cmd_server += " --worker_num=" + str(worker_num)
|
||||
cmd_server += " --server_num=" + str(server_num)
|
||||
cmd_server += " --scheduler_ip=" + scheduler_ip
|
||||
cmd_server += " --scheduler_port=" + str(scheduler_port)
|
||||
cmd_server += " --fl_server_port=" + str(fl_server_port + i)
|
||||
cmd_server += " --start_fl_job_threshold=" + str(start_fl_job_threshold)
|
||||
cmd_server += " --fl_name=" + fl_name
|
||||
cmd_server += " --fl_iteration_num=" + str(fl_iteration_num)
|
||||
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
|
||||
cmd_server += " --client_batch_size=" + str(client_batch_size)
|
||||
cmd_server += " --secure_aggregation=" + str(secure_aggregation)
|
||||
cmd_server += " > server.log 2>&1 &"
|
||||
|
||||
import time
|
||||
time.sleep(0.3)
|
||||
subprocess.call(['bash', '-c', cmd_server])
|
|
@ -0,0 +1,29 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export PYTHONPATH=../../../../:$PYTHONPATH
|
||||
server_num=$1
|
||||
worker_num=$2
|
||||
ip=$3
|
||||
port=$4
|
||||
|
||||
for((i=0;i<worker_num;i++));
|
||||
do
|
||||
ofs=`expr $i % $server_num`
|
||||
real_port=`expr $port + $ofs`
|
||||
echo $real_port
|
||||
python simulator.py --pid=$i --http_ip=$ip --http_port=$port --use_elb=True --server_num=$1 > simulator_$i.log 2>&1 &
|
||||
done
|
|
@ -0,0 +1,191 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import random
|
||||
import sys
|
||||
import requests
|
||||
import flatbuffers
|
||||
import numpy as np
|
||||
from mindspore.schema import (RequestFLJob, ResponseFLJob, ResponseCode,
|
||||
RequestUpdateModel, FeatureMap, RequestGetModel, ResponseGetModel)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pid", type=int, default=0)
|
||||
parser.add_argument("--http_ip", type=str, default="10.113.216.106")
|
||||
parser.add_argument("--http_port", type=int, default=6666)
|
||||
parser.add_argument("--use_elb", type=bool, default=False)
|
||||
parser.add_argument("--server_num", type=int, default=1)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
pid = args.pid
|
||||
http_ip = args.http_ip
|
||||
http_port = args.http_port
|
||||
use_elb = args.use_elb
|
||||
server_num = args.server_num
|
||||
|
||||
str_fl_id = 'fl_lenet_' + str(pid)
|
||||
|
||||
def generate_port():
|
||||
if not use_elb:
|
||||
return http_port
|
||||
port = random.randint(0, 100000) % server_num + http_port
|
||||
return port
|
||||
|
||||
|
||||
def build_start_fl_job(iteration):
|
||||
start_fl_job_builder = flatbuffers.Builder(1024)
|
||||
|
||||
fl_name = start_fl_job_builder.CreateString('fl_test_job')
|
||||
fl_id = start_fl_job_builder.CreateString(str_fl_id)
|
||||
data_size = 32
|
||||
timestamp = start_fl_job_builder.CreateString('2020/11/16/19/18')
|
||||
|
||||
RequestFLJob.RequestFLJobStart(start_fl_job_builder)
|
||||
RequestFLJob.RequestFLJobAddFlName(start_fl_job_builder, fl_name)
|
||||
RequestFLJob.RequestFLJobAddFlId(start_fl_job_builder, fl_id)
|
||||
RequestFLJob.RequestFLJobAddIteration(start_fl_job_builder, iteration)
|
||||
RequestFLJob.RequestFLJobAddDataSize(start_fl_job_builder, data_size)
|
||||
RequestFLJob.RequestFLJobAddTimestamp(start_fl_job_builder, timestamp)
|
||||
fl_job_req = RequestFLJob.RequestFLJobEnd(start_fl_job_builder)
|
||||
|
||||
start_fl_job_builder.Finish(fl_job_req)
|
||||
buf = start_fl_job_builder.Output()
|
||||
return buf
|
||||
|
||||
def build_feature_map(builder, names, lengths):
|
||||
if len(names) != len(lengths):
|
||||
return None
|
||||
feature_maps = []
|
||||
np_data = []
|
||||
for j, _ in enumerate(names):
|
||||
name = names[j]
|
||||
length = lengths[j]
|
||||
weight_full_name = builder.CreateString(name)
|
||||
FeatureMap.FeatureMapStartDataVector(builder, length)
|
||||
weight = np.random.rand(length) * 32
|
||||
np_data.append(weight)
|
||||
for idx in range(length - 1, -1, -1):
|
||||
builder.PrependFloat32(weight[idx])
|
||||
data = builder.EndVector(length)
|
||||
FeatureMap.FeatureMapStart(builder)
|
||||
FeatureMap.FeatureMapAddData(builder, data)
|
||||
FeatureMap.FeatureMapAddWeightFullname(builder, weight_full_name)
|
||||
feature_map = FeatureMap.FeatureMapEnd(builder)
|
||||
feature_maps.append(feature_map)
|
||||
return feature_maps, np_data
|
||||
|
||||
def build_update_model(iteration):
|
||||
builder_update_model = flatbuffers.Builder(1)
|
||||
fl_name = builder_update_model.CreateString('fl_test_job')
|
||||
fl_id = builder_update_model.CreateString(str_fl_id)
|
||||
timestamp = builder_update_model.CreateString('2020/11/16/19/18')
|
||||
|
||||
feature_maps, np_data = build_feature_map(builder_update_model,
|
||||
["conv1.weight", "conv2.weight", "fc1.weight",
|
||||
"fc2.weight", "fc3.weight", "fc1.bias", "fc2.bias", "fc3.bias"],
|
||||
[450, 2400, 48000, 10080, 5208, 120, 84, 62])
|
||||
|
||||
RequestUpdateModel.RequestUpdateModelStartFeatureMapVector(builder_update_model, 1)
|
||||
for single_feature_map in feature_maps:
|
||||
builder_update_model.PrependUOffsetTRelative(single_feature_map)
|
||||
feature_map = builder_update_model.EndVector(len(feature_maps))
|
||||
|
||||
RequestUpdateModel.RequestUpdateModelStart(builder_update_model)
|
||||
RequestUpdateModel.RequestUpdateModelAddFlName(builder_update_model, fl_name)
|
||||
RequestUpdateModel.RequestUpdateModelAddFlId(builder_update_model, fl_id)
|
||||
RequestUpdateModel.RequestUpdateModelAddIteration(builder_update_model, iteration)
|
||||
RequestUpdateModel.RequestUpdateModelAddFeatureMap(builder_update_model, feature_map)
|
||||
RequestUpdateModel.RequestUpdateModelAddTimestamp(builder_update_model, timestamp)
|
||||
req_update_model = RequestUpdateModel.RequestUpdateModelEnd(builder_update_model)
|
||||
builder_update_model.Finish(req_update_model)
|
||||
buf = builder_update_model.Output()
|
||||
return buf, np_data
|
||||
|
||||
def build_get_model(iteration):
|
||||
builder_get_model = flatbuffers.Builder(1)
|
||||
fl_name = builder_get_model.CreateString('fl_test_job')
|
||||
timestamp = builder_get_model.CreateString('2020/12/16/19/18')
|
||||
|
||||
RequestGetModel.RequestGetModelStart(builder_get_model)
|
||||
RequestGetModel.RequestGetModelAddFlName(builder_get_model, fl_name)
|
||||
RequestGetModel.RequestGetModelAddIteration(builder_get_model, iteration)
|
||||
RequestGetModel.RequestGetModelAddTimestamp(builder_get_model, timestamp)
|
||||
req_get_model = RequestGetModel.RequestGetModelEnd(builder_get_model)
|
||||
builder_get_model.Finish(req_get_model)
|
||||
buf = builder_get_model.Output()
|
||||
return buf
|
||||
|
||||
weight_name_to_idx = {
|
||||
"conv1.weight": 0,
|
||||
"conv2.weight": 1,
|
||||
"fc1.weight": 2,
|
||||
"fc2.weight": 3,
|
||||
"fc3.weight": 4,
|
||||
"fc1.bias": 5,
|
||||
"fc2.bias": 6,
|
||||
"fc3.bias": 7
|
||||
}
|
||||
|
||||
session = requests.Session()
|
||||
current_iteration = 1
|
||||
url = "http://" + http_ip + ":" + str(generate_port())
|
||||
np.random.seed(0)
|
||||
while True:
|
||||
url1 = "http://" + http_ip + ":" + str(generate_port()) + '/startFLJob'
|
||||
print("start url is ", url1)
|
||||
x = requests.post(url1, data=build_start_fl_job(current_iteration))
|
||||
rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
|
||||
print("start fl job iteration:", current_iteration, ", id:", args.pid)
|
||||
while rsp_fl_job.Retcode() != ResponseCode.ResponseCode.SUCCEED:
|
||||
x = requests.post(url1, data=build_start_fl_job(current_iteration))
|
||||
rsp_fl_job = rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
|
||||
print("epoch is", rsp_fl_job.FlPlanConfig().Epochs())
|
||||
sys.stdout.flush()
|
||||
|
||||
url2 = "http://" + http_ip + ":" + str(generate_port()) + '/updateModel'
|
||||
print("req update model iteration:", current_iteration, ", id:", args.pid)
|
||||
update_model_buf, update_model_np_data = build_update_model(current_iteration)
|
||||
x = session.post(url2, data=update_model_buf)
|
||||
print("rsp update model iteration:", current_iteration, ", id:", args.pid)
|
||||
sys.stdout.flush()
|
||||
|
||||
url3 = "http://" + http_ip + ":" + str(generate_port()) + '/getModel'
|
||||
print("req get model iteration:", current_iteration, ", id:", args.pid)
|
||||
x = session.post(url3, data=build_get_model(current_iteration))
|
||||
rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
|
||||
print("rsp get model iteration:", current_iteration, ", id:", args.pid, rsp_get_model.Retcode())
|
||||
sys.stdout.flush()
|
||||
|
||||
repeat_time = 0
|
||||
while rsp_get_model.Retcode() == ResponseCode.ResponseCode.SucNotReady:
|
||||
time.sleep(0.1)
|
||||
x = session.post(url3, data=build_get_model(current_iteration))
|
||||
rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
|
||||
repeat_time += 1
|
||||
if repeat_time > 1000:
|
||||
print("GetModel try timeout ", args.pid)
|
||||
sys.exit(0)
|
||||
|
||||
for i in range(0, 1):
|
||||
print(rsp_get_model.FeatureMap(i).WeightFullname())
|
||||
origin = update_model_np_data[weight_name_to_idx[rsp_get_model.FeatureMap(i).WeightFullname().decode('utf-8')]]
|
||||
after = rsp_get_model.FeatureMap(i).DataAsNumpy() * 32
|
||||
print("Before update model", args.pid, origin[0:10])
|
||||
print("After get model", args.pid, after[0:10])
|
||||
sys.stdout.flush()
|
||||
assert np.allclose(origin, after, rtol=1e-05, atol=1e-05)
|
||||
current_iteration += 1
|
|
@ -0,0 +1,423 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""AdamWeightDecayForBert, a customized Adam for bert. Input: gradient, overflow flag."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
|
||||
_adam_opt = C.MultitypeFuncGraph("adam_opt")
|
||||
_scaler_one = Tensor(1, mstype.int32)
|
||||
_scaler_ten = Tensor(10, mstype.float32)
|
||||
|
||||
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Bool", "Bool")
|
||||
def _update_run_kernel(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter):
|
||||
"""
|
||||
Update parameters by AdamWeightDecay op.
|
||||
"""
|
||||
if optim_filter:
|
||||
adam = P.AdamWeightDecay()
|
||||
if decay_flags:
|
||||
next_param = adam(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient)
|
||||
else:
|
||||
next_param = adam(param, m, v, lr, beta1, beta2, eps, 0.0, gradient)
|
||||
return next_param
|
||||
return gradient
|
||||
|
||||
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Bool", "Bool")
|
||||
def _update_run_op(beta1, beta2, eps, lr, overflow, weight_decay, param, m, v, gradient, decay_flag, optim_filter):
|
||||
"""
|
||||
Update parameters.
|
||||
|
||||
Args:
|
||||
beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
|
||||
beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
|
||||
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
|
||||
lr (Tensor): Learning rate.
|
||||
overflow (Tensor): Whether overflow occurs.
|
||||
weight_decay (Number): Weight decay. Should be equal to or greater than 0.
|
||||
param (Tensor): Parameters.
|
||||
m (Tensor): m value of parameters.
|
||||
v (Tensor): v value of parameters.
|
||||
gradient (Tensor): Gradient of parameters.
|
||||
decay_flag (bool): Applies weight decay or not.
|
||||
optim_filter (bool): Applies parameter update or not.
|
||||
|
||||
Returns:
|
||||
Tensor, the new value of v after updating.
|
||||
"""
|
||||
if optim_filter:
|
||||
op_mul = P.Mul()
|
||||
op_square = P.Square()
|
||||
op_sqrt = P.Sqrt()
|
||||
op_cast = P.Cast()
|
||||
op_reshape = P.Reshape()
|
||||
op_shape = P.Shape()
|
||||
op_select = P.Select()
|
||||
|
||||
param_fp32 = op_cast(param, mstype.float32)
|
||||
m_fp32 = op_cast(m, mstype.float32)
|
||||
v_fp32 = op_cast(v, mstype.float32)
|
||||
gradient_fp32 = op_cast(gradient, mstype.float32)
|
||||
|
||||
cond = op_cast(F.fill(mstype.int32, op_shape(m_fp32), 1) * op_reshape(overflow, (())), mstype.bool_)
|
||||
next_m = op_mul(beta1, m_fp32) + op_select(cond, m_fp32,\
|
||||
op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32))
|
||||
|
||||
next_v = op_mul(beta2, v_fp32) + op_select(cond, v_fp32,\
|
||||
op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient_fp32)))
|
||||
|
||||
update = next_m / (eps + op_sqrt(next_v))
|
||||
if decay_flag:
|
||||
update = op_mul(weight_decay, param_fp32) + update
|
||||
|
||||
update_with_lr = op_mul(lr, update)
|
||||
zeros = F.fill(mstype.float32, op_shape(param_fp32), 0)
|
||||
next_param = param_fp32 - op_select(cond, zeros, op_reshape(update_with_lr, op_shape(param_fp32)))
|
||||
|
||||
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
|
||||
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
|
||||
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
|
||||
|
||||
return op_cast(next_param, F.dtype(param))
|
||||
return gradient
|
||||
|
||||
|
||||
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
|
||||
def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
|
||||
beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable):
|
||||
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
|
||||
success = True
|
||||
indices = gradient.indices
|
||||
values = gradient.values
|
||||
if ps_parameter and not cache_enable:
|
||||
op_shape = P.Shape()
|
||||
shapes = (op_shape(param), op_shape(m), op_shape(v),
|
||||
op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
|
||||
op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
|
||||
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
|
||||
eps, values, indices), shapes), param))
|
||||
return success
|
||||
|
||||
if not target:
|
||||
success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2,
|
||||
eps, values, indices))
|
||||
else:
|
||||
op_mul = P.Mul()
|
||||
op_square = P.Square()
|
||||
op_sqrt = P.Sqrt()
|
||||
scatter_add = P.ScatterAdd(use_locking)
|
||||
|
||||
assign_m = F.assign(m, op_mul(beta1, m))
|
||||
assign_v = F.assign(v, op_mul(beta2, v))
|
||||
|
||||
grad_indices = gradient.indices
|
||||
grad_value = gradient.values
|
||||
|
||||
next_m = scatter_add(m,
|
||||
grad_indices,
|
||||
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
|
||||
|
||||
next_v = scatter_add(v,
|
||||
grad_indices,
|
||||
op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value)))
|
||||
|
||||
if use_nesterov:
|
||||
m_temp = next_m * _scaler_ten
|
||||
assign_m_nesterov = F.assign(m, op_mul(beta1, next_m))
|
||||
div_value = scatter_add(m,
|
||||
op_mul(grad_indices, _scaler_one),
|
||||
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
|
||||
param_update = div_value / (op_sqrt(next_v) + eps)
|
||||
|
||||
m_recover = F.assign(m, m_temp / _scaler_ten)
|
||||
|
||||
F.control_depend(m_temp, assign_m_nesterov)
|
||||
F.control_depend(assign_m_nesterov, div_value)
|
||||
F.control_depend(param_update, m_recover)
|
||||
else:
|
||||
param_update = next_m / (op_sqrt(next_v) + eps)
|
||||
|
||||
lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
|
||||
|
||||
next_param = param - lr_t * param_update
|
||||
|
||||
F.control_depend(assign_m, next_m)
|
||||
F.control_depend(assign_v, next_v)
|
||||
|
||||
success = F.depend(success, F.assign(param, next_param))
|
||||
success = F.depend(success, F.assign(m, next_m))
|
||||
success = F.depend(success, F.assign(v, next_v))
|
||||
|
||||
return success
|
||||
|
||||
|
||||
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
|
||||
def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target,
|
||||
beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param,
|
||||
moment1, moment2, ps_parameter, cache_enable):
|
||||
"""Apply adam optimizer to the weight parameter using Tensor."""
|
||||
success = True
|
||||
if ps_parameter and not cache_enable:
|
||||
op_shape = P.Shape()
|
||||
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
|
||||
(op_shape(param), op_shape(moment1), op_shape(moment2))), param))
|
||||
else:
|
||||
success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
|
||||
eps, gradient))
|
||||
return success
|
||||
|
||||
|
||||
@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor")
|
||||
def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2):
|
||||
"""Apply AdamOffload optimizer to the weight parameter using Tensor."""
|
||||
success = True
|
||||
delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient)
|
||||
success = F.depend(success, F.assign_add(param, delat_param))
|
||||
return success
|
||||
|
||||
|
||||
def _check_param_value(beta1, beta2, eps, prim_name):
|
||||
"""Check the type of inputs."""
|
||||
validator.check_value_type("beta1", beta1, [float], prim_name)
|
||||
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||
validator.check_value_type("eps", eps, [float], prim_name)
|
||||
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
|
||||
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
|
||||
validator.check_positive_float(eps, "eps", prim_name)
|
||||
|
||||
class AdamWeightDecayForBert(Optimizer):
|
||||
"""
|
||||
Implements the Adam algorithm to fix the weight decay.
|
||||
|
||||
Note:
|
||||
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
|
||||
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
|
||||
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
|
||||
|
||||
To improve parameter groups performance, the customized order of parameters can be supported.
|
||||
|
||||
Args:
|
||||
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
|
||||
the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
|
||||
"lr", "weight_decay" and "order_params" are the keys can be parsed.
|
||||
|
||||
- params: Required. The value must be a list of `Parameter`.
|
||||
|
||||
- lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
|
||||
If not, the `learning_rate` in the API will be used.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the API will be used.
|
||||
|
||||
- order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
|
||||
the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
|
||||
which in the 'order_params' must be in one of group parameters.
|
||||
|
||||
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
|
||||
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
|
||||
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
|
||||
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
|
||||
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
|
||||
dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
|
||||
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
|
||||
Default: 1e-3.
|
||||
beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
|
||||
Should be in range (0.0, 1.0).
|
||||
beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
|
||||
Should be in range (0.0, 1.0).
|
||||
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
|
||||
Should be greater than 0.
|
||||
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
- **overflow** (tuple[Tensor]) - The overflow flag in dynamiclossscale.
|
||||
|
||||
Outputs:
|
||||
tuple[bool], all elements are True.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> #1) All parameters use the same learning rate and weight decay
|
||||
>>> optim = AdamWeightDecay(params=net.trainable_params())
|
||||
>>>
|
||||
>>> #2) Use parameter groups and set different values
|
||||
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
|
||||
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
|
||||
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
|
||||
... {'params': no_conv_params, 'lr': 0.01},
|
||||
... {'order_params': net.trainable_params()}]
|
||||
>>> optim = AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0)
|
||||
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
|
||||
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
|
||||
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
|
||||
>>>
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
||||
"""
|
||||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
|
||||
super(AdamWeightDecayForBert, self).__init__(learning_rate, params, weight_decay)
|
||||
_check_param_value(beta1, beta2, eps, self.cls_name)
|
||||
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
|
||||
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
|
||||
self.eps = Tensor(np.array([eps]).astype(np.float32))
|
||||
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
|
||||
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.op_select = P.Select()
|
||||
self.op_cast = P.Cast()
|
||||
self.op_reshape = P.Reshape()
|
||||
self.op_shape = P.Shape()
|
||||
|
||||
def construct(self, gradients, overflow):
|
||||
"""AdamWeightDecayForBert"""
|
||||
lr = self.get_lr()
|
||||
cond = self.op_cast(F.fill(mstype.int32, self.op_shape(self.beta1), 1) *\
|
||||
self.op_reshape(overflow, (())), mstype.bool_)
|
||||
beta1 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta1)
|
||||
beta2 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta2)
|
||||
if self.is_group:
|
||||
if self.is_group_lr:
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
|
||||
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, beta1, beta2, self.eps, lr, overflow),
|
||||
self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
|
||||
self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
if self.use_parallel:
|
||||
self.broadcast_params(optim_result)
|
||||
return optim_result
|
||||
|
||||
class AdamWeightDecayOp(Optimizer):
|
||||
"""
|
||||
Implements the Adam algorithm to fix the weight decay. It is a complete operator, not a combination of other ops.
|
||||
|
||||
Note:
|
||||
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
|
||||
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
|
||||
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
|
||||
|
||||
To improve parameter groups performance, the customized order of parameters can be supported.
|
||||
|
||||
Args:
|
||||
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
|
||||
the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
|
||||
"lr", "weight_decay" and "order_params" are the keys can be parsed.
|
||||
|
||||
- params: Required. The value must be a list of `Parameter`.
|
||||
|
||||
- lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
|
||||
If not, the `learning_rate` in the API will be used.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the API will be used.
|
||||
|
||||
- order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
|
||||
the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
|
||||
which in the 'order_params' must be in one of group parameters.
|
||||
|
||||
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
|
||||
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
|
||||
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
|
||||
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
|
||||
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
|
||||
dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
|
||||
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
|
||||
Default: 1e-3.
|
||||
beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
|
||||
Should be in range (0.0, 1.0).
|
||||
beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
|
||||
Should be in range (0.0, 1.0).
|
||||
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
|
||||
Should be greater than 0.
|
||||
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
||||
Outputs:
|
||||
tuple[bool], all elements are True.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> #1) All parameters use the same learning rate and weight decay
|
||||
>>> optim = AdamWeightDecayOp(params=net.trainable_params())
|
||||
>>>
|
||||
>>> #2) Use parameter groups and set different values
|
||||
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
|
||||
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
|
||||
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
|
||||
... {'params': no_conv_params, 'lr': 0.01},
|
||||
... {'order_params': net.trainable_params()}]
|
||||
>>> optim = AdamWeightDecayOp(group_params, learning_rate=0.1, weight_decay=0.0)
|
||||
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
|
||||
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
|
||||
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
|
||||
>>>
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
||||
"""
|
||||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
|
||||
super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay)
|
||||
_check_param_value(beta1, beta2, eps, self.cls_name)
|
||||
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
|
||||
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
|
||||
self.eps = Tensor(np.array([eps]).astype(np.float32))
|
||||
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
|
||||
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, gradients):
|
||||
"""AdamWeightDecayOp"""
|
||||
lr = self.get_lr()
|
||||
if self.is_group:
|
||||
if self.is_group_lr:
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
|
||||
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr),
|
||||
self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
|
||||
self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
if self.use_parallel:
|
||||
self.broadcast_params(optim_result)
|
||||
return optim_result
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
|
||||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||
"""weight initial for conv layer"""
|
||||
weight = weight_variable()
|
||||
return nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
weight_init=weight,
|
||||
has_bias=False,
|
||||
pad_mode="valid",
|
||||
)
|
||||
|
||||
|
||||
def fc_with_initialize(input_channels, out_channels):
|
||||
"""weight initial for fc layer"""
|
||||
weight = weight_variable()
|
||||
bias = weight_variable()
|
||||
return nn.Dense(input_channels, out_channels, weight, bias)
|
||||
|
||||
|
||||
def weight_variable():
|
||||
"""weight initial"""
|
||||
return TruncatedNormal(0.02)
|
||||
|
||||
|
||||
class LeNet5(nn.Cell):
|
||||
def __init__(self, num_class=10, channel=3):
|
||||
super(LeNet5, self).__init__()
|
||||
self.num_class = num_class
|
||||
self.conv1 = conv(channel, 6, 5)
|
||||
self.conv2 = conv(6, 16, 5)
|
||||
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
|
||||
self.fc2 = fc_with_initialize(120, 84)
|
||||
self.fc3 = fc_with_initialize(84, self.num_class)
|
||||
self.relu = nn.ReLU()
|
||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.flatten = nn.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.conv2(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import ast
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from src.model import LeNet5
|
||||
from src.adam import AdamWeightDecayOp
|
||||
|
||||
parser = argparse.ArgumentParser(description="test_fl_lenet")
|
||||
parser.add_argument("--device_target", type=str, default="CPU")
|
||||
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
|
||||
parser.add_argument("--ms_role", type=str, default="MS_WORKER")
|
||||
parser.add_argument("--worker_num", type=int, default=0)
|
||||
parser.add_argument("--server_num", type=int, default=1)
|
||||
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
parser.add_argument("--fl_server_port", type=int, default=6666)
|
||||
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
|
||||
parser.add_argument("--fl_name", type=str, default="Lenet")
|
||||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||
parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
server_mode = args.server_mode
|
||||
ms_role = args.ms_role
|
||||
worker_num = args.worker_num
|
||||
server_num = args.server_num
|
||||
scheduler_ip = args.scheduler_ip
|
||||
scheduler_port = args.scheduler_port
|
||||
fl_server_port = args.fl_server_port
|
||||
start_fl_job_threshold = args.start_fl_job_threshold
|
||||
fl_name = args.fl_name
|
||||
fl_iteration_num = args.fl_iteration_num
|
||||
client_epoch_num = args.client_epoch_num
|
||||
client_batch_size = args.client_batch_size
|
||||
secure_aggregation = args.secure_aggregation
|
||||
|
||||
ctx = {
|
||||
"enable_ps": False,
|
||||
"server_mode": server_mode,
|
||||
"ms_role": ms_role,
|
||||
"worker_num": worker_num,
|
||||
"server_num": server_num,
|
||||
"scheduler_ip": scheduler_ip,
|
||||
"scheduler_port": scheduler_port,
|
||||
"fl_server_port": fl_server_port,
|
||||
"start_fl_job_threshold": start_fl_job_threshold,
|
||||
"fl_name": fl_name,
|
||||
"fl_iteration_num": fl_iteration_num,
|
||||
"client_epoch_num": client_epoch_num,
|
||||
"client_batch_size": client_batch_size,
|
||||
"secure_aggregation": secure_aggregation
|
||||
}
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)
|
||||
context.set_ps_context(**ctx)
|
||||
|
||||
if __name__ == "__main__":
|
||||
epoch = 5
|
||||
np.random.seed(0)
|
||||
network = LeNet5(62)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
|
||||
net_adam_opt = AdamWeightDecayOp(network.trainable_params(), weight_decay=0.1)
|
||||
net_with_criterion = WithLossCell(network, criterion)
|
||||
train_network = TrainOneStepCell(net_with_criterion, net_opt)
|
||||
train_network.set_train()
|
||||
losses = []
|
||||
|
||||
for _ in range(epoch):
|
||||
data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32))
|
||||
label = Tensor(np.random.randint(0, 61, (32)).astype(np.int32))
|
||||
loss = train_network(data, label).asnumpy()
|
||||
losses.append(loss)
|
||||
print(losses)
|
Loading…
Reference in New Issue